xmpp/layer1.go
changeset 148 b1b4900eee5b
parent 147 d7679d991b17
child 153 bbd4166df95d
--- a/xmpp/layer1.go	Sun Sep 15 13:09:26 2013 -0600
+++ b/xmpp/layer1.go	Sun Sep 15 16:18:20 2013 -0600
@@ -4,50 +4,113 @@
 package xmpp
 
 import (
+	"crypto/tls"
 	"io"
 	"net"
 	"time"
 )
 
-func (cl *Client) recvTransport(w io.WriteCloser) {
+var l1interval = time.Second
+
+type layer1 struct {
+	sock      net.Conn
+	recvSocks chan<- net.Conn
+	sendSocks chan net.Conn
+}
+
+func startLayer1(sock net.Conn, recvWriter io.WriteCloser,
+	sendReader io.ReadCloser) *layer1 {
+	l1 := layer1{sock: sock}
+	recvSocks := make(chan net.Conn)
+	l1.recvSocks = recvSocks
+	sendSocks := make(chan net.Conn, 1)
+	l1.sendSocks = sendSocks
+	go recvTransport(recvSocks, recvWriter)
+	go sendTransport(sendSocks, sendReader)
+	recvSocks <- sock
+	sendSocks <- sock
+	return &l1
+}
+
+func (l1 *layer1) startTls(conf *tls.Config) {
+	sendSockToSender := func(sock net.Conn) {
+		for {
+			select {
+			case <-l1.sendSocks:
+			case l1.sendSocks <- sock:
+				return
+			}
+		}
+	}
+
+	sendSockToSender(nil)
+	l1.recvSocks <- nil
+	l1.sock = tls.Client(l1.sock, conf)
+	sendSockToSender(l1.sock)
+	l1.recvSocks <- l1.sock
+}
+
+func recvTransport(socks <-chan net.Conn, w io.WriteCloser) {
 	defer w.Close()
+	var sock net.Conn
 	p := make([]byte, 1024)
 	for {
-		if cl.socket == nil {
-			cl.waitForSocket()
+		select {
+		case sock = <-socks:
+		default:
 		}
-		cl.socket.SetReadDeadline(time.Now().Add(time.Second))
-		nr, err := cl.socket.Read(p)
-		if nr == 0 {
-			if errno, ok := err.(*net.OpError); ok {
-				if errno.Timeout() {
-					continue
+
+		if sock == nil {
+			time.Sleep(l1interval)
+		} else {
+			sock.SetReadDeadline(time.Now().Add(l1interval))
+			nr, err := sock.Read(p)
+			if nr == 0 {
+				if errno, ok := err.(*net.OpError); ok {
+					if errno.Timeout() {
+						continue
+					}
 				}
+				Warn.Logf("recvTransport: %s", err)
+				break
 			}
-			Warn.Logf("read: %s", err)
-			break
-		}
-		nw, err := w.Write(p[:nr])
-		if nw < nr {
-			Warn.Logf("read: %s", err)
-			break
+			nw, err := w.Write(p[:nr])
+			if nw < nr {
+				Warn.Logf("recvTransport: %s", err)
+				break
+			}
 		}
 	}
 }
 
-func (cl *Client) sendTransport(r io.Reader) {
-	defer cl.socket.Close()
+func sendTransport(socks <-chan net.Conn, r io.Reader) {
+	var sock net.Conn
 	p := make([]byte, 1024)
 	for {
 		nr, err := r.Read(p)
 		if nr == 0 {
-			Warn.Logf("write: %s", err)
+			Warn.Logf("sendTransport: %s", err)
 			break
 		}
-		nw, err := cl.socket.Write(p[:nr])
-		if nw < nr {
-			Warn.Logf("write: %s", err)
-			break
+		for nr > 0 {
+			select {
+			case sock = <-socks:
+				if sock != nil {
+					defer sock.Close()
+				}
+			default:
+			}
+
+			if sock == nil {
+				time.Sleep(l1interval)
+			} else {
+				nw, err := sock.Write(p[:nr])
+				nr -= nw
+				if nr != 0 {
+					Warn.Logf("write: %s", err)
+					break
+				}
+			}
 		}
 	}
 }