Use gorilla websocket in proxy-go too

Trac: 32465
This commit is contained in:
Arlo Breault 2019-11-20 19:33:28 -05:00
parent 7557e96a8d
commit 30b5ef8a9e
5 changed files with 128 additions and 134 deletions

View file

@ -0,0 +1,89 @@
package websocketconn
import (
"io"
"log"
"sync"
"time"
"github.com/gorilla/websocket"
)
// An abstraction that makes an underlying WebSocket connection look like an
// io.ReadWriteCloser.
type WebSocketConn struct {
Ws *websocket.Conn
r io.Reader
}
// Implements io.Reader.
func (conn *WebSocketConn) Read(b []byte) (n int, err error) {
var opCode int
if conn.r == nil {
// New message
var r io.Reader
for {
if opCode, r, err = conn.Ws.NextReader(); err != nil {
return
}
if opCode != websocket.BinaryMessage && opCode != websocket.TextMessage {
continue
}
conn.r = r
break
}
}
n, err = conn.r.Read(b)
if err == io.EOF {
// Message finished
conn.r = nil
err = nil
}
return
}
// Implements io.Writer.
func (conn *WebSocketConn) Write(b []byte) (n int, err error) {
var w io.WriteCloser
if w, err = conn.Ws.NextWriter(websocket.BinaryMessage); err != nil {
return
}
if n, err = w.Write(b); err != nil {
return
}
err = w.Close()
return
}
// Implements io.Closer.
func (conn *WebSocketConn) Close() error {
// Ignore any error in trying to write a Close frame.
_ = conn.Ws.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(time.Second))
return conn.Ws.Close()
}
// Create a new WebSocketConn.
func NewWebSocketConn(ws *websocket.Conn) WebSocketConn {
var conn WebSocketConn
conn.Ws = ws
return conn
}
// Copy from WebSocket to socket and vice versa.
func CopyLoop(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser) {
var wg sync.WaitGroup
copyer := func(dst io.ReadWriteCloser, src io.ReadWriteCloser) {
defer wg.Done()
if _, err := io.Copy(dst, src); err != nil {
log.Printf("io.Copy inside CopyLoop generated an error: %v", err)
}
dst.Close()
src.Close()
}
wg.Add(2)
go copyer(c1, c2)
go copyer(c2, c1)
wg.Wait()
}

View file

@ -0,0 +1,30 @@
package websocketconn
import (
"net"
"testing"
. "github.com/smartystreets/goconvey/convey"
)
func TestWebsocketConn(t *testing.T) {
Convey("CopyLoop", t, func() {
c1, s1 := net.Pipe()
c2, s2 := net.Pipe()
go CopyLoop(s1, s2)
go func() {
bytes := []byte("Hello!")
c1.Write(bytes)
}()
bytes := make([]byte, 6)
n, err := c2.Read(bytes)
So(n, ShouldEqual, 6)
So(err, ShouldEqual, nil)
So(bytes, ShouldResemble, []byte("Hello!"))
s1.Close()
// Check that copy loop has closed other connection
_, err = s2.Write(bytes)
So(err, ShouldNotBeNil)
})
}

View file

@ -374,23 +374,4 @@ func TestUtilityFuncs(t *testing.T) {
sid2 := genSessionID() sid2 := genSessionID()
So(sid1, ShouldNotEqual, sid2) So(sid1, ShouldNotEqual, sid2)
}) })
Convey("CopyLoop", t, func() {
c1, s1 := net.Pipe()
c2, s2 := net.Pipe()
go CopyLoop(s1, s2)
go func() {
bytes := []byte("Hello!")
c1.Write(bytes)
}()
bytes := make([]byte, 6)
n, err := c2.Read(bytes)
So(n, ShouldEqual, 6)
So(err, ShouldEqual, nil)
So(bytes, ShouldResemble, []byte("Hello!"))
s1.Close()
//Check that copy loop has closed other connection
_, err = s2.Write(bytes)
So(err, ShouldNotBeNil)
})
} }

View file

