Make the proxy to report the number of clients to the broker

So the assignment of proxies is based on the load. The number of clients
is ronded down to 8. Existing proxies that doesn't report the number
of clients will be distributed equaly to new proxies until they get 8
clients, that is okish as the existing proxies do have a maximum
capacity of 10.

Fixes #40048
This commit is contained in:
meskio 2021-06-25 13:47:47 +02:00
parent 74bdb85b30
commit 7a1857c42f
No known key found for this signature in database
GPG key ID: 52B8F5AC97A2DA86
9 changed files with 165 additions and 77 deletions

View file

@ -124,16 +124,18 @@ type ProxyPoll struct {
id string
proxyType string
natType string
clients int
offerChannel chan *ClientOffer
}
// Registers a Snowflake and waits for some Client to send an offer,
// as part of the polling logic of the proxy handler.
func (ctx *BrokerContext) RequestOffer(id string, proxyType string, natType string) *ClientOffer {
func (ctx *BrokerContext) RequestOffer(id string, proxyType string, natType string, clients int) *ClientOffer {
request := new(ProxyPoll)
request.id = id
request.proxyType = proxyType
request.natType = natType
request.clients = clients
request.offerChannel = make(chan *ClientOffer)
ctx.proxyPolls <- request
// Block until an offer is available, or timeout which sends a nil offer.
@ -146,7 +148,7 @@ func (ctx *BrokerContext) RequestOffer(id string, proxyType string, natType stri
// client offer or nil on timeout / none are available.
func (ctx *BrokerContext) Broker() {
for request := range ctx.proxyPolls {
snowflake := ctx.AddSnowflake(request.id, request.proxyType, request.natType)
snowflake := ctx.AddSnowflake(request.id, request.proxyType, request.natType, request.clients)
// Wait for a client to avail an offer to the snowflake.
go func(request *ProxyPoll) {
select {
@ -174,10 +176,10 @@ func (ctx *BrokerContext) Broker() {
// Create and add a Snowflake to the heap.
// Required to keep track of proxies between providing them
// with an offer and awaiting their second POST with an answer.
func (ctx *BrokerContext) AddSnowflake(id string, proxyType string, natType string) *Snowflake {
func (ctx *BrokerContext) AddSnowflake(id string, proxyType string, natType string, clients int) *Snowflake {
snowflake := new(Snowflake)
snowflake.id = id
snowflake.clients = 0
snowflake.clients = clients
snowflake.proxyType = proxyType
snowflake.natType = natType
snowflake.offerChannel = make(chan *ClientOffer)
@ -205,7 +207,7 @@ func proxyPolls(ctx *BrokerContext, w http.ResponseWriter, r *http.Request) {
return
}
sid, proxyType, natType, err := messages.DecodePollRequest(body)
sid, proxyType, natType, clients, err := messages.DecodePollRequest(body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
@ -222,7 +224,7 @@ func proxyPolls(ctx *BrokerContext, w http.ResponseWriter, r *http.Request) {
}
// Wait for a client to avail an offer to the snowflake, or timeout if nil.
offer := ctx.RequestOffer(sid, proxyType, natType)
offer := ctx.RequestOffer(sid, proxyType, natType, clients)
var b []byte
if nil == offer {
ctx.metrics.lock.Lock()

View file

@ -32,7 +32,7 @@ func TestBroker(t *testing.T) {
Convey("Adds Snowflake", func() {
So(ctx.snowflakes.Len(), ShouldEqual, 0)
So(len(ctx.idToSnowflake), ShouldEqual, 0)
ctx.AddSnowflake("foo", "", NATUnrestricted)
ctx.AddSnowflake("foo", "", NATUnrestricted, 0)
So(ctx.snowflakes.Len(), ShouldEqual, 1)
So(len(ctx.idToSnowflake), ShouldEqual, 1)
})
@ -59,7 +59,7 @@ func TestBroker(t *testing.T) {
Convey("Request an offer from the Snowflake Heap", func() {
done := make(chan *ClientOffer)
go func() {
offer := ctx.RequestOffer("test", "", NATUnrestricted)
offer := ctx.RequestOffer("test", "", NATUnrestricted, 0)
done <- offer
}()
request := <-ctx.proxyPolls
@ -84,7 +84,7 @@ func TestBroker(t *testing.T) {
Convey("with a proxy answer if available.", func() {
done := make(chan bool)
// Prepare a fake proxy to respond with.
snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted)
snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted, 0)
go func() {
clientOffers(ctx, w, r)
done <- true
@ -102,7 +102,7 @@ func TestBroker(t *testing.T) {
return
}
done := make(chan bool)
snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted)
snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted, 0)
go func() {
clientOffers(ctx, w, r)
// Takes a few seconds here...
@ -132,7 +132,7 @@ func TestBroker(t *testing.T) {
Convey("with a proxy answer if available.", func() {
done := make(chan bool)
// Prepare a fake proxy to respond with.
snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted)
snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted, 0)
go func() {
clientOffers(ctx, w, r)
done <- true
@ -150,7 +150,7 @@ func TestBroker(t *testing.T) {
return
}
done := make(chan bool)
snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted)
snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted, 0)
go func() {
clientOffers(ctx, w, r)
// Takes a few seconds here...
@ -201,7 +201,7 @@ func TestBroker(t *testing.T) {
})
Convey("Responds to proxy answers...", func() {
s := ctx.AddSnowflake("test", "", NATUnrestricted)
s := ctx.AddSnowflake("test", "", NATUnrestricted, 0)
w := httptest.NewRecorder()
data := bytes.NewReader([]byte(`{"Version":"1.0","Sid":"test","Answer":"test"}`))
@ -314,7 +314,7 @@ func TestBroker(t *testing.T) {
// Manually do the Broker goroutine action here for full control.
p := <-ctx.proxyPolls
So(p.id, ShouldEqual, "ymbcCMto7KHNGYlp")
s := ctx.AddSnowflake(p.id, "", NATUnrestricted)
s := ctx.AddSnowflake(p.id, "", NATUnrestricted, 0)
go func() {
offer := <-s.offerChannel
p.offerChannel <- offer
@ -593,7 +593,7 @@ func TestMetrics(t *testing.T) {
So(err, ShouldBeNil)
// Prepare a fake proxy to respond with.
snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted)
snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted, 0)
go func() {
clientOffers(ctx, w, r)
done <- true

View file

@ -15,6 +15,7 @@ func TestDecodeProxyPollRequest(t *testing.T) {
sid string
proxyType string
natType string
clients int
data string
err error
}{
@ -23,6 +24,7 @@ func TestDecodeProxyPollRequest(t *testing.T) {
"ymbcCMto7KHNGYlp",
"",
"unknown",
0,
`{"Sid":"ymbcCMto7KHNGYlp","Version":"1.0"}`,
nil,
},
@ -31,6 +33,7 @@ func TestDecodeProxyPollRequest(t *testing.T) {
"ymbcCMto7KHNGYlp",
"standalone",
"unknown",
0,
`{"Sid":"ymbcCMto7KHNGYlp","Version":"1.1","Type":"standalone"}`,
nil,
},
@ -39,14 +42,25 @@ func TestDecodeProxyPollRequest(t *testing.T) {
"ymbcCMto7KHNGYlp",
"standalone",
"restricted",
0,
`{"Sid":"ymbcCMto7KHNGYlp","Version":"1.2","Type":"standalone", "NAT":"restricted"}`,
nil,
},
{
//Version 1.2 proxy message with clients
"ymbcCMto7KHNGYlp",
"standalone",
"restricted",
24,
`{"Sid":"ymbcCMto7KHNGYlp","Version":"1.2","Type":"standalone", "NAT":"restricted","Clients":24}`,
nil,
},
{
//Version 0.X proxy message:
"",
"",
"",
0,
"",
&json.SyntaxError{},
},
@ -54,6 +68,7 @@ func TestDecodeProxyPollRequest(t *testing.T) {
"",
"",
"",
0,
`{"Sid":"ymbcCMto7KHNGYlp"}`,
fmt.Errorf(""),
},
@ -61,6 +76,7 @@ func TestDecodeProxyPollRequest(t *testing.T) {
"",
"",
"",
0,
"{}",
fmt.Errorf(""),
},
@ -68,6 +84,7 @@ func TestDecodeProxyPollRequest(t *testing.T) {
"",
"",
"",
0,
`{"Version":"1.0"}`,
fmt.Errorf(""),
},
@ -75,14 +92,16 @@ func TestDecodeProxyPollRequest(t *testing.T) {
"",
"",
"",
0,
`{"Version":"2.0"}`,
fmt.Errorf(""),
},
} {
sid, proxyType, natType, err := DecodePollRequest([]byte(test.data))
sid, proxyType, natType, clients, err := DecodePollRequest([]byte(test.data))
So(sid, ShouldResemble, test.sid)
So(proxyType, ShouldResemble, test.proxyType)
So(natType, ShouldResemble, test.natType)
So(clients, ShouldEqual, test.clients)
So(err, ShouldHaveSameTypeAs, test.err)
}
@ -91,12 +110,13 @@ func TestDecodeProxyPollRequest(t *testing.T) {
func TestEncodeProxyPollRequests(t *testing.T) {
Convey("Context", t, func() {
b, err := EncodePollRequest("ymbcCMto7KHNGYlp", "standalone", "unknown")
b, err := EncodePollRequest("ymbcCMto7KHNGYlp", "standalone", "unknown", 16)
So(err, ShouldEqual, nil)
sid, proxyType, natType, err := DecodePollRequest(b)
sid, proxyType, natType, clients, err := DecodePollRequest(b)
So(sid, ShouldEqual, "ymbcCMto7KHNGYlp")
So(proxyType, ShouldEqual, "standalone")
So(natType, ShouldEqual, "unknown")
So(clients, ShouldEqual, 16)
So(err, ShouldEqual, nil)
})
}

View file

@ -17,8 +17,9 @@ const version = "1.2"
{
Sid: [generated session id of proxy],
Version: 1.2,
Type: ["badge"|"webext"|"standalone"]
NAT: ["unknown"|"restricted"|"unrestricted"]
Type: ["badge"|"webext"|"standalone"],
NAT: ["unknown"|"restricted"|"unrestricted"],
Clients: [number of current clients, rounded down to multiples of 8]
}
== ProxyPollResponse ==
@ -79,43 +80,48 @@ type ProxyPollRequest struct {
Version string
Type string
NAT string
Clients int
}
func EncodePollRequest(sid string, proxyType string, natType string) ([]byte, error) {
func EncodePollRequest(sid string, proxyType string, natType string, clients int) ([]byte, error) {
return json.Marshal(ProxyPollRequest{
Sid: sid,
Version: version,
Type: proxyType,
NAT: natType,
Clients: clients,
})
}
// Decodes a poll message from a snowflake proxy and returns the
// sid and proxy type of the proxy on success and an error if it failed
func DecodePollRequest(data []byte) (string, string, string, error) {
// sid, proxy type, nat type and clients of the proxy on success
// and an error if it failed
func DecodePollRequest(data []byte) (sid string, proxyType string, natType string, clients int, err error) {
var message ProxyPollRequest
err := json.Unmarshal(data, &message)
err = json.Unmarshal(data, &message)
if err != nil {
return "", "", "", err
return
}
majorVersion := strings.Split(message.Version, ".")[0]
if majorVersion != "1" {
return "", "", "", fmt.Errorf("using unknown version")
err = fmt.Errorf("using unknown version")
return
}
// Version 1.x requires an Sid
if message.Sid == "" {
return "", "", "", fmt.Errorf("no supplied session id")
err = fmt.Errorf("no supplied session id")
return
}
natType := message.NAT
natType = message.NAT
if natType == "" {
natType = "unknown"
}
return message.Sid, message.Type, natType, nil
return message.Sid, message.Type, natType, message.Clients, nil
}
type ProxyPollResponse struct {

View file

@ -141,7 +141,9 @@ POST /proxy HTTP
{
Sid: [generated session id of proxy],
Version: 1.1,
Type: ["badge"|"webext"|"standalone"|"mobile"]
Type: ["badge"|"webext"|"standalone"|"mobile"],
NAT: ["unknown"|"restricted"|"unrestricted"],
Clients: [number of current clients, rounded down to multiples of 8]
}
```

View file

@ -7,7 +7,6 @@ import (
"io/ioutil"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"testing"
@ -337,8 +336,9 @@ func TestBrokerInteractions(t *testing.T) {
const sampleAnswer = `{"type":"answer","sdp":` + sampleSDP + `}`
Convey("Proxy connections to broker", t, func() {
broker := new(SignalingServer)
broker.url, _ = url.Parse("localhost")
broker, err := newSignalingServer("localhost", false)
So(err, ShouldEqual, nil)
tokens = newTokens(0)
//Mock peerConnection
config = webrtc.Configuration{
@ -469,17 +469,6 @@ func TestUtilityFuncs(t *testing.T) {
So(err, ShouldEqual, io.ErrClosedPipe)
})
})
Convey("Tokens", t, func() {
tokens = make(chan bool, 2)
for i := uint(0); i < 2; i++ {
tokens <- true
}
So(len(tokens), ShouldEqual, 2)
getToken()
So(len(tokens), ShouldEqual, 1)
retToken()
So(len(tokens), ShouldEqual, 2)
})
Convey("SessionID Generation", t, func() {
sid1 := genSessionID()
sid2 := genSessionID()

View file

@ -55,7 +55,7 @@ const (
)
var (
tokens chan bool
tokens *tokens_t
config webrtc.Configuration
client http.Client
)
@ -171,14 +171,6 @@ func (c *webRTCConn) SetWriteDeadline(t time.Time) error {
return fmt.Errorf("SetWriteDeadline not implemented")
}
func getToken() {
<-tokens
}
func retToken() {
tokens <- true
}
func genSessionID() string {
buf := make([]byte, sessionIDLength)
_, err := rand.Read(buf)
@ -204,6 +196,21 @@ type SignalingServer struct {
keepLocalAddresses bool
}
func newSignalingServer(rawURL string, keepLocalAddresses bool) (*SignalingServer, error) {
var err error
s := new(SignalingServer)
s.keepLocalAddresses = keepLocalAddresses
s.url, err = url.Parse(rawURL)
if err != nil {
return nil, fmt.Errorf("invalid broker url: %s", err)
}
s.transport = http.DefaultTransport.(*http.Transport)
s.transport.(*http.Transport).ResponseHeaderTimeout = 30 * time.Second
return s, nil
}
func (s *SignalingServer) Post(path string, payload io.Reader) ([]byte, error) {
req, err := http.NewRequest("POST", path, payload)
@ -238,7 +245,8 @@ func (s *SignalingServer) pollOffer(sid string) *webrtc.SessionDescription {
timeOfNextPoll = now
}
body, err := messages.EncodePollRequest(sid, "standalone", currentNATType)
numClients := int((tokens.count() / 8) * 8) // Round down to 8
body, err := messages.EncodePollRequest(sid, "standalone", currentNATType, numClients)
if err != nil {
log.Printf("Error encoding poll message: %s", err.Error())
return nil
@ -323,7 +331,7 @@ func CopyLoop(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser) {
// RemoteAddr). https://bugs.torproject.org/18628#comment:8
func datachannelHandler(conn *webRTCConn, remoteAddr net.Addr) {
defer conn.Close()
defer retToken()
defer tokens.ret()
u, err := url.Parse(relayURL)
if err != nil {
@ -494,14 +502,14 @@ func runSession(sid string) {
offer := broker.pollOffer(sid)
if offer == nil {
log.Printf("bad offer from broker")
retToken()
tokens.ret()
return
}
dataChan := make(chan struct{})
pc, err := makePeerConnectionFromOffer(offer, config, dataChan, datachannelHandler)
if err != nil {
log.Printf("error making WebRTC connection: %s", err)
retToken()
tokens.ret()
return
}
err = broker.sendAnswer(sid, pc)
@ -510,7 +518,7 @@ func runSession(sid string) {
if inerr := pc.Close(); inerr != nil {
log.Printf("error calling pc.Close: %v", inerr)
}
retToken()
tokens.ret()
return
}
// Set a timeout on peerconnection. If the connection state has not
@ -524,7 +532,7 @@ func runSession(sid string) {
if err := pc.Close(); err != nil {
log.Printf("error calling pc.Close: %v", err)
}
retToken()
tokens.ret()
}
}
@ -536,7 +544,7 @@ func main() {
var unsafeLogging bool
var keepLocalAddresses bool
flag.UintVar(&capacity, "capacity", 10, "maximum concurrent clients")
flag.UintVar(&capacity, "capacity", 0, "maximum concurrent clients")
flag.StringVar(&rawBrokerURL, "broker", defaultBrokerURL, "broker URL")
flag.StringVar(&relayURL, "relay", defaultRelayURL, "websocket relay URL")
flag.StringVar(&stunURL, "stun", defaultSTUNURL, "stun URL")
@ -565,12 +573,11 @@ func main() {
log.Println("starting")
var err error
broker = new(SignalingServer)
broker.keepLocalAddresses = keepLocalAddresses
broker.url, err = url.Parse(rawBrokerURL)
broker, err = newSignalingServer(rawBrokerURL, keepLocalAddresses)
if err != nil {
log.Fatalf("invalid broker url: %s", err)
log.Fatal(err)
}
_, err = url.Parse(stunURL)
if err != nil {
log.Fatalf("invalid stun url: %s", err)
@ -580,8 +587,6 @@ func main() {
log.Fatalf("invalid relay url: %s", err)
}
broker.transport = http.DefaultTransport.(*http.Transport)
broker.transport.(*http.Transport).ResponseHeaderTimeout = 15 * time.Second
config = webrtc.Configuration{
ICEServers: []webrtc.ICEServer{
{
@ -589,17 +594,14 @@ func main() {
},
},
}
tokens = make(chan bool, capacity)
for i := uint(0); i < capacity; i++ {
tokens <- true
}
tokens = newTokens(capacity)
// use probetest to determine NAT compatability
checkNATType(config, defaultProbeURL)
log.Printf("NAT type: %s", currentNATType)
for {
getToken()
tokens.get()
sessionID := genSessionID()
runSession(sessionID)
}
@ -607,12 +609,7 @@ func main() {
func checkNATType(config webrtc.Configuration, probeURL string) {
var err error
probe := new(SignalingServer)
probe.transport = http.DefaultTransport.(*http.Transport)
probe.transport.(*http.Transport).ResponseHeaderTimeout = 30 * time.Second
probe.url, err = url.Parse(probeURL)
probe, err := newSignalingServer(probeURL, false)
if err != nil {
log.Printf("Error parsing url: %s", err.Error())
}

44
proxy/tokens.go Normal file
View file

@ -0,0 +1,44 @@
package main
import (
"sync/atomic"
)
type tokens_t struct {
ch chan struct{}
capacity uint
clients int64
}
func newTokens(capacity uint) *tokens_t {
var ch chan struct{}
if capacity != 0 {
ch = make(chan struct{}, capacity)
}
return &tokens_t{
ch: ch,
capacity: capacity,
clients: 0,
}
}
func (t *tokens_t) get() {
atomic.AddInt64(&t.clients, 1)
if t.capacity != 0 {
t.ch <- struct{}{}
}
}
func (t *tokens_t) ret() {
atomic.AddInt64(&t.clients, -1)
if t.capacity != 0 {
<-t.ch
}
}
func (t tokens_t) count() int64 {
return atomic.LoadInt64(&t.clients)
}

28
proxy/tokens_test.go Normal file
View file

@ -0,0 +1,28 @@
package main
import (
"testing"
. "github.com/smartystreets/goconvey/convey"
)
func TestTokens(t *testing.T) {
Convey("Tokens", t, func() {
tokens := newTokens(2)
So(tokens.count(), ShouldEqual, 0)
tokens.get()
So(tokens.count(), ShouldEqual, 1)
tokens.ret()
So(tokens.count(), ShouldEqual, 0)
})
Convey("Tokens capacity 0", t, func() {
tokens := newTokens(0)
So(tokens.count(), ShouldEqual, 0)
for i := 0; i < 20; i++ {
tokens.get()
}
So(tokens.count(), ShouldEqual, 20)
tokens.ret()
So(tokens.count(), ShouldEqual, 19)
})
}