Added a capability to use extensions. There are still some bugs with
authorChris Jones <chris@cjones.org>
Fri, 30 Dec 2011 21:49:00 -0700
changeset 36 9fe022261dcc
parent 35 569833f08780
child 37 fbda8e925fdf
Added a capability to use extensions. There are still some bugs with marshaling involving receiver functions on embedded structs.
Makefile
examples/interact.go
roster.go
roster_test.go
stream.go
structs.go
xmpp.go
xmpp_test.go
--- a/Makefile	Fri Dec 30 18:25:08 2011 -0700
+++ b/Makefile	Fri Dec 30 21:49:00 2011 -0700
@@ -7,6 +7,7 @@
 TARG=cjyar/xmpp
 GOFILES=\
 	xmpp.go \
+	roster.go \
 	stream.go \
 	structs.go \
 
--- a/examples/interact.go	Fri Dec 30 18:25:08 2011 -0700
+++ b/examples/interact.go	Fri Dec 30 21:49:00 2011 -0700
@@ -24,7 +24,7 @@
 		os.Exit(2)
 	}
 
-	c, err := xmpp.NewClient(&jid, *pw)
+	c, err := xmpp.NewClient(&jid, *pw, nil)
 	if err != nil {
 		log.Fatalf("NewClient(%v): %v", jid, err)
 	}
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/roster.go	Fri Dec 30 21:49:00 2011 -0700
@@ -0,0 +1,91 @@
+// Copyright 2011 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package xmpp
+
+import (
+	"fmt"
+	"io"
+	"os"
+	"xml"
+)
+
+// This file contains support for roster management, RFC 3921, Section 7.
+
+type RosterIq struct {
+	Iq
+	Query RosterQuery
+}
+var _ ExtendedStanza = &RosterIq{}
+
+// Roster query/result
+type RosterQuery struct {
+	// Should always be NsRoster, "query"
+	XMLName xml.Name
+	Item []RosterItem
+}
+
+// See RFC 3921, Section 7.1.
+type RosterItem struct {
+	// Should always be "item"
+	XMLName xml.Name
+	Jid string `xml:"attr"`
+	Subscription string `xml:"attr"`
+	Name string `xml:"attr"`
+	Group []string
+}
+
+func (riq *RosterIq) InnerMarshal(w io.Writer) os.Error {
+	return xml.Marshal(w, riq.Query)
+}
+
+// Implicitly becomes part of NewClient's extStanza arg.
+func rosterStanza(name *xml.Name) ExtendedStanza {
+	return &RosterIq{}
+}
+
+// Synchronously fetch this entity's roster from the server and cache
+// that information.
+func (cl *Client) fetchRoster() os.Error {
+	iq := &RosterIq{Iq: Iq{From: cl.Jid.String(), Id: <- cl.Id,
+		Type: "get"}, Query: RosterQuery{XMLName:
+			xml.Name{Local: "query", Space: NsRoster}}}
+	ch := make(chan os.Error)
+	f := func(st Stanza) bool {
+		iq, ok := st.(*RosterIq)
+		if !ok {
+			ch <- os.NewError(fmt.Sprintf(
+				"Roster query result not iq: %v", st))
+			return false
+		}
+		if iq.Type == "error" {
+			ch <- iq.Error
+			return false
+		}
+		q := iq.Query
+		cl.roster = make(map[string] *RosterItem, len(q.Item))
+		for _, item := range(q.Item) {
+			cl.roster[item.Jid] = &item
+		}
+		ch <- nil
+		return false
+	}
+	cl.HandleStanza(iq.Id, f)
+	cl.Out <- iq
+	// Wait for f to complete.
+	return <- ch
+}
+
+// BUG(cjyar) The roster isn't actually updated when things change.
+
+// Returns the current roster of other entities which this one has a
+// relationship with. Changes to the roster will be signaled by an
+// appropriate Iq appearing on Client.In. See RFC 3921, Section 7.4.
+func (cl *Client) Roster() map[string] *RosterItem {
+	r := make(map[string] *RosterItem)
+	for key, val := range(cl.roster) {
+		r[key] = val
+	}
+	return r
+}
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/roster_test.go	Fri Dec 30 21:49:00 2011 -0700
@@ -0,0 +1,25 @@
+// Copyright 2011 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package xmpp
+
+import (
+	"testing"
+	"xml"
+)
+
+// This is mostly just tests of the roster data structures.
+
+func TestRosterIqMarshal(t *testing.T) {
+	iq := &RosterIq{Iq: Iq{From: "from", Lang: "en"}, Query:
+		RosterQuery{XMLName: xml.Name{Space: NsRoster, Local:
+				"query"}, Item: []RosterItem{}}}
+	var s Stanza = iq
+	if _, ok := s.(ExtendedStanza) ; !ok {
+		t.Errorf("Not an ExtendedStanza")
+	}
+	exp := `<iq from="from" xml:lang="en"><query xmlns="` +
+		NsRoster + `"></query></iq>`
+	assertMarshal(t, exp, iq)
+}
--- a/stream.go	Fri Dec 30 18:25:08 2011 -0700
+++ b/stream.go	Fri Dec 30 21:49:00 2011 -0700
@@ -12,6 +12,7 @@
 
 import (
 	"big"
+	"bytes"
 	"crypto/md5"
 	"crypto/rand"
 	"crypto/tls"
@@ -80,7 +81,8 @@
 	}
 }
 
