Added TLS negotiation.
authorChris Jones <chris@cjones.org>
Mon, 26 Dec 2011 18:07:14 -0700 (2011-12-27)
changeset 10 f38b0ee7b1c1
parent 9 4fe926b03827
child 11 48be1ae93fd4
Added TLS negotiation.
Makefile
stream.go
structs.go
xmpp.go
--- a/Makefile	Mon Dec 26 14:36:41 2011 -0700
+++ b/Makefile	Mon Dec 26 18:07:14 2011 -0700
@@ -7,6 +7,7 @@
 TARG=cjyar/xmpp
 GOFILES=\
 	xmpp.go \
+	stream.go \
 	structs.go \
 
 include $(GOROOT)/src/Make.pkg
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/stream.go	Mon Dec 26 18:07:14 2011 -0700
@@ -0,0 +1,241 @@
+// Copyright 2011 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file contains the three layers of processing for the
+// communication with the server: transport (where TLS happens), XML
+// (where strings are converted to go structures), and Stream (where
+// we respond to XMPP events on behalf of the library client).
+
+package xmpp
+
+import (
+	"crypto/tls"
+	"io"
+	"log"
+	"net"
+	"os"
+	"time"
+	"xml"
+)
+
+func (cl *Client) readTransport(w io.Writer) {
+	defer tryClose(cl.socket, w)
+	cl.socket.SetReadTimeout(1e8)
+	p := make([]byte, 1024)
+	for {
+		if cl.socket == nil {
+			cl.waitForSocket()
+		}
+		nr, err := cl.socket.Read(p)
+		if nr == 0 {
+			if errno, ok := err.(*net.OpError) ; ok {
+				if errno.Timeout() {
+					continue
+				}
+			}
+			log.Printf("read: %s", err.String())
+			break
+		}
+		nw, err := w.Write(p[:nr])
+		if nw < nr {
+			log.Println("read: %s", err.String())
+			break
+		}
+	}
+}
+
+func (cl *Client) writeTransport(r io.Reader) {
+	defer tryClose(r, cl.socket)
+	p := make([]byte, 1024)
+	for {
+		nr, err := r.Read(p)
+		if nr == 0 {
+			log.Printf("write: %s", err.String())
+			break
+		}
+		nw, err := cl.socket.Write(p[:nr])
+		if nw < nr {
+			log.Println("write: %s", err.String())
+			break
+		}
+	}
+}
+
+func readXml(r io.Reader, ch chan<- interface{}) {
+	if debug {
+		pr, pw := io.Pipe()
+		go tee(r, pw, "S: ")
+		r = pr
+	}
+	defer tryClose(r, ch)
+
+	p := xml.NewParser(r)
+	for {
+		// Sniff the next token on the stream.
+		t, err := p.Token()
+		if t == nil {
+			if err != os.EOF {
+				log.Printf("read: %v", err)
+			}
+			break
+		}
+		var se xml.StartElement
+		var ok bool
+		if se, ok = t.(xml.StartElement) ; !ok {
+			continue
+		}
+
+		// Allocate the appropriate structure for this token.
+		var obj interface{}
+		switch se.Name.Space + " " + se.Name.Local {
+		case nsStream + " stream":
+			st, err := parseStream(se)
+			if err != nil {
+				log.Printf("unmarshal stream: %v",
+					err)
+				break
+			}
+			ch <- st
+			continue
+		case "stream error", nsStream + " error":
+			obj = &StreamError{}
+		case nsStream + " features":
+			obj = &Features{}
+		case nsTLS + " proceed", nsTLS + " failure":
+			obj = &starttls{}
+		default:
+			obj = &Unrecognized{}
+			log.Printf("Ignoring unrecognized: %s %s\n",
+				se.Name.Space, se.Name.Local)
+		}
+
+		// Read the complete XML stanza.
+		err = p.Unmarshal(obj, &se)
+		if err != nil {
+			log.Printf("unmarshal: %v", err)
+			break
+		}
+
+		// Put it on the channel.
+		ch <- obj
+	}
+}
+
+func writeXml(w io.Writer, ch <-chan interface{}) {
+	if debug {
+		pr, pw := io.Pipe()
+		go tee(pr, w, "C: ")
+		w = pw
+	}
+	defer tryClose(w, ch)
+
+	for obj := range ch {
+		err := xml.Marshal(w, obj)
+		if err != nil {
+			log.Printf("write: %v", err)
+			break
+		}
+	}
+}
+
+func writeText(w io.Writer, ch <-chan *string) {
+	if debug {
+		pr, pw := io.Pipe()
+		go tee(pr, w, "C: ")
+		w = pw
+	}
+	defer tryClose(w, ch)
+
+	for str := range ch {
+		_, err := w.Write([]byte(*str))
+		if err != nil {
+			log.Printf("writeStr: %v", err)
+			break
+		}
+	}
+}
+
+func (cl *Client) readStream(srvIn <-chan interface{}, srvOut, cliOut chan<- interface{}) {
+	defer tryClose(srvIn, cliOut)
+
+	for x := range srvIn {
+		switch obj := x.(type) {
+		case *Stream:
+			handleStream(obj)
+		case *Features:
+			handleFeatures(obj, srvOut)
+		case *starttls:
+			cl.handleTls(obj)
+		default:
+			cliOut <- x
+		}
+	}
+}
+
+func writeStream(srvOut chan<- interface{}, cliIn <-chan interface{}) {
+	defer tryClose(srvOut, cliIn)
+
+	for x := range cliIn {
+		srvOut <- x
+	}
+}
+
+func handleStream(ss *Stream) {
+}
+
+func handleFeatures(fe *Features, srvOut chan<- interface{}) {
+	if fe.Starttls != nil {
+		start := &starttls{XMLName: xml.Name{Space: nsTLS,
+			Local: "starttls"}}
+		srvOut <- start
+	}
+}
+
+// readTransport() is running concurrently. We need to stop it,
+// negotiate TLS, then start it again. It calls waitForSocket() in
+// its inner loop; see below.
+func (cl *Client) handleTls(t *starttls) {
+	tcp := cl.socket
+
+	// Set the socket to nil, and wait for the reader routine to
+	// signal that it's paused.
+	cl.socket = nil
+	cl.socketSync.Add(1)
+	cl.socketSync.Wait()
+
+	// Negotiate TLS with the server.
+	tls := tls.Client(tcp, nil)
+
+	// Make the TLS connection available to the reader, and wait
+	// for it to signal that it's working again.
+	cl.socketSync.Add(1)
+	cl.socket = tls
+	cl.socketSync.Wait()
+
+	// Reset the read timeout on the (underlying) socket so the
+	// reader doesn't get woken up unnecessarily.
+	tcp.SetReadTimeout(0)
+
+	log.Println("TLS negotiation succeeded.")
+
+	// Now re-send the initial handshake message to start the new
+	// session.
+	hsOut := &Stream{To: cl.Jid.Domain, Version: Version}
+	cl.xmlOut <- hsOut
+}
+
+// Synchronize with handleTls(). Called from readTransport() when
+// cl.socket is nil.
+func (cl *Client) waitForSocket() {
+	// Signal that we've stopped reading from the socket.
+	cl.socketSync.Done()
+
+	// Wait until the socket is available again.
+	for cl.socket == nil {
+		time.Sleep(1e8)
+	}
+
+	// Signal that we're going back to the read loop.
+	cl.socketSync.Done()
+}
--- a/structs.go	Mon Dec 26 14:36:41 2011 -0700
+++ b/structs.go	Mon Dec 26 18:07:14 2011 -0700
@@ -17,14 +17,6 @@
 	"xml"
 )
 
