xmpp/xmpp.go
changeset 183 b4bd77d58a3e
parent 180 3010996c1928
child 185 ba8a4ae40e13
--- a/xmpp/xmpp.go	Sun Feb 09 09:50:38 2014 -0700
+++ b/xmpp/xmpp.go	Sun Feb 09 09:52:28 2014 -0700
@@ -7,14 +7,13 @@
 package xmpp
 
 import (
-	"bytes"
 	"crypto/tls"
 	"encoding/xml"
-	"errors"
 	"fmt"
 	"io"
 	"net"
 	"reflect"
+	"sync"
 )
 
 const (
@@ -36,38 +35,6 @@
 	clientSrv = "xmpp-client"
 )
 
-// Status of the connection.
-type Status int
-
-const (
-	statusUnconnected = iota
-	statusConnected
-	statusConnectedTls
-	statusAuthenticated
-	statusBound
-	statusRunning
-	statusShutdown
-)
-
-var (
-	// The client has not yet connected, or it has been
-	// disconnected from the server.
-	StatusUnconnected Status = statusUnconnected
-	// Initial connection established.
-	StatusConnected Status = statusConnected
-	// Like StatusConnected, but with TLS.
-	StatusConnectedTls Status = statusConnectedTls
-	// Authentication succeeded.
-	StatusAuthenticated Status = statusAuthenticated
-	// Resource binding complete.
-	StatusBound Status = statusBound
-	// Session has started and normal message traffic can be sent
-	// and received.
-	StatusRunning Status = statusRunning
-	// The session has closed, or is in the process of closing.
-	StatusShutdown Status = statusShutdown
-)
-
 // A filter can modify the XMPP traffic to or from the remote
 // server. It's part of an Extension. The filter function will be
 // called in a new goroutine, so it doesn't need to return. The filter
@@ -76,10 +43,9 @@
 
 // Extensions can add stanza filters and/or new XML element types.
 type Extension struct {
-	// Maps from an XML namespace to a function which constructs a
-	// structure to hold the contents of stanzas in that
-	// namespace.
-	StanzaHandlers map[xml.Name]reflect.Type
+	// Maps from an XML name to a structure which holds stanza
+	// contents with that name.
+	StanzaTypes map[xml.Name]reflect.Type
 	// If non-nil, will be called once to start the filter
 	// running. RecvFilter intercepts incoming messages on their
 	// way from the remote server to the application; SendFilter
@@ -101,9 +67,10 @@
 	// set up the XMPP stream will not appear here.
 	Recv <-chan Stanza
 	// Outgoing XMPP stanzas to the server should be sent to this
-	// channel.
+	// channel. The application should not close this channel;
+	// rather, call Close().
 	Send    chan<- Stanza
-	sendXml chan<- interface{}
+	sendRaw chan<- interface{}
 	statmgr *statmgr
 	// The client's roster is also known as the buddy list. It's
 	// the set of contacts which are known to this JID, or which
@@ -114,6 +81,8 @@
 	sendFilterAdd, recvFilterAdd chan Filter
 	tlsConfig                    tls.Config
 	layer1                       *layer1
+	error                        chan error
+	shutdownOnce                 sync.Once
 }
 
 // Creates an XMPP client identified by the given JID, authenticating