-func readXml(r io.Reader, ch chan<- interface{}) {
+func readXml(r io.Reader, ch chan<- interface{},
+	extStanza map[string] func(*xml.Name) ExtendedStanza) {
 	if debug {
 		pr, pw := io.Pipe()
 		go tee(r, pw, "S: ")
@@ -144,9 +146,26 @@
 			break
 		}
 
-		// BUG(cjyar) If it's a Stanza, use reflection to
-		// search for any Generic elements and fill in
-		// their attributes.
+		// If it's a Stanza, we check its "Any" element for a
+		// namespace that's registered with one of our
+		// extensions. If so, we need to re-unmarshal into an
+		// object of the correct type.
+		if st, ok := obj.(Stanza) ; ok && st.XChild() != nil {
+			name := st.XChild().XMLName
+			ns := name.Space
+			con := extStanza[ns]
+			if con != nil {
+				obj = con(&name)
+				xmlStr, _ := marshalXML(st)
+				r := bytes.NewBuffer(xmlStr)
+				err = xml.Unmarshal(r, &obj)
+				if err != nil {
+					log.Printf("ext unmarshal: %v",
+						err)
+					break
+				}
+			}
+		}
 
 		// Put it on the channel.
 		ch <- obj
--- a/structs.go	Fri Dec 30 18:25:08 2011 -0700
+++ b/structs.go	Fri Dec 30 21:49:00 2011 -0700
@@ -95,6 +95,11 @@
 	XError() *Error
 	// A (non-error) nested element, if any.
 	XChild() *Generic
+	innerxml() string
+}
+
+type ExtendedStanza interface {
+	InnerMarshal(io.Writer) os.Error
 }
 
 // message stanza
@@ -104,6 +109,7 @@
 	Id string `xml:"attr"`
 	Type string `xml:"attr"`
 	Lang string `xml:"attr"`
+	Innerxml string `xml:"innerxml"`
 	Error *Error
 	Subject *Generic
 	Body *Generic
@@ -112,6 +118,7 @@
 }
 var _ xml.Marshaler = &Message{}
 var _ Stanza = &Message{}
+var _ ExtendedStanza = &Message{}
 
 // presence stanza
 type Presence struct {
@@ -120,6 +127,7 @@
 	Id string `xml:"attr"`
 	Type string `xml:"attr"`
 	Lang string `xml:"attr"`
+	Innerxml string `xml:"innerxml"`
 	Error *Error
 	Show *Generic
 	Status *Generic
@@ -128,6 +136,7 @@
 }
 var _ xml.Marshaler = &Presence{}
 var _ Stanza = &Presence{}
+var _ ExtendedStanza = &Presence{}
 
 // iq stanza
 type Iq struct {
@@ -136,30 +145,13 @@
 	Id string `xml:"attr"`
 	Type string `xml:"attr"`
 	Lang string `xml:"attr"`
+	Innerxml string `xml:"innerxml"`
 	Error *Error
 	Any *Generic
-	Query *RosterQuery
 }
 var _ xml.Marshaler = &Iq{}
 var _ Stanza = &Iq{}
 
