Add UDP Like transport mode to snowflake

This commit is contained in:
Shelikhoo 2023-12-12 14:43:30 +00:00 committed by WofWca
parent fa122efb61
commit 457c4fbf15
8 changed files with 275 additions and 5 deletions

View file

@ -0,0 +1,109 @@
package snowflake_client
import (
"io"
"log"
"net"
"time"
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/turbotunnel"
)
const (
packetClientIDConn_StateNew = iota
packetClientIDConn_StateConnectionIDAcknowledged
)
type ClientID = turbotunnel.ClientID
func newPacketClientIDConn(ClientID ClientID, transport io.ReadWriter) *packetClientIDConn {
return &packetClientIDConn{
state: packetClientIDConn_StateNew,
ConnID: ClientID,
transport: transport,
}
}
type packetClientIDConn struct {
state int
ConnID ClientID
transport io.ReadWriter
}
func (c *packetClientIDConn) Write(p []byte) (int, error) {
switch c.state {
case packetClientIDConn_StateConnectionIDAcknowledged:
packet := make([]byte, len(p)+1)
packet[0] = 0xff
copy(packet[1:], p)
_, err := c.transport.Write(packet)
if err != nil {
return 0, err
}
return len(p), nil
case packetClientIDConn_StateNew:
packet := make([]byte, len(p)+1+len(c.ConnID))
packet[0] = 0xfe
copy(packet[1:], c.ConnID[:])
copy(packet[1+len(c.ConnID):], p)
_, err := c.transport.Write(packet)
if err != nil {
return 0, err
}
return len(p), nil
default:
panic("invalid state")
}
}
func (c *packetClientIDConn) Read(p []byte) (int, error) {
n, err := c.transport.Read(p)
if err != nil {
return 0, err
}
if p[0] == 0xff {
c.state = packetClientIDConn_StateConnectionIDAcknowledged
return copy(p, p[1:n]), nil
} else {
log.Println("discarded unknown packet")
}
return 0, nil
}
type packetConnWrapper struct {
io.ReadWriter
remoteAddr net.Addr
localAddr net.Addr
}
func (pcw *packetConnWrapper) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
n, err = pcw.Read(p)
if err != nil {
return 0, nil, err
}
return n, pcw.remoteAddr, nil
}
func (pcw *packetConnWrapper) WriteTo(p []byte, addr net.Addr) (n int, err error) {
return pcw.Write(p)
}
func (pcw *packetConnWrapper) Close() error {
return nil
}
func (pcw *packetConnWrapper) LocalAddr() net.Addr {
return pcw.localAddr
}
func (pcw *packetConnWrapper) SetDeadline(t time.Time) error {
return nil
}
func (pcw *packetConnWrapper) SetReadDeadline(t time.Time) error {
return nil
}
func (pcw *packetConnWrapper) SetWriteDeadline(t time.Time) error {
return nil
}

View file

