diff -r 626c390682fc -r b4bd77d58a3e xmpp/xmpp.go --- a/xmpp/xmpp.go Sun Feb 09 09:50:38 2014 -0700 +++ b/xmpp/xmpp.go Sun Feb 09 09:52:28 2014 -0700 @@ -7,14 +7,13 @@ package xmpp import ( - "bytes" "crypto/tls" "encoding/xml" - "errors" "fmt" "io" "net" "reflect" + "sync" ) const ( @@ -36,38 +35,6 @@ clientSrv = "xmpp-client" ) -// Status of the connection. -type Status int - -const ( - statusUnconnected = iota - statusConnected - statusConnectedTls - statusAuthenticated - statusBound - statusRunning - statusShutdown -) - -var ( - // The client has not yet connected, or it has been - // disconnected from the server. - StatusUnconnected Status = statusUnconnected - // Initial connection established. - StatusConnected Status = statusConnected - // Like StatusConnected, but with TLS. - StatusConnectedTls Status = statusConnectedTls - // Authentication succeeded. - StatusAuthenticated Status = statusAuthenticated - // Resource binding complete. - StatusBound Status = statusBound - // Session has started and normal message traffic can be sent - // and received. - StatusRunning Status = statusRunning - // The session has closed, or is in the process of closing. - StatusShutdown Status = statusShutdown -) - // A filter can modify the XMPP traffic to or from the remote // server. It's part of an Extension. The filter function will be // called in a new goroutine, so it doesn't need to return. The filter @@ -76,10 +43,9 @@ // Extensions can add stanza filters and/or new XML element types. type Extension struct { - // Maps from an XML namespace to a function which constructs a - // structure to hold the contents of stanzas in that - // namespace. - StanzaHandlers map[xml.Name]reflect.Type + // Maps from an XML name to a structure which holds stanza + // contents with that name. + StanzaTypes map[xml.Name]reflect.Type // If non-nil, will be called once to start the filter // running. RecvFilter intercepts incoming messages on their // way from the remote server to the application; SendFilter @@ -101,9 +67,10 @@ // set up the XMPP stream will not appear here. Recv <-chan Stanza // Outgoing XMPP stanzas to the server should be sent to this - // channel. + // channel. The application should not close this channel; + // rather, call Close(). Send chan<- Stanza - sendXml chan<- interface{} + sendRaw chan<- interface{} statmgr *statmgr // The client's roster is also known as the buddy list. It's // the set of contacts which are known to this JID, or which @@ -114,6 +81,8 @@ sendFilterAdd, recvFilterAdd chan Filter tlsConfig tls.Config layer1 *layer1 + error chan error + shutdownOnce sync.Once } // Creates an XMPP client identified by the given JID, authenticating @@ -123,34 +92,8 @@ func NewClient(jid *JID, password string, tlsconf tls.Config, exts []Extension, pr Presence, status chan<- Status) (*Client, error) { - // Include the mandatory extensions. - roster := newRosterExt() - exts = append(exts, roster.Extension) - exts = append(exts, bindExt) - - cl := new(Client) - cl.Roster = *roster - cl.password = password - cl.Jid = *jid - cl.handlers = make(chan *callback, 100) - cl.tlsConfig = tlsconf - cl.sendFilterAdd = make(chan Filter) - cl.recvFilterAdd = make(chan Filter) - cl.statmgr = newStatmgr(status) - - extStanza := make(map[xml.Name]reflect.Type) - for _, ext := range exts { - for k, v := range ext.StanzaHandlers { - if _, ok := extStanza[k]; ok { - return nil, fmt.Errorf("duplicate handler %s", - k) - } - extStanza[k] = v - } - } - // Resolve the domain in the JID. - _, srvs, err := net.LookupSRV(clientSrv, "tcp", jid.Domain) + _, srvs, err := net.LookupSRV(clientSrv, "tcp", jid.Domain()) if err != nil { return nil, fmt.Errorf("LookupSrv %s: %v", jid.Domain, err) } @@ -176,20 +119,75 @@ if tcp == nil { return nil, err } + + return newClient(tcp, jid, password, tlsconf, exts, pr, status) +} + +// Connect to the specified host and port. This is otherwise identical +// to NewClient. +func NewClientFromHost(jid *JID, password string, tlsconf tls.Config, + exts []Extension, pr Presence, status chan<- Status, host string, + port int) (*Client, error) { + + addrStr := fmt.Sprintf("%s:%d", host, port) + addr, err := net.ResolveTCPAddr("tcp", addrStr) + if err != nil { + return nil, err + } + tcp, err := net.DialTCP("tcp", nil, addr) + if err != nil { + return nil, err + } + + return newClient(tcp, jid, password, tlsconf, exts, pr, status) +} + +func newClient(tcp *net.TCPConn, jid *JID, password string, tlsconf tls.Config, + exts []Extension, pr Presence, status chan<- Status) (*Client, error) { + + // Include the mandatory extensions. + roster := newRosterExt() + exts = append(exts, roster.Extension) + exts = append(exts, bindExt) + + cl := new(Client) + cl.Roster = *roster + cl.password = password + cl.Jid = *jid + cl.handlers = make(chan *callback, 100) + cl.tlsConfig = tlsconf + cl.sendFilterAdd = make(chan Filter) + cl.recvFilterAdd = make(chan Filter) + cl.statmgr = newStatmgr(status) + cl.error = make(chan error, 1) + + extStanza := make(map[xml.Name]reflect.Type) + for _, ext := range exts { + for k, v := range ext.StanzaTypes { + if _, ok := extStanza[k]; ok { + return nil, fmt.Errorf("duplicate handler %s", + k) + } + extStanza[k] = v + } + } + + // The thing that called this made a TCP connection, so now we + // can signal that it's connected. cl.setStatus(StatusConnected) // Start the transport handler, initially unencrypted. recvReader, recvWriter := io.Pipe() sendReader, sendWriter := io.Pipe() - cl.layer1 = startLayer1(tcp, recvWriter, sendReader, + cl.layer1 = cl.startLayer1(tcp, recvWriter, sendReader, cl.statmgr.newListener()) // Start the reader and writer that convert to and from XML. recvXmlCh := make(chan interface{}) - go recvXml(recvReader, recvXmlCh, extStanza) + go cl.recvXml(recvReader, recvXmlCh, extStanza) sendXmlCh := make(chan interface{}) - cl.sendXml = sendXmlCh - go sendXml(sendWriter, sendXmlCh) + cl.sendRaw = sendXmlCh + go cl.sendXml(sendWriter, sendXmlCh) // Start the reader and writer that convert between XML and // XMPP stanzas. @@ -213,36 +211,37 @@ } // Initial handshake. - hsOut := &stream{To: jid.Domain, Version: XMPPVersion} - cl.sendXml <- hsOut + hsOut := &stream{To: jid.Domain(), Version: XMPPVersion} + cl.sendRaw <- hsOut // Wait until resource binding is complete. if err := cl.statmgr.awaitStatus(StatusBound); err != nil { - return nil, err + return nil, cl.getError(err) } + // Forget about the password, for paranoia's sake. + cl.password = "" + // Initialize the session. id := NextId() - iq := &Iq{Header: Header{To: cl.Jid.Domain, Id: id, Type: "set", + iq := &Iq{Header: Header{To: JID(cl.Jid.Domain()), Id: id, Type: "set", Nested: []interface{}{Generic{XMLName: xml.Name{Space: NsSession, Local: "session"}}}}} ch := make(chan error) f := func(st Stanza) { iq, ok := st.(*Iq) if !ok { - Warn.Log("iq reply not iq; can't start session") - ch <- errors.New("bad session start reply") + ch <- fmt.Errorf("bad session start reply: %#v", st) } if iq.Type == "error" { - Warn.Logf("Can't start session: %v", iq) - ch <- iq.Error + ch <- fmt.Errorf("Can't start session: %v", iq.Error) } ch <- nil } cl.SetCallback(id, f) - cl.sendXml <- iq + cl.sendRaw <- iq // Now wait until the callback is called. if err := <-ch; err != nil { - return nil, err + return nil, cl.getError(err) } // This allows the client to receive stanzas. @@ -254,35 +253,45 @@ // Send the initial presence. cl.Send <- &pr - return cl, nil + return cl, cl.getError(nil) +} + +func (cl *Client) Close() { + // Shuts down the receivers: + cl.setStatus(StatusShutdown) + + // Shuts down the senders: + cl.shutdownOnce.Do(func() { close(cl.Send) }) } -func tee(r io.Reader, w io.Writer, prefix string) { - defer func(w io.Writer) { - if c, ok := w.(io.Closer); ok { - c.Close() - } - }(w) - - buf := bytes.NewBuffer([]uint8(prefix)) - for { - var c [1]byte - n, _ := r.Read(c[:]) - if n == 0 { - break - } - n, _ = w.Write(c[:n]) - if n == 0 { - break - } - buf.Write(c[:n]) - if c[0] == '\n' || c[0] == '>' { - Debug.Log(buf) - buf = bytes.NewBuffer([]uint8(prefix)) - } - } - leftover := buf.String() - if leftover != "" { - Debug.Log(buf) +// If there's a buffered error in the channel, return it. Otherwise, +// return what was passed to us. The idea is that the error in the +// channel probably preceded (and caused) the one that's passed as an +// argument here. +func (cl *Client) getError(err1 error) error { + select { + case err0 := <-cl.error: + return err0 + default: + return err1 } } + +// Register an error that happened in the internals somewhere. If +// there's already an error in the channel, discard the newer one in +// favor of the older. +func (cl *Client) setError(err error) { + defer cl.Close() + defer cl.setStatus(StatusError) + + if len(cl.error) > 0 { + return + } + // If we're in a race between two calls to this function, + // trying to set the "first" error, just arbitrarily let one + // of them win. + select { + case cl.error <- err: + default: + } +}