-// Roster query/result
-type RosterQuery struct {
-	// Should always be query in the NsRoster namespace
-	XMLName xml.Name
-	Item []RosterItem
-}
-
-// See RFC 3921, Section 7.1.
-type RosterItem struct {
-	// Should always be "item"
-	XMLName xml.Name
-	Jid string `xml:"attr"`
-	Subscription string `xml:"attr"`
-	Name string `xml:"attr"`
-	Group []string
-}
-
 // Describes an XMPP stanza error. See RFC 3920, Section 9.3.
 type Error struct {
 	// The error type attribute.
@@ -282,8 +274,6 @@
 		u.XMLName.Local)
 }
 
-// BUG(cjyar) This is fragile. We should find a way to use go's native
-// XML marshaling.
 func marshalXML(st Stanza) ([]byte, os.Error) {
 	buf := bytes.NewBuffer(nil)
 	buf.WriteString("<")
@@ -304,15 +294,22 @@
 		writeField(buf, "xml:lang", st.XLang())
 	}
 	buf.WriteString(">")
-	if st.XError() != nil {
-		bytes, _ := st.XError().MarshalXML()
-		buf.WriteString(string(bytes))
-	}
-	if st.XChild() != nil {
-		xml.Marshal(buf, st.XChild())
-	}
-	if iq, ok := st.(*Iq) ; ok && iq.Query != nil {
-		xml.Marshal(buf, iq.Query)
+	if ext, ok := st.(ExtendedStanza) ; ok {
+		if st.XError() != nil {
+			bytes, _ := st.XError().MarshalXML()
+			buf.WriteString(string(bytes))
+		}
+		err := ext.InnerMarshal(buf)
+		if err != nil {
+			return nil, err
+		}
+	} else {
+		inner := st.innerxml()
+		if inner == "" {
+			xml.Marshal(buf, st.XChild())
+		} else {
+			buf.WriteString(st.innerxml())
+		}
 	}
 	buf.WriteString("</")
 	buf.WriteString(st.XName())
@@ -369,10 +366,30 @@
 	return m.Any
 }
 
+func (m *Message) innerxml() string {
+	return m.Innerxml
+}
+
 func (m *Message) MarshalXML() ([]byte, os.Error) {
 	return marshalXML(m)
 }
 
+func (m *Message) InnerMarshal(w io.Writer) os.Error {
+	err := xml.Marshal(w, m.Subject)
+	if err != nil {
+		return err
+	}
+	err = xml.Marshal(w, m.Body)
+	if err != nil {
+		return err
+	}
+	err = xml.Marshal(w, m.Thread)
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
 func (p *Presence) XName() string {
 	return "presence"
 }
@@ -405,10 +422,30 @@
 	return p.Any
 }
 
+func (p *Presence) innerxml() string {
+	return p.Innerxml
+}
+
 func (p *Presence) MarshalXML() ([]byte, os.Error) {
 	return marshalXML(p)
 }
 
+func (p *Presence) InnerMarshal(w io.Writer) os.Error {
+	err := xml.Marshal(w, p.Show)
+	if err != nil {
+		return err
+	}
+	err = xml.Marshal(w, p.Status)
+	if err != nil {
+		return err
+	}
+	err = xml.Marshal(w, p.Priority)
+	if err != nil {
+		return err
+	}
+	return nil
+}
+
 func (iq *Iq) XName() string {
 	return "iq"
 }
@@ -441,6 +478,10 @@
 	return iq.Any
 }
 
+func (iq *Iq) innerxml() string {
+	return iq.Innerxml
+}
+
 func (iq *Iq) MarshalXML() ([]byte, os.Error) {
 	return marshalXML(iq)
 }
--- a/xmpp.go	Fri Dec 30 18:25:08 2011 -0700
+++ b/xmpp.go	Fri Dec 30 21:49:00 2011 -0700
@@ -83,7 +83,8 @@
 // has completed. The negotiation will occur asynchronously, and any
 // send operation to Client.Out will block until negotiation (resource
 // binding) is complete.
