# HG changeset patch # User Chris Jones # Date 1324948034 25200 # Node ID f38b0ee7b1c14f8569973ca369969478b1d54309 # Parent 4fe926b03827d5b7620d67558ada4aa164324abb Added TLS negotiation. diff -r 4fe926b03827 -r f38b0ee7b1c1 Makefile --- a/Makefile Mon Dec 26 14:36:41 2011 -0700 +++ b/Makefile Mon Dec 26 18:07:14 2011 -0700 @@ -7,6 +7,7 @@ TARG=cjyar/xmpp GOFILES=\ xmpp.go \ + stream.go \ structs.go \ include $(GOROOT)/src/Make.pkg diff -r 4fe926b03827 -r f38b0ee7b1c1 stream.go --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/stream.go Mon Dec 26 18:07:14 2011 -0700 @@ -0,0 +1,241 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file contains the three layers of processing for the +// communication with the server: transport (where TLS happens), XML +// (where strings are converted to go structures), and Stream (where +// we respond to XMPP events on behalf of the library client). + +package xmpp + +import ( + "crypto/tls" + "io" + "log" + "net" + "os" + "time" + "xml" +) + +func (cl *Client) readTransport(w io.Writer) { + defer tryClose(cl.socket, w) + cl.socket.SetReadTimeout(1e8) + p := make([]byte, 1024) + for { + if cl.socket == nil { + cl.waitForSocket() + } + nr, err := cl.socket.Read(p) + if nr == 0 { + if errno, ok := err.(*net.OpError) ; ok { + if errno.Timeout() { + continue + } + } + log.Printf("read: %s", err.String()) + break + } + nw, err := w.Write(p[:nr]) + if nw < nr { + log.Println("read: %s", err.String()) + break + } + } +} + +func (cl *Client) writeTransport(r io.Reader) { + defer tryClose(r, cl.socket) + p := make([]byte, 1024) + for { + nr, err := r.Read(p) + if nr == 0 { + log.Printf("write: %s", err.String()) + break + } + nw, err := cl.socket.Write(p[:nr]) + if nw < nr { + log.Println("write: %s", err.String()) + break + } + } +} + +func readXml(r io.Reader, ch chan<- interface{}) { + if debug { + pr, pw := io.Pipe() + go tee(r, pw, "S: ") + r = pr + } + defer tryClose(r, ch) + + p := xml.NewParser(r) + for { + // Sniff the next token on the stream. + t, err := p.Token() + if t == nil { + if err != os.EOF { + log.Printf("read: %v", err) + } + break + } + var se xml.StartElement + var ok bool + if se, ok = t.(xml.StartElement) ; !ok { + continue + } + + // Allocate the appropriate structure for this token. + var obj interface{} + switch se.Name.Space + " " + se.Name.Local { + case nsStream + " stream": + st, err := parseStream(se) + if err != nil { + log.Printf("unmarshal stream: %v", + err) + break + } + ch <- st + continue + case "stream error", nsStream + " error": + obj = &StreamError{} + case nsStream + " features": + obj = &Features{} + case nsTLS + " proceed", nsTLS + " failure": + obj = &starttls{} + default: + obj = &Unrecognized{} + log.Printf("Ignoring unrecognized: %s %s\n", + se.Name.Space, se.Name.Local) + } + + // Read the complete XML stanza. + err = p.Unmarshal(obj, &se) + if err != nil { + log.Printf("unmarshal: %v", err) + break + } + + // Put it on the channel. + ch <- obj + } +} + +func writeXml(w io.Writer, ch <-chan interface{}) { + if debug { + pr, pw := io.Pipe() + go tee(pr, w, "C: ") + w = pw + } + defer tryClose(w, ch) + + for obj := range ch { + err := xml.Marshal(w, obj) + if err != nil { + log.Printf("write: %v", err) + break + } + } +} + +func writeText(w io.Writer, ch <-chan *string) { + if debug { + pr, pw := io.Pipe() + go tee(pr, w, "C: ") + w = pw + } + defer tryClose(w, ch) + + for str := range ch { + _, err := w.Write([]byte(*str)) + if err != nil { + log.Printf("writeStr: %v", err) + break + } + } +} + +func (cl *Client) readStream(srvIn <-chan interface{}, srvOut, cliOut chan<- interface{}) { + defer tryClose(srvIn, cliOut) + + for x := range srvIn { + switch obj := x.(type) { + case *Stream: + handleStream(obj) + case *Features: + handleFeatures(obj, srvOut) + case *starttls: + cl.handleTls(obj) + default: + cliOut <- x + } + } +} + +func writeStream(srvOut chan<- interface{}, cliIn <-chan interface{}) { + defer tryClose(srvOut, cliIn) + + for x := range cliIn { + srvOut <- x + } +} + +func handleStream(ss *Stream) { +} + +func handleFeatures(fe *Features, srvOut chan<- interface{}) { + if fe.Starttls != nil { + start := &starttls{XMLName: xml.Name{Space: nsTLS, + Local: "starttls"}} + srvOut <- start + } +} + +// readTransport() is running concurrently. We need to stop it, +// negotiate TLS, then start it again. It calls waitForSocket() in +// its inner loop; see below. +func (cl *Client) handleTls(t *starttls) { + tcp := cl.socket + + // Set the socket to nil, and wait for the reader routine to + // signal that it's paused. + cl.socket = nil + cl.socketSync.Add(1) + cl.socketSync.Wait() + + // Negotiate TLS with the server. + tls := tls.Client(tcp, nil) + + // Make the TLS connection available to the reader, and wait + // for it to signal that it's working again. + cl.socketSync.Add(1) + cl.socket = tls + cl.socketSync.Wait() + + // Reset the read timeout on the (underlying) socket so the + // reader doesn't get woken up unnecessarily. + tcp.SetReadTimeout(0) + + log.Println("TLS negotiation succeeded.") + + // Now re-send the initial handshake message to start the new + // session. + hsOut := &Stream{To: cl.Jid.Domain, Version: Version} + cl.xmlOut <- hsOut +} + +// Synchronize with handleTls(). Called from readTransport() when +// cl.socket is nil. +func (cl *Client) waitForSocket() { + // Signal that we've stopped reading from the socket. + cl.socketSync.Done() + + // Wait until the socket is available again. + for cl.socket == nil { + time.Sleep(1e8) + } + + // Signal that we're going back to the read loop. + cl.socketSync.Done() +} diff -r 4fe926b03827 -r f38b0ee7b1c1 structs.go --- a/structs.go Mon Dec 26 14:36:41 2011 -0700 +++ b/structs.go Mon Dec 26 18:07:14 2011 -0700 @@ -17,14 +17,6 @@ "xml" ) -const ( - // Version of RFC 3920 that we implement. - Version = "1.0" - nsStreams = "urn:ietf:params:xml:ns:xmpp-streams" - nsStream = "http://etherx.jabber.org/streams" - nsTLS = "urn:ietf:params:xml:ns:xmpp-tls" -) - // JID represents an entity that can communicate with other // entities. It looks like node@domain/resource. Node and resource are // sometimes optional. @@ -67,11 +59,12 @@ var _ xml.Marshaler = &errText{} type Features struct { - Starttls starttls + Starttls *starttls Mechanisms mechs } type starttls struct { + XMLName xml.Name required *string } diff -r 4fe926b03827 -r f38b0ee7b1c1 xmpp.go --- a/xmpp.go Mon Dec 26 14:36:41 2011 -0700 +++ b/xmpp.go Mon Dec 26 18:07:14 2011 -0700 @@ -10,22 +10,35 @@ "bytes" "fmt" "io" - "log" "net" "os" - "xml" + "sync" ) const ( + // Version of RFC 3920 that we implement. + Version = "1.0" + + // Various XML namespaces. + nsStreams = "urn:ietf:params:xml:ns:xmpp-streams" + nsStream = "http://etherx.jabber.org/streams" + nsTLS = "urn:ietf:params:xml:ns:xmpp-tls" + + // DNS SRV names serverSrv = "xmpp-server" clientSrv = "xmpp-client" - debug = false + + debug = true ) // The client in a client-server XMPP connection. type Client struct { + Jid JID + socket net.Conn + socketSync sync.WaitGroup In <-chan interface{} Out chan<- interface{} + xmlOut chan<- interface{} TextOut chan<- *string } var _ io.Closer = &Client{} @@ -60,27 +73,29 @@ return nil, err } + cl := new(Client) + cl.Jid = *jid + cl.socket = tcp + // Start the transport handler, initially unencrypted. - tlsr, tlsw := startTransport(tcp) + tlsr, tlsw := cl.startTransport() // Start the reader and writers that convert to and from XML. xmlIn := startXmlReader(tlsr) - xmlOut := startXmlWriter(tlsw) + cl.xmlOut = startXmlWriter(tlsw) textOut := startTextWriter(tlsw) // Start the XMPP stream handler which filters stream-level // events and responds to them. - clIn := startStreamReader(xmlIn) - clOut := startStreamWriter(xmlOut) + clIn := cl.startStreamReader(xmlIn, cl.xmlOut) + clOut := startStreamWriter(cl.xmlOut) // Initial handshake. hsOut := &Stream{To: jid.Domain, Version: Version} - xmlOut <- hsOut + cl.xmlOut <- hsOut // TODO Wait for initialization to finish. - // Make the Client and init its fields. - cl := new(Client) cl.In = clIn cl.Out = clOut cl.TextOut = textOut @@ -93,27 +108,11 @@ return nil } -func startTransport(tcp io.ReadWriter) (io.Reader, io.Writer) { - f := func(r io.Reader, w io.Writer, dir string) { - defer tryClose(r, w) - p := make([]byte, 1024) - for { - nr, err := r.Read(p) - if nr == 0 { - log.Printf("%s: %s", dir, err.String()) - break - } - nw, err := w.Write(p[:nr]) - if nw < nr { - log.Println("%s: %s", dir, err.String()) - break - } - } - } +func (cl *Client) startTransport() (io.Reader, io.Writer) { inr, inw := io.Pipe() outr, outw := io.Pipe() - go f(tcp, inw, "read") - go f(outr, tcp, "write") + go cl.readTransport(inw) + go cl.writeTransport(outr) return inr, outw } @@ -135,9 +134,9 @@ return ch } -func startStreamReader(xmlIn <-chan interface{}) <-chan interface{} { +func (cl *Client) startStreamReader(xmlIn <-chan interface{}, srvOut chan<- interface{}) <-chan interface{} { ch := make(chan interface{}) - go readStream(xmlIn, ch) + go cl.readStream(xmlIn, srvOut, ch) return ch } @@ -147,114 +146,6 @@ return ch } -func readXml(r io.Reader, ch chan<- interface{}) { - if debug { - pr, pw := io.Pipe() - go tee(r, pw, "S: ") - r = pr - } - defer tryClose(r, ch) - - p := xml.NewParser(r) - for { - // Sniff the next token on the stream. - t, err := p.Token() - if t == nil { - if err != os.EOF { - log.Printf("read: %v", err) - } - break - } - var se xml.StartElement - var ok bool - if se, ok = t.(xml.StartElement) ; !ok { - continue - } - - // Allocate the appropriate structure for this token. - var obj interface{} - switch se.Name.Space + " " + se.Name.Local { - case nsStream + " stream": - st, err := parseStream(se) - if err != nil { - log.Printf("unmarshal stream: %v", - err) - break - } - ch <- st - continue - case "stream error", nsStream + " error": - obj = &StreamError{} - case nsStream + " features": - obj = &Features{} - default: - obj = &Unrecognized{} - log.Printf("Ignoring unrecognized: %s %s\n", - se.Name.Space, se.Name.Local) - } - - // Read the complete XML stanza. - err = p.Unmarshal(obj, &se) - if err != nil { - log.Printf("unmarshal: %v", err) - break - } - - // Put it on the channel. - ch <- obj - } -} - -func writeXml(w io.Writer, ch <-chan interface{}) { - if debug { - pr, pw := io.Pipe() - go tee(pr, w, "C: ") - w = pw - } - defer tryClose(w, ch) - - for obj := range ch { - err := xml.Marshal(w, obj) - if err != nil { - log.Printf("write: %v", err) - break - } - } -} - -func writeText(w io.Writer, ch <-chan *string) { - if debug { - pr, pw := io.Pipe() - go tee(pr, w, "C: ") - w = pw - } - defer tryClose(w, ch) - - for str := range ch { - _, err := w.Write([]byte(*str)) - if err != nil { - log.Printf("writeStr: %v", err) - break - } - } -} - -func readStream(srvIn <-chan interface{}, cliOut chan<- interface{}) { - defer tryClose(srvIn, cliOut) - - for x := range srvIn { - cliOut <- x - } -} - -func writeStream(srvOut chan<- interface{}, cliIn <-chan interface{}) { - defer tryClose(srvOut, cliIn) - - for x := range cliIn { - srvOut <- x - } -} - func tee(r io.Reader, w io.Writer, prefix string) { defer tryClose(r, w)