diff --git a/cached_reader.go b/cached_reader.go
index fe389c5..eb35d1b 100644
--- a/cached_reader.go
+++ b/cached_reader.go
@@ -5,45 +5,48 @@ import (
)
type cachedReader struct {
- buffer *bufio.Reader
- cache []byte
- cacheCap int
- cacheLen int
+ buffer *bufio.Reader
+ cache []byte
caching bool
}
func newCachedReader(r *bufio.Reader) *cachedReader {
return &cachedReader{
- buffer: r,
- cache: make([]byte, 4096),
- cacheCap: 4096,
- cacheLen: 0,
- caching: false,
+ buffer: r,
+ cache: make([]byte, 0, 4096),
+ caching: false,
}
}
func (c *cachedReader) StartCaching() {
- c.cacheLen = 0
+ c.cache = c.cache[:0]
c.caching = true
}
-func (c *cachedReader) ReadByte() (byte, error) {
- if !c.caching {
- return c.buffer.ReadByte()
- }
- b, err := c.buffer.ReadByte()
+func (c *cachedReader) ReadByte() (b byte, err error) {
+ b, err = c.buffer.ReadByte()
if err != nil {
- return b, err
+ return
}
- if c.cacheLen < c.cacheCap {
- c.cache[c.cacheLen] = b
- c.cacheLen++
+ if c.caching {
+ c.cacheByte(b)
}
- return b, err
+ return
}
func (c *cachedReader) Cache() []byte {
- return c.cache[:c.cacheLen]
+ return c.cache
+}
+
+func (c *cachedReader) CacheWithLimit(n int) []byte {
+ if n < 1 {
+ return nil
+ }
+ l := len(c.cache)
+ if n > l {
+ n = l
+ }
+ return c.cache[:n]
}
func (c *cachedReader) StopCaching() {
@@ -55,11 +58,9 @@ func (c *cachedReader) Read(p []byte) (int, error) {
if err != nil {
return n, err
}
- if c.caching && c.cacheLen < c.cacheCap {
+ if c.caching {
for i := 0; i < n; i++ {
- c.cache[c.cacheLen] = p[i]
- c.cacheLen++
- if c.cacheLen >= c.cacheCap {
+ if !c.cacheByte(p[i]) {
break
}
}
@@ -67,3 +68,12 @@ func (c *cachedReader) Read(p []byte) (int, error) {
return n, err
}
+func (c *cachedReader) cacheByte(b byte) bool {
+ n := len(c.cache)
+ if n == cap(c.cache) {
+ return false
+ }
+ c.cache = c.cache[:n+1]
+ c.cache[n] = b
+ return true
+}
diff --git a/cached_reader_test.go b/cached_reader_test.go
index 8cbfef2..d3a931d 100644
--- a/cached_reader_test.go
+++ b/cached_reader_test.go
@@ -39,4 +39,19 @@ func TestCaching(t *testing.T) {
if !bytes.Equal(cached, []byte("BCDEF")) {
t.Fatalf("Incorrect cached buffer value")
}
+
+ cached = cachedReader.CacheWithLimit(-1)
+ if cached != nil {
+ t.Fatalf("Incorrect cached buffer value")
+ }
+
+ cached = cachedReader.CacheWithLimit(3)
+ if !bytes.Equal(cached, []byte("BCD")) {
+ t.Fatalf("Incorrect cached buffer value")
+ }
+
+ cached = cachedReader.CacheWithLimit(1000)
+ if !bytes.Equal(cached, []byte("BCDEF")) {
+ t.Fatalf("Incorrect cached buffer value")
+ }
}
diff --git a/node.go b/node.go
index 43b2a8e..f864bc6 100644
--- a/node.go
+++ b/node.go
@@ -92,6 +92,13 @@ func WithPreserveSpace() OutputOption {
}
}
+// WithoutPreserveSpace will not preserve spaces in output
+func WithoutPreserveSpace() OutputOption {
+ return func(oc *outputConfiguration) {
+ oc.preserveSpaces = false
+ }
+}
+
// WithIndentation sets the indentation string used for formatting the output.
func WithIndentation(indentation string) OutputOption {
return func(oc *outputConfiguration) {
@@ -328,7 +335,9 @@ func (n *Node) Write(writer io.Writer, self bool) error {
// WriteWithOptions writes xml with given options to given writer.
func (n *Node) WriteWithOptions(writer io.Writer, opts ...OutputOption) (err error) {
- config := &outputConfiguration{}
+ config := &outputConfiguration{
+ preserveSpaces: true,
+ }
// Set the options
for _, opt := range opts {
opt(config)
@@ -400,11 +409,7 @@ func AddChild(parent, n *Node) {
parent.LastChild = n
}
-// AddSibling adds a new node 'n' as a sibling of a given node 'sibling'.
-// Note it is not necessarily true that the new node 'n' would be added
-// immediately after 'sibling'. If 'sibling' isn't the last child of its
-// parent, then the new node 'n' will be added at the end of the sibling
-// chain of their parent.
+// AddSibling adds a new node 'n' as a last node of sibling chain for a given node 'sibling'.
func AddSibling(sibling, n *Node) {
for t := sibling.NextSibling; t != nil; t = t.NextSibling {
sibling = t
@@ -418,6 +423,19 @@ func AddSibling(sibling, n *Node) {
}
}
+// AddImmediateSibling adds a new node 'n' as immediate sibling a given node 'sibling'.
+func AddImmediateSibling(sibling, n *Node) {
+ n.Parent = sibling.Parent
+ n.NextSibling = sibling.NextSibling
+ sibling.NextSibling = n
+ n.PrevSibling = sibling
+ if n.NextSibling != nil {
+ n.NextSibling.PrevSibling = n
+ } else if n.Parent != nil {
+ sibling.Parent.LastChild = n
+ }
+}
+
// RemoveFromTree removes a node and its subtree from the document
// tree it is in. If the node is the root of the tree, then it's no-op.
func RemoveFromTree(n *Node) {
@@ -445,3 +463,15 @@ func RemoveFromTree(n *Node) {
n.PrevSibling = nil
n.NextSibling = nil
}
+
+// GetRoot returns a root of the tree where 'n' is a node.
+func GetRoot(n *Node) *Node {
+ if n == nil {
+ return nil
+ }
+ root := n
+ for root.Parent != nil {
+ root = root.Parent
+ }
+ return root
+}
diff --git a/node_test.go b/node_test.go
index 1621766..7831274 100644
--- a/node_test.go
+++ b/node_test.go
@@ -250,7 +250,7 @@ func TestRemoveFromTree(t *testing.T) {
testTrue(t, n != nil)
RemoveFromTree(n)
verifyNodePointers(t, doc)
- testValue(t, doc.OutputXML(false),
+ testValue(t, doc.OutputXMLWithOptions(WithoutPreserveSpace()),
``)
})
@@ -260,7 +260,7 @@ func TestRemoveFromTree(t *testing.T) {
testTrue(t, n != nil)
RemoveFromTree(n)
verifyNodePointers(t, doc)
- testValue(t, doc.OutputXML(false),
+ testValue(t, doc.OutputXMLWithOptions(WithoutPreserveSpace()),
``)
})
@@ -270,7 +270,7 @@ func TestRemoveFromTree(t *testing.T) {
testTrue(t, n != nil)
RemoveFromTree(n)
verifyNodePointers(t, doc)
- testValue(t, doc.OutputXML(false),
+ testValue(t, doc.OutputXMLWithOptions(WithoutPreserveSpace()),
``)
})
@@ -280,7 +280,7 @@ func TestRemoveFromTree(t *testing.T) {
testTrue(t, n != nil)
RemoveFromTree(n)
verifyNodePointers(t, doc)
- testValue(t, doc.OutputXML(false),
+ testValue(t, doc.OutputXMLWithOptions(WithoutPreserveSpace()),
``)
})
@@ -290,7 +290,7 @@ func TestRemoveFromTree(t *testing.T) {
testValue(t, procInst.Type, DeclarationNode)
RemoveFromTree(procInst)
verifyNodePointers(t, doc)
- testValue(t, doc.OutputXML(false),
+ testValue(t, doc.OutputXMLWithOptions(WithoutPreserveSpace()),
``)
})
@@ -300,7 +300,7 @@ func TestRemoveFromTree(t *testing.T) {
testValue(t, commentNode.Type, CommentNode)
RemoveFromTree(commentNode)
verifyNodePointers(t, doc)
- testValue(t, doc.OutputXML(false),
+ testValue(t, doc.OutputXMLWithOptions(WithoutPreserveSpace()),
``)
})
@@ -308,11 +308,36 @@ func TestRemoveFromTree(t *testing.T) {
doc := parseXML()
RemoveFromTree(doc)
verifyNodePointers(t, doc)
- testValue(t, doc.OutputXML(false),
+ testValue(t, doc.OutputXMLWithOptions(WithoutPreserveSpace()),
``)
})
}
+func TestAddImmediateSibling(t *testing.T) {
+ s := `
+
+
+
+
+
+
+
+
+ `
+ root, err := Parse(strings.NewReader(s))
+ if err != nil {
+ t.Error(err)
+ }
+
+ aaa := findNode(root, "AAA")
+ n := aaa.SelectElement("BBB")
+ if n == nil {
+ t.Fatalf("n is nil")
+ }
+ AddImmediateSibling(n, &Node{Type: ElementNode, Data: "r"})
+ testValue(t, root.OutputXMLWithOptions(WithoutPreserveSpace()), ``)
+}
+
func TestSelectElement(t *testing.T) {
s := `
@@ -497,7 +522,6 @@ func TestWriteWithNamespacePrefix(t *testing.T) {
}
}
-
func TestQueryWithPrefix(t *testing.T) {
s := `ns2:ClientThis is a client fault`
doc, _ := Parse(strings.NewReader(s))
@@ -582,7 +606,7 @@ func TestOutputXMLWithSpaceDirect(t *testing.T) {
t.Errorf(`expected "%s", obtained "%s"`, expected, g)
}
- output := html.UnescapeString(doc.OutputXML(true))
+ output := html.UnescapeString(doc.OutputXMLWithOptions(WithOutputSelf(), WithoutPreserveSpace()))
if strings.Contains(output, "\n") {
t.Errorf("the outputted xml contains newlines")
}
@@ -606,7 +630,7 @@ func TestOutputXMLWithSpaceOverwrittenToPreserve(t *testing.T) {
t.Errorf(`expected "%s", obtained "%s"`, expected, g)
}
- output := html.UnescapeString(doc.OutputXML(true))
+ output := html.UnescapeString(doc.OutputXMLWithOptions(WithOutputSelf(), WithoutPreserveSpace()))
if strings.Contains(output, "\n") {
t.Errorf("the outputted xml contains newlines")
}
@@ -680,8 +704,8 @@ func TestOutputXMLWithPreserveSpaceOption(t *testing.T) {
`
doc, _ := Parse(strings.NewReader(s))
- resultWithSpace := doc.OutputXMLWithOptions(WithPreserveSpace())
- resultWithoutSpace := doc.OutputXMLWithOptions()
+ resultWithSpace := doc.OutputXMLWithOptions()
+ resultWithoutSpace := doc.OutputXMLWithOptions(WithoutPreserveSpace())
if !strings.Contains(resultWithSpace, "> Robert <") {
t.Errorf("output was not expected. expected %v but got %v", " Robert ", resultWithSpace)
}
diff --git a/parse.go b/parse.go
index 7627f46..d359b50 100644
--- a/parse.go
+++ b/parse.go
@@ -2,6 +2,7 @@ package xmlquery
import (
"bufio"
+ "bytes"
"encoding/xml"
"fmt"
"io"
@@ -39,15 +40,31 @@ func Parse(r io.Reader) (*Node, error) {
func ParseWithOptions(r io.Reader, options ParserOptions) (*Node, error) {
p := createParser(r)
options.apply(p)
- for {
- _, err := p.parse()
- if err == io.EOF {
- return p.doc, nil
+ var err error
+ for err == nil {
+ _, err = p.parse()
+ }
+
+ if err == io.EOF {
+ // additional check for validity
+ // according to: https://www.w3.org/TR/xml
+ // the document MUST contain at least ONE element
+ valid := false
+ for doc := p.doc; doc != nil; doc = doc.NextSibling {
+ for node := doc.FirstChild; node != nil; node = node.NextSibling {
+ if node.Type == ElementNode {
+ valid = true
+ break
+ }
+ }
}
- if err != nil {
- return nil, err
+ if !valid {
+ return nil, fmt.Errorf("xmlquery: invalid XML document")
}
+ return p.doc, nil
}
+
+ return nil, err
}
type parser struct {
@@ -168,7 +185,7 @@ func (p *parser) parse() (*Node, error) {
if node.NamespaceURI != "" {
if v, ok := p.space2prefix[node.NamespaceURI]; ok {
- cached := string(p.reader.Cache())
+ cached := string(p.reader.CacheWithLimit(len(v.name) + len(node.Data) + 2))
if strings.HasPrefix(cached, fmt.Sprintf("%s:%s", v.name, node.Data)) || strings.HasPrefix(cached, fmt.Sprintf("<%s:%s", v.name, node.Data)) {
node.Prefix = v.name
}
@@ -228,12 +245,11 @@ func (p *parser) parse() (*Node, error) {
}
case xml.CharData:
// First, normalize the cache...
- cached := strings.ToUpper(string(p.reader.Cache()))
+ cached := bytes.ToUpper(p.reader.CacheWithLimit(9))
nodeType := TextNode
- if strings.HasPrefix(cached, "
-
-
- book 2
-
-
- book 2
-
-
-`
- doc, err := Parse(strings.NewReader(s))
- if err != nil {
- t.Fatal(err)
- }
- list := Find(doc, `/bk:books/bk:book`)
- if found, expected := len(list), 2; found != expected {
- t.Fatalf("should found %d bk:book but found %d", expected, found)
- }
-}
-
func TestDefaultNamespace_2(t *testing.T) {
s := `