stream.go
changeset 10 f38b0ee7b1c1
child 11 48be1ae93fd4
equal deleted inserted replaced
9:4fe926b03827 10:f38b0ee7b1c1
       
     1 // Copyright 2011 The Go Authors.  All rights reserved.
       
     2 // Use of this source code is governed by a BSD-style
       
     3 // license that can be found in the LICENSE file.
       
     4 
       
     5 // This file contains the three layers of processing for the
       
     6 // communication with the server: transport (where TLS happens), XML
       
     7 // (where strings are converted to go structures), and Stream (where
       
     8 // we respond to XMPP events on behalf of the library client).
       
     9 
       
    10 package xmpp
       
    11 
       
    12 import (
       
    13 	"crypto/tls"
       
    14 	"io"
       
    15 	"log"
       
    16 	"net"
       
    17 	"os"
       
    18 	"time"
       
    19 	"xml"
       
    20 )
       
    21 
       
    22 func (cl *Client) readTransport(w io.Writer) {
       
    23 	defer tryClose(cl.socket, w)
       
    24 	cl.socket.SetReadTimeout(1e8)
       
    25 	p := make([]byte, 1024)
       
    26 	for {
       
    27 		if cl.socket == nil {
       
    28 			cl.waitForSocket()
       
    29 		}
       
    30 		nr, err := cl.socket.Read(p)
       
    31 		if nr == 0 {
       
    32 			if errno, ok := err.(*net.OpError) ; ok {
       
    33 				if errno.Timeout() {
       
    34 					continue
       
    35 				}
       
    36 			}
       
    37 			log.Printf("read: %s", err.String())
       
    38 			break
       
    39 		}
       
    40 		nw, err := w.Write(p[:nr])
       
    41 		if nw < nr {
       
    42 			log.Println("read: %s", err.String())
       
    43 			break
       
    44 		}
       
    45 	}
       
    46 }
       
    47 
       
    48 func (cl *Client) writeTransport(r io.Reader) {
       
    49 	defer tryClose(r, cl.socket)
       
    50 	p := make([]byte, 1024)
       
    51 	for {
       
    52 		nr, err := r.Read(p)
       
    53 		if nr == 0 {
       
    54 			log.Printf("write: %s", err.String())
       
    55 			break
       
    56 		}
       
    57 		nw, err := cl.socket.Write(p[:nr])
       
    58 		if nw < nr {
       
    59 			log.Println("write: %s", err.String())
       
    60 			break
       
    61 		}
       
    62 	}
       
    63 }
       
    64 
       
    65 func readXml(r io.Reader, ch chan<- interface{}) {
       
    66 	if debug {
       
    67 		pr, pw := io.Pipe()
       
    68 		go tee(r, pw, "S: ")
       
    69 		r = pr
       
    70 	}
       
    71 	defer tryClose(r, ch)
       
    72 
       
    73 	p := xml.NewParser(r)
       
    74 	for {
       
    75 		// Sniff the next token on the stream.
       
    76 		t, err := p.Token()
       
    77 		if t == nil {
       
    78 			if err != os.EOF {
       
    79 				log.Printf("read: %v", err)
       
    80 			}
       
    81 			break
       
    82 		}
       
    83 		var se xml.StartElement
       
    84 		var ok bool
       
    85 		if se, ok = t.(xml.StartElement) ; !ok {
       
    86 			continue
       
    87 		}
       
    88 
       
    89 		// Allocate the appropriate structure for this token.
       
    90 		var obj interface{}
       
    91 		switch se.Name.Space + " " + se.Name.Local {
       
    92 		case nsStream + " stream":
       
    93 			st, err := parseStream(se)
       
    94 			if err != nil {
       
    95 				log.Printf("unmarshal stream: %v",
       
    96 					err)
       
    97 				break
       
    98 			}
       
    99 			ch <- st
       
   100 			continue
       
   101 		case "stream error", nsStream + " error":
       
   102 			obj = &StreamError{}
       
   103 		case nsStream + " features":
       
   104 			obj = &Features{}
       
   105 		case nsTLS + " proceed", nsTLS + " failure":
       
   106 			obj = &starttls{}
       
   107 		default:
       
   108 			obj = &Unrecognized{}
       
   109 			log.Printf("Ignoring unrecognized: %s %s\n",
       
   110 				se.Name.Space, se.Name.Local)
       
   111 		}
       
   112 
       
   113 		// Read the complete XML stanza.
       
   114 		err = p.Unmarshal(obj, &se)
       
   115 		if err != nil {
       
   116 			log.Printf("unmarshal: %v", err)
       
   117 			break
       
   118 		}
       
   119 
       
   120 		// Put it on the channel.
       
   121 		ch <- obj
       
   122 	}
       
   123 }
       
   124 
       
   125 func writeXml(w io.Writer, ch <-chan interface{}) {
       
   126 	if debug {
       
   127 		pr, pw := io.Pipe()
       
   128 		go tee(pr, w, "C: ")
       
   129 		w = pw
       
   130 	}
       
   131 	defer tryClose(w, ch)
       
   132 
       
   133 	for obj := range ch {
       
   134 		err := xml.Marshal(w, obj)
       
   135 		if err != nil {
       
   136 			log.Printf("write: %v", err)
       
   137 			break
       
   138 		}
       
   139 	}
       
   140 }
       
   141 
       
   142 func writeText(w io.Writer, ch <-chan *string) {
       
   143 	if debug {
       
   144 		pr, pw := io.Pipe()
       
   145 		go tee(pr, w, "C: ")
       
   146 		w = pw
       
   147 	}
       
   148 	defer tryClose(w, ch)
       
   149 
       
   150 	for str := range ch {
       
   151 		_, err := w.Write([]byte(*str))
       
   152 		if err != nil {
       
   153 			log.Printf("writeStr: %v", err)
       
   154 			break
       
   155 		}
       
   156 	}
       
   157 }
       
   158 
       
   159 func (cl *Client) readStream(srvIn <-chan interface{}, srvOut, cliOut chan<- interface{}) {
       
   160 	defer tryClose(srvIn, cliOut)
       
   161 
       
   162 	for x := range srvIn {
       
   163 		switch obj := x.(type) {
       
   164 		case *Stream:
       
   165 			handleStream(obj)
       
   166 		case *Features:
       
   167 			handleFeatures(obj, srvOut)
       
   168 		case *starttls:
       
   169 			cl.handleTls(obj)
       
   170 		default:
       
   171 			cliOut <- x
       
   172 		}
       
   173 	}
       
   174 }
       
   175 
       
   176 func writeStream(srvOut chan<- interface{}, cliIn <-chan interface{}) {
       
   177 	defer tryClose(srvOut, cliIn)
       
   178 
       
   179 	for x := range cliIn {
       
   180 		srvOut <- x
       
   181 	}
       
   182 }
       
   183 
       
   184 func handleStream(ss *Stream) {
       
   185 }
       
   186 
       
   187 func handleFeatures(fe *Features, srvOut chan<- interface{}) {
       
   188 	if fe.Starttls != nil {
       
   189 		start := &starttls{XMLName: xml.Name{Space: nsTLS,
       
   190 			Local: "starttls"}}
       
   191 		srvOut <- start
       
   192 	}
       
   193 }
       
   194 
       
   195 // readTransport() is running concurrently. We need to stop it,
       
   196 // negotiate TLS, then start it again. It calls waitForSocket() in
       
   197 // its inner loop; see below.
       
   198 func (cl *Client) handleTls(t *starttls) {
       
   199 	tcp := cl.socket
       
   200 
       
   201 	// Set the socket to nil, and wait for the reader routine to
       
   202 	// signal that it's paused.
       
   203 	cl.socket = nil
       
   204 	cl.socketSync.Add(1)
       
   205 	cl.socketSync.Wait()
       
   206 
       
   207 	// Negotiate TLS with the server.
       
   208 	tls := tls.Client(tcp, nil)
       
   209 
       
   210 	// Make the TLS connection available to the reader, and wait
       
   211 	// for it to signal that it's working again.
       
   212 	cl.socketSync.Add(1)
       
   213 	cl.socket = tls
       
   214 	cl.socketSync.Wait()
       
   215 
       
   216 	// Reset the read timeout on the (underlying) socket so the
       
   217 	// reader doesn't get woken up unnecessarily.
       
   218 	tcp.SetReadTimeout(0)
       
   219 
       
   220 	log.Println("TLS negotiation succeeded.")
       
   221 
       
   222 	// Now re-send the initial handshake message to start the new
       
   223 	// session.
       
   224 	hsOut := &Stream{To: cl.Jid.Domain, Version: Version}
       
   225 	cl.xmlOut <- hsOut
       
   226 }
       
   227 
       
   228 // Synchronize with handleTls(). Called from readTransport() when
       
   229 // cl.socket is nil.
       
   230 func (cl *Client) waitForSocket() {
       
   231 	// Signal that we've stopped reading from the socket.
       
   232 	cl.socketSync.Done()
       
   233 
       
   234 	// Wait until the socket is available again.
       
   235 	for cl.socket == nil {
       
   236 		time.Sleep(1e8)
       
   237 	}
       
   238 
       
   239 	// Signal that we're going back to the read loop.
       
   240 	cl.socketSync.Done()
       
   241 }