-func NewClient(jid *JID, password string) (*Client, os.Error) {
+func NewClient(jid *JID, password string,
+	extStanza map[string] func(*xml.Name) ExtendedStanza) (*Client, os.Error) {
 	// Resolve the domain in the JID.
 	_, srvs, err := net.LookupSRV(clientSrv, "tcp", jid.Domain)
 	if err != nil {
@@ -120,6 +121,11 @@
 	idCh := make(chan string)
 	cl.Id = idCh
 
+	if extStanza == nil {
+		extStanza = make(map[string] func(*xml.Name) ExtendedStanza)
+	}
+	extStanza[NsRoster] = rosterStanza
+
 	// Start the unique id generator.
 	go makeIds(idCh)
 
@@ -127,7 +133,7 @@
 	tlsr, tlsw := cl.startTransport()
 
 	// Start the reader and writers that convert to and from XML.
-	xmlIn := startXmlReader(tlsr)
+	xmlIn := startXmlReader(tlsr, extStanza)
 	cl.xmlOut = startXmlWriter(tlsw)
 
 	// Start the XMPP stream handler which filters stream-level
@@ -158,9 +164,10 @@
 	return inr, outw
 }
 
-func startXmlReader(r io.Reader) <-chan interface{} {
+func startXmlReader(r io.Reader,
+	extStanza map[string] func(*xml.Name) ExtendedStanza) <-chan interface{} {
 	ch := make(chan interface{})
-	go readXml(r, ch)
+	go readXml(r, ch, extStanza)
 	return ch
 }
 
@@ -287,54 +294,3 @@
 	}
 	return nil
 }
-
-// Synchronously fetch this entity's roster from the server and cache
-// that information.
-func (cl *Client) fetchRoster() os.Error {
-	iq := &Iq{From: cl.Jid.String(), Id: <- cl.Id, Type: "get",
-		Query: &RosterQuery{XMLName: xml.Name{Local: "query",
-			Space: NsRoster}}}
-	ch := make(chan os.Error)
-	f := func(st Stanza) bool {
-		iq, ok := st.(*Iq)
-		if !ok {
-			ch <- os.NewError(fmt.Sprintf(
-				"Roster query result not iq: %v", st))
-			return false
-		}
-		if iq.Type == "error" {
-			ch <- iq.Error
-			return false
-		}
-		q := iq.Query
-		if q == nil {
-			ch <- os.NewError(fmt.Sprintf(
-				"Roster query result nil query: %v",
-				iq))
-			return false
-		}
-		cl.roster = make(map[string] *RosterItem, len(q.Item))
-		for _, item := range(q.Item) {
-			cl.roster[item.Jid] = &item
-		}
-		ch <- nil
-		return false
-	}
-	cl.HandleStanza(iq.Id, f)
-	cl.Out <- iq
-	// Wait for f to complete.
-	return <- ch
-}
-
-// BUG(cjyar) The roster isn't actually updated when things change.
-
-// Returns the current roster of other entities which this one has a
-// relationship with. Changes to the roster will be signaled by an
-// appropriate Iq appearing on Client.In. See RFC 3921, Section 7.4.
-func (cl *Client) Roster() map[string] *RosterItem {
-	r := make(map[string] *RosterItem)
-	for key, val := range(cl.roster) {
-		r[key] = val
-	}
-	return r
-}
--- a/xmpp_test.go	Fri Dec 30 18:25:08 2011 -0700
+++ b/xmpp_test.go	Fri Dec 30 21:49:00 2011 -0700
@@ -16,7 +16,7 @@
 func TestReadError(t *testing.T) {
 	r := strings.NewReader(`<stream:error><bad-foo/></stream:error>`)
 	ch := make(chan interface{})
-	go readXml(r, ch)
+	go readXml(r, ch, make(map[string] func(*xml.Name) ExtendedStanza))
 	x := <- ch
 	se, ok := x.(*streamError)
 	if !ok {
@@ -32,7 +32,7 @@
 		`<text xml:lang="en" xmlns="` + NsStreams +
 		`">Error text</text></stream:error>`)
 	ch = make(chan interface{})
-	go readXml(r, ch)
+	go readXml(r, ch, make(map[string] func(*xml.Name) ExtendedStanza))
 	x = <- ch
 	se, ok = x.(*streamError)
 	if !ok {
@@ -50,7 +50,7 @@
 		`xmlns="jabber:client" xmlns:stream="` + NsStream +
 		`" version="1.0">`)
 	ch := make(chan interface{})
-	go readXml(r, ch)
+	go readXml(r, ch, make(map[string] func(*xml.Name) ExtendedStanza))
 	x := <- ch
 	ss, ok := x.(*stream)
 	if !ok {