From 6ec0025e93c45c59862f7644c48fa9d719ff8c57 Mon Sep 17 00:00:00 2001 From: Shelikhoo Date: Wed, 19 Feb 2025 15:47:36 +0000 Subject: [PATCH] Add broker and server side rejection based on proxy version --- broker/broker.go | 18 ++++++++ broker/ipc.go | 45 +++++++++++++++++-- common/messages/proxy.go | 17 +++++--- server/lib/http.go | 94 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 166 insertions(+), 8 deletions(-) diff --git a/broker/broker.go b/broker/broker.go index 482f0d2..1b407e9 100644 --- a/broker/broker.go +++ b/broker/broker.go @@ -47,6 +47,7 @@ type BrokerContext struct { bridgeList BridgeListHolderFileBased allowedRelayPattern string presumedPatternForLegacyClient string + minProxyVersion string } func (ctx *BrokerContext) GetBridgeInfo(fingerprint bridgefingerprint.Fingerprint) (BridgeInfo, error) { @@ -57,6 +58,20 @@ func NewBrokerContext( metricsLogger *log.Logger, allowedRelayPattern, presumedPatternForLegacyClient string, +) *BrokerContext { + return NewBrokerContextWithMinProxyVersion( + metricsLogger, + allowedRelayPattern, + presumedPatternForLegacyClient, + "1.3", + ) +} + +func NewBrokerContextWithMinProxyVersion( + metricsLogger *log.Logger, + allowedRelayPattern, + presumedPatternForLegacyClient string, + minProxyVersion string, ) *BrokerContext { snowflakes := new(SnowflakeHeap) heap.Init(snowflakes) @@ -87,6 +102,7 @@ func NewBrokerContext( bridgeList: bridgeListHolder, allowedRelayPattern: allowedRelayPattern, presumedPatternForLegacyClient: presumedPatternForLegacyClient, + minProxyVersion: minProxyVersion, } } @@ -204,6 +220,7 @@ func main() { var disableGeoip bool var metricsFilename string var unsafeLogging bool + var minProxyVersion string flag.StringVar(&acmeEmail, "acme-email", "", "optional contact email for Let's Encrypt notifications") flag.StringVar(&acmeHostnamesCommas, "acme-hostnames", "", "comma-separated hostnames for TLS certificate") @@ -222,6 +239,7 @@ func main() { flag.BoolVar(&disableGeoip, "disable-geoip", false, "don't use geoip for stats collection") flag.StringVar(&metricsFilename, "metrics-log", "", "path to metrics logging output") flag.BoolVar(&unsafeLogging, "unsafe-logging", false, "prevent logs from being scrubbed") + flag.StringVar(&minProxyVersion, "min-proxy-version", "1.3", "the minimum version of the Snowflake proxy that the broker will accept") flag.Parse() var err error diff --git a/broker/ipc.go b/broker/ipc.go index bde72b0..aed5a02 100644 --- a/broker/ipc.go +++ b/broker/ipc.go @@ -5,11 +5,13 @@ import ( "encoding/hex" "fmt" "log" + "strconv" + "strings" "time" - "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/bridgefingerprint" - "github.com/prometheus/client_golang/prometheus" + + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/bridgefingerprint" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages" ) @@ -66,12 +68,49 @@ func (i *IPC) Debug(_ interface{}, response *string) error { return nil } +func versionCompare(versionL, versionR string) int { + versionSpliter := func(version string) (int64, int64) { + s := strings.Split(version, ".") + if len(s) != 2 { + return -1, -1 + } + major, err := strconv.ParseInt(s[0], 10, 64) + if err != nil { + return -1, -1 + } + minor, err := strconv.ParseInt(s[1], 10, 64) + if err != nil { + return -1, -1 + } + return major, minor + } + versionLMajor, versionLMinor := versionSpliter(versionL) + versionRMajor, versionRMinor := versionSpliter(versionR) + if versionLMajor > versionRMajor { + return -1 + } else if versionLMajor == versionRMajor { + if versionLMinor > versionRMinor { + return -1 + } else if versionLMinor == versionRMinor { + return 0 + } else { + return 1 + } + } else { + return 1 + } +} + func (i *IPC) ProxyPolls(arg messages.Arg, response *[]byte) error { - sid, proxyType, natType, clients, relayPattern, relayPatternSupported, err := messages.DecodeProxyPollRequestWithRelayPrefix(arg.Body) + sid, proxyType, natType, clients, relayPattern, relayPatternSupported, version, err := messages.DecodeProxyPollRequestWithRelayPrefixAndReturnVersion(arg.Body) if err != nil { return messages.ErrBadRequest } + if versionCompare(i.ctx.minProxyVersion, version) < 0 { + return messages.ErrBadRequest + } + if !relayPatternSupported { i.ctx.metrics.lock.Lock() i.ctx.metrics.proxyPollWithoutRelayURLExtension++ diff --git a/common/messages/proxy.go b/common/messages/proxy.go index 6fe02be..9075c0f 100644 --- a/common/messages/proxy.go +++ b/common/messages/proxy.go @@ -13,7 +13,7 @@ import ( ) const ( - version = "1.3" + version = "1.4" ProxyUnknown = "unknown" ) @@ -124,11 +124,18 @@ func DecodeProxyPollRequest(data []byte) (sid string, proxyType string, natType return } -// Decodes a poll message from a snowflake proxy and returns the -// sid, proxy type, nat type and clients of the proxy on success -// and an error if it failed func DecodeProxyPollRequestWithRelayPrefix(data []byte) ( sid string, proxyType string, natType string, clients int, relayPrefix string, relayPrefixAware bool, err error) { + sid, proxyType, natType, clients, relayPrefix, relayPrefixAware, _, err = DecodeProxyPollRequestWithRelayPrefixAndReturnVersion(data) + return +} + +// Decodes a poll message from a snowflake proxy and returns the +// sid, proxy type, nat type and clients, version of the proxy on success +// and an error if it failed +func DecodeProxyPollRequestWithRelayPrefixAndReturnVersion(data []byte) ( + sid string, proxyType string, natType string, clients int, + relayPrefix string, relayPrefixAware bool, version string, err error) { var message ProxyPollRequest err = json.Unmarshal(data, &message) @@ -169,7 +176,7 @@ func DecodeProxyPollRequestWithRelayPrefix(data []byte) ( acceptedRelayPattern = *message.AcceptedRelayPattern } return message.Sid, message.Type, message.NAT, message.Clients, - acceptedRelayPattern, message.AcceptedRelayPattern != nil, nil + acceptedRelayPattern, message.AcceptedRelayPattern != nil, message.Version, nil } type ProxyPollResponse struct { diff --git a/server/lib/http.go b/server/lib/http.go index 85593e2..e6cfb77 100644 --- a/server/lib/http.go +++ b/server/lib/http.go @@ -1,10 +1,13 @@ package snowflake_server import ( + "bufio" "crypto/hmac" "crypto/rand" "crypto/sha256" "encoding/binary" + "fmt" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/encapsulation" "io" "log" "net" @@ -108,6 +111,14 @@ func (handler *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { addr := clientAddr(clientIPParam) protocol := r.URL.Query().Get("protocol") + if protocol == "" { + err = handler.turbotunnelMode(conn, addr) + if err != nil && err != io.EOF { + log.Println(err) + return + } + } + err = handler.turboTunnelUDPLikeMode(conn, addr, protocol) if err != nil && err != io.EOF { log.Println(err) @@ -115,6 +126,89 @@ func (handler *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } +// turbotunnelMode handles clients that sent turbotunnel.Token at the start of +// their stream. These clients expect to send and receive encapsulated packets, +// with a long-lived session identified by ClientID. +func (handler *httpHandler) turbotunnelMode(conn net.Conn, addr net.Addr) error { + // Read the ClientID prefix. Every packet encapsulated in this WebSocket + // connection pertains to the same ClientID. + var clientID turbotunnel.ClientID + _, err := io.ReadFull(conn, clientID[:]) + if err != nil { + return fmt.Errorf("reading ClientID: %w", err) + } + + // Store a short-term mapping from the ClientID to the client IP + // address attached to this WebSocket connection. tor will want us to + // provide a client IP address when we call pt.DialOr. But a KCP session + // does not necessarily correspond to any single IP address--it's + // composed of packets that are carried in possibly multiple WebSocket + // streams. We apply the heuristic that the IP address of the most + // recent WebSocket connection that has had to do with a session, at the + // time the session is established, is the IP address that should be + // credited for the entire KCP session. + clientIDAddrMap.Set(clientID, addr) + + pconn := handler.lookupPacketConn(clientID) + + var wg sync.WaitGroup + wg.Add(2) + done := make(chan struct{}) + + // The remainder of the WebSocket stream consists of encapsulated + // packets. We read them one by one and feed them into the + // QueuePacketConn on which kcp.ServeConn was set up, which eventually + // leads to KCP-level sessions in the acceptSessions function. + go func() { + defer wg.Done() + defer close(done) // Signal the write loop to finish + var p [2048]byte + for { + n, err := encapsulation.ReadData(conn, p[:]) + if err == io.ErrShortBuffer { + err = nil + } + if err != nil { + return + } + pconn.QueueIncoming(p[:n], clientID) + } + }() + + // At the same time, grab packets addressed to this ClientID and + // encapsulate them into the downstream. + go func() { + defer wg.Done() + defer conn.Close() // Signal the read loop to finish + + // Buffer encapsulation.WriteData operations to keep length + // prefixes in the same send as the data that follows. + bw := bufio.NewWriter(conn) + for { + select { + case <-done: + return + case p, ok := <-pconn.OutgoingQueue(clientID): + if !ok { + return + } + _, err := encapsulation.WriteData(bw, p) + pconn.Restore(p) + if err == nil { + err = bw.Flush() + } + if err != nil { + return + } + } + } + }() + + wg.Wait() + + return nil +} + func (handler *httpHandler) turboTunnelUDPLikeMode(conn *websocketconn.Conn, addr net.Addr, protocol string) error { // Read the ClientID from the WebRTC data channel protocol string. Every // packet received on this WebSocket connection pertains to the same