Stop using custom websocket library in server

Trac: 31028
This commit is contained in:
Arlo Breault 2019-10-16 21:00:13 -04:00
parent 300a23c6a0
commit c417fd5599
2 changed files with 53 additions and 37 deletions

View file

@ -26,8 +26,8 @@ install:
- go get -u github.com/keroserene/go-webrtc - go get -u github.com/keroserene/go-webrtc
- go get -u github.com/pion/webrtc - go get -u github.com/pion/webrtc
- go get -u github.com/dchest/uniuri - go get -u github.com/dchest/uniuri
- go get -u github.com/gorilla/websocket
- go get -u git.torproject.org/pluggable-transports/goptlib.git - go get -u git.torproject.org/pluggable-transports/goptlib.git
- go get -u git.torproject.org/pluggable-transports/websocket.git/websocket
- go get -u google.golang.org/appengine - go get -u google.golang.org/appengine
- go get -u golang.org/x/crypto/acme/autocert - go get -u golang.org/x/crypto/acme/autocert
- go get -u golang.org/x/net/http2 - go get -u golang.org/x/net/http2

View file

@ -21,7 +21,7 @@ import (
pt "git.torproject.org/pluggable-transports/goptlib.git" pt "git.torproject.org/pluggable-transports/goptlib.git"
"git.torproject.org/pluggable-transports/snowflake.git/common/safelog" "git.torproject.org/pluggable-transports/snowflake.git/common/safelog"
"git.torproject.org/pluggable-transports/websocket.git/websocket" "github.com/gorilla/websocket"
"golang.org/x/crypto/acme/autocert" "golang.org/x/crypto/acme/autocert"
"golang.org/x/net/http2" "golang.org/x/net/http2"
) )
@ -53,50 +53,60 @@ additional HTTP listener on port 80 to work with ACME.
// An abstraction that makes an underlying WebSocket connection look like an // An abstraction that makes an underlying WebSocket connection look like an
// io.ReadWriteCloser. // io.ReadWriteCloser.
type webSocketConn struct { type webSocketConn struct {
Ws *websocket.WebSocket Ws *websocket.Conn
messageBuf []byte r io.Reader
} }
// Implements io.Reader. // Implements io.Reader.
func (conn *webSocketConn) Read(b []byte) (n int, err error) { func (conn *webSocketConn) Read(b []byte) (n int, err error) {
for len(conn.messageBuf) == 0 { var opCode int
var m websocket.Message if conn.r == nil {
m, err = conn.Ws.ReadMessage() // New message
if err != nil { var r io.Reader
return for {
if opCode, r, err = conn.Ws.NextReader(); err != nil {
return
}
if opCode != websocket.BinaryMessage && opCode != websocket.TextMessage {
continue
}
conn.r = r
break
} }
if m.Opcode == 8 {
err = io.EOF
return
}
if m.Opcode != 2 {
err = fmt.Errorf("got non-binary opcode %d", m.Opcode)
return
}
conn.messageBuf = m.Payload
} }
n = copy(b, conn.messageBuf) n, err = conn.r.Read(b)
conn.messageBuf = conn.messageBuf[n:] if err != nil {
if err == io.EOF {
// Message finished
conn.r = nil
err = nil
}
}
return return
} }
// Implements io.Writer. // Implements io.Writer.
func (conn *webSocketConn) Write(b []byte) (int, error) { func (conn *webSocketConn) Write(b []byte) (n int, err error) {
err := conn.Ws.WriteMessage(2, b) var w io.WriteCloser
return len(b), err if w, err = conn.Ws.NextWriter(websocket.BinaryMessage); err != nil {
return
}
if n, err = w.Write(b); err != nil {
return
}
err = w.Close()
return
} }
// Implements io.Closer. // Implements io.Closer.
func (conn *webSocketConn) Close() error { func (conn *webSocketConn) Close() error {
// Ignore any error in trying to write a Close frame. return conn.Ws.Close()
_ = conn.Ws.WriteFrame(8, nil)
return conn.Ws.Conn.Close()
} }
// Create a new webSocketConn. // Create a new webSocketConn.
func newWebSocketConn(ws *websocket.WebSocket) webSocketConn { func newWebSocketConn(ws *websocket.Conn) webSocketConn {
var conn webSocketConn var conn webSocketConn
conn.Ws = ws conn.Ws = ws
return conn return conn
@ -145,16 +155,22 @@ func clientAddr(clientIPParam string) string {
return (&net.TCPAddr{IP: clientIP, Port: 1, Zone: ""}).String() return (&net.TCPAddr{IP: clientIP, Port: 1, Zone: ""}).String()
} }
func webSocketHandler(ws *websocket.WebSocket) { var upgrader = websocket.Upgrader{}
// Undo timeouts on HTTP request handling.
if err := ws.Conn.SetDeadline(time.Time{}); err != nil { type HTTPHandler struct{}
log.Printf("unable to set deadlines with error: %v", err)
func (handler *HTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ws, err := upgrader.Upgrade(w, r, nil)
if err != nil {
log.Println(err)
return
} }
conn := newWebSocketConn(ws) conn := newWebSocketConn(ws)
defer conn.Close() defer conn.Close()
// Pass the address of client as the remote address of incoming connection // Pass the address of client as the remote address of incoming connection
clientIPParam := ws.Request().URL.Query().Get("client_ip") clientIPParam := r.URL.Query().Get("client_ip")
addr := clientAddr(clientIPParam) addr := clientAddr(clientIPParam)
if addr == "" { if addr == "" {
statsChannel <- false statsChannel <- false
@ -162,7 +178,6 @@ func webSocketHandler(ws *websocket.WebSocket) {
statsChannel <- true statsChannel <- true
} }
or, err := pt.DialOr(&ptInfo, addr, ptMethodName) or, err := pt.DialOr(&ptInfo, addr, ptMethodName)
if err != nil { if err != nil {
log.Printf("failed to connect to ORPort: %s", err) log.Printf("failed to connect to ORPort: %s", err)
return return
@ -185,11 +200,12 @@ func initServer(addr *net.TCPAddr,
return nil, fmt.Errorf("cannot listen on port %d; configure a port using ServerTransportListenAddr", addr.Port) return nil, fmt.Errorf("cannot listen on port %d; configure a port using ServerTransportListenAddr", addr.Port)
} }
var config websocket.Config upgrader.CheckOrigin = func(r *http.Request) bool { return true }
config.MaxMessageSize = maxMessageSize
var handler HTTPHandler
server := &http.Server{ server := &http.Server{
Addr: addr.String(), Addr: addr.String(),
Handler: config.Handler(webSocketHandler), Handler: &handler,
ReadTimeout: requestTimeout, ReadTimeout: requestTimeout,
} }
// We need to override server.TLSConfig.GetCertificate--but first // We need to override server.TLSConfig.GetCertificate--but first