# HG changeset patch # User Chris Jones # Date 1380394937 21600 # Node ID bbd4166df95d45cda9fcbff5d90d8523cb7a9183 # Parent 69c5b4382e39efea2669e9cd339a61e91afe3c26 Simplified the API: There's only one constructor, and it does everything necessary to initiate the stream. StartSession() and Roster.Update() have both been eliminated. diff -r 69c5b4382e39 -r bbd4166df95d TODO.txt --- a/TODO.txt Sun Sep 22 17:43:34 2013 -0500 +++ b/TODO.txt Sat Sep 28 13:02:17 2013 -0600 @@ -24,3 +24,10 @@ Callback doesn't need to return bool. It shouldn't affect what's given to the client. + +Don't keep the password in memory once we're done with it. + +Rename extension.StanzaHandlers to something like StanzaTypes. + +Think about how to gracefully shutdown. Probably have a Close() +function. diff -r 69c5b4382e39 -r bbd4166df95d example/interact.go --- a/example/interact.go Sun Sep 22 17:43:34 2013 -0500 +++ b/example/interact.go Sat Sep 28 13:02:17 2013 -0600 @@ -42,7 +42,7 @@ } tlsConf := tls.Config{InsecureSkipVerify: true} - c, err := xmpp.NewClient(&jid, *pw, tlsConf, nil) + c, err := xmpp.NewClient(&jid, *pw, tlsConf, nil, xmpp.Presence{}, nil) if err != nil { log.Fatalf("NewClient(%v): %v", jid, err) } @@ -55,11 +55,6 @@ fmt.Println("done reading") }(c.Recv) - err = c.StartSession(&xmpp.Presence{}) - if err != nil { - log.Fatalf("StartSession: %v", err) - } - c.Roster.Update() roster := c.Roster.Get() fmt.Printf("%d roster entries:\n", len(roster)) for i, entry := range roster { diff -r 69c5b4382e39 -r bbd4166df95d xmpp/filter_test.go --- a/xmpp/filter_test.go Sun Sep 22 17:43:34 2013 -0500 +++ b/xmpp/filter_test.go Sat Sep 28 13:02:17 2013 -0600 @@ -34,6 +34,7 @@ func filterN(numFilts int, t *testing.T) { add := make(chan Filter) in := make(chan Stanza) + defer close(in) out := make(chan Stanza) go filterMgr(add, in, out) for i := 0; i < numFilts; i++ { diff -r 69c5b4382e39 -r bbd4166df95d xmpp/layer1.go --- a/xmpp/layer1.go Sun Sep 22 17:43:34 2013 -0500 +++ b/xmpp/layer1.go Sat Sep 28 13:02:17 2013 -0600 @@ -19,13 +19,13 @@ } func startLayer1(sock net.Conn, recvWriter io.WriteCloser, - sendReader io.ReadCloser) *layer1 { + sendReader io.ReadCloser, status <-chan Status) *layer1 { l1 := layer1{sock: sock} recvSocks := make(chan net.Conn) l1.recvSocks = recvSocks sendSocks := make(chan net.Conn, 1) l1.sendSocks = sendSocks - go recvTransport(recvSocks, recvWriter) + go recvTransport(recvSocks, recvWriter, status) go sendTransport(sendSocks, sendReader) recvSocks <- sock sendSocks <- sock @@ -50,12 +50,19 @@ l1.recvSocks <- l1.sock } -func recvTransport(socks <-chan net.Conn, w io.WriteCloser) { +func recvTransport(socks <-chan net.Conn, w io.WriteCloser, + status <-chan Status) { + defer w.Close() var sock net.Conn p := make([]byte, 1024) for { select { + case stat := <-status: + if stat == StatusShutdown { + return + } + case sock = <-socks: default: } @@ -72,12 +79,12 @@ } } Warn.Logf("recvTransport: %s", err) - break + return } nw, err := w.Write(p[:nr]) if nw < nr { Warn.Logf("recvTransport: %s", err) - break + return } } } diff -r 69c5b4382e39 -r bbd4166df95d xmpp/layer3.go --- a/xmpp/layer3.go Sun Sep 22 17:43:34 2013 -0500 +++ b/xmpp/layer3.go Sat Sep 28 13:02:17 2013 -0600 @@ -5,14 +5,12 @@ import ( "encoding/xml" - "fmt" ) // Callback to handle a stanza with a particular id. type callback struct { id string - // Return true means pass this to the application - f func(Stanza) bool + f func(Stanza) } // Receive XMPP stanzas from the client and send them on to the @@ -22,22 +20,18 @@ // inappropriate into our negotiations with the server. The control // channel controls this loop's activity. func sendStream(sendXml chan<- interface{}, recvXmpp <-chan Stanza, - control <-chan sendCmd) { + status <-chan Status) { defer close(sendXml) var input <-chan Stanza for { select { - case cmd := <-control: - switch cmd { - case sendDeny: + case stat := <-status: + switch stat { + default: input = nil - case sendAllow: + case StatusRunning: input = recvXmpp - case sendAbort: - return - default: - panic(fmt.Sprintf("unknown cmd %d", cmd)) } case x, ok := <-input: if !ok { @@ -53,13 +47,23 @@ } // Receive XMLish structures, handle all the stream-related ones, and -// send XMPP stanzas on to the client. -func (cl *Client) recvStream(recvXml <-chan interface{}, sendXmpp chan<- Stanza) { +// send XMPP stanzas on to the client once the connection is running. +func (cl *Client) recvStream(recvXml <-chan interface{}, sendXmpp chan<- Stanza, + status <-chan Status) { defer close(sendXmpp) + defer cl.statmgr.close() - handlers := make(map[string]func(Stanza) bool) + handlers := make(map[string]func(Stanza)) + doSend := false for { select { + case stat := <-status: + switch stat { + default: + doSend = false + case StatusRunning: + doSend = true + } case h := <-cl.handlers: handlers[h.id] = h.f case x, ok := <-recvXml: @@ -78,14 +82,13 @@ case *auth: cl.handleSasl(obj) case Stanza: - send := true id := obj.GetHeader().Id if handlers[id] != nil { f := handlers[id] delete(handlers, id) - send = f(obj) + f(obj) } - if send { + if doSend { sendXmpp <- obj } default: @@ -97,7 +100,7 @@ func (cl *Client) handleStreamError(se *streamError) { Info.Logf("Received stream error: %v", se) - cl.inputControl <- sendAbort + cl.setStatus(StatusShutdown) } func (cl *Client) handleFeatures(fe *Features) { @@ -123,6 +126,8 @@ func (cl *Client) handleTls(t *starttls) { cl.layer1.startTls(&cl.tlsConfig) + cl.setStatus(StatusConnectedTls) + // Now re-send the initial handshake message to start the new // session. cl.sendXml <- &stream{To: cl.Jid.Domain, Version: XMPPVersion} @@ -137,14 +142,13 @@ } msg := &Iq{Header: Header{Type: "set", Id: NextId(), Nested: []interface{}{bindReq}}} - f := func(st Stanza) bool { + f := func(st Stanza) { iq, ok := st.(*Iq) if !ok { Warn.Log("non-iq response") } if iq.Type == "error" { Warn.Log("Resource binding failed") - return false } var bindRepl *bindIq for _, ele := range iq.Nested { @@ -155,22 +159,18 @@ } if bindRepl == nil { Warn.Logf("Bad bind reply: %#v", iq) - return false } jidStr := bindRepl.Jid if jidStr == nil || *jidStr == "" { Warn.Log("Can't bind empty resource") - return false } jid := new(JID) if err := jid.Set(*jidStr); err != nil { Warn.Logf("Can't parse JID %s: %s", *jidStr, err) - return false } cl.Jid = *jid Info.Logf("Bound resource: %s", cl.Jid.String()) - cl.bindDone() - return false + cl.setStatus(StatusBound) } cl.SetCallback(msg.Id, f) cl.sendXml <- msg @@ -182,7 +182,7 @@ // available on the normal Client.Recv channel. The callback must not // read from that channel, as deliveries on it cannot proceed until // the handler returns true or false. -func (cl *Client) SetCallback(id string, f func(Stanza) bool) { +func (cl *Client) SetCallback(id string, f func(Stanza)) { h := &callback{id: id, f: f} cl.handlers <- h } diff -r 69c5b4382e39 -r bbd4166df95d xmpp/roster.go --- a/xmpp/roster.go Sun Sep 22 17:43:34 2013 -0500 +++ b/xmpp/roster.go Sat Sep 28 13:02:17 2013 -0600 @@ -22,15 +22,9 @@ Group []string } -type rosterCb struct { - id string - cb func() -} - type Roster struct { Extension get chan []RosterItem - callbacks chan rosterCb toServer chan Stanza } @@ -41,24 +35,21 @@ func (r *Roster) rosterMgr(upd <-chan Stanza) { roster := make(map[string]RosterItem) - waits := make(map[string]func()) var snapshot []RosterItem + var get chan<- []RosterItem for { select { + case get <- snapshot: + case stan, ok := <-upd: if !ok { return } - hdr := stan.GetHeader() - if f := waits[hdr.Id]; f != nil { - delete(waits, hdr.Id) - f() - } iq, ok := stan.(*Iq) if !ok { continue } - if iq.Type != "result" { + if iq.Type != "result" && iq.Type != "set" { continue } var rq *RosterQuery @@ -78,9 +69,7 @@ for _, ri := range roster { snapshot = append(snapshot, ri) } - case r.get <- snapshot: - case cb := <-r.callbacks: - waits[cb.id] = cb.cb + get = r.get } } } @@ -119,33 +108,22 @@ r.StanzaHandlers[rName] = reflect.TypeOf(RosterQuery{}) r.RecvFilter, r.SendFilter = r.makeFilters() r.get = make(chan []RosterItem) - r.callbacks = make(chan rosterCb) r.toServer = make(chan Stanza) return &r } // Return the most recent snapshot of the roster status. This is // updated automatically as roster updates are received from the -// server, but especially in response to calls to Update(). +// server. This function may block immediately after the XMPP +// connection has been established, until the first roster update is +// received from the server. func (r *Roster) Get() []RosterItem { return <-r.get } -// Synchronously fetch this entity's roster from the server and cache -// that information. The client can access the roster by watching for -// RosterQuery objects or by calling Get(). -func (r *Roster) Update() { +// Asynchronously fetch this entity's roster from the server. +func (r *Roster) update() { iq := &Iq{Header: Header{Type: "get", Id: NextId(), Nested: []interface{}{RosterQuery{}}}} - waitchan := make(chan int) - done := func() { - close(waitchan) - } - r.waitFor(iq.Id, done) r.toServer <- iq - <-waitchan } - -func (r *Roster) waitFor(id string, cb func()) { - r.callbacks <- rosterCb{id: id, cb: cb} -} diff -r 69c5b4382e39 -r bbd4166df95d xmpp/sasl.go --- a/xmpp/sasl.go Sun Sep 22 17:43:34 2013 -0500 +++ b/xmpp/sasl.go Sat Sep 28 13:02:17 2013 -0600 @@ -52,6 +52,7 @@ case "failure": Info.Log("SASL authentication failed") case "success": + cl.setStatus(StatusAuthenticated) Info.Log("Sasl authentication succeeded") cl.Features = nil ss := &stream{To: cl.Jid.Domain, Version: XMPPVersion} diff -r 69c5b4382e39 -r bbd4166df95d xmpp/status.go --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/xmpp/status.go Sat Sep 28 13:02:17 2013 -0600 @@ -0,0 +1,101 @@ +// Track the current status of the connection to the server. + +package xmpp + +import ( + "fmt" +) + +type statmgr struct { + newStatus chan Status + newlistener chan chan Status +} + +func newStatmgr(client chan<- Status) *statmgr { + s := statmgr{} + s.newStatus = make(chan Status) + s.newlistener = make(chan chan Status) + go s.manager(client) + return &s +} + +func (s *statmgr) manager(client chan<- Status) { + // We handle this specially, in case the client doesn't read + // our final status message. + defer func() { + if client != nil { + select { + case client <- StatusShutdown: + default: + } + close(client) + } + }() + + stat := StatusUnconnected + listeners := []chan Status{} + for { + select { + case stat = <-s.newStatus: + for _, l := range listeners { + sendToListener(l, stat) + } + if client != nil && stat != StatusShutdown { + client <- stat + } + case l, ok := <-s.newlistener: + if !ok { + return + } + defer close(l) + sendToListener(l, stat) + listeners = append(listeners, l) + } + } +} + +func sendToListener(listen chan Status, stat Status) { + for { + select { + case <-listen: + case listen <- stat: + return + } + } +} + +func (cl *Client) setStatus(stat Status) { + cl.statmgr.setStatus(stat) +} + +func (s *statmgr) setStatus(stat Status) { + s.newStatus <- stat +} + +func (s *statmgr) newListener() <-chan Status { + l := make(chan Status, 1) + s.newlistener <- l + return l +} + +func (s *statmgr) close() { + close(s.newlistener) +} + +func (s *statmgr) awaitStatus(waitFor Status) error { + // BUG(chris): This routine leaks one channel each time it's + // called. Listeners are never removed. + l := s.newListener() + for current := range l { + if current == waitFor { + return nil + } + if current == StatusShutdown { + break + } + if current > waitFor { + return nil + } + } + return fmt.Errorf("shut down waiting for status change") +} diff -r 69c5b4382e39 -r bbd4166df95d xmpp/status_test.go --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/xmpp/status_test.go Sat Sep 28 13:02:17 2013 -0600 @@ -0,0 +1,63 @@ +package xmpp + +import ( + "testing" + "time" +) + +func TestStatusListen(t *testing.T) { + sm := newStatmgr(nil) + l := sm.newListener() + stat, ok := <-l + if !ok { + t.Error() + } else if stat != StatusUnconnected { + t.Errorf("got %d", stat) + } + + sm.setStatus(StatusConnected) + stat, ok = <-l + if !ok { + t.Error() + } else if stat != StatusConnected { + t.Errorf("got %d", stat) + } + + sm.setStatus(StatusBound) + stat, ok = <-l + if !ok { + t.Error() + } else if stat != StatusBound { + t.Errorf("got %d", stat) + } + + sm.setStatus(StatusShutdown) + stat = <-l + if stat != StatusShutdown { + t.Errorf("got %d", stat) + } +} + +func TestAwaitStatus(t *testing.T) { + sm := newStatmgr(nil) + + syncCh := make(chan int) + + go func() { + sm.setStatus(StatusConnected) + sm.setStatus(StatusBound) + time.Sleep(100 * time.Millisecond) + syncCh <- 0 + }() + + err := sm.awaitStatus(StatusBound) + if err != nil { + t.Fatal(err) + } + select { + case <-syncCh: + t.Fatal("didn't wait") + default: + } + <-syncCh +} diff -r 69c5b4382e39 -r bbd4166df95d xmpp/xmpp.go --- a/xmpp/xmpp.go Sun Sep 22 17:43:34 2013 -0500 +++ b/xmpp/xmpp.go Sat Sep 28 13:02:17 2013 -0600 @@ -36,20 +36,36 @@ clientSrv = "xmpp-client" ) -// Flow control for preventing sending stanzas until negotiation has -// completed. -type sendCmd int +// Status of the connection. +type Status int const ( - sendAllowConst = iota - sendDenyConst - sendAbortConst + statusUnconnected = iota + statusConnected + statusConnectedTls + statusAuthenticated + statusBound + statusRunning + statusShutdown ) var ( - sendAllow sendCmd = sendAllowConst - sendDeny sendCmd = sendDenyConst - sendAbort sendCmd = sendAbortConst + // The client has not yet connected, or it has been + // disconnected from the server. + StatusUnconnected Status = statusUnconnected + // Initial connection established. + StatusConnected Status = statusConnected + // Like StatusConnected, but with TLS. + StatusConnectedTls Status = statusConnectedTls + // Authentication succeeded. + StatusAuthenticated Status = statusAuthenticated + // Resource binding complete. + StatusBound Status = statusBound + // Session has started and normal message traffic can be sent + // and received. + StatusRunning Status = statusRunning + // The session has closed, or is in the process of closing. + StatusShutdown Status = statusShutdown ) // A filter can modify the XMPP traffic to or from the remote @@ -74,14 +90,12 @@ // The client in a client-server XMPP connection. type Client struct { - // This client's JID. This will be updated asynchronously by - // the time StartSession() returns. + // This client's full JID, including resource Jid JID password string saslExpected string authDone bool handlers chan *callback - inputControl chan sendCmd // Incoming XMPP stanzas from the remote will be published on // this channel. Information which is used by this library to // set up the XMPP stream will not appear here. @@ -90,34 +104,51 @@ // channel. Send chan<- Stanza sendXml chan<- interface{} + statmgr *statmgr // The client's roster is also known as the buddy list. It's // the set of contacts which are known to this JID, or which // this JID is known to. Roster Roster - // Features advertised by the remote. This will be updated - // asynchronously as new features are received throughout the - // connection process. It should not be updated once - // StartSession() returns. + // Features advertised by the remote. Features *Features sendFilterAdd, recvFilterAdd chan Filter - // Allows the user to override the TLS configuration. - tlsConfig tls.Config - layer1 *layer1 + tlsConfig tls.Config + layer1 *layer1 } -// Connect to the appropriate server and authenticate as the given JID -// with the given password. This function will return as soon as a TCP -// connection has been established, but before XMPP stream negotiation -// has completed. The negotiation will occur asynchronously, and any -// send operation to Client.Send will block until negotiation -// (resource binding) is complete. The caller must immediately start -// reading from Client.Recv. -func NewClient(jid *JID, password string, tlsconf tls.Config, exts []Extension) (*Client, error) { +// Creates an XMPP client identified by the given JID, authenticating +// with the provided password and TLS config. Zero or more extensions +// may be specified. The initial presence will be broadcast. If status +// is non-nil, connection progress information will be sent on it. +func NewClient(jid *JID, password string, tlsconf tls.Config, exts []Extension, + pr Presence, status chan<- Status) (*Client, error) { + // Include the mandatory extensions. roster := newRosterExt() exts = append(exts, roster.Extension) exts = append(exts, bindExt) + cl := new(Client) + cl.Roster = *roster + cl.password = password + cl.Jid = *jid + cl.handlers = make(chan *callback, 100) + cl.tlsConfig = tlsconf + cl.sendFilterAdd = make(chan Filter) + cl.recvFilterAdd = make(chan Filter) + cl.statmgr = newStatmgr(status) + + extStanza := make(map[xml.Name]reflect.Type) + for _, ext := range exts { + for k, v := range ext.StanzaHandlers { + if _, ok := extStanza[k]; ok { + return nil, fmt.Errorf("duplicate handler %s", + k) + } + extStanza[k] = v + } + } + // Resolve the domain in the JID. _, srvs, err := net.LookupSRV(clientSrv, "tcp", jid.Domain) if err != nil { @@ -145,32 +176,13 @@ if tcp == nil { return nil, err } - - cl := new(Client) - cl.Roster = *roster - cl.password = password - cl.Jid = *jid - cl.handlers = make(chan *callback, 100) - cl.inputControl = make(chan sendCmd) - cl.tlsConfig = tlsconf - cl.sendFilterAdd = make(chan Filter) - cl.recvFilterAdd = make(chan Filter) - - extStanza := make(map[xml.Name]reflect.Type) - for _, ext := range exts { - for k, v := range ext.StanzaHandlers { - if _, ok := extStanza[k]; ok { - return nil, fmt.Errorf("duplicate handler %s", - k) - } - extStanza[k] = v - } - } + cl.setStatus(StatusConnected) // Start the transport handler, initially unencrypted. recvReader, recvWriter := io.Pipe() sendReader, sendWriter := io.Pipe() - cl.layer1 = startLayer1(tcp, recvWriter, sendReader) + cl.layer1 = startLayer1(tcp, recvWriter, sendReader, + cl.statmgr.newListener()) // Start the reader and writer that convert to and from XML. recvXmlCh := make(chan interface{}) @@ -182,12 +194,12 @@ // Start the reader and writer that convert between XML and // XMPP stanzas. recvRawXmpp := make(chan Stanza) - go cl.recvStream(recvXmlCh, recvRawXmpp) + go cl.recvStream(recvXmlCh, recvRawXmpp, cl.statmgr.newListener()) sendRawXmpp := make(chan Stanza) - go sendStream(sendXmlCh, sendRawXmpp, cl.inputControl) + go sendStream(sendXmlCh, sendRawXmpp, cl.statmgr.newListener()) - // Start the manager for the filters that can modify what the - // app sees. + // Start the managers for the filters that can modify what the + // app sees or sends. recvFiltXmpp := make(chan Stanza) cl.Recv = recvFiltXmpp go filterMgr(cl.recvFilterAdd, recvRawXmpp, recvFiltXmpp) @@ -204,6 +216,44 @@ hsOut := &stream{To: jid.Domain, Version: XMPPVersion} cl.sendXml <- hsOut + // Wait until resource binding is complete. + if err := cl.statmgr.awaitStatus(StatusBound); err != nil { + return nil, err + } + + // Initialize the session. + id := NextId() + iq := &Iq{Header: Header{To: cl.Jid.Domain, Id: id, Type: "set", + Nested: []interface{}{Generic{XMLName: xml.Name{Space: NsSession, Local: "session"}}}}} + ch := make(chan error) + f := func(st Stanza) { + iq, ok := st.(*Iq) + if !ok { + Warn.Log("iq reply not iq; can't start session") + ch <- errors.New("bad session start reply") + } + if iq.Type == "error" { + Warn.Logf("Can't start session: %v", iq) + ch <- iq.Error + } + ch <- nil + } + cl.SetCallback(id, f) + cl.sendXml <- iq + // Now wait until the callback is called. + if err := <-ch; err != nil { + return nil, err + } + + // This allows the client to receive stanzas. + cl.setStatus(StatusRunning) + + // Request the roster. + cl.Roster.update() + + // Send the initial presence. + cl.Send <- &pr + return cl, nil } @@ -236,48 +286,3 @@ Debug.Log(buf) } } - -// bindDone is called when we've finished resource binding (and all -// the negotiations that precede it). Now we can start accepting -// traffic from the app. -func (cl *Client) bindDone() { - cl.inputControl <- sendAllow -} - -// Start an XMPP session. A typical XMPP client should call this -// immediately after creating the Client in order to start the session -// and broadcast an initial presence. The presence can be as simple as -// a newly-initialized Presence struct. See RFC 3921, Section -// 3. After calling this, a normal client should call Roster.Update(). -func (cl *Client) StartSession(pr *Presence) error { - id := NextId() - iq := &Iq{Header: Header{To: cl.Jid.Domain, Id: id, Type: "set", - Nested: []interface{}{Generic{XMLName: xml.Name{Space: NsSession, Local: "session"}}}}} - ch := make(chan error) - f := func(st Stanza) bool { - iq, ok := st.(*Iq) - if !ok { - Warn.Log("iq reply not iq; can't start session") - ch <- errors.New("bad session start reply") - return false - } - if iq.Type == "error" { - Warn.Logf("Can't start session: %v", iq) - ch <- iq.Error - return false - } - ch <- nil - return false - } - cl.SetCallback(id, f) - cl.Send <- iq - - // Now wait until the callback is called. - if err := <-ch; err != nil { - return err - } - if pr != nil { - cl.Send <- pr - } - return nil -}