diff --git a/client/lib/connwrapper.go b/client/lib/connwrapper.go new file mode 100644 index 0000000..f8d0614 --- /dev/null +++ b/client/lib/connwrapper.go @@ -0,0 +1,61 @@ +package snowflake_client + +import ( + "errors" + "io" + "net" + "time" +) + +type ReadWriteCloserPreservesBoundary interface { + io.ReadWriteCloser + MessageBoundaryPreserved() +} + +var errENOSYS = errors.New("not implemented") + +func newPacketConnWrapper(localAddr, remoteAddr net.Addr, rwc ReadWriteCloserPreservesBoundary) net.PacketConn { + return &packetConnWrapper{ + ReadWriteCloserPreservesBoundary: rwc, + remoteAddr: remoteAddr, + localAddr: localAddr, + } +} + +type packetConnWrapper struct { + ReadWriteCloserPreservesBoundary + remoteAddr net.Addr + localAddr net.Addr +} + +func (pcw *packetConnWrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, err = pcw.Read(p) + if err != nil { + return 0, nil, err + } + return n, pcw.remoteAddr, nil +} + +func (pcw *packetConnWrapper) WriteTo(p []byte, addr net.Addr) (n int, err error) { + return pcw.Write(p) +} + +func (pcw *packetConnWrapper) Close() error { + return pcw.ReadWriteCloserPreservesBoundary.Close() +} + +func (pcw *packetConnWrapper) LocalAddr() net.Addr { + return pcw.localAddr +} + +func (pcw *packetConnWrapper) SetDeadline(t time.Time) error { + return errENOSYS +} + +func (pcw *packetConnWrapper) SetReadDeadline(t time.Time) error { + return errENOSYS +} + +func (pcw *packetConnWrapper) SetWriteDeadline(t time.Time) error { + return errENOSYS +} diff --git a/client/lib/rendezvous.go b/client/lib/rendezvous.go index da06027..20fd34b 100644 --- a/client/lib/rendezvous.go +++ b/client/lib/rendezvous.go @@ -21,6 +21,7 @@ import ( "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/event" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/nat" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/turbotunnel" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/util" ) @@ -251,6 +252,7 @@ type WebRTCDialer struct { eventLogger event.SnowflakeEventReceiver proxy *url.URL + clientID turbotunnel.ClientID } // Deprecated: Use NewWebRTCDialerWithNatPolicyAndEventsAndProxy instead @@ -281,7 +283,6 @@ func NewWebRTCDialerWithEventsAndProxy(broker *BrokerChannel, iceServers []webrt ) } -// NewWebRTCDialerWithNatPolicyAndEventsAndProxy constructs a new WebRTCDialer. func NewWebRTCDialerWithNatPolicyAndEventsAndProxy( broker *BrokerChannel, natPolicy *NATPolicy, @@ -289,6 +290,27 @@ func NewWebRTCDialerWithNatPolicyAndEventsAndProxy( max int, eventLogger event.SnowflakeEventReceiver, proxy *url.URL, +) *WebRTCDialer { + return newWebRTCDialerWithNatPolicyAndEventsAndProxyAndClientID( + broker, + natPolicy, + iceServers, + max, + eventLogger, + proxy, + turbotunnel.NewClientID(), + ) +} + +// NewWebRTCDialerWithNatPolicyAndEventsAndProxy constructs a new WebRTCDialer. +func newWebRTCDialerWithNatPolicyAndEventsAndProxyAndClientID( + broker *BrokerChannel, + natPolicy *NATPolicy, + iceServers []webrtc.ICEServer, + max int, + eventLogger event.SnowflakeEventReceiver, + proxy *url.URL, + clientID turbotunnel.ClientID, ) *WebRTCDialer { config := webrtc.Configuration{ ICEServers: iceServers, @@ -302,6 +324,7 @@ func NewWebRTCDialerWithNatPolicyAndEventsAndProxy( eventLogger: eventLogger, proxy: proxy, + clientID: clientID, } } @@ -309,9 +332,7 @@ func NewWebRTCDialerWithNatPolicyAndEventsAndProxy( func (w WebRTCDialer) Catch() (*WebRTCPeer, error) { // TODO: [#25591] Fetch ICE server information from Broker. // TODO: [#25596] Consider TURN servers here too. - return NewWebRTCPeerWithNatPolicyAndEventsAndProxy( - w.webrtcConfig, w.BrokerChannel, w.natPolicy, w.eventLogger, w.proxy, - ) + return NewWebRTCPeerWithNatPolicyAndEventsProxyAndClientID(w.webrtcConfig, w.BrokerChannel, w.natPolicy, w.eventLogger, w.proxy, w.clientID) } // GetMax returns the maximum number of snowflakes to collect. diff --git a/client/lib/snowflake.go b/client/lib/snowflake.go index f1a3bad..939e593 100644 --- a/client/lib/snowflake.go +++ b/client/lib/snowflake.go @@ -32,6 +32,7 @@ import ( "math/rand" "net" "net/url" + "os" "strings" "time" @@ -42,6 +43,7 @@ import ( "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/event" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/nat" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/packetpadding" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/turbotunnel" ) @@ -163,7 +165,10 @@ func NewSnowflakeClient(config ClientConfig) (*Transport, error) { max = config.Max } eventsLogger := event.NewSnowflakeEventDispatcher() - transport := &Transport{dialer: NewWebRTCDialerWithNatPolicyAndEventsAndProxy(broker, natPolicy, iceServers, max, eventsLogger, config.CommunicationProxy), eventDispatcher: eventsLogger} + transport := &Transport{ + dialer: NewWebRTCDialerWithNatPolicyAndEventsAndProxy(broker, natPolicy, iceServers, max, eventsLogger, config.CommunicationProxy), + eventDispatcher: eventsLogger, + } return transport, nil } @@ -324,13 +329,11 @@ func parseIceServers(addresses []string) []webrtc.ICEServer { // over. The net.PacketConn successively connects through Snowflake proxies // pulled from snowflakes. func newSession(snowflakes SnowflakeCollector) (net.PacketConn, *smux.Session, error) { - clientID := turbotunnel.NewClientID() - // We build a persistent KCP session on a sequence of ephemeral WebRTC // connections. This dialContext tells RedialPacketConn how to get a new // WebRTC connection when the previous one dies. Inside each WebRTC - // connection, we use encapsulationPacketConn to encode packets into a - // stream. + // connection, KCP packets are sent and received, one-to-one, in data + // channel messages. dialContext := func(ctx context.Context) (net.PacketConn, error) { log.Printf("redialing on same connection") // Obtain an available WebRTC remote. May block. @@ -339,17 +342,12 @@ func newSession(snowflakes SnowflakeCollector) (net.PacketConn, *smux.Session, e return nil, errors.New("handler: Received invalid Snowflake") } log.Println("---- Handler: snowflake assigned ----") - // Send the magic Turbo Tunnel token. - _, err := conn.Write(turbotunnel.Token[:]) - if err != nil { - return nil, err - } - // Send ClientID prefix. - _, err = conn.Write(clientID[:]) - if err != nil { - return nil, err - } - return newEncapsulationPacketConn(dummyAddr{}, dummyAddr{}, conn), nil + + packetConnWrapper := newPacketConnWrapper(dummyAddr{}, dummyAddr{}, + packetpadding.NewPaddableConnection(conn, + packetpadding.New())) + + return packetConnWrapper, nil } pconn := turbotunnel.NewRedialPacketConn(dummyAddr{}, dummyAddr{}, dialContext) @@ -375,6 +373,14 @@ func newSession(snowflakes SnowflakeCollector) (net.PacketConn, *smux.Session, e 0, // default resend 1, // nc=1 => congestion window off ) + if os.Getenv("SNOWFLAKE_TEST_KCP_FAST3MODE") == "1" { + conn.SetNoDelay( + 1, + 10, + 2, + 1, + ) + } // On the KCP connection we overlay an smux session and stream. smuxConfig := smux.DefaultConfig() smuxConfig.Version = 2 diff --git a/client/lib/webrtc.go b/client/lib/webrtc.go index 9d803a2..ed745d9 100644 --- a/client/lib/webrtc.go +++ b/client/lib/webrtc.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "encoding/hex" "errors" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages" "io" "log" "net" @@ -18,6 +19,7 @@ import ( "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/event" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/proxy" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/turbotunnel" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/util" ) @@ -43,46 +45,63 @@ type WebRTCPeer struct { bytesLogger bytesLogger eventsLogger event.SnowflakeEventReceiver proxy *url.URL + + clientID turbotunnel.ClientID } // Deprecated: Use NewWebRTCPeerWithNatPolicyAndEventsAndProxy Instead. -func NewWebRTCPeer( +func newWebRTCPeer( config *webrtc.Configuration, broker *BrokerChannel, ) (*WebRTCPeer, error) { - return NewWebRTCPeerWithNatPolicyAndEventsAndProxy( + return newWebRTCPeerWithNatPolicyAndEventsAndProxy( config, broker, nil, nil, nil, ) } // Deprecated: Use NewWebRTCPeerWithNatPolicyAndEventsAndProxy Instead. -func NewWebRTCPeerWithEvents( +func newWebRTCPeerWithEvents( config *webrtc.Configuration, broker *BrokerChannel, eventsLogger event.SnowflakeEventReceiver, ) (*WebRTCPeer, error) { - return NewWebRTCPeerWithNatPolicyAndEventsAndProxy( + return newWebRTCPeerWithNatPolicyAndEventsAndProxy( config, broker, nil, eventsLogger, nil, ) } // Deprecated: Use NewWebRTCPeerWithNatPolicyAndEventsAndProxy Instead. -func NewWebRTCPeerWithEventsAndProxy( +func newWebRTCPeerWithEventsAndProxy( config *webrtc.Configuration, broker *BrokerChannel, eventsLogger event.SnowflakeEventReceiver, proxy *url.URL, ) (*WebRTCPeer, error) { - return NewWebRTCPeerWithNatPolicyAndEventsAndProxy( + return newWebRTCPeerWithNatPolicyAndEventsAndProxy( config, broker, nil, eventsLogger, proxy, ) } +func newWebRTCPeerWithNatPolicyAndEventsAndProxy( + config *webrtc.Configuration, + broker *BrokerChannel, natPolicy *NATPolicy, eventsLogger event.SnowflakeEventReceiver, proxy *url.URL, +) (*WebRTCPeer, error) { + return NewWebRTCPeerWithNatPolicyAndEventsProxyAndClientID( + config, + broker, + natPolicy, + eventsLogger, + proxy, + turbotunnel.ClientID{}, + ) +} + // NewWebRTCPeerWithNatPolicyAndEventsAndProxy constructs // a WebRTC PeerConnection to a snowflake proxy. // // The creation of the peer handles the signaling to the Snowflake broker, including // the exchange of SDP information, the creation of a PeerConnection, and the establishment // of a DataChannel to the Snowflake proxy. -func NewWebRTCPeerWithNatPolicyAndEventsAndProxy( - config *webrtc.Configuration, broker *BrokerChannel, natPolicy *NATPolicy, - eventsLogger event.SnowflakeEventReceiver, proxy *url.URL, +// clientID is the hinted ID for the connection. +func NewWebRTCPeerWithNatPolicyAndEventsProxyAndClientID(config *webrtc.Configuration, + broker *BrokerChannel, natPolicy *NATPolicy, eventsLogger event.SnowflakeEventReceiver, proxy *url.URL, + clientID turbotunnel.ClientID, ) (*WebRTCPeer, error) { if eventsLogger == nil { eventsLogger = event.NewSnowflakeEventDispatcher() @@ -106,6 +125,7 @@ func NewWebRTCPeerWithNatPolicyAndEventsAndProxy( connection.eventsLogger = eventsLogger connection.proxy = proxy + connection.clientID = clientID err := connection.connect(config, broker, natPolicy) if err != nil { @@ -296,9 +316,18 @@ func (c *WebRTCPeer) preparePeerConnection( log.Printf("NewPeerConnection ERROR: %s", err) return err } - ordered := true + ordered := false + var maxRetransmission uint16 = 0 + connectionMetadata := messages.ClientConnectionMetadata{ClientID: c.clientID[:]} + encodedMetadata, err := connectionMetadata.EncodeConnectionMetadata() + if err != nil { + return err + } + protocol := encodedMetadata dataChannelOptions := &webrtc.DataChannelInit{ - Ordered: &ordered, + Ordered: &ordered, + Protocol: &protocol, + MaxRetransmits: &maxRetransmission, } // We must create the data channel before creating an offer // https://github.com/pion/webrtc/wiki/Release-WebRTC@v3.0.0#a-data-channel-is-no-longer-implicitly-created-with-a-peerconnection @@ -383,3 +412,5 @@ func (c *WebRTCPeer) cleanup() { } } } + +func (c *WebRTCPeer) MessageBoundaryPreserved() {} diff --git a/client/snowflake.go b/client/snowflake.go index 648481f..fb6bbb7 100644 --- a/client/snowflake.go +++ b/client/snowflake.go @@ -271,7 +271,11 @@ func main() { switch methodName { case "snowflake": // TODO: Be able to recover when SOCKS dies. - ln, err := pt.ListenSocks("tcp", "127.0.0.1:0") + listenAddr := "127.0.0.1:0" + if forcedListenAddr := os.Getenv("SNOWFLAKE_TEST_FORCELISTENADDR"); forcedListenAddr != "" { + listenAddr = forcedListenAddr + } + ln, err := pt.ListenSocks("tcp", listenAddr) if err != nil { pt.CmethodError(methodName, err.Error()) break diff --git a/common/messages/client.go b/common/messages/client.go index da6359e..030e052 100644 --- a/common/messages/client.go +++ b/common/messages/client.go @@ -5,10 +5,13 @@ package messages import ( "bytes" + "encoding/base64" "encoding/json" "fmt" - "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/bridgefingerprint" + "github.com/fxamacker/cbor" + + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/bridgefingerprint" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/nat" ) @@ -149,3 +152,34 @@ func DecodeClientPollResponse(data []byte) (*ClientPollResponse, error) { return &message, nil } + +// ClientConnectionMetadata is a struct that contains metadata about a snowflake connection between client and server +// It will be sent from the client to the proxy in WebRTC data channel protocol string +// The proxy will then send the metadata to the server in the protocol get parameter of the WebSocket connection +type ClientConnectionMetadata struct { + ClientID []byte `json:"client_id"` +} + +func (meta *ClientConnectionMetadata) EncodeConnectionMetadata() (string, error) { + jsonData, err := cbor.Marshal(meta, cbor.CanonicalEncOptions()) + if err != nil { + return "", err + } + + return base64.RawURLEncoding.EncodeToString(jsonData), nil +} + +func DecodeConnectionMetadata(data string) (*ClientConnectionMetadata, error) { + decodedData, err := base64.RawURLEncoding.DecodeString(data) + if err != nil { + return nil, err + } + + var meta ClientConnectionMetadata + err = cbor.Unmarshal(decodedData, &meta) + if err != nil { + return nil, err + } + + return &meta, nil +} diff --git a/common/packetpadding/conn.go b/common/packetpadding/conn.go new file mode 100644 index 0000000..b214e70 --- /dev/null +++ b/common/packetpadding/conn.go @@ -0,0 +1,46 @@ +package packetpadding + +import ( + "io" +) + +type ReadWriteCloserPreservesBoundary interface { + io.ReadWriteCloser + MessageBoundaryPreserved() +} + +type PaddableConnection interface { + ReadWriteCloserPreservesBoundary +} + +func NewPaddableConnection(rwc ReadWriteCloserPreservesBoundary, padding PacketPaddingContainer) PaddableConnection { + return &paddableConnection{ + ReadWriteCloserPreservesBoundary: rwc, + padding: padding, + } +} + +type paddableConnection struct { + ReadWriteCloserPreservesBoundary + padding PacketPaddingContainer +} + +func (c *paddableConnection) Write(p []byte) (n int, err error) { + dataLen := len(p) + if _, err = c.ReadWriteCloserPreservesBoundary.Write(c.padding.Pack(p, 0)); err != nil { + return 0, err + } + return dataLen, nil +} + +func (c *paddableConnection) Read(p []byte) (n int, err error) { + if n, err = c.ReadWriteCloserPreservesBoundary.Read(p); err != nil { + return 0, err + } + + payload, _ := c.padding.Unpack(p[:n]) + if payload != nil { + copy(p, payload) + } + return len(payload), nil +} diff --git a/common/packetpadding/container.go b/common/packetpadding/container.go new file mode 100644 index 0000000..318ba2e --- /dev/null +++ b/common/packetpadding/container.go @@ -0,0 +1,52 @@ +package packetpadding + +import "encoding/binary" + +func New() PacketPaddingContainer { + return packetPaddingContainer{} +} + +type packetPaddingContainer struct { +} + +func (c packetPaddingContainer) Pack(data_OWNERSHIP_RELINQUISHED []byte, paddingLength int) []byte { + data := append(data_OWNERSHIP_RELINQUISHED, make([]byte, paddingLength)...) + dataLength := len(data_OWNERSHIP_RELINQUISHED) + data = binary.BigEndian.AppendUint16(data, uint16(dataLength)) + return data +} + +func (c packetPaddingContainer) Pad(paddingLength int) []byte { + if assertPaddingLengthIsNotNegative := paddingLength < 0; assertPaddingLengthIsNotNegative { + return nil + } + switch paddingLength { + case 0: + return []byte{} + case 1: + return []byte{0} + case 2: + return []byte{0, 0} + default: + return append(make([]byte, paddingLength-2), byte(paddingLength>>8), byte(paddingLength)) + } + +} + +func (c packetPaddingContainer) Unpack(wrappedData_OWNERSHIP_RELINQUISHED []byte) ([]byte, int) { + dataLength := len(wrappedData_OWNERSHIP_RELINQUISHED) + if dataLength < 2 { + return nil, dataLength + } + + dataLen := int(binary.BigEndian.Uint16(wrappedData_OWNERSHIP_RELINQUISHED[dataLength-2:])) + if dataLen > 2047 { + return nil, 0 + } + paddingLength := dataLength - dataLen - 2 + if paddingLength < 0 { + return nil, paddingLength + } + + return wrappedData_OWNERSHIP_RELINQUISHED[:dataLen], paddingLength +} diff --git a/common/packetpadding/containerIfce.go b/common/packetpadding/containerIfce.go new file mode 100644 index 0000000..84fac0c --- /dev/null +++ b/common/packetpadding/containerIfce.go @@ -0,0 +1,34 @@ +package packetpadding + +// PacketPaddingContainer is an interface that defines methods to pad packets +// with a given number of bytes, and to unpack the padding from a padded packet. +// The packet format is as follows if the desired output length is greater than +// 2 bytes: +// | data | padding | data length | +// The data length is a 16-bit big-endian integer that represents the length of +// the data in bytes. +// If the desired output length is 2 bytes or less, the packet format is as +// follows: +// | padding | +// No payload will be included in the packet. +type PacketPaddingContainer interface { + // Pack pads the given data with the given number of bytes, and appends the + // length of the data to the end of the data. The returned byte slice + // contains the padded data. + // This generates a packet with a length of + // len(data_OWNERSHIP_RELINQUISHED) + padding + 2 + // @param data_OWNERSHIP_RELINQUISHED - The payload, this reference is consumed and should not be used after this call. + // @param padding - The number of padding bytes to add to the data. + Pack(data_OWNERSHIP_RELINQUISHED []byte, paddingLength int) []byte + + // Unpack extracts the data and padding from the given padded data. It + // returns the data and the number of padding bytes. + // the data may be nil. + // @param wrappedData_OWNERSHIP_RELINQUISHED - The packet, this reference is consumed and should not be used after this call. + Unpack(wrappedData_OWNERSHIP_RELINQUISHED []byte) ([]byte, int) + + // Pad returns a padding packet of padding length. + // If the padding length is less than 0, nil is returned. + // @param padding - The number of padding bytes to add to the data. + Pad(paddingLength int) []byte +} diff --git a/common/packetpadding/container_test.go b/common/packetpadding/container_test.go new file mode 100644 index 0000000..e68dcc8 --- /dev/null +++ b/common/packetpadding/container_test.go @@ -0,0 +1,113 @@ +package packetpadding_test + +import ( + "testing" + + . "github.com/smartystreets/goconvey/convey" + + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/packetpadding" +) + +func TestPacketPaddingContainer(t *testing.T) { + Convey("Given a PacketPaddingContainer", t, func() { + container := packetpadding.New() + + Convey("When packing data with padding", func() { + data := []byte("testdata") + paddingLength := 4 + packedData := container.Pack(data, paddingLength) + + Convey("The packed data should have the correct length", func() { + expectedLength := len(data) + paddingLength + 2 + So(len(packedData), ShouldEqual, expectedLength) + }) + + Convey("When unpacking the packed data", func() { + unpackedData, unpackedPaddingLength := container.Unpack(packedData) + + Convey("The unpacked data should match the original data", func() { + So(string(unpackedData), ShouldEqual, string(data)) + }) + + Convey("The unpacked padding length should match the original padding length", func() { + So(unpackedPaddingLength, ShouldEqual, paddingLength) + }) + }) + }) + + Convey("When packing empty data with padding", func() { + data := []byte("") + paddingLength := 4 + packedData := container.Pack(data, paddingLength) + + Convey("The packed data should have the correct length", func() { + expectedLength := len(data) + paddingLength + 2 + So(len(packedData), ShouldEqual, expectedLength) + }) + + Convey("When unpacking the packed data", func() { + unpackedData, unpackedPaddingLength := container.Unpack(packedData) + + Convey("The unpacked data should match the original data", func() { + So(string(unpackedData), ShouldEqual, string(data)) + }) + + Convey("The unpacked padding length should match the original padding length", func() { + So(unpackedPaddingLength, ShouldEqual, paddingLength) + }) + }) + }) + + Convey("When packing data with zero padding", func() { + data := []byte("testdata") + paddingLength := 0 + packedData := container.Pack(data, paddingLength) + + Convey("The packed data should have the correct length", func() { + expectedLength := len(data) + paddingLength + 2 + So(len(packedData), ShouldEqual, expectedLength) + }) + + Convey("When unpacking the packed data", func() { + unpackedData, unpackedPaddingLength := container.Unpack(packedData) + + Convey("The unpacked data should match the original data", func() { + So(string(unpackedData), ShouldEqual, string(data)) + }) + + Convey("The unpacked padding length should match the original padding length", func() { + So(unpackedPaddingLength, ShouldEqual, paddingLength) + }) + }) + }) + + Convey("When padding data", func() { + Convey("With a positive padding length", func() { + padLength := 3 + padData := container.Pad(padLength) + + Convey("The padded data should have the correct length", func() { + So(len(padData), ShouldEqual, padLength) + }) + }) + + Convey("With a zero padding length", func() { + padLength := 0 + padData := container.Pad(padLength) + + Convey("The padded data should be empty", func() { + So(len(padData), ShouldEqual, 0) + }) + }) + + Convey("With a negative padding length", func() { + padLength := -1 + padData := container.Pad(padLength) + + Convey("The padded data should be nil", func() { + So(padData, ShouldBeNil) + }) + }) + }) + }) +} diff --git a/common/websocketconn/websocketconn.go b/common/websocketconn/websocketconn.go index e5256df..0178f3d 100644 --- a/common/websocketconn/websocketconn.go +++ b/common/websocketconn/websocketconn.go @@ -41,6 +41,8 @@ func (conn *Conn) SetDeadline(t time.Time) error { return err } +func (conn *Conn) MessageBoundaryPreserved() {} + func readLoop(w io.Writer, ws *websocket.Conn) error { var buf [2048]byte for { diff --git a/go.mod b/go.mod index 780a46d..9860065 100644 --- a/go.mod +++ b/go.mod @@ -47,6 +47,7 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudflare/circl v1.3.7 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/fxamacker/cbor v1.5.1 // indirect github.com/google/uuid v1.6.0 // indirect github.com/gopherjs/gopherjs v1.17.2 // indirect github.com/jtolds/gls v4.20.0+incompatible // indirect @@ -77,6 +78,7 @@ require ( github.com/tjfoc/gmsm v1.4.1 // indirect github.com/txthinking/runnergroup v0.0.0-20210608031112-152c7c4432bf // indirect github.com/wlynxg/anet v0.0.5 // indirect + github.com/x448/float16 v0.8.4 // indirect golang.org/x/mod v0.18.0 // indirect golang.org/x/sync v0.11.0 // indirect golang.org/x/text v0.22.0 // indirect diff --git a/go.sum b/go.sum index 3ada462..e3f4bf2 100644 --- a/go.sum +++ b/go.sum @@ -44,6 +44,8 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/fxamacker/cbor v1.5.1 h1:XjQWBgdmQyqimslUh5r4tUGmoqzHmBFQOImkWGi2awg= +github.com/fxamacker/cbor v1.5.1/go.mod h1:3aPGItF174ni7dDzd6JZ206H8cmr4GDNBGpPa971zsU= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc= @@ -162,6 +164,8 @@ github.com/txthinking/socks5 v0.0.0-20230325130024-4230056ae301 h1:d/Wr/Vl/wiJHc github.com/txthinking/socks5 v0.0.0-20230325130024-4230056ae301/go.mod h1:ntmMHL/xPq1WLeKiw8p/eRATaae6PiVRNipHFJxI8PM= github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU= github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= +github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xtaci/kcp-go/v5 v5.6.8 h1:jlI/0jAyjoOjT/SaGB58s4bQMJiNS41A2RKzR6TMWeI= github.com/xtaci/kcp-go/v5 v5.6.8/go.mod h1:oE9j2NVqAkuKO5o8ByKGch3vgVX3BNf8zqP8JiGq0bM= github.com/xtaci/lossyconn v0.0.0-20190602105132-8df528c0c9ae h1:J0GxkO96kL4WF+AIT3M4mfUVinOCPgf2uUWYFUzN0sM= diff --git a/proxy/lib/snowflake.go b/proxy/lib/snowflake.go index 4ce5164..a67946a 100644 --- a/proxy/lib/snowflake.go +++ b/proxy/lib/snowflake.go @@ -35,6 +35,7 @@ import ( "net" "net/http" "net/url" + "os" "strings" "sync" "time" @@ -342,7 +343,7 @@ func (sf *SnowflakeProxy) datachannelHandler(conn *webRTCConn, remoteAddr net.Ad relayURL = sf.RelayURL } - wsConn, err := connectToRelay(relayURL, remoteAddr) + wsConn, err := connectToRelay(relayURL, remoteAddr, conn.GetConnectionProtocol()) if err != nil { log.Print(err) return @@ -353,7 +354,11 @@ func (sf *SnowflakeProxy) datachannelHandler(conn *webRTCConn, remoteAddr net.Ad log.Printf("datachannelHandler ends") } -func connectToRelay(relayURL string, remoteAddr net.Addr) (*websocketconn.Conn, error) { +func connectToRelay( + relayURL string, + remoteAddr net.Addr, + webrtcConnProtocol string, +) (*websocketconn.Conn, error) { u, err := url.Parse(relayURL) if err != nil { return nil, fmt.Errorf("invalid relay url: %s", err) @@ -369,6 +374,12 @@ func connectToRelay(relayURL string, remoteAddr net.Addr) (*websocketconn.Conn, log.Printf("no remote address given in websocket") } + { + q := u.Query() + q.Set("protocol", webrtcConnProtocol) + u.RawQuery = q.Encode() + } + ws, _, err := websocket.DefaultDialer.Dial(u.String(), nil) if err != nil { return nil, fmt.Errorf("error dialing relay: %s = %s", u.String(), err) @@ -449,7 +460,7 @@ func (sf *SnowflakeProxy) makePeerConnectionFromOffer( close(dataChan) pr, pw := io.Pipe() - conn := newWebRTCConn(pc, dc, pr, sf.bytesLogger) + conn := newWebRTCConn(pc, dc, pr, sf.bytesLogger, dc.Protocol()) dc.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold) @@ -461,7 +472,7 @@ func (sf *SnowflakeProxy) makePeerConnectionFromOffer( }) dc.OnOpen(func() { - log.Printf("Data Channel %s-%d open\n", dc.Label(), dc.ID()) + log.Printf("Data Channel %s-%d;%s open\n", dc.Label(), dc.ID(), dc.Protocol()) sf.EventDispatcher.OnNewSnowflakeEvent(event.EventOnProxyClientConnected{}) if sf.OutboundAddress != "" { @@ -835,6 +846,11 @@ func (sf *SnowflakeProxy) Stop() { func (sf *SnowflakeProxy) checkNATType(config webrtc.Configuration, probeURL string) error { log.Printf("Checking our NAT type, contacting NAT check probe server at \"%v\"...", probeURL) + if os.Getenv("SNOWFLAKE_TEST_ASSUMEUNRESTRICTED") != "" { + currentNATType = NATUnrestricted + return nil + } + probe, err := newSignalingServer(probeURL) if err != nil { return fmt.Errorf("Error parsing url: %w", err) diff --git a/proxy/lib/webrtcconn.go b/proxy/lib/webrtcconn.go index b09ff3d..1c8bf08 100644 --- a/proxy/lib/webrtcconn.go +++ b/proxy/lib/webrtcconn.go @@ -41,9 +41,14 @@ type webRTCConn struct { cancelTimeoutLoop context.CancelFunc bytesLogger bytesLogger + + // protocol reflect the protocol field in the channel opening + // message of Data Channel Establishment Protocol. + // In snowflake it is used to transmit connection metadata. + protocol string } -func newWebRTCConn(pc *webrtc.PeerConnection, dc *webrtc.DataChannel, pr *io.PipeReader, bytesLogger bytesLogger) *webRTCConn { +func newWebRTCConn(pc *webrtc.PeerConnection, dc *webrtc.DataChannel, pr *io.PipeReader, bytesLogger bytesLogger, protocol string) *webRTCConn { conn := &webRTCConn{pc: pc, dc: dc, pr: pr, bytesLogger: bytesLogger} conn.isClosing = false conn.activity = make(chan struct{}, 100) @@ -51,6 +56,7 @@ func newWebRTCConn(pc *webrtc.PeerConnection, dc *webrtc.DataChannel, pr *io.Pip conn.inactivityTimeout = 30 * time.Second ctx, cancel := context.WithCancel(context.Background()) conn.cancelTimeoutLoop = cancel + conn.protocol = protocol go conn.timeoutLoop(ctx) return conn } @@ -137,6 +143,10 @@ func (c *webRTCConn) SetWriteDeadline(t time.Time) error { return fmt.Errorf("SetWriteDeadline not implemented") } +func (c *webRTCConn) GetConnectionProtocol() string { + return c.protocol +} + func remoteIPFromSDP(str string) net.IP { // Look for remote IP in "a=candidate" attribute fields // https://tools.ietf.org/html/rfc5245#section-15.1 diff --git a/server/lib/http.go b/server/lib/http.go index 403aeb1..26ecaab 100644 --- a/server/lib/http.go +++ b/server/lib/http.go @@ -1,13 +1,10 @@ package snowflake_server import ( - "bufio" - "bytes" "crypto/hmac" "crypto/rand" "crypto/sha256" "encoding/binary" - "fmt" "io" "log" "net" @@ -17,7 +14,8 @@ import ( "github.com/gorilla/websocket" - "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/encapsulation" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/packetpadding" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/turbotunnel" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/websocketconn" ) @@ -32,7 +30,7 @@ const requestTimeout = 10 * time.Second const clientMapTimeout = 1 * time.Minute // How big to make the map of ClientIDs to IP addresses. The map is used in -// turbotunnelMode to store a reasonable IP address for a client session that +// turboTunnelUDPLikeMode to store a reasonable IP address for a client session that // may outlive any single WebSocket connection. const clientIDAddrMapCapacity = 98304 @@ -108,47 +106,25 @@ func (handler *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Pass the address of client as the remote address of incoming connection clientIPParam := r.URL.Query().Get("client_ip") addr := clientAddr(clientIPParam) + protocol := r.URL.Query().Get("protocol") - var token [len(turbotunnel.Token)]byte - _, err = io.ReadFull(conn, token[:]) - if err != nil { - // Don't bother logging EOF: that happens with an unused - // connection, which clients make frequently as they maintain a - // pool of proxies. - if err != io.EOF { - log.Printf("reading token: %v", err) - } - return - } - - switch { - case bytes.Equal(token[:], turbotunnel.Token[:]): - err = handler.turbotunnelMode(conn, addr) - default: - // We didn't find a matching token, which means that we are - // dealing with a client that doesn't know about such things. - // Close the conn as we no longer support the old - // one-session-per-WebSocket mode. - log.Println("Received unsupported oneshot connection") - return - } - if err != nil { + err = handler.turboTunnelUDPLikeMode(conn, addr, protocol) + if err != nil && err != io.EOF { log.Println(err) return } } -// turbotunnelMode handles clients that sent turbotunnel.Token at the start of -// their stream. These clients expect to send and receive encapsulated packets, -// with a long-lived session identified by ClientID. -func (handler *httpHandler) turbotunnelMode(conn net.Conn, addr net.Addr) error { - // Read the ClientID prefix. Every packet encapsulated in this WebSocket - // connection pertains to the same ClientID. - var clientID turbotunnel.ClientID - _, err := io.ReadFull(conn, clientID[:]) +func (handler *httpHandler) turboTunnelUDPLikeMode(conn *websocketconn.Conn, addr net.Addr, protocol string) error { + // Read the ClientID from the WebRTC data channel protocol string. Every + // packet received on this WebSocket connection pertains to the same + // ClientID. + clientID := turbotunnel.ClientID{} + metaData, err := messages.DecodeConnectionMetadata(protocol) if err != nil { - return fmt.Errorf("reading ClientID: %w", err) + return err } + copy(clientID[:], metaData.ClientID[:]) // Store a short-term mapping from the ClientID to the client IP // address attached to this WebSocket connection. tor will want us to @@ -167,8 +143,10 @@ func (handler *httpHandler) turbotunnelMode(conn net.Conn, addr net.Addr) error wg.Add(2) done := make(chan struct{}) - // The remainder of the WebSocket stream consists of encapsulated - // packets. We read them one by one and feed them into the + connPaddable := packetpadding.NewPaddableConnection(conn, packetpadding.New()) + + // The remainder of the WebSocket stream consists of packets, one packet + // per WebSocket message. We read them one by one and feed them into the // QueuePacketConn on which kcp.ServeConn was set up, which eventually // leads to KCP-level sessions in the acceptSessions function. go func() { @@ -176,11 +154,9 @@ func (handler *httpHandler) turbotunnelMode(conn net.Conn, addr net.Addr) error defer close(done) // Signal the write loop to finish var p [2048]byte for { - n, err := encapsulation.ReadData(conn, p[:]) - if err == io.ErrShortBuffer { - err = nil - } + n, err := connPaddable.Read(p[:]) if err != nil { + log.Println(err) return } pconn.QueueIncoming(p[:n], clientID) @@ -192,10 +168,6 @@ func (handler *httpHandler) turbotunnelMode(conn net.Conn, addr net.Addr) error go func() { defer wg.Done() defer conn.Close() // Signal the read loop to finish - - // Buffer encapsulation.WriteData operations to keep length - // prefixes in the same send as the data that follows. - bw := bufio.NewWriter(conn) for { select { case <-done: @@ -204,12 +176,10 @@ func (handler *httpHandler) turbotunnelMode(conn net.Conn, addr net.Addr) error if !ok { return } - _, err := encapsulation.WriteData(bw, p) + _, err := connPaddable.Write(p) pconn.Restore(p) - if err == nil { - err = bw.Flush() - } if err != nil { + log.Println(err) return } } diff --git a/server/lib/snowflake.go b/server/lib/snowflake.go index bcf9dd6..b158f03 100644 --- a/server/lib/snowflake.go +++ b/server/lib/snowflake.go @@ -41,6 +41,7 @@ import ( "log" "net" "net/http" + "os" "sync" "time" @@ -253,7 +254,7 @@ func (l *SnowflakeListener) acceptSessions(ln *kcp.Listener) error { return err } // Permit coalescing the payloads of consecutive sends. - conn.SetStreamMode(true) + conn.SetStreamMode(false) // Set the maximum send and receive window sizes to a high number // Removes KCP bottlenecks: https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/-/issues/40026 conn.SetWindowSize(WindowSize, WindowSize) @@ -265,6 +266,14 @@ func (l *SnowflakeListener) acceptSessions(ln *kcp.Listener) error { 0, // default resend 1, // nc=1 => congestion window off ) + if os.Getenv("SNOWFLAKE_TEST_KCP_FAST3MODE") == "1" { + conn.SetNoDelay( + 1, + 10, + 2, + 1, + ) + } go func() { defer conn.Close() err := l.acceptStreams(conn)