@ -339,6 +339,16 @@ func newSession(snowflakes SnowflakeCollector) (net.PacketConn, *smux.Session, e
return nil, errors.New("handler: Received invalid Snowflake") return nil, errors.New("handler: Received invalid Snowflake")
} }
log.Println("---- Handler: snowflake assigned ----") log.Println("---- Handler: snowflake assigned ----")
log.Printf("activeTransportMode = %c \n", conn.activeTransportMode)
if conn.activeTransportMode == 'u' {
packetIDConn := newPacketClientIDConn(clientID, conn)
packetConnWrapper := &packetConnWrapper{
ReadWriter: packetIDConn,
remoteAddr: dummyAddr{},
localAddr: dummyAddr{},
}
return packetConnWrapper, nil
}
// Send the magic Turbo Tunnel token. // Send the magic Turbo Tunnel token.
_, err := conn.Write(turbotunnel.Token[:]) _, err := conn.Write(turbotunnel.Token[:])
if err != nil { if err != nil {
@ -363,7 +373,7 @@ func newSession(snowflakes SnowflakeCollector) (net.PacketConn, *smux.Session, e
return nil, nil, err return nil, nil, err
} }
// Permit coalescing the payloads of consecutive sends. // Permit coalescing the payloads of consecutive sends.
conn.SetStreamMode(true) conn.SetStreamMode(false)
// Set the maximum send and receive window sizes to a high number // Set the maximum send and receive window sizes to a high number
// Removes KCP bottlenecks: https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/-/issues/40026 // Removes KCP bottlenecks: https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/-/issues/40026
conn.SetWindowSize(WindowSize, WindowSize) conn.SetWindowSize(WindowSize, WindowSize)

View file

@ -4,6 +4,7 @@ import (
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt"
"io" "io"
"log" "log"
"net" "net"
@ -43,6 +44,8 @@ type WebRTCPeer struct {
bytesLogger bytesLogger bytesLogger bytesLogger
eventsLogger event.SnowflakeEventReceiver eventsLogger event.SnowflakeEventReceiver
proxy *url.URL proxy *url.URL
activeTransportMode byte
} }
// Deprecated: Use NewWebRTCPeerWithNatPolicyAndEventsAndProxy Instead. // Deprecated: Use NewWebRTCPeerWithNatPolicyAndEventsAndProxy Instead.
@ -191,6 +194,7 @@ func (c *WebRTCPeer) connect(
) error { ) error {
log.Println(c.id, " connecting...") log.Println(c.id, " connecting...")
c.activeTransportMode = 'u'
err := c.preparePeerConnection(config, broker.keepLocalAddresses) err := c.preparePeerConnection(config, broker.keepLocalAddresses)
localDescription := c.pc.LocalDescription() localDescription := c.pc.LocalDescription()
c.eventsLogger.OnNewSnowflakeEvent(event.EventOnOfferCreated{ c.eventsLogger.OnNewSnowflakeEvent(event.EventOnOfferCreated{
@ -297,8 +301,17 @@ func (c *WebRTCPeer) preparePeerConnection(
return err return err
} }
ordered := true ordered := true
var maxRetransmission *uint16
if c.activeTransportMode == 'u' {
ordered = false
maxRetransmissionVal := uint16(0)
maxRetransmission = &maxRetransmissionVal
}
protocol := fmt.Sprintf("%c", c.activeTransportMode)
dataChannelOptions := &webrtc.DataChannelInit{ dataChannelOptions := &webrtc.DataChannelInit{
Ordered: &ordered, Ordered: &ordered,
Protocol: &protocol,
MaxRetransmits: maxRetransmission,
} }
// We must create the data channel before creating an offer // We must create the data channel before creating an offer
// https://github.com/pion/webrtc/wiki/Release-WebRTC@v3.0.0#a-data-channel-is-no-longer-implicitly-created-with-a-peerconnection // https://github.com/pion/webrtc/wiki/Release-WebRTC@v3.0.0#a-data-channel-is-no-longer-implicitly-created-with-a-peerconnection

View file

@ -343,7 +343,7 @@ func (sf *SnowflakeProxy) datachannelHandler(conn *webRTCConn, remoteAddr net.Ad
relayURL = sf.RelayURL relayURL = sf.RelayURL
} }
wsConn, err := connectToRelay(relayURL, remoteAddr) wsConn, err := connectToRelay(relayURL, remoteAddr, conn.GetConnectionProtocol())
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return return
@ -354,7 +354,11 @@ func (sf *SnowflakeProxy) datachannelHandler(conn *webRTCConn, remoteAddr net.Ad
log.Printf("datachannelHandler ends") log.Printf("datachannelHandler ends")
} }
func connectToRelay(relayURL string, remoteAddr net.Addr) (*websocketconn.Conn, error) { func connectToRelay(
relayURL string,
remoteAddr net.Addr,
webrtcConnProtocol string,
) (*websocketconn.Conn, error) {
u, err := url.Parse(relayURL) u, err := url.Parse(relayURL)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid relay url: %s", err) return nil, fmt.Errorf("invalid relay url: %s", err)
@ -370,6 +374,12 @@ func connectToRelay(relayURL string, remoteAddr net.Addr) (*websocketconn.Conn,
log.Printf("no remote address given in websocket") log.Printf("no remote address given in websocket")
} }
if webrtcConnProtocol != "" {
q := u.Query()
q.Set("protocol", webrtcConnProtocol)
u.RawQuery = q.Encode()
}
ws, _, err := websocket.DefaultDialer.Dial(u.String(), nil) ws, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("error dialing relay: %s = %s", u.String(), err) return nil, fmt.Errorf("error dialing relay: %s = %s", u.String(), err)
@ -451,6 +461,7 @@ func (sf *SnowflakeProxy) makePeerConnectionFromOffer(
pr, pw := io.Pipe() pr, pw := io.Pipe()
conn := newWebRTCConn(pc, dc, pr, sf.bytesLogger) conn := newWebRTCConn(pc, dc, pr, sf.bytesLogger)
conn.SetConnectionProtocol(dc.Protocol())
dc.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold) dc.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold)

View file

@ -41,6 +41,8 @@ type webRTCConn struct {
cancelTimeoutLoop context.CancelFunc cancelTimeoutLoop context.CancelFunc
bytesLogger bytesLogger bytesLogger bytesLogger
protocol string
} }
func newWebRTCConn(pc *webrtc.PeerConnection, dc *webrtc.DataChannel, pr *io.PipeReader, bytesLogger bytesLogger) *webRTCConn { func newWebRTCConn(pc *webrtc.PeerConnection, dc *webrtc.DataChannel, pr *io.PipeReader, bytesLogger bytesLogger) *webRTCConn {
@ -137,6 +139,14 @@ func (c *webRTCConn) SetWriteDeadline(t time.Time) error {
return fmt.Errorf("SetWriteDeadline not implemented") return fmt.Errorf("SetWriteDeadline not implemented")
} }
func (c *webRTCConn) SetConnectionProtocol(protocol string) {
c.protocol = protocol
}
func (c *webRTCConn) GetConnectionProtocol() string {
return c.protocol
}
func remoteIPFromSDP(str string) net.IP { func remoteIPFromSDP(str string) net.IP {
// Look for remote IP in "a=candidate" attribute fields // Look for remote IP in "a=candidate" attribute fields
// https://tools.ietf.org/html/rfc5245#section-15.1 // https://tools.ietf.org/html/rfc5245#section-15.1

View file

@ -108,6 +108,16 @@ func (handler *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Pass the address of client as the remote address of incoming connection // Pass the address of client as the remote address of incoming connection
clientIPParam := r.URL.Query().Get("client_ip") clientIPParam := r.URL.Query().Get("client_ip")
addr := clientAddr(clientIPParam) addr := clientAddr(clientIPParam)
clientTransport := r.URL.Query().Get("protocol")
if clientTransport == "u" {
err = handler.turboTunnelUDPLikeMode(conn, addr)
if err != nil && err != io.EOF {
log.Println(err)
return
}
return
}
var token [len(turbotunnel.Token)]byte var token [len(turbotunnel.Token)]byte
_, err = io.ReadFull(conn, token[:]) _, err = io.ReadFull(conn, token[:])
@ -221,6 +231,61 @@ func (handler *httpHandler) turbotunnelMode(conn net.Conn, addr net.Addr) error
return nil return nil
} }
func (handler *httpHandler) turboTunnelUDPLikeMode(conn net.Conn, addr net.Addr) error {
packetConnIDCon := packetConnIDConnServer{Conn: conn}
var packet [1600]byte
n, err := packetConnIDCon.Read(packet[:])
if err != nil {
return fmt.Errorf("reading ClientID: %v", err)
}
clientID, err := packetConnIDCon.GetClientID()
if err != nil {
return fmt.Errorf("reading ClientID: %v", err)
}
clientIDAddrMap.Set(clientID, addr)
pconn := handler.lookupPacketConn(clientID)
pconn.QueueIncoming(packet[:n], clientID)
var wg sync.WaitGroup
wg.Add(2)
done := make(chan struct{})
go func() {
defer wg.Done()
defer close(done) // Signal the write loop to finish
for {
n, err := packetConnIDCon.Read(packet[:])
if err != nil {
log.Println(err)
return
}
pconn.QueueIncoming(packet[:n], clientID)
}
}()
go func() {
defer wg.Done()
defer conn.Close() // Signal the read loop to finish
for {
select {
case <-done:
return
case p, ok := <-pconn.OutgoingQueue(clientID):
if !ok {
return
}
_, err := packetConnIDCon.Write(p)
pconn.Restore(p)
if err != nil {
log.Println(err)
return
}
}
}
}()
wg.Wait()
return nil
}
// ClientMapAddr is a string that represents a connecting client. // ClientMapAddr is a string that represents a connecting client.
type ClientMapAddr string type ClientMapAddr string

View file

@ -0,0 +1,52 @@
package snowflake_server
import (
"errors"
"net"
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/turbotunnel"
)
type ConnID = turbotunnel.ClientID
type packetConnIDConnServer struct {
// This net.Conn must preserve message boundaries.
net.Conn
connID ConnID
clientIDReceived bool
}
var ErrClientIDNotReceived = errors.New("ClientID not received")
func (p *packetConnIDConnServer) GetClientID() (ConnID, error) {
if !p.clientIDReceived {
return p.connID, ErrClientIDNotReceived
}
return p.connID, nil
}
func (p *packetConnIDConnServer) Read(buf []byte) (n int, err error) {
n, err = p.Conn.Read(buf)
if err != nil {
return
}
switch buf[0] {
case 0xfe:
p.clientIDReceived = true
copy(p.connID[:], buf[1:9])
copy(buf[0:], buf[9:])
return n - 9, nil
case 0xff:
copy(buf[0:], buf[1:])
return n - 1, nil
}
return 0, nil
}
func (p *packetConnIDConnServer) Write(buf []byte) (n int, err error) {
n, err = p.Conn.Write(append([]byte{0xff}, buf...))
if err != nil {
return 0, err
}
return len(buf) - 1, nil
}

View file

@ -253,7 +253,7 @@ func (l *SnowflakeListener) acceptSessions(ln *kcp.Listener) error {
return err return err
} }
// Permit coalescing the payloads of consecutive sends. // Permit coalescing the payloads of consecutive sends.
conn.SetStreamMode(true) conn.SetStreamMode(false)
// Set the maximum send and receive window sizes to a high number // Set the maximum send and receive window sizes to a high number
// Removes KCP bottlenecks: https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/-/issues/40026 // Removes KCP bottlenecks: https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/-/issues/40026
conn.SetWindowSize(WindowSize, WindowSize) conn.SetWindowSize(WindowSize, WindowSize)