Extended stanzas work now.
authorChris Jones <chris@cjones.org>
Sat, 31 Dec 2011 11:39:23 -0700
changeset 38 2839fece923e
parent 37 fbda8e925fdf
child 39 4a06f7ccfa84
Extended stanzas work now.
roster.go
roster_test.go
stream.go
structs.go
structs_test.go
xmpp.go
xmpp_test.go
--- a/roster.go	Sat Dec 31 10:11:01 2011 -0700
+++ b/roster.go	Sat Dec 31 11:39:23 2011 -0700
@@ -6,19 +6,12 @@
 
 import (
 	"fmt"
-	"io"
 	"os"
 	"xml"
 )
 
 // This file contains support for roster management, RFC 3921, Section 7.
 
-type RosterIq struct {
-	Iq
-	Query RosterQuery
-}
-var _ ExtendedStanza = &RosterIq{}
-
 // Roster query/result
 type RosterQuery struct {
 	// Should always be NsRoster, "query"
@@ -36,40 +29,31 @@
 	Group []string
 }
 
-func (riq *RosterIq) MarshalXML() ([]byte, os.Error) {
-	return marshalXML(riq)
-}
-
-func (riq *RosterIq) InnerMarshal(w io.Writer) os.Error {
-	return xml.Marshal(w, riq.Query)
-}
-
 // Implicitly becomes part of NewClient's extStanza arg.