@@ -123,34 +92,8 @@
 func NewClient(jid *JID, password string, tlsconf tls.Config, exts []Extension,
 	pr Presence, status chan<- Status) (*Client, error) {
 
-	// Include the mandatory extensions.
-	roster := newRosterExt()
-	exts = append(exts, roster.Extension)
-	exts = append(exts, bindExt)
-
-	cl := new(Client)
-	cl.Roster = *roster
-	cl.password = password
-	cl.Jid = *jid
-	cl.handlers = make(chan *callback, 100)
-	cl.tlsConfig = tlsconf
-	cl.sendFilterAdd = make(chan Filter)
-	cl.recvFilterAdd = make(chan Filter)
-	cl.statmgr = newStatmgr(status)
-
-	extStanza := make(map[xml.Name]reflect.Type)
-	for _, ext := range exts {
-		for k, v := range ext.StanzaHandlers {
-			if _, ok := extStanza[k]; ok {
-				return nil, fmt.Errorf("duplicate handler %s",
-					k)
-			}
-			extStanza[k] = v
-		}
-	}
-
 	// Resolve the domain in the JID.
-	_, srvs, err := net.LookupSRV(clientSrv, "tcp", jid.Domain)
+	_, srvs, err := net.LookupSRV(clientSrv, "tcp", jid.Domain())
 	if err != nil {
 		return nil, fmt.Errorf("LookupSrv %s: %v", jid.Domain, err)
 	}
@@ -176,20 +119,75 @@
 	if tcp == nil {
 		return nil, err
 	}
+
+	return newClient(tcp, jid, password, tlsconf, exts, pr, status)
+}
+
+// Connect to the specified host and port. This is otherwise identical
+// to NewClient.
+func NewClientFromHost(jid *JID, password string, tlsconf tls.Config,
+	exts []Extension, pr Presence, status chan<- Status, host string,
+	port int) (*Client, error) {
+
+	addrStr := fmt.Sprintf("%s:%d", host, port)
+	addr, err := net.ResolveTCPAddr("tcp", addrStr)
+	if err != nil {
+		return nil, err
+	}
+	tcp, err := net.DialTCP("tcp", nil, addr)
+	if err != nil {
+		return nil, err
+	}
+
+	return newClient(tcp, jid, password, tlsconf, exts, pr, status)
+}
+
+func newClient(tcp *net.TCPConn, jid *JID, password string, tlsconf tls.Config,
+	exts []Extension, pr Presence, status chan<- Status) (*Client, error) {
+
+	// Include the mandatory extensions.
+	roster := newRosterExt()
+	exts = append(exts, roster.Extension)
+	exts = append(exts, bindExt)
+
+	cl := new(Client)
+	cl.Roster = *roster
+	cl.password = password
+	cl.Jid = *jid
+	cl.handlers = make(chan *callback, 100)
+	cl.tlsConfig = tlsconf
+	cl.sendFilterAdd = make(chan Filter)
+	cl.recvFilterAdd = make(chan Filter)
+	cl.statmgr = newStatmgr(status)
+	cl.error = make(chan error, 1)
+
+	extStanza := make(map[xml.Name]reflect.Type)
+	for _, ext := range exts {
+		for k, v := range ext.StanzaTypes {
+			if _, ok := extStanza[k]; ok {
+				return nil, fmt.Errorf("duplicate handler %s",
+					k)
+			}
+			extStanza[k] = v
+		}
+	}
+
+	// The thing that called this made a TCP connection, so now we
+	// can signal that it's connected.
 	cl.setStatus(StatusConnected)
 
 	// Start the transport handler, initially unencrypted.
 	recvReader, recvWriter := io.Pipe()
 	sendReader, sendWriter := io.Pipe()
-	cl.layer1 = startLayer1(tcp, recvWriter, sendReader,
+	cl.layer1 = cl.startLayer1(tcp, recvWriter, sendReader,
 		cl.statmgr.newListener())
 
 	// Start the reader and writer that convert to and from XML.
 	recvXmlCh := make(chan interface{})
-	go recvXml(recvReader, recvXmlCh, extStanza)
+	go cl.recvXml(recvReader, recvXmlCh, extStanza)
 	sendXmlCh := make(chan interface{})
-	cl.sendXml = sendXmlCh
-	go sendXml(sendWriter, sendXmlCh)
+	cl.sendRaw = sendXmlCh
+	go cl.sendXml(sendWriter, sendXmlCh)
 
 	// Start the reader and writer that convert between XML and
 	// XMPP stanzas.
@@ -213,36 +211,37 @@
 	}
 
 	// Initial handshake.
-	hsOut := &stream{To: jid.Domain, Version: XMPPVersion}
-	cl.sendXml <- hsOut
+	hsOut := &stream{To: jid.Domain(), Version: XMPPVersion}
+	cl.sendRaw <- hsOut
 
 	// Wait until resource binding is complete.
 	if err := cl.statmgr.awaitStatus(StatusBound); err != nil {
-		return nil, err
+		return nil, cl.getError(err)
 	}
 
+	// Forget about the password, for paranoia's sake.
+	cl.password = ""
+
 	// Initialize the session.
 	id := NextId()
-	iq := &Iq{Header: Header{To: cl.Jid.Domain, Id: id, Type: "set",
+	iq := &Iq{Header: Header{To: JID(cl.Jid.Domain()), Id: id, Type: "set",
 		Nested: []interface{}{Generic{XMLName: xml.Name{Space: NsSession, Local: "session"}}}}}
 	ch := make(chan error)
 	f := func(st Stanza) {
 		iq, ok := st.(*Iq)
 		if !ok {
-			Warn.Log("iq reply not iq; can't start session")
-			ch <- errors.New("bad session start reply")
+			ch <- fmt.Errorf("bad session start reply: %#v", st)
 		}
 		if iq.Type == "error" {
-			Warn.Logf("Can't start session: %v", iq)
-			ch <- iq.Error
+			ch <- fmt.Errorf("Can't start session: %v", iq.Error)
 		}
 		ch <- nil
 	}
 	cl.SetCallback(id, f)
-	cl.sendXml <- iq
+	cl.sendRaw <- iq
 	// Now wait until the callback is called.
 	if err := <-ch; err != nil {
-		return nil, err
+		return nil, cl.getError(err)
 	}
 
 	// This allows the client to receive stanzas.
@@ -254,35 +253,45 @@
 	// Send the initial presence.
 	cl.Send <- &pr
 
-	return cl, nil
+	return cl, cl.getError(nil)
+}
+
+func (cl *Client) Close() {
+	// Shuts down the receivers:
+	cl.setStatus(StatusShutdown)
+
+	// Shuts down the senders:
+	cl.shutdownOnce.Do(func() { close(cl.Send) })
 }
 
-func tee(r io.Reader, w io.Writer, prefix string) {
-	defer func(w io.Writer) {
-		if c, ok := w.(io.Closer); ok {
-			c.Close()
-		}
-	}(w)
-
-	buf := bytes.NewBuffer([]uint8(prefix))
-	for {
-		var c [1]byte
-		n, _ := r.Read(c[:])
-		if n == 0 {
-			break
-		}
-		n, _ = w.Write(c[:n])
-		if n == 0 {
-			break
-		}
-		buf.Write(c[:n])
-		if c[0] == '\n' || c[0] == '>' {
-			Debug.Log(buf)
-			buf = bytes.NewBuffer([]uint8(prefix))
-		}
-	}
-	leftover := buf.String()
-	if leftover != "" {
-		Debug.Log(buf)
+// If there's a buffered error in the channel, return it. Otherwise,
+// return what was passed to us. The idea is that the error in the
+// channel probably preceded (and caused) the one that's passed as an
+// argument here.
+func (cl *Client) getError(err1 error) error {
+	select {
+	case err0 := <-cl.error:
+		return err0
+	default:
+		return err1
 	}
 }
+
+// Register an error that happened in the internals somewhere. If
+// there's already an error in the channel, discard the newer one in
+// favor of the older.
+func (cl *Client) setError(err error) {
+	defer cl.Close()
+	defer cl.setStatus(StatusError)
+
+	if len(cl.error) > 0 {
+		return
+	}
+	// If we're in a race between two calls to this function,
+	// trying to set the "first" error, just arbitrarily let one
+	// of them win.
+	select {
+	case cl.error <- err:
+	default:
+	}
+}