@ -21,8 +21,9 @@ import (
"git.torproject.org/pluggable-transports/snowflake.git/common/messages" "git.torproject.org/pluggable-transports/snowflake.git/common/messages"
"git.torproject.org/pluggable-transports/snowflake.git/common/safelog" "git.torproject.org/pluggable-transports/snowflake.git/common/safelog"
"git.torproject.org/pluggable-transports/snowflake.git/common/websocketconn"
"github.com/gorilla/websocket"
"github.com/pion/webrtc" "github.com/pion/webrtc"
"golang.org/x/net/websocket"
) )
const defaultBrokerURL = "https://snowflake-broker.bamsoftware.com/" const defaultBrokerURL = "https://snowflake-broker.bamsoftware.com/"
@ -239,22 +240,6 @@ func (b *Broker) sendAnswer(sid string, pc *webrtc.PeerConnection) error {
return nil return nil
} }
func CopyLoop(c1 net.Conn, c2 net.Conn) {
var wg sync.WaitGroup
copyer := func(dst io.ReadWriteCloser, src io.ReadWriteCloser) {
defer wg.Done()
if _, err := io.Copy(dst, src); err != nil {
log.Printf("io.Copy inside CopyLoop generated an error: %v", err)
}
dst.Close()
src.Close()
}
wg.Add(2)
go copyer(c1, c2)
go copyer(c2, c1)
wg.Wait()
}
// We pass conn.RemoteAddr() as an additional parameter, rather than calling // We pass conn.RemoteAddr() as an additional parameter, rather than calling
// conn.RemoteAddr() inside this function, as a workaround for a hang that // conn.RemoteAddr() inside this function, as a workaround for a hang that
// otherwise occurs inside of conn.pc.RemoteDescription() (called by // otherwise occurs inside of conn.pc.RemoteDescription() (called by
@ -279,15 +264,15 @@ func datachannelHandler(conn *webRTCConn, remoteAddr net.Addr) {
log.Printf("no remote address given in websocket") log.Printf("no remote address given in websocket")
} }
wsConn, err := websocket.Dial(u.String(), "", relayURL) ws, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil { if err != nil {
log.Printf("error dialing relay: %s", err) log.Printf("error dialing relay: %s", err)
return return
} }
wsConn := websocketconn.NewWebSocketConn(ws)
log.Printf("connected to relay") log.Printf("connected to relay")
defer wsConn.Close() defer wsConn.Close()
wsConn.PayloadType = websocket.BinaryFrame websocketconn.CopyLoop(conn, &wsConn)
CopyLoop(conn, wsConn)
log.Printf("datachannelHandler ends") log.Printf("datachannelHandler ends")
} }

View file

@ -15,12 +15,12 @@ import (
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"strings" "strings"
"sync"
"syscall" "syscall"
"time" "time"
pt "git.torproject.org/pluggable-transports/goptlib.git" pt "git.torproject.org/pluggable-transports/goptlib.git"
"git.torproject.org/pluggable-transports/snowflake.git/common/safelog" "git.torproject.org/pluggable-transports/snowflake.git/common/safelog"
"git.torproject.org/pluggable-transports/snowflake.git/common/websocketconn"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/acme/autocert"
"golang.org/x/net/http2" "golang.org/x/net/http2"
@ -50,97 +50,6 @@ additional HTTP listener on port 80 to work with ACME.
flag.PrintDefaults() flag.PrintDefaults()
} }
// An abstraction that makes an underlying WebSocket connection look like an
// io.ReadWriteCloser.
type webSocketConn struct {
Ws *websocket.Conn
r io.Reader
}
// Implements io.Reader.
func (conn *webSocketConn) Read(b []byte) (n int, err error) {
var opCode int
if conn.r == nil {
// New message
var r io.Reader
for {
if opCode, r, err = conn.Ws.NextReader(); err != nil {
return
}
if opCode != websocket.BinaryMessage && opCode != websocket.TextMessage {
continue
}
conn.r = r
break
}
}
n, err = conn.r.Read(b)
if err == io.EOF {
// Message finished
conn.r = nil
err = nil
}
return
}
// Implements io.Writer.
func (conn *webSocketConn) Write(b []byte) (n int, err error) {
var w io.WriteCloser
if w, err = conn.Ws.NextWriter(websocket.BinaryMessage); err != nil {
return
}
if n, err = w.Write(b); err != nil {
return
}
err = w.Close()
return
}
// Implements io.Closer.
func (conn *webSocketConn) Close() error {
// Ignore any error in trying to write a Close frame.
_ = conn.Ws.WriteControl(websocket.CloseMessage, []byte{}, time.Now().Add(time.Second))
return conn.Ws.Close()
}
// Create a new webSocketConn.
func newWebSocketConn(ws *websocket.Conn) webSocketConn {
var conn webSocketConn
conn.Ws = ws
return conn
}
// Copy from WebSocket to socket and vice versa.
func proxy(local *net.TCPConn, conn *webSocketConn) {
var wg sync.WaitGroup
wg.Add(2)
go func() {
if _, err := io.Copy(conn, local); err != nil {
log.Printf("error copying ORPort to WebSocket %v", err)
}
if err := local.CloseRead(); err != nil {
log.Printf("error closing read after copying ORPort to WebSocket %v", err)
}
conn.Close()
wg.Done()
}()
go func() {
if _, err := io.Copy(local, conn); err != nil {
log.Printf("error copying WebSocket to ORPort")
}
if err := local.CloseWrite(); err != nil {
log.Printf("error closing write after copying WebSocket to ORPort %v", err)
}
conn.Close()
wg.Done()
}()
wg.Wait()
}
// Return an address string suitable to pass into pt.DialOr. // Return an address string suitable to pass into pt.DialOr.
func clientAddr(clientIPParam string) string { func clientAddr(clientIPParam string) string {
if clientIPParam == "" { if clientIPParam == "" {
@ -166,8 +75,8 @@ func (handler *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
conn := newWebSocketConn(ws) wsConn := websocketconn.NewWebSocketConn(ws)
defer conn.Close() defer wsConn.Close()
// Pass the address of client as the remote address of incoming connection // Pass the address of client as the remote address of incoming connection
clientIPParam := r.URL.Query().Get("client_ip") clientIPParam := r.URL.Query().Get("client_ip")
@ -184,7 +93,7 @@ func (handler *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
defer or.Close() defer or.Close()
proxy(or, &conn) websocketconn.CopyLoop(or, &wsConn)
} }
func initServer(addr *net.TCPAddr, func initServer(addr *net.TCPAddr,