-func rosterStanza(name *xml.Name) ExtendedStanza {
-	return &RosterIq{}
+func newRosterQuery(name *xml.Name) interface{} {
+	return &RosterQuery{}
 }
 
 // Synchronously fetch this entity's roster from the server and cache
 // that information.
 func (cl *Client) fetchRoster() os.Error {
-	iq := &RosterIq{Iq: Iq{From: cl.Jid.String(), Id: <- cl.Id,
-		Type: "get"}, Query: RosterQuery{XMLName:
-			xml.Name{Local: "query", Space: NsRoster}}}
+	iq := &Iq{From: cl.Jid.String(), Id: <- cl.Id, Type: "get",
+		Nested: RosterQuery{XMLName: xml.Name{Local: "query",
+			Space: NsRoster}}}
 	ch := make(chan os.Error)
 	f := func(st Stanza) bool {
-		iq, ok := st.(*RosterIq)
-		if !ok {
-			ch <- os.NewError(fmt.Sprintf(
-				"Roster query result not iq: %v", st))
-			return false
-		}
 		if iq.Type == "error" {
 			ch <- iq.Error
 			return false
 		}
-		q := iq.Query
-		cl.roster = make(map[string] *RosterItem, len(q.Item))
-		for _, item := range(q.Item) {
+		rq, ok := st.XNested().(*RosterQuery)
+		if !ok {
+			ch <- os.NewError(fmt.Sprintf(
+				"Roster query result not query: %v", st))
+			return false
+		}
+		cl.roster = make(map[string] *RosterItem, len(rq.Item))
+		for _, item := range(rq.Item) {
 			cl.roster[item.Jid] = &item
 		}
 		ch <- nil
--- a/roster_test.go	Sat Dec 31 10:11:01 2011 -0700
+++ b/roster_test.go	Sat Dec 31 11:39:23 2011 -0700
@@ -5,6 +5,8 @@
 package xmpp
 
 import (
+	"reflect"
+	"strings"
 	"testing"
 	"xml"
 )
@@ -12,14 +14,39 @@
 // This is mostly just tests of the roster data structures.
 
 func TestRosterIqMarshal(t *testing.T) {
-	iq := &RosterIq{Iq: Iq{From: "from", Lang: "en"}, Query:
+	iq := &Iq{From: "from", Lang: "en", Nested:
 		RosterQuery{XMLName: xml.Name{Space: NsRoster, Local:
 				"query"}, Item: []RosterItem{}}}
-	var s Stanza = iq
-	if _, ok := s.(ExtendedStanza) ; !ok {
-		t.Errorf("Not an ExtendedStanza")
-	}
 	exp := `<iq from="from" xml:lang="en"><query xmlns="` +
 		NsRoster + `"></query></iq>`
 	assertMarshal(t, exp, iq)
 }
+
+func TestRosterIqUnmarshal(t *testing.T) {
+	str := `<iq from="from" xml:lang="en"><query xmlns="` +
+		NsRoster + `"><item jid="a@b.c"/></query></iq>`
+	r := strings.NewReader(str)
+	var st Stanza = &Iq{}
+	xml.Unmarshal(r, st)
+	err := parseExtended(st, newRosterQuery)
+	if err != nil {
+		t.Fatalf("parseExtended: %v", err)
+	}
+	assertEquals(t, "iq", st.XName())
+	assertEquals(t, "from", st.XFrom())
+	assertEquals(t, "en", st.XLang())
+	nested := st.XNested()
+	if nested == nil {
+		t.Fatalf("nested nil")
+	}
+	rq, ok := nested.(*RosterQuery)
+	if !ok {
+		t.Fatalf("nested not RosterQuery: %v",
+			reflect.TypeOf(nested))
+	}
+	if len(rq.Item) != 1 {
+		t.Fatalf("Wrong # items: %v", rq.Item)
+	}
+	item := rq.Item[0]
+	assertEquals(t, "a@b.c", item.Jid)
+}
--- a/stream.go	Sat Dec 31 10:11:01 2011 -0700
+++ b/stream.go	Sat Dec 31 11:39:23 2011 -0700
@@ -12,7 +12,6 @@
 
 import (
 	"big"
-	"bytes"
 	"crypto/md5"
 	"crypto/rand"
 	"crypto/tls"
@@ -82,7 +81,7 @@
 }
 
 func readXml(r io.Reader, ch chan<- interface{},
-	extStanza map[string] func(*xml.Name) ExtendedStanza) {
+	extStanza map[string] func(*xml.Name) interface{}) {
 	if debug {
 		pr, pw := io.Pipe()
 		go tee(r, pw, "S: ")
@@ -150,15 +149,12 @@
 		// namespace that's registered with one of our
 		// extensions. If so, we need to re-unmarshal into an
 		// object of the correct type.
-		if st, ok := obj.(Stanza) ; ok && st.XChild() != nil {
-			name := st.XChild().XMLName
+		if st, ok := obj.(Stanza) ; ok && st.generic() != nil {
+			name := st.generic().XMLName
 			ns := name.Space
 			con := extStanza[ns]
 			if con != nil {
-				obj = con(&name)
-				xmlStr, _ := marshalXML(st)
-				r := bytes.NewBuffer(xmlStr)
-				err = xml.Unmarshal(r, obj)
+				err = parseExtended(st, con)
 				if err != nil {
 					log.Printf("ext unmarshal: %v",
 						err)
@@ -172,6 +168,38 @@
 	}
 }
 
+func parseExtended(st Stanza, con func(*xml.Name) interface{}) os.Error {
+	name := st.generic().XMLName
+	nested := con(&name)
+
+	// Now parse the stanza's innerxml to find the string that we
+	// can unmarshal this nested element from.
+	reader := strings.NewReader(st.innerxml())
+	p := xml.NewParser(reader)
+	var start *xml.StartElement
+	for {
+		t, err := p.Token()
+		if err != nil {
+			return err
+		}
+		if se, ok := t.(xml.StartElement) ; ok {
+			if se.Name.Space == name.Space {
+				start = &se
+				break
+			}
+		}
+	}
+
+	// Unmarshal the nested element and stuff it back into the
+	// stanza.
+	err := p.Unmarshal(nested, start)
+	if err != nil {
+		return err
+	}
+	st.setNested(nested)
+	return nil
+}
+
 func writeXml(w io.Writer, ch <-chan interface{}) {
 	if debug {
 		pr, pw := io.Pipe()
@@ -527,6 +555,7 @@
 	return response
 }
 
+// BUG(cjyar) This should use iq.nested rather than iq.generic.
 // Send a request to bind a resource. RFC 3920, section 7.
 func (cl *Client) bind(bind *Generic) {
 	res := cl.Jid.Resource
@@ -542,7 +571,7 @@
 			log.Println("Resource binding failed")
 			return false
 		}
-		bind := st.XChild()
+		bind := st.generic()
 		if bind == nil {
 			log.Println("nil resource bind")
 			return false
--- a/structs.go	Sat Dec 31 10:11:01 2011 -0700
+++ b/structs.go	Sat Dec 31 11:39:23 2011 -0700
@@ -94,15 +94,12 @@
 	// A nested error element, if any.
 	XError() *Error
 	// A (non-error) nested element, if any.
-	XChild() *Generic
+	XNested() interface{}
+	setNested(interface{})
+	generic() *Generic
 	innerxml() string
 }
 
-type ExtendedStanza interface {
-	Stanza
-	InnerMarshal(io.Writer) os.Error
-}
-
 // message stanza
 type Message struct {
 	To string `xml:"attr"`
@@ -116,10 +113,10 @@
 	Body *Generic
 	Thread *Generic
 	Any *Generic
+	Nested interface{}
 }
 var _ xml.Marshaler = &Message{}
 var _ Stanza = &Message{}
-var _ ExtendedStanza = &Message{}
 
 // presence stanza
 type Presence struct {
@@ -134,10 +131,10 @@
 	Status *Generic
 	Priority *Generic
 	Any *Generic
+	Nested interface{}
 }
 var _ xml.Marshaler = &Presence{}
 var _ Stanza = &Presence{}
-var _ ExtendedStanza = &Presence{}
 
 // iq stanza
 type Iq struct {
@@ -149,6 +146,7 @@
 	Innerxml string `xml:"innerxml"`
 	Error *Error
 	Any *Generic
+	Nested interface{}
 }
 var _ xml.Marshaler = &Iq{}
 var _ Stanza = &Iq{}
@@ -295,23 +293,15 @@
 		writeField(buf, "xml:lang", st.XLang())
 	}
 	buf.WriteString(">")
-	if ext, ok := st.(ExtendedStanza) ; ok {
-		if st.XError() != nil {
-			bytes, _ := st.XError().MarshalXML()
-			buf.WriteString(string(bytes))
-		}
-		err := ext.InnerMarshal(buf)
-		if err != nil {
-			return nil, err
-		}
-	} else {
-		inner := st.innerxml()
-		if inner == "" {
-			xml.Marshal(buf, st.XChild())
-		} else {
-			buf.WriteString(st.innerxml())
-		}
+
+	if st.XNested() != nil {
+		xml.Marshal(buf, st.XNested())
+	} else if st.generic() != nil {
+		xml.Marshal(buf, st.generic())
+	} else if st.innerxml() != "" {
+		buf.WriteString(st.innerxml())
 	}
+
 	buf.WriteString("</")
 	buf.WriteString(st.XName())
 	buf.WriteString(">")
@@ -363,7 +353,15 @@
 	return m.Error
 }
 
-func (m *Message) XChild() *Generic {
+func (m *Message) XNested() interface{} {
+	return m.Nested
+}
+
+func (m *Message) setNested(n interface{}) {
+	m.Nested = n
+}
+
+func (m *Message) generic() *Generic {
 	return m.Any
 }
 
@@ -419,7 +417,15 @@
 	return p.Error
 }
 
-func (p *Presence) XChild() *Generic {
+func (p *Presence) XNested() interface{} {
+	return p.Nested
+}
+
+func (p *Presence) setNested(n interface{}) {
+	p.Nested = n
+}
+
+func (p *Presence) generic() *Generic {
 	return p.Any
 }
 
@@ -475,7 +481,15 @@
 	return iq.Error
 }
 
-func (iq *Iq) XChild() *Generic {
+func (iq *Iq) XNested() interface{} {
+	return iq.Nested
+}
+
+func (iq *Iq) setNested(n interface{}) {
+	iq.Nested = n
+}
+
+func (iq *Iq) generic() *Generic {
 	return iq.Any
 }
 
--- a/structs_test.go	Sat Dec 31 10:11:01 2011 -0700
+++ b/structs_test.go	Sat Dec 31 11:39:23 2011 -0700
@@ -106,11 +106,11 @@
 	if st.XError() != nil {
 		t.Errorf("iq: error %v", st.XError())
 	}
-	if st.XChild() == nil {
+	if st.generic() == nil {
 		t.Errorf("iq: nil child")
 	}
-	assertEquals(t, "foo", st.XChild().XMLName.Local)
-	assertEquals(t, "text", st.XChild().Chardata)
+	assertEquals(t, "foo", st.generic().XMLName.Local)
+	assertEquals(t, "text", st.generic().Chardata)
 
 	str = `<message to="alice" from="bob"/>`
 	st, err = ParseStanza(str)
@@ -125,8 +125,8 @@
 	if st.XError() != nil {
 		t.Errorf("message: error %v", st.XError())
 	}
-	if st.XChild() != nil {
-		t.Errorf("message: child %v", st.XChild())
+	if st.generic() != nil {
+		t.Errorf("message: child %v", st.generic())
 	}
 
 	str = `<presence/>`
--- a/xmpp.go	Sat Dec 31 10:11:01 2011 -0700
+++ b/xmpp.go	Sat Dec 31 11:39:23 2011 -0700
@@ -84,7 +84,7 @@
 // send operation to Client.Out will block until negotiation (resource
 // binding) is complete.
 func NewClient(jid *JID, password string,
-	extStanza map[string] func(*xml.Name) ExtendedStanza) (*Client, os.Error) {
+	extStanza map[string] func(*xml.Name) interface{}) (*Client, os.Error) {
 	// Resolve the domain in the JID.
 	_, srvs, err := net.LookupSRV(clientSrv, "tcp", jid.Domain)
 	if err != nil {
@@ -122,9 +122,9 @@
 	cl.Id = idCh
 
 	if extStanza == nil {
-		extStanza = make(map[string] func(*xml.Name) ExtendedStanza)
+		extStanza = make(map[string] func(*xml.Name) interface{})
 	}
-	extStanza[NsRoster] = rosterStanza
+	extStanza[NsRoster] = newRosterQuery
 
 	// Start the unique id generator.
 	go makeIds(idCh)
@@ -165,7 +165,7 @@
 }
 
 func startXmlReader(r io.Reader,
-	extStanza map[string] func(*xml.Name) ExtendedStanza) <-chan interface{} {
+	extStanza map[string] func(*xml.Name) interface{}) <-chan interface{} {
 	ch := make(chan interface{})
 	go readXml(r, ch, extStanza)
 	return ch
--- a/xmpp_test.go	Sat Dec 31 10:11:01 2011 -0700
+++ b/xmpp_test.go	Sat Dec 31 11:39:23 2011 -0700
@@ -16,7 +16,7 @@
 func TestReadError(t *testing.T) {
 	r := strings.NewReader(`<stream:error><bad-foo/></stream:error>`)
 	ch := make(chan interface{})
-	go readXml(r, ch, make(map[string] func(*xml.Name) ExtendedStanza))
+	go readXml(r, ch, make(map[string] func(*xml.Name) interface{}))
 	x := <- ch
 	se, ok := x.(*streamError)
 	if !ok {
@@ -32,7 +32,7 @@
 		`<text xml:lang="en" xmlns="` + NsStreams +
 		`">Error text</text></stream:error>`)
 	ch = make(chan interface{})
-	go readXml(r, ch, make(map[string] func(*xml.Name) ExtendedStanza))
+	go readXml(r, ch, make(map[string] func(*xml.Name) interface{}))
 	x = <- ch
 	se, ok = x.(*streamError)
 	if !ok {
@@ -50,7 +50,7 @@
 		`xmlns="jabber:client" xmlns:stream="` + NsStream +
 		`" version="1.0">`)
 	ch := make(chan interface{})
-	go readXml(r, ch, make(map[string] func(*xml.Name) ExtendedStanza))
+	go readXml(r, ch, make(map[string] func(*xml.Name) interface{}))
 	x := <- ch
 	ss, ok := x.(*stream)
 	if !ok {