-const (
-	// Version of RFC 3920 that we implement.
-	Version = "1.0"
-	nsStreams = "urn:ietf:params:xml:ns:xmpp-streams"
-	nsStream = "http://etherx.jabber.org/streams"
-	nsTLS = "urn:ietf:params:xml:ns:xmpp-tls"
-)
-
 // JID represents an entity that can communicate with other
 // entities. It looks like node@domain/resource. Node and resource are
 // sometimes optional.
@@ -67,11 +59,12 @@
 var _ xml.Marshaler = &errText{}
 
 type Features struct {
-	Starttls starttls
+	Starttls *starttls
 	Mechanisms mechs
 }
 
 type starttls struct {
+	XMLName xml.Name
 	required *string
 }
 
--- a/xmpp.go	Mon Dec 26 14:36:41 2011 -0700
+++ b/xmpp.go	Mon Dec 26 18:07:14 2011 -0700
@@ -10,22 +10,35 @@
 	"bytes"
 	"fmt"
 	"io"
-	"log"
 	"net"
 	"os"
-	"xml"
+	"sync"
 )
 
 const (
+	// Version of RFC 3920 that we implement.
+	Version = "1.0"
+
+	// Various XML namespaces.
+	nsStreams = "urn:ietf:params:xml:ns:xmpp-streams"
+	nsStream = "http://etherx.jabber.org/streams"
+	nsTLS = "urn:ietf:params:xml:ns:xmpp-tls"
+
+	// DNS SRV names
 	serverSrv = "xmpp-server"
 	clientSrv = "xmpp-client"
-	debug = false
+
+	debug = true
 )
 
 // The client in a client-server XMPP connection.
 type Client struct {
+	Jid JID
+	socket net.Conn
+	socketSync sync.WaitGroup
 	In <-chan interface{}
 	Out chan<- interface{}
+	xmlOut chan<- interface{}
 	TextOut chan<- *string
 }
 var _ io.Closer = &Client{}
@@ -60,27 +73,29 @@
 		return nil, err
 	}
 
+	cl := new(Client)
+	cl.Jid = *jid
+	cl.socket = tcp
+
 	// Start the transport handler, initially unencrypted.
-	tlsr, tlsw := startTransport(tcp)
+	tlsr, tlsw := cl.startTransport()
 
 	// Start the reader and writers that convert to and from XML.
 	xmlIn := startXmlReader(tlsr)
-	xmlOut := startXmlWriter(tlsw)
+	cl.xmlOut = startXmlWriter(tlsw)
 	textOut := startTextWriter(tlsw)
 
 	// Start the XMPP stream handler which filters stream-level
 	// events and responds to them.
