diff --git a/client/lib/packetIDConnClient.go b/client/lib/packetIDConnClient.go new file mode 100644 index 0000000..15b6c51 --- /dev/null +++ b/client/lib/packetIDConnClient.go @@ -0,0 +1,109 @@ +package snowflake_client + +import ( + "io" + "log" + "net" + "time" + + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/turbotunnel" +) + +const ( + packetClientIDConn_StateNew = iota + packetClientIDConn_StateConnectionIDAcknowledged +) + +type ClientID = turbotunnel.ClientID + +func newPacketClientIDConn(ClientID ClientID, transport io.ReadWriter) *packetClientIDConn { + return &packetClientIDConn{ + state: packetClientIDConn_StateNew, + ConnID: ClientID, + transport: transport, + } +} + +type packetClientIDConn struct { + state int + ConnID ClientID + transport io.ReadWriter +} + +func (c *packetClientIDConn) Write(p []byte) (int, error) { + switch c.state { + case packetClientIDConn_StateConnectionIDAcknowledged: + packet := make([]byte, len(p)+1) + packet[0] = 0xff + copy(packet[1:], p) + _, err := c.transport.Write(packet) + if err != nil { + return 0, err + } + return len(p), nil + case packetClientIDConn_StateNew: + packet := make([]byte, len(p)+1+len(c.ConnID)) + packet[0] = 0xfe + copy(packet[1:], c.ConnID[:]) + copy(packet[1+len(c.ConnID):], p) + _, err := c.transport.Write(packet) + if err != nil { + return 0, err + } + return len(p), nil + default: + panic("invalid state") + } +} + +func (c *packetClientIDConn) Read(p []byte) (int, error) { + n, err := c.transport.Read(p) + if err != nil { + return 0, err + } + if p[0] == 0xff { + c.state = packetClientIDConn_StateConnectionIDAcknowledged + return copy(p, p[1:n]), nil + } else { + log.Println("discarded unknown packet") + } + return 0, nil +} + +type packetConnWrapper struct { + io.ReadWriter + remoteAddr net.Addr + localAddr net.Addr +} + +func (pcw *packetConnWrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + n, err = pcw.Read(p) + if err != nil { + return 0, nil, err + } + return n, pcw.remoteAddr, nil +} + +func (pcw *packetConnWrapper) WriteTo(p []byte, addr net.Addr) (n int, err error) { + return pcw.Write(p) +} + +func (pcw *packetConnWrapper) Close() error { + return nil +} + +func (pcw *packetConnWrapper) LocalAddr() net.Addr { + return pcw.localAddr +} + +func (pcw *packetConnWrapper) SetDeadline(t time.Time) error { + return nil +} + +func (pcw *packetConnWrapper) SetReadDeadline(t time.Time) error { + return nil +} + +func (pcw *packetConnWrapper) SetWriteDeadline(t time.Time) error { + return nil +} diff --git a/client/lib/snowflake.go b/client/lib/snowflake.go index f1a3bad..7442ce9 100644 --- a/client/lib/snowflake.go +++ b/client/lib/snowflake.go @@ -339,6 +339,16 @@ func newSession(snowflakes SnowflakeCollector) (net.PacketConn, *smux.Session, e return nil, errors.New("handler: Received invalid Snowflake") } log.Println("---- Handler: snowflake assigned ----") + log.Printf("activeTransportMode = %c \n", conn.activeTransportMode) + if conn.activeTransportMode == 'u' { + packetIDConn := newPacketClientIDConn(clientID, conn) + packetConnWrapper := &packetConnWrapper{ + ReadWriter: packetIDConn, + remoteAddr: dummyAddr{}, + localAddr: dummyAddr{}, + } + return packetConnWrapper, nil + } // Send the magic Turbo Tunnel token. _, err := conn.Write(turbotunnel.Token[:]) if err != nil { @@ -363,7 +373,7 @@ func newSession(snowflakes SnowflakeCollector) (net.PacketConn, *smux.Session, e return nil, nil, err } // Permit coalescing the payloads of consecutive sends. - conn.SetStreamMode(true) + conn.SetStreamMode(false) // Set the maximum send and receive window sizes to a high number // Removes KCP bottlenecks: https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/-/issues/40026 conn.SetWindowSize(WindowSize, WindowSize) diff --git a/client/lib/webrtc.go b/client/lib/webrtc.go index 9d803a2..85b18c9 100644 --- a/client/lib/webrtc.go +++ b/client/lib/webrtc.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "encoding/hex" "errors" + "fmt" "io" "log" "net" @@ -43,6 +44,8 @@ type WebRTCPeer struct { bytesLogger bytesLogger eventsLogger event.SnowflakeEventReceiver proxy *url.URL + + activeTransportMode byte } // Deprecated: Use NewWebRTCPeerWithNatPolicyAndEventsAndProxy Instead. @@ -191,6 +194,7 @@ func (c *WebRTCPeer) connect( ) error { log.Println(c.id, " connecting...") + c.activeTransportMode = 'u' err := c.preparePeerConnection(config, broker.keepLocalAddresses) localDescription := c.pc.LocalDescription() c.eventsLogger.OnNewSnowflakeEvent(event.EventOnOfferCreated{ @@ -297,8 +301,17 @@ func (c *WebRTCPeer) preparePeerConnection( return err } ordered := true + var maxRetransmission *uint16 + if c.activeTransportMode == 'u' { + ordered = false + maxRetransmissionVal := uint16(0) + maxRetransmission = &maxRetransmissionVal + } + protocol := fmt.Sprintf("%c", c.activeTransportMode) dataChannelOptions := &webrtc.DataChannelInit{ - Ordered: &ordered, + Ordered: &ordered, + Protocol: &protocol, + MaxRetransmits: maxRetransmission, } // We must create the data channel before creating an offer // https://github.com/pion/webrtc/wiki/Release-WebRTC@v3.0.0#a-data-channel-is-no-longer-implicitly-created-with-a-peerconnection diff --git a/proxy/lib/snowflake.go b/proxy/lib/snowflake.go index 7bd1aaf..c6c667b 100644 --- a/proxy/lib/snowflake.go +++ b/proxy/lib/snowflake.go @@ -343,7 +343,7 @@ func (sf *SnowflakeProxy) datachannelHandler(conn *webRTCConn, remoteAddr net.Ad relayURL = sf.RelayURL } - wsConn, err := connectToRelay(relayURL, remoteAddr) + wsConn, err := connectToRelay(relayURL, remoteAddr, conn.GetConnectionProtocol()) if err != nil { log.Print(err) return @@ -354,7 +354,11 @@ func (sf *SnowflakeProxy) datachannelHandler(conn *webRTCConn, remoteAddr net.Ad log.Printf("datachannelHandler ends") } -func connectToRelay(relayURL string, remoteAddr net.Addr) (*websocketconn.Conn, error) { +func connectToRelay( + relayURL string, + remoteAddr net.Addr, + webrtcConnProtocol string, +) (*websocketconn.Conn, error) { u, err := url.Parse(relayURL) if err != nil { return nil, fmt.Errorf("invalid relay url: %s", err) @@ -370,6 +374,12 @@ func connectToRelay(relayURL string, remoteAddr net.Addr) (*websocketconn.Conn, log.Printf("no remote address given in websocket") } + if webrtcConnProtocol != "" { + q := u.Query() + q.Set("protocol", webrtcConnProtocol) + u.RawQuery = q.Encode() + } + ws, _, err := websocket.DefaultDialer.Dial(u.String(), nil) if err != nil { return nil, fmt.Errorf("error dialing relay: %s = %s", u.String(), err) @@ -451,6 +461,7 @@ func (sf *SnowflakeProxy) makePeerConnectionFromOffer( pr, pw := io.Pipe() conn := newWebRTCConn(pc, dc, pr, sf.bytesLogger) + conn.SetConnectionProtocol(dc.Protocol()) dc.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold) diff --git a/proxy/lib/webrtcconn.go b/proxy/lib/webrtcconn.go index b09ff3d..f2381bf 100644 --- a/proxy/lib/webrtcconn.go +++ b/proxy/lib/webrtcconn.go @@ -41,6 +41,8 @@ type webRTCConn struct { cancelTimeoutLoop context.CancelFunc bytesLogger bytesLogger + + protocol string } func newWebRTCConn(pc *webrtc.PeerConnection, dc *webrtc.DataChannel, pr *io.PipeReader, bytesLogger bytesLogger) *webRTCConn { @@ -137,6 +139,14 @@ func (c *webRTCConn) SetWriteDeadline(t time.Time) error { return fmt.Errorf("SetWriteDeadline not implemented") } +func (c *webRTCConn) SetConnectionProtocol(protocol string) { + c.protocol = protocol +} + +func (c *webRTCConn) GetConnectionProtocol() string { + return c.protocol +} + func remoteIPFromSDP(str string) net.IP { // Look for remote IP in "a=candidate" attribute fields // https://tools.ietf.org/html/rfc5245#section-15.1 diff --git a/server/lib/http.go b/server/lib/http.go index 403aeb1..a667f7b 100644 --- a/server/lib/http.go +++ b/server/lib/http.go @@ -108,6 +108,16 @@ func (handler *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Pass the address of client as the remote address of incoming connection clientIPParam := r.URL.Query().Get("client_ip") addr := clientAddr(clientIPParam) + clientTransport := r.URL.Query().Get("protocol") + + if clientTransport == "u" { + err = handler.turboTunnelUDPLikeMode(conn, addr) + if err != nil && err != io.EOF { + log.Println(err) + return + } + return + } var token [len(turbotunnel.Token)]byte _, err = io.ReadFull(conn, token[:]) @@ -221,6 +231,61 @@ func (handler *httpHandler) turbotunnelMode(conn net.Conn, addr net.Addr) error return nil } +func (handler *httpHandler) turboTunnelUDPLikeMode(conn net.Conn, addr net.Addr) error { + packetConnIDCon := packetConnIDConnServer{Conn: conn} + var packet [1600]byte + n, err := packetConnIDCon.Read(packet[:]) + if err != nil { + return fmt.Errorf("reading ClientID: %v", err) + } + clientID, err := packetConnIDCon.GetClientID() + if err != nil { + return fmt.Errorf("reading ClientID: %v", err) + } + clientIDAddrMap.Set(clientID, addr) + + pconn := handler.lookupPacketConn(clientID) + pconn.QueueIncoming(packet[:n], clientID) + var wg sync.WaitGroup + wg.Add(2) + done := make(chan struct{}) + go func() { + defer wg.Done() + defer close(done) // Signal the write loop to finish + for { + n, err := packetConnIDCon.Read(packet[:]) + if err != nil { + log.Println(err) + return + } + pconn.QueueIncoming(packet[:n], clientID) + } + }() + go func() { + defer wg.Done() + defer conn.Close() // Signal the read loop to finish + for { + select { + case <-done: + return + case p, ok := <-pconn.OutgoingQueue(clientID): + if !ok { + return + } + _, err := packetConnIDCon.Write(p) + pconn.Restore(p) + if err != nil { + log.Println(err) + return + } + } + } + }() + + wg.Wait() + return nil +} + // ClientMapAddr is a string that represents a connecting client. type ClientMapAddr string diff --git a/server/lib/packetIDConnServer.go b/server/lib/packetIDConnServer.go new file mode 100644 index 0000000..feca1fa --- /dev/null +++ b/server/lib/packetIDConnServer.go @@ -0,0 +1,52 @@ +package snowflake_server + +import ( + "errors" + "net" + + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/turbotunnel" +) + +type ConnID = turbotunnel.ClientID + +type packetConnIDConnServer struct { + // This net.Conn must preserve message boundaries. + net.Conn + connID ConnID + clientIDReceived bool +} + +var ErrClientIDNotReceived = errors.New("ClientID not received") + +func (p *packetConnIDConnServer) GetClientID() (ConnID, error) { + if !p.clientIDReceived { + return p.connID, ErrClientIDNotReceived + } + return p.connID, nil +} + +func (p *packetConnIDConnServer) Read(buf []byte) (n int, err error) { + n, err = p.Conn.Read(buf) + if err != nil { + return + } + switch buf[0] { + case 0xfe: + p.clientIDReceived = true + copy(p.connID[:], buf[1:9]) + copy(buf[0:], buf[9:]) + return n - 9, nil + case 0xff: + copy(buf[0:], buf[1:]) + return n - 1, nil + } + return 0, nil +} + +func (p *packetConnIDConnServer) Write(buf []byte) (n int, err error) { + n, err = p.Conn.Write(append([]byte{0xff}, buf...)) + if err != nil { + return 0, err + } + return len(buf) - 1, nil +} diff --git a/server/lib/snowflake.go b/server/lib/snowflake.go index bcf9dd6..d7d0c40 100644 --- a/server/lib/snowflake.go +++ b/server/lib/snowflake.go @@ -253,7 +253,7 @@ func (l *SnowflakeListener) acceptSessions(ln *kcp.Listener) error { return err } // Permit coalescing the payloads of consecutive sends. - conn.SetStreamMode(true) + conn.SetStreamMode(false) // Set the maximum send and receive window sizes to a high number // Removes KCP bottlenecks: https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/-/issues/40026 conn.SetWindowSize(WindowSize, WindowSize)