Use multiple parallel KCP state machines in the server.

To distribute CPU load.

https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/-/issues/40200
This commit is contained in:
David Fifield 2022-10-01 11:43:29 -06:00
parent 53e381e45d
commit c6fabb212d
2 changed files with 73 additions and 27 deletions

View file

@ -3,6 +3,10 @@ package snowflake_server
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/binary"
"fmt" "fmt"
"io" "io"
"log" "log"
@ -49,9 +53,45 @@ var upgrader = websocket.Upgrader{
var clientIDAddrMap = newClientIDMap(clientIDAddrMapCapacity) var clientIDAddrMap = newClientIDMap(clientIDAddrMapCapacity)
type httpHandler struct { type httpHandler struct {
// pconn is the adapter layer between stream-oriented WebSocket // pconns is the adapter layer between stream-oriented WebSocket
// connections and the packet-oriented KCP layer. // connections and the packet-oriented KCP layer. There are multiple of
pconn *turbotunnel.QueuePacketConn // these, corresponding to the multiple kcp.ServeConn in
// Transport.Listen. Clients are assigned to a particular instance by a
// hash of ClientID, indexed by a hash of the ClientID, in order to
// distribute KCP processing load across CPU cores.
pconns []*turbotunnel.QueuePacketConn
// clientIDLookupKey is a secret key used to tweak the hash-based
// assignement of ClientID to pconn, in order to avoid manipulation of
// hash assignments.
clientIDLookupKey []byte
}
// newHTTPHandler creates a new http.Handler that exchanges encapsulated packets
// over incoming WebSocket connections.
func newHTTPHandler(localAddr net.Addr, numInstances int) *httpHandler {
pconns := make([]*turbotunnel.QueuePacketConn, 0, numInstances)
for i := 0; i < numInstances; i++ {
pconns = append(pconns, turbotunnel.NewQueuePacketConn(localAddr, clientMapTimeout))
}
clientIDLookupKey := make([]byte, 16)
_, err := rand.Read(clientIDLookupKey)
if err != nil {
panic(err)
}
return &httpHandler{
pconns: pconns,
clientIDLookupKey: clientIDLookupKey,
}
}
// lookupPacketConn returns the element of pconns that corresponds to client ID,
// according to the hash-based mapping.
func (handler *httpHandler) lookupPacketConn(clientID turbotunnel.ClientID) *turbotunnel.QueuePacketConn {
s := hmac.New(sha256.New, handler.clientIDLookupKey).Sum(clientID[:])
return handler.pconns[binary.LittleEndian.Uint64(s)%uint64(len(handler.pconns))]
} }
func (handler *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (handler *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
@ -82,7 +122,7 @@ func (handler *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch { switch {
case bytes.Equal(token[:], turbotunnel.Token[:]): case bytes.Equal(token[:], turbotunnel.Token[:]):
err = turbotunnelMode(conn, addr, handler.pconn) err = handler.turbotunnelMode(conn, addr)
default: default:
// We didn't find a matching token, which means that we are // We didn't find a matching token, which means that we are
// dealing with a client that doesn't know about such things. // dealing with a client that doesn't know about such things.
@ -100,7 +140,7 @@ func (handler *httpHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// turbotunnelMode handles clients that sent turbotunnel.Token at the start of // turbotunnelMode handles clients that sent turbotunnel.Token at the start of
// their stream. These clients expect to send and receive encapsulated packets, // their stream. These clients expect to send and receive encapsulated packets,
// with a long-lived session identified by ClientID. // with a long-lived session identified by ClientID.
func turbotunnelMode(conn net.Conn, addr net.Addr, pconn *turbotunnel.QueuePacketConn) error { func (handler *httpHandler) turbotunnelMode(conn net.Conn, addr net.Addr) error {
// Read the ClientID prefix. Every packet encapsulated in this WebSocket // Read the ClientID prefix. Every packet encapsulated in this WebSocket
// connection pertains to the same ClientID. // connection pertains to the same ClientID.
var clientID turbotunnel.ClientID var clientID turbotunnel.ClientID
@ -120,6 +160,8 @@ func turbotunnelMode(conn net.Conn, addr net.Addr, pconn *turbotunnel.QueuePacke
// credited for the entire KCP session. // credited for the entire KCP session.
clientIDAddrMap.Set(clientID, addr) clientIDAddrMap.Set(clientID, addr)
pconn := handler.lookupPacketConn(clientID)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(2) wg.Add(2)
done := make(chan struct{}) done := make(chan struct{})

View file

@ -55,6 +55,11 @@ const (
WindowSize = 65535 WindowSize = 65535
// StreamSize controls the maximum amount of in flight data between a client and server. // StreamSize controls the maximum amount of in flight data between a client and server.
StreamSize = 1048576 //1MB StreamSize = 1048576 //1MB
// numKCPInstances is the number of parallel KCP state machines to run.
// Clients are assigned to a particular KCP instance by a hash of their
// ClientID.
numKCPInstances = 2
) )
// Transport is a structure with methods that conform to the Go PT v2.1 API // Transport is a structure with methods that conform to the Go PT v2.1 API
@ -76,17 +81,13 @@ func (t *Transport) Listen(addr net.Addr) (*SnowflakeListener, error) {
addr: addr, addr: addr,
queue: make(chan net.Conn, 65534), queue: make(chan net.Conn, 65534),
closed: make(chan struct{}), closed: make(chan struct{}),
ln: make([]*kcp.Listener, 0, numKCPInstances),
} }
handler := httpHandler{ handler := newHTTPHandler(addr, numKCPInstances)
// pconn is shared among all connections to this server. It
// overlays packet-based client sessions on top of ephemeral
// WebSocket connections.
pconn: turbotunnel.NewQueuePacketConn(addr, clientMapTimeout),
}
server := &http.Server{ server := &http.Server{
Addr: addr.String(), Addr: addr.String(),
Handler: &handler, Handler: handler,
ReadTimeout: requestTimeout, ReadTimeout: requestTimeout,
} }
// We need to override server.TLSConfig.GetCertificate--but first // We need to override server.TLSConfig.GetCertificate--but first
@ -139,12 +140,13 @@ func (t *Transport) Listen(addr net.Addr) (*SnowflakeListener, error) {
listener.server = server listener.server = server
// Start a KCP engine, set up to read and write its packets over the // Start the KCP engines, set up to read and write its packets over the
// WebSocket connections that arrive at the web server. // WebSocket connections that arrive at the web server.
// handler.ServeHTTP is responsible for encapsulation/decapsulation of // handler.ServeHTTP is responsible for encapsulation/decapsulation of
// packets on behalf of KCP. KCP takes those packets and turns them into // packets on behalf of KCP. KCP takes those packets and turns them into
// sessions which appear in the acceptSessions function. // sessions which appear in the acceptSessions function.
ln, err := kcp.ServeConn(nil, 0, 0, handler.pconn) for i, pconn := range handler.pconns {
ln, err := kcp.ServeConn(nil, 0, 0, pconn)
if err != nil { if err != nil {
server.Close() server.Close()
return nil, err return nil, err
@ -153,11 +155,11 @@ func (t *Transport) Listen(addr net.Addr) (*SnowflakeListener, error) {
defer ln.Close() defer ln.Close()
err := listener.acceptSessions(ln) err := listener.acceptSessions(ln)
if err != nil { if err != nil {
log.Printf("acceptSessions: %v", err) log.Printf("acceptSessions %d: %v", i, err)
} }
}() }()
listener.ln = append(listener.ln, ln)
listener.ln = ln }
return listener, nil return listener, nil
@ -167,7 +169,7 @@ type SnowflakeListener struct {
addr net.Addr addr net.Addr
queue chan net.Conn queue chan net.Conn
server *http.Server server *http.Server
ln *kcp.Listener ln []*kcp.Listener
closed chan struct{} closed chan struct{}
closeOnce sync.Once closeOnce sync.Once
} }
@ -196,7 +198,9 @@ func (l *SnowflakeListener) Close() error {
l.closeOnce.Do(func() { l.closeOnce.Do(func() {
close(l.closed) close(l.closed)
l.server.Close() l.server.Close()
l.ln.Close() for _, ln := range l.ln {
ln.Close()
}
}) })
return nil return nil
} }