Added TLS negotiation.
--- a/Makefile Mon Dec 26 14:36:41 2011 -0700
+++ b/Makefile Mon Dec 26 18:07:14 2011 -0700
@@ -7,6 +7,7 @@
TARG=cjyar/xmpp
GOFILES=\
xmpp.go \
+ stream.go \
structs.go \
include $(GOROOT)/src/Make.pkg
--- /dev/null Thu Jan 01 00:00:00 1970 +0000
+++ b/stream.go Mon Dec 26 18:07:14 2011 -0700
@@ -0,0 +1,241 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This file contains the three layers of processing for the
+// communication with the server: transport (where TLS happens), XML
+// (where strings are converted to go structures), and Stream (where
+// we respond to XMPP events on behalf of the library client).
+
+package xmpp
+
+import (
+ "crypto/tls"
+ "io"
+ "log"
+ "net"
+ "os"
+ "time"
+ "xml"
+)
+
+func (cl *Client) readTransport(w io.Writer) {
+ defer tryClose(cl.socket, w)
+ cl.socket.SetReadTimeout(1e8)
+ p := make([]byte, 1024)
+ for {
+ if cl.socket == nil {
+ cl.waitForSocket()
+ }
+ nr, err := cl.socket.Read(p)
+ if nr == 0 {
+ if errno, ok := err.(*net.OpError) ; ok {
+ if errno.Timeout() {
+ continue
+ }
+ }
+ log.Printf("read: %s", err.String())
+ break
+ }
+ nw, err := w.Write(p[:nr])
+ if nw < nr {
+ log.Println("read: %s", err.String())
+ break
+ }
+ }
+}
+
+func (cl *Client) writeTransport(r io.Reader) {
+ defer tryClose(r, cl.socket)
+ p := make([]byte, 1024)
+ for {
+ nr, err := r.Read(p)
+ if nr == 0 {
+ log.Printf("write: %s", err.String())
+ break
+ }
+ nw, err := cl.socket.Write(p[:nr])
+ if nw < nr {
+ log.Println("write: %s", err.String())
+ break
+ }
+ }
+}
+
+func readXml(r io.Reader, ch chan<- interface{}) {
+ if debug {
+ pr, pw := io.Pipe()
+ go tee(r, pw, "S: ")
+ r = pr
+ }
+ defer tryClose(r, ch)
+
+ p := xml.NewParser(r)
+ for {
+ // Sniff the next token on the stream.
+ t, err := p.Token()
+ if t == nil {
+ if err != os.EOF {
+ log.Printf("read: %v", err)
+ }
+ break
+ }
+ var se xml.StartElement
+ var ok bool
+ if se, ok = t.(xml.StartElement) ; !ok {
+ continue
+ }
+
+ // Allocate the appropriate structure for this token.
+ var obj interface{}
+ switch se.Name.Space + " " + se.Name.Local {
+ case nsStream + " stream":
+ st, err := parseStream(se)
+ if err != nil {
+ log.Printf("unmarshal stream: %v",
+ err)
+ break
+ }
+ ch <- st
+ continue
+ case "stream error", nsStream + " error":
+ obj = &StreamError{}
+ case nsStream + " features":
+ obj = &Features{}
+ case nsTLS + " proceed", nsTLS + " failure":
+ obj = &starttls{}
+ default:
+ obj = &Unrecognized{}
+ log.Printf("Ignoring unrecognized: %s %s\n",
+ se.Name.Space, se.Name.Local)
+ }
+
+ // Read the complete XML stanza.
+ err = p.Unmarshal(obj, &se)
+ if err != nil {
+ log.Printf("unmarshal: %v", err)
+ break
+ }
+
+ // Put it on the channel.
+ ch <- obj
+ }
+}
+
+func writeXml(w io.Writer, ch <-chan interface{}) {
+ if debug {
+ pr, pw := io.Pipe()
+ go tee(pr, w, "C: ")
+ w = pw
+ }
+ defer tryClose(w, ch)
+
+ for obj := range ch {
+ err := xml.Marshal(w, obj)
+ if err != nil {
+ log.Printf("write: %v", err)
+ break
+ }
+ }
+}
+
+func writeText(w io.Writer, ch <-chan *string) {
+ if debug {
+ pr, pw := io.Pipe()
+ go tee(pr, w, "C: ")
+ w = pw
+ }
+ defer tryClose(w, ch)
+
+ for str := range ch {
+ _, err := w.Write([]byte(*str))
+ if err != nil {
+ log.Printf("writeStr: %v", err)
+ break
+ }
+ }
+}
+
+func (cl *Client) readStream(srvIn <-chan interface{}, srvOut, cliOut chan<- interface{}) {
+ defer tryClose(srvIn, cliOut)
+
+ for x := range srvIn {
+ switch obj := x.(type) {
+ case *Stream:
+ handleStream(obj)
+ case *Features:
+ handleFeatures(obj, srvOut)
+ case *starttls:
+ cl.handleTls(obj)
+ default:
+ cliOut <- x
+ }
+ }
+}
+
+func writeStream(srvOut chan<- interface{}, cliIn <-chan interface{}) {
+ defer tryClose(srvOut, cliIn)
+
+ for x := range cliIn {
+ srvOut <- x
+ }
+}
+
+func handleStream(ss *Stream) {
+}
+
+func handleFeatures(fe *Features, srvOut chan<- interface{}) {
+ if fe.Starttls != nil {
+ start := &starttls{XMLName: xml.Name{Space: nsTLS,
+ Local: "starttls"}}
+ srvOut <- start
+ }
+}
+
+// readTransport() is running concurrently. We need to stop it,
+// negotiate TLS, then start it again. It calls waitForSocket() in
+// its inner loop; see below.
+func (cl *Client) handleTls(t *starttls) {
+ tcp := cl.socket
+
+ // Set the socket to nil, and wait for the reader routine to
+ // signal that it's paused.
+ cl.socket = nil
+ cl.socketSync.Add(1)
+ cl.socketSync.Wait()
+
+ // Negotiate TLS with the server.
+ tls := tls.Client(tcp, nil)
+
+ // Make the TLS connection available to the reader, and wait
+ // for it to signal that it's working again.
+ cl.socketSync.Add(1)
+ cl.socket = tls
+ cl.socketSync.Wait()
+
+ // Reset the read timeout on the (underlying) socket so the
+ // reader doesn't get woken up unnecessarily.
+ tcp.SetReadTimeout(0)
+
+ log.Println("TLS negotiation succeeded.")
+
+ // Now re-send the initial handshake message to start the new
+ // session.
+ hsOut := &Stream{To: cl.Jid.Domain, Version: Version}
+ cl.xmlOut <- hsOut
+}
+
+// Synchronize with handleTls(). Called from readTransport() when
+// cl.socket is nil.
+func (cl *Client) waitForSocket() {
+ // Signal that we've stopped reading from the socket.
+ cl.socketSync.Done()
+
+ // Wait until the socket is available again.
+ for cl.socket == nil {
+ time.Sleep(1e8)
+ }
+
+ // Signal that we're going back to the read loop.
+ cl.socketSync.Done()
+}
--- a/structs.go Mon Dec 26 14:36:41 2011 -0700
+++ b/structs.go Mon Dec 26 18:07:14 2011 -0700
@@ -17,14 +17,6 @@
"xml"
)
-const (
- // Version of RFC 3920 that we implement.
- Version = "1.0"
- nsStreams = "urn:ietf:params:xml:ns:xmpp-streams"
- nsStream = "http://etherx.jabber.org/streams"
- nsTLS = "urn:ietf:params:xml:ns:xmpp-tls"
-)
-
// JID represents an entity that can communicate with other
// entities. It looks like node@domain/resource. Node and resource are
// sometimes optional.
@@ -67,11 +59,12 @@
var _ xml.Marshaler = &errText{}
type Features struct {
- Starttls starttls
+ Starttls *starttls
Mechanisms mechs
}
type starttls struct {
+ XMLName xml.Name
required *string
}
--- a/xmpp.go Mon Dec 26 14:36:41 2011 -0700
+++ b/xmpp.go Mon Dec 26 18:07:14 2011 -0700
@@ -10,22 +10,35 @@
"bytes"
"fmt"
"io"
- "log"
"net"
"os"
- "xml"
+ "sync"
)
const (
+ // Version of RFC 3920 that we implement.
+ Version = "1.0"
+
+ // Various XML namespaces.
+ nsStreams = "urn:ietf:params:xml:ns:xmpp-streams"
+ nsStream = "http://etherx.jabber.org/streams"
+ nsTLS = "urn:ietf:params:xml:ns:xmpp-tls"
+
+ // DNS SRV names
serverSrv = "xmpp-server"
clientSrv = "xmpp-client"
- debug = false
+
+ debug = true
)
// The client in a client-server XMPP connection.
type Client struct {
+ Jid JID
+ socket net.Conn
+ socketSync sync.WaitGroup
In <-chan interface{}
Out chan<- interface{}
+ xmlOut chan<- interface{}
TextOut chan<- *string
}
var _ io.Closer = &Client{}
@@ -60,27 +73,29 @@
return nil, err
}
+ cl := new(Client)
+ cl.Jid = *jid
+ cl.socket = tcp
+
// Start the transport handler, initially unencrypted.
- tlsr, tlsw := startTransport(tcp)
+ tlsr, tlsw := cl.startTransport()
// Start the reader and writers that convert to and from XML.
xmlIn := startXmlReader(tlsr)
- xmlOut := startXmlWriter(tlsw)
+ cl.xmlOut = startXmlWriter(tlsw)
textOut := startTextWriter(tlsw)
// Start the XMPP stream handler which filters stream-level
// events and responds to them.
- clIn := startStreamReader(xmlIn)
- clOut := startStreamWriter(xmlOut)
+ clIn := cl.startStreamReader(xmlIn, cl.xmlOut)
+ clOut := startStreamWriter(cl.xmlOut)
// Initial handshake.
hsOut := &Stream{To: jid.Domain, Version: Version}
- xmlOut <- hsOut
+ cl.xmlOut <- hsOut
// TODO Wait for initialization to finish.
- // Make the Client and init its fields.
- cl := new(Client)
cl.In = clIn
cl.Out = clOut
cl.TextOut = textOut
@@ -93,27 +108,11 @@
return nil
}
-func startTransport(tcp io.ReadWriter) (io.Reader, io.Writer) {
- f := func(r io.Reader, w io.Writer, dir string) {
- defer tryClose(r, w)
- p := make([]byte, 1024)
- for {
- nr, err := r.Read(p)
- if nr == 0 {
- log.Printf("%s: %s", dir, err.String())
- break
- }
- nw, err := w.Write(p[:nr])
- if nw < nr {
- log.Println("%s: %s", dir, err.String())
- break
- }
- }
- }
+func (cl *Client) startTransport() (io.Reader, io.Writer) {
inr, inw := io.Pipe()
outr, outw := io.Pipe()
- go f(tcp, inw, "read")
- go f(outr, tcp, "write")
+ go cl.readTransport(inw)
+ go cl.writeTransport(outr)
return inr, outw
}
@@ -135,9 +134,9 @@
return ch
}
-func startStreamReader(xmlIn <-chan interface{}) <-chan interface{} {
+func (cl *Client) startStreamReader(xmlIn <-chan interface{}, srvOut chan<- interface{}) <-chan interface{} {
ch := make(chan interface{})
- go readStream(xmlIn, ch)
+ go cl.readStream(xmlIn, srvOut, ch)
return ch
}
@@ -147,114 +146,6 @@
return ch
}
-func readXml(r io.Reader, ch chan<- interface{}) {
- if debug {
- pr, pw := io.Pipe()
- go tee(r, pw, "S: ")
- r = pr
- }
- defer tryClose(r, ch)
-
- p := xml.NewParser(r)
- for {
- // Sniff the next token on the stream.
- t, err := p.Token()
- if t == nil {
- if err != os.EOF {
- log.Printf("read: %v", err)
- }
- break
- }
- var se xml.StartElement
- var ok bool
- if se, ok = t.(xml.StartElement) ; !ok {
- continue
- }
-
- // Allocate the appropriate structure for this token.
- var obj interface{}
- switch se.Name.Space + " " + se.Name.Local {
- case nsStream + " stream":
- st, err := parseStream(se)
- if err != nil {
- log.Printf("unmarshal stream: %v",
- err)
- break
- }
- ch <- st
- continue
- case "stream error", nsStream + " error":
- obj = &StreamError{}
- case nsStream + " features":
- obj = &Features{}
- default:
- obj = &Unrecognized{}
- log.Printf("Ignoring unrecognized: %s %s\n",
- se.Name.Space, se.Name.Local)
- }
-
- // Read the complete XML stanza.
- err = p.Unmarshal(obj, &se)
- if err != nil {
- log.Printf("unmarshal: %v", err)
- break
- }
-
- // Put it on the channel.
- ch <- obj
- }
-}
-
-func writeXml(w io.Writer, ch <-chan interface{}) {
- if debug {
- pr, pw := io.Pipe()
- go tee(pr, w, "C: ")
- w = pw
- }
- defer tryClose(w, ch)
-
- for obj := range ch {
- err := xml.Marshal(w, obj)
- if err != nil {
- log.Printf("write: %v", err)
- break
- }
- }
-}
-
-func writeText(w io.Writer, ch <-chan *string) {
- if debug {
- pr, pw := io.Pipe()
- go tee(pr, w, "C: ")
- w = pw
- }
- defer tryClose(w, ch)
-
- for str := range ch {
- _, err := w.Write([]byte(*str))
- if err != nil {
- log.Printf("writeStr: %v", err)
- break
- }
- }
-}
-
-func readStream(srvIn <-chan interface{}, cliOut chan<- interface{}) {
- defer tryClose(srvIn, cliOut)
-
- for x := range srvIn {
- cliOut <- x
- }
-}
-
-func writeStream(srvOut chan<- interface{}, cliIn <-chan interface{}) {
- defer tryClose(srvOut, cliIn)
-
- for x := range cliIn {
- srvOut <- x
- }
-}
-
func tee(r io.Reader, w io.Writer, prefix string) {
defer tryClose(r, w)