-	clIn := startStreamReader(xmlIn)
-	clOut := startStreamWriter(xmlOut)
+	clIn := cl.startStreamReader(xmlIn, cl.xmlOut)
+	clOut := startStreamWriter(cl.xmlOut)
 
 	// Initial handshake.
 	hsOut := &Stream{To: jid.Domain, Version: Version}
-	xmlOut <- hsOut
+	cl.xmlOut <- hsOut
 
 	// TODO Wait for initialization to finish.
 
-	// Make the Client and init its fields.
-	cl := new(Client)
 	cl.In = clIn
 	cl.Out = clOut
 	cl.TextOut = textOut
@@ -93,27 +108,11 @@
 	return nil
 }
 
-func startTransport(tcp io.ReadWriter) (io.Reader, io.Writer) {
-	f := func(r io.Reader, w io.Writer, dir string) {
-		defer tryClose(r, w)
-		p := make([]byte, 1024)
-		for {
-			nr, err := r.Read(p)
-			if nr == 0 {
-				log.Printf("%s: %s", dir, err.String())
-				break
-			}
-			nw, err := w.Write(p[:nr])
-			if nw < nr {
-				log.Println("%s: %s", dir, err.String())
-				break
-			}
-		}
-	}
+func (cl *Client) startTransport() (io.Reader, io.Writer) {
 	inr, inw := io.Pipe()
 	outr, outw := io.Pipe()
-	go f(tcp, inw, "read")
-	go f(outr, tcp, "write")
+	go cl.readTransport(inw)
+	go cl.writeTransport(outr)
 	return inr, outw
 }
 
@@ -135,9 +134,9 @@
 	return ch
 }
 
-func startStreamReader(xmlIn <-chan interface{}) <-chan interface{} {
+func (cl *Client) startStreamReader(xmlIn <-chan interface{}, srvOut chan<- interface{}) <-chan interface{} {
 	ch := make(chan interface{})
-	go readStream(xmlIn, ch)
+	go cl.readStream(xmlIn, srvOut, ch)
 	return ch
 }
 
@@ -147,114 +146,6 @@
 	return ch
 }
 
-func readXml(r io.Reader, ch chan<- interface{}) {
-	if debug {
-		pr, pw := io.Pipe()
-		go tee(r, pw, "S: ")
-		r = pr
-	}
-	defer tryClose(r, ch)
-
-	p := xml.NewParser(r)
-	for {
-		// Sniff the next token on the stream.
-		t, err := p.Token()
-		if t == nil {
-			if err != os.EOF {
-				log.Printf("read: %v", err)
-			}
-			break
-		}
-		var se xml.StartElement
-		var ok bool
-		if se, ok = t.(xml.StartElement) ; !ok {
-			continue
-		}
-
-		// Allocate the appropriate structure for this token.
-		var obj interface{}
-		switch se.Name.Space + " " + se.Name.Local {
-		case nsStream + " stream":
-			st, err := parseStream(se)
-			if err != nil {
-				log.Printf("unmarshal stream: %v",
-					err)
-				break
-			}
-			ch <- st
-			continue
-		case "stream error", nsStream + " error":
-			obj = &StreamError{}
-		case nsStream + " features":
-			obj = &Features{}
-		default:
-			obj = &Unrecognized{}
-			log.Printf("Ignoring unrecognized: %s %s\n",
-				se.Name.Space, se.Name.Local)
-		}
-
-		// Read the complete XML stanza.
-		err = p.Unmarshal(obj, &se)
-		if err != nil {
-			log.Printf("unmarshal: %v", err)
-			break
-		}
-
-		// Put it on the channel.
-		ch <- obj
-	}
-}
-
-func writeXml(w io.Writer, ch <-chan interface{}) {
-	if debug {
-		pr, pw := io.Pipe()
-		go tee(pr, w, "C: ")
-		w = pw
-	}
-	defer tryClose(w, ch)
-
-	for obj := range ch {
-		err := xml.Marshal(w, obj)
-		if err != nil {
-			log.Printf("write: %v", err)
-			break
-		}
-	}
-}
-
-func writeText(w io.Writer, ch <-chan *string) {
-	if debug {
-		pr, pw := io.Pipe()
-		go tee(pr, w, "C: ")
-		w = pw
-	}
-	defer tryClose(w, ch)
-
-	for str := range ch {
-		_, err := w.Write([]byte(*str))
-		if err != nil {
-			log.Printf("writeStr: %v", err)
-			break
-		}
-	}
-}
-
-func readStream(srvIn <-chan interface{}, cliOut chan<- interface{}) {
-	defer tryClose(srvIn, cliOut)
-
-	for x := range srvIn {
-		cliOut <- x
-	}
-}
-
-func writeStream(srvOut chan<- interface{}, cliIn <-chan interface{}) {
-	defer tryClose(srvOut, cliIn)
-
-	for x := range cliIn {
-		srvOut <- x
-	}
-}
-
 func tee(r io.Reader, w io.Writer, prefix string) {
 	defer tryClose(r, w)