diff --git a/cached_reader.go b/cached_reader.go index fe389c5..eb35d1b 100644 --- a/cached_reader.go +++ b/cached_reader.go @@ -5,45 +5,48 @@ import ( ) type cachedReader struct { - buffer *bufio.Reader - cache []byte - cacheCap int - cacheLen int + buffer *bufio.Reader + cache []byte caching bool } func newCachedReader(r *bufio.Reader) *cachedReader { return &cachedReader{ - buffer: r, - cache: make([]byte, 4096), - cacheCap: 4096, - cacheLen: 0, - caching: false, + buffer: r, + cache: make([]byte, 0, 4096), + caching: false, } } func (c *cachedReader) StartCaching() { - c.cacheLen = 0 + c.cache = c.cache[:0] c.caching = true } -func (c *cachedReader) ReadByte() (byte, error) { - if !c.caching { - return c.buffer.ReadByte() - } - b, err := c.buffer.ReadByte() +func (c *cachedReader) ReadByte() (b byte, err error) { + b, err = c.buffer.ReadByte() if err != nil { - return b, err + return } - if c.cacheLen < c.cacheCap { - c.cache[c.cacheLen] = b - c.cacheLen++ + if c.caching { + c.cacheByte(b) } - return b, err + return } func (c *cachedReader) Cache() []byte { - return c.cache[:c.cacheLen] + return c.cache +} + +func (c *cachedReader) CacheWithLimit(n int) []byte { + if n < 1 { + return nil + } + l := len(c.cache) + if n > l { + n = l + } + return c.cache[:n] } func (c *cachedReader) StopCaching() { @@ -55,11 +58,9 @@ func (c *cachedReader) Read(p []byte) (int, error) { if err != nil { return n, err } - if c.caching && c.cacheLen < c.cacheCap { + if c.caching { for i := 0; i < n; i++ { - c.cache[c.cacheLen] = p[i] - c.cacheLen++ - if c.cacheLen >= c.cacheCap { + if !c.cacheByte(p[i]) { break } } @@ -67,3 +68,12 @@ func (c *cachedReader) Read(p []byte) (int, error) { return n, err } +func (c *cachedReader) cacheByte(b byte) bool { + n := len(c.cache) + if n == cap(c.cache) { + return false + } + c.cache = c.cache[:n+1] + c.cache[n] = b + return true +} diff --git a/cached_reader_test.go b/cached_reader_test.go index 8cbfef2..d3a931d 100644 --- a/cached_reader_test.go +++ b/cached_reader_test.go @@ -39,4 +39,19 @@ func TestCaching(t *testing.T) { if !bytes.Equal(cached, []byte("BCDEF")) { t.Fatalf("Incorrect cached buffer value") } + + cached = cachedReader.CacheWithLimit(-1) + if cached != nil { + t.Fatalf("Incorrect cached buffer value") + } + + cached = cachedReader.CacheWithLimit(3) + if !bytes.Equal(cached, []byte("BCD")) { + t.Fatalf("Incorrect cached buffer value") + } + + cached = cachedReader.CacheWithLimit(1000) + if !bytes.Equal(cached, []byte("BCDEF")) { + t.Fatalf("Incorrect cached buffer value") + } } diff --git a/node.go b/node.go index 43b2a8e..f864bc6 100644 --- a/node.go +++ b/node.go @@ -92,6 +92,13 @@ func WithPreserveSpace() OutputOption { } } +// WithoutPreserveSpace will not preserve spaces in output +func WithoutPreserveSpace() OutputOption { + return func(oc *outputConfiguration) { + oc.preserveSpaces = false + } +} + // WithIndentation sets the indentation string used for formatting the output. func WithIndentation(indentation string) OutputOption { return func(oc *outputConfiguration) { @@ -328,7 +335,9 @@ func (n *Node) Write(writer io.Writer, self bool) error { // WriteWithOptions writes xml with given options to given writer. func (n *Node) WriteWithOptions(writer io.Writer, opts ...OutputOption) (err error) { - config := &outputConfiguration{} + config := &outputConfiguration{ + preserveSpaces: true, + } // Set the options for _, opt := range opts { opt(config) @@ -400,11 +409,7 @@ func AddChild(parent, n *Node) { parent.LastChild = n } -// AddSibling adds a new node 'n' as a sibling of a given node 'sibling'. -// Note it is not necessarily true that the new node 'n' would be added -// immediately after 'sibling'. If 'sibling' isn't the last child of its -// parent, then the new node 'n' will be added at the end of the sibling -// chain of their parent. +// AddSibling adds a new node 'n' as a last node of sibling chain for a given node 'sibling'. func AddSibling(sibling, n *Node) { for t := sibling.NextSibling; t != nil; t = t.NextSibling { sibling = t @@ -418,6 +423,19 @@ func AddSibling(sibling, n *Node) { } } +// AddImmediateSibling adds a new node 'n' as immediate sibling a given node 'sibling'. +func AddImmediateSibling(sibling, n *Node) { + n.Parent = sibling.Parent + n.NextSibling = sibling.NextSibling + sibling.NextSibling = n + n.PrevSibling = sibling + if n.NextSibling != nil { + n.NextSibling.PrevSibling = n + } else if n.Parent != nil { + sibling.Parent.LastChild = n + } +} + // RemoveFromTree removes a node and its subtree from the document // tree it is in. If the node is the root of the tree, then it's no-op. func RemoveFromTree(n *Node) { @@ -445,3 +463,15 @@ func RemoveFromTree(n *Node) { n.PrevSibling = nil n.NextSibling = nil } + +// GetRoot returns a root of the tree where 'n' is a node. +func GetRoot(n *Node) *Node { + if n == nil { + return nil + } + root := n + for root.Parent != nil { + root = root.Parent + } + return root +} diff --git a/node_test.go b/node_test.go index 1621766..7831274 100644 --- a/node_test.go +++ b/node_test.go @@ -250,7 +250,7 @@ func TestRemoveFromTree(t *testing.T) { testTrue(t, n != nil) RemoveFromTree(n) verifyNodePointers(t, doc) - testValue(t, doc.OutputXML(false), + testValue(t, doc.OutputXMLWithOptions(WithoutPreserveSpace()), ``) }) @@ -260,7 +260,7 @@ func TestRemoveFromTree(t *testing.T) { testTrue(t, n != nil) RemoveFromTree(n) verifyNodePointers(t, doc) - testValue(t, doc.OutputXML(false), + testValue(t, doc.OutputXMLWithOptions(WithoutPreserveSpace()), ``) }) @@ -270,7 +270,7 @@ func TestRemoveFromTree(t *testing.T) { testTrue(t, n != nil) RemoveFromTree(n) verifyNodePointers(t, doc) - testValue(t, doc.OutputXML(false), + testValue(t, doc.OutputXMLWithOptions(WithoutPreserveSpace()), ``) }) @@ -280,7 +280,7 @@ func TestRemoveFromTree(t *testing.T) { testTrue(t, n != nil) RemoveFromTree(n) verifyNodePointers(t, doc) - testValue(t, doc.OutputXML(false), + testValue(t, doc.OutputXMLWithOptions(WithoutPreserveSpace()), ``) }) @@ -290,7 +290,7 @@ func TestRemoveFromTree(t *testing.T) { testValue(t, procInst.Type, DeclarationNode) RemoveFromTree(procInst) verifyNodePointers(t, doc) - testValue(t, doc.OutputXML(false), + testValue(t, doc.OutputXMLWithOptions(WithoutPreserveSpace()), ``) }) @@ -300,7 +300,7 @@ func TestRemoveFromTree(t *testing.T) { testValue(t, commentNode.Type, CommentNode) RemoveFromTree(commentNode) verifyNodePointers(t, doc) - testValue(t, doc.OutputXML(false), + testValue(t, doc.OutputXMLWithOptions(WithoutPreserveSpace()), ``) }) @@ -308,11 +308,36 @@ func TestRemoveFromTree(t *testing.T) { doc := parseXML() RemoveFromTree(doc) verifyNodePointers(t, doc) - testValue(t, doc.OutputXML(false), + testValue(t, doc.OutputXMLWithOptions(WithoutPreserveSpace()), ``) }) } +func TestAddImmediateSibling(t *testing.T) { + s := ` + + + + + + + + + ` + root, err := Parse(strings.NewReader(s)) + if err != nil { + t.Error(err) + } + + aaa := findNode(root, "AAA") + n := aaa.SelectElement("BBB") + if n == nil { + t.Fatalf("n is nil") + } + AddImmediateSibling(n, &Node{Type: ElementNode, Data: "r"}) + testValue(t, root.OutputXMLWithOptions(WithoutPreserveSpace()), ``) +} + func TestSelectElement(t *testing.T) { s := ` @@ -497,7 +522,6 @@ func TestWriteWithNamespacePrefix(t *testing.T) { } } - func TestQueryWithPrefix(t *testing.T) { s := `ns2:ClientThis is a client fault` doc, _ := Parse(strings.NewReader(s)) @@ -582,7 +606,7 @@ func TestOutputXMLWithSpaceDirect(t *testing.T) { t.Errorf(`expected "%s", obtained "%s"`, expected, g) } - output := html.UnescapeString(doc.OutputXML(true)) + output := html.UnescapeString(doc.OutputXMLWithOptions(WithOutputSelf(), WithoutPreserveSpace())) if strings.Contains(output, "\n") { t.Errorf("the outputted xml contains newlines") } @@ -606,7 +630,7 @@ func TestOutputXMLWithSpaceOverwrittenToPreserve(t *testing.T) { t.Errorf(`expected "%s", obtained "%s"`, expected, g) } - output := html.UnescapeString(doc.OutputXML(true)) + output := html.UnescapeString(doc.OutputXMLWithOptions(WithOutputSelf(), WithoutPreserveSpace())) if strings.Contains(output, "\n") { t.Errorf("the outputted xml contains newlines") } @@ -680,8 +704,8 @@ func TestOutputXMLWithPreserveSpaceOption(t *testing.T) { ` doc, _ := Parse(strings.NewReader(s)) - resultWithSpace := doc.OutputXMLWithOptions(WithPreserveSpace()) - resultWithoutSpace := doc.OutputXMLWithOptions() + resultWithSpace := doc.OutputXMLWithOptions() + resultWithoutSpace := doc.OutputXMLWithOptions(WithoutPreserveSpace()) if !strings.Contains(resultWithSpace, "> Robert <") { t.Errorf("output was not expected. expected %v but got %v", " Robert ", resultWithSpace) } diff --git a/parse.go b/parse.go index 7627f46..d359b50 100644 --- a/parse.go +++ b/parse.go @@ -2,6 +2,7 @@ package xmlquery import ( "bufio" + "bytes" "encoding/xml" "fmt" "io" @@ -39,15 +40,31 @@ func Parse(r io.Reader) (*Node, error) { func ParseWithOptions(r io.Reader, options ParserOptions) (*Node, error) { p := createParser(r) options.apply(p) - for { - _, err := p.parse() - if err == io.EOF { - return p.doc, nil + var err error + for err == nil { + _, err = p.parse() + } + + if err == io.EOF { + // additional check for validity + // according to: https://www.w3.org/TR/xml + // the document MUST contain at least ONE element + valid := false + for doc := p.doc; doc != nil; doc = doc.NextSibling { + for node := doc.FirstChild; node != nil; node = node.NextSibling { + if node.Type == ElementNode { + valid = true + break + } + } } - if err != nil { - return nil, err + if !valid { + return nil, fmt.Errorf("xmlquery: invalid XML document") } + return p.doc, nil } + + return nil, err } type parser struct { @@ -168,7 +185,7 @@ func (p *parser) parse() (*Node, error) { if node.NamespaceURI != "" { if v, ok := p.space2prefix[node.NamespaceURI]; ok { - cached := string(p.reader.Cache()) + cached := string(p.reader.CacheWithLimit(len(v.name) + len(node.Data) + 2)) if strings.HasPrefix(cached, fmt.Sprintf("%s:%s", v.name, node.Data)) || strings.HasPrefix(cached, fmt.Sprintf("<%s:%s", v.name, node.Data)) { node.Prefix = v.name } @@ -228,12 +245,11 @@ func (p *parser) parse() (*Node, error) { } case xml.CharData: // First, normalize the cache... - cached := strings.ToUpper(string(p.reader.Cache())) + cached := bytes.ToUpper(p.reader.CacheWithLimit(9)) nodeType := TextNode - if strings.HasPrefix(cached, " - - - book 2 - - - book 2 - - -` - doc, err := Parse(strings.NewReader(s)) - if err != nil { - t.Fatal(err) - } - list := Find(doc, `/bk:books/bk:book`) - if found, expected := len(list), 2; found != expected { - t.Fatalf("should found %d bk:book but found %d", expected, found) - } -} - func TestDefaultNamespace_2(t *testing.T) { s := ` + + + book 2 + + + book 2 + + +` + doc, err := Parse(strings.NewReader(s)) + if err != nil { + t.Fatal(err) + } + list := Find(doc, `/bk:books/bk:book`) + if found, expected := len(list), 2; found != expected { + t.Fatalf("should found %d bk:book but found %d", expected, found) + } +} + func TestDuplicateNamespaceURL(t *testing.T) { s := ` @@ -290,7 +291,7 @@ func TestParse(t *testing.T) { testAttr(t, findNode(books[1], "title"), "lang", "en") testValue(t, findNode(books[1], "price").InnerText(), "39.95") - testValue(t, books[0].OutputXML(true), `Harry Potter29.99`) + testValue(t, books[0].OutputXMLWithOptions(WithOutputSelf(), WithoutPreserveSpace()), `Harry Potter29.99`) } func TestMissDeclaration(t *testing.T) { @@ -308,6 +309,14 @@ func TestMissDeclaration(t *testing.T) { } } +func TestNonXMLParse(t *testing.T) { + s := `{"a":null}` + doc, err := Parse(strings.NewReader(s)) + if err == nil || doc != nil { + t.Fatal(err) + } +} + func TestMissingNamespace(t *testing.T) { s := ` value 1 @@ -431,7 +440,7 @@ func TestStreamParser_InvalidXPath(t *testing.T) { } func testOutputXML(t *testing.T, msg string, expectedXML string, n *Node) { - if n.OutputXML(true) != expectedXML { + if n.OutputXMLWithOptions(WithOutputSelf(), WithoutPreserveSpace()) != expectedXML { t.Fatalf("%s, expected XML: '%s', actual: '%s'", msg, expectedXML, n.OutputXML(true)) } }