diff --git a/broker/broker.go b/broker/broker.go index 906c210..fc4727d 100644 --- a/broker/broker.go +++ b/broker/broker.go @@ -124,16 +124,18 @@ type ProxyPoll struct { id string proxyType string natType string + clients int offerChannel chan *ClientOffer } // Registers a Snowflake and waits for some Client to send an offer, // as part of the polling logic of the proxy handler. -func (ctx *BrokerContext) RequestOffer(id string, proxyType string, natType string) *ClientOffer { +func (ctx *BrokerContext) RequestOffer(id string, proxyType string, natType string, clients int) *ClientOffer { request := new(ProxyPoll) request.id = id request.proxyType = proxyType request.natType = natType + request.clients = clients request.offerChannel = make(chan *ClientOffer) ctx.proxyPolls <- request // Block until an offer is available, or timeout which sends a nil offer. @@ -146,7 +148,7 @@ func (ctx *BrokerContext) RequestOffer(id string, proxyType string, natType stri // client offer or nil on timeout / none are available. func (ctx *BrokerContext) Broker() { for request := range ctx.proxyPolls { - snowflake := ctx.AddSnowflake(request.id, request.proxyType, request.natType) + snowflake := ctx.AddSnowflake(request.id, request.proxyType, request.natType, request.clients) // Wait for a client to avail an offer to the snowflake. go func(request *ProxyPoll) { select { @@ -174,10 +176,10 @@ func (ctx *BrokerContext) Broker() { // Create and add a Snowflake to the heap. // Required to keep track of proxies between providing them // with an offer and awaiting their second POST with an answer. -func (ctx *BrokerContext) AddSnowflake(id string, proxyType string, natType string) *Snowflake { +func (ctx *BrokerContext) AddSnowflake(id string, proxyType string, natType string, clients int) *Snowflake { snowflake := new(Snowflake) snowflake.id = id - snowflake.clients = 0 + snowflake.clients = clients snowflake.proxyType = proxyType snowflake.natType = natType snowflake.offerChannel = make(chan *ClientOffer) @@ -205,7 +207,7 @@ func proxyPolls(ctx *BrokerContext, w http.ResponseWriter, r *http.Request) { return } - sid, proxyType, natType, err := messages.DecodePollRequest(body) + sid, proxyType, natType, clients, err := messages.DecodePollRequest(body) if err != nil { w.WriteHeader(http.StatusBadRequest) return @@ -222,7 +224,7 @@ func proxyPolls(ctx *BrokerContext, w http.ResponseWriter, r *http.Request) { } // Wait for a client to avail an offer to the snowflake, or timeout if nil. - offer := ctx.RequestOffer(sid, proxyType, natType) + offer := ctx.RequestOffer(sid, proxyType, natType, clients) var b []byte if nil == offer { ctx.metrics.lock.Lock() diff --git a/broker/snowflake-broker_test.go b/broker/snowflake-broker_test.go index 646fb02..825bc6f 100644 --- a/broker/snowflake-broker_test.go +++ b/broker/snowflake-broker_test.go @@ -32,7 +32,7 @@ func TestBroker(t *testing.T) { Convey("Adds Snowflake", func() { So(ctx.snowflakes.Len(), ShouldEqual, 0) So(len(ctx.idToSnowflake), ShouldEqual, 0) - ctx.AddSnowflake("foo", "", NATUnrestricted) + ctx.AddSnowflake("foo", "", NATUnrestricted, 0) So(ctx.snowflakes.Len(), ShouldEqual, 1) So(len(ctx.idToSnowflake), ShouldEqual, 1) }) @@ -59,7 +59,7 @@ func TestBroker(t *testing.T) { Convey("Request an offer from the Snowflake Heap", func() { done := make(chan *ClientOffer) go func() { - offer := ctx.RequestOffer("test", "", NATUnrestricted) + offer := ctx.RequestOffer("test", "", NATUnrestricted, 0) done <- offer }() request := <-ctx.proxyPolls @@ -84,7 +84,7 @@ func TestBroker(t *testing.T) { Convey("with a proxy answer if available.", func() { done := make(chan bool) // Prepare a fake proxy to respond with. - snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted) + snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted, 0) go func() { clientOffers(ctx, w, r) done <- true @@ -102,7 +102,7 @@ func TestBroker(t *testing.T) { return } done := make(chan bool) - snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted) + snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted, 0) go func() { clientOffers(ctx, w, r) // Takes a few seconds here... @@ -132,7 +132,7 @@ func TestBroker(t *testing.T) { Convey("with a proxy answer if available.", func() { done := make(chan bool) // Prepare a fake proxy to respond with. - snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted) + snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted, 0) go func() { clientOffers(ctx, w, r) done <- true @@ -150,7 +150,7 @@ func TestBroker(t *testing.T) { return } done := make(chan bool) - snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted) + snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted, 0) go func() { clientOffers(ctx, w, r) // Takes a few seconds here... @@ -201,7 +201,7 @@ func TestBroker(t *testing.T) { }) Convey("Responds to proxy answers...", func() { - s := ctx.AddSnowflake("test", "", NATUnrestricted) + s := ctx.AddSnowflake("test", "", NATUnrestricted, 0) w := httptest.NewRecorder() data := bytes.NewReader([]byte(`{"Version":"1.0","Sid":"test","Answer":"test"}`)) @@ -314,7 +314,7 @@ func TestBroker(t *testing.T) { // Manually do the Broker goroutine action here for full control. p := <-ctx.proxyPolls So(p.id, ShouldEqual, "ymbcCMto7KHNGYlp") - s := ctx.AddSnowflake(p.id, "", NATUnrestricted) + s := ctx.AddSnowflake(p.id, "", NATUnrestricted, 0) go func() { offer := <-s.offerChannel p.offerChannel <- offer @@ -593,7 +593,7 @@ func TestMetrics(t *testing.T) { So(err, ShouldBeNil) // Prepare a fake proxy to respond with. - snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted) + snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted, 0) go func() { clientOffers(ctx, w, r) done <- true diff --git a/common/messages/messages_test.go b/common/messages/messages_test.go index 3962d3b..abb978d 100644 --- a/common/messages/messages_test.go +++ b/common/messages/messages_test.go @@ -15,6 +15,7 @@ func TestDecodeProxyPollRequest(t *testing.T) { sid string proxyType string natType string + clients int data string err error }{ @@ -23,6 +24,7 @@ func TestDecodeProxyPollRequest(t *testing.T) { "ymbcCMto7KHNGYlp", "", "unknown", + 0, `{"Sid":"ymbcCMto7KHNGYlp","Version":"1.0"}`, nil, }, @@ -31,6 +33,7 @@ func TestDecodeProxyPollRequest(t *testing.T) { "ymbcCMto7KHNGYlp", "standalone", "unknown", + 0, `{"Sid":"ymbcCMto7KHNGYlp","Version":"1.1","Type":"standalone"}`, nil, }, @@ -39,14 +42,25 @@ func TestDecodeProxyPollRequest(t *testing.T) { "ymbcCMto7KHNGYlp", "standalone", "restricted", + 0, `{"Sid":"ymbcCMto7KHNGYlp","Version":"1.2","Type":"standalone", "NAT":"restricted"}`, nil, }, + { + //Version 1.2 proxy message with clients + "ymbcCMto7KHNGYlp", + "standalone", + "restricted", + 24, + `{"Sid":"ymbcCMto7KHNGYlp","Version":"1.2","Type":"standalone", "NAT":"restricted","Clients":24}`, + nil, + }, { //Version 0.X proxy message: "", "", "", + 0, "", &json.SyntaxError{}, }, @@ -54,6 +68,7 @@ func TestDecodeProxyPollRequest(t *testing.T) { "", "", "", + 0, `{"Sid":"ymbcCMto7KHNGYlp"}`, fmt.Errorf(""), }, @@ -61,6 +76,7 @@ func TestDecodeProxyPollRequest(t *testing.T) { "", "", "", + 0, "{}", fmt.Errorf(""), }, @@ -68,6 +84,7 @@ func TestDecodeProxyPollRequest(t *testing.T) { "", "", "", + 0, `{"Version":"1.0"}`, fmt.Errorf(""), }, @@ -75,14 +92,16 @@ func TestDecodeProxyPollRequest(t *testing.T) { "", "", "", + 0, `{"Version":"2.0"}`, fmt.Errorf(""), }, } { - sid, proxyType, natType, err := DecodePollRequest([]byte(test.data)) + sid, proxyType, natType, clients, err := DecodePollRequest([]byte(test.data)) So(sid, ShouldResemble, test.sid) So(proxyType, ShouldResemble, test.proxyType) So(natType, ShouldResemble, test.natType) + So(clients, ShouldEqual, test.clients) So(err, ShouldHaveSameTypeAs, test.err) } @@ -91,12 +110,13 @@ func TestDecodeProxyPollRequest(t *testing.T) { func TestEncodeProxyPollRequests(t *testing.T) { Convey("Context", t, func() { - b, err := EncodePollRequest("ymbcCMto7KHNGYlp", "standalone", "unknown") + b, err := EncodePollRequest("ymbcCMto7KHNGYlp", "standalone", "unknown", 16) So(err, ShouldEqual, nil) - sid, proxyType, natType, err := DecodePollRequest(b) + sid, proxyType, natType, clients, err := DecodePollRequest(b) So(sid, ShouldEqual, "ymbcCMto7KHNGYlp") So(proxyType, ShouldEqual, "standalone") So(natType, ShouldEqual, "unknown") + So(clients, ShouldEqual, 16) So(err, ShouldEqual, nil) }) } diff --git a/common/messages/proxy.go b/common/messages/proxy.go index 2d9e58d..366e833 100644 --- a/common/messages/proxy.go +++ b/common/messages/proxy.go @@ -17,8 +17,9 @@ const version = "1.2" { Sid: [generated session id of proxy], Version: 1.2, - Type: ["badge"|"webext"|"standalone"] - NAT: ["unknown"|"restricted"|"unrestricted"] + Type: ["badge"|"webext"|"standalone"], + NAT: ["unknown"|"restricted"|"unrestricted"], + Clients: [number of current clients, rounded down to multiples of 8] } == ProxyPollResponse == @@ -79,43 +80,48 @@ type ProxyPollRequest struct { Version string Type string NAT string + Clients int } -func EncodePollRequest(sid string, proxyType string, natType string) ([]byte, error) { +func EncodePollRequest(sid string, proxyType string, natType string, clients int) ([]byte, error) { return json.Marshal(ProxyPollRequest{ Sid: sid, Version: version, Type: proxyType, NAT: natType, + Clients: clients, }) } // Decodes a poll message from a snowflake proxy and returns the -// sid and proxy type of the proxy on success and an error if it failed -func DecodePollRequest(data []byte) (string, string, string, error) { +// sid, proxy type, nat type and clients of the proxy on success +// and an error if it failed +func DecodePollRequest(data []byte) (sid string, proxyType string, natType string, clients int, err error) { var message ProxyPollRequest - err := json.Unmarshal(data, &message) + err = json.Unmarshal(data, &message) if err != nil { - return "", "", "", err + return } majorVersion := strings.Split(message.Version, ".")[0] if majorVersion != "1" { - return "", "", "", fmt.Errorf("using unknown version") + err = fmt.Errorf("using unknown version") + return } // Version 1.x requires an Sid if message.Sid == "" { - return "", "", "", fmt.Errorf("no supplied session id") + err = fmt.Errorf("no supplied session id") + return } - natType := message.NAT + natType = message.NAT if natType == "" { natType = "unknown" } - return message.Sid, message.Type, natType, nil + return message.Sid, message.Type, natType, message.Clients, nil } type ProxyPollResponse struct { diff --git a/doc/broker-spec.txt b/doc/broker-spec.txt index 9e4b8ae..f2cd231 100644 --- a/doc/broker-spec.txt +++ b/doc/broker-spec.txt @@ -141,7 +141,9 @@ POST /proxy HTTP { Sid: [generated session id of proxy], Version: 1.1, - Type: ["badge"|"webext"|"standalone"|"mobile"] + Type: ["badge"|"webext"|"standalone"|"mobile"], + NAT: ["unknown"|"restricted"|"unrestricted"], + Clients: [number of current clients, rounded down to multiples of 8] } ``` diff --git a/proxy/proxy-go_test.go b/proxy/proxy-go_test.go index e935ad9..183b1b4 100644 --- a/proxy/proxy-go_test.go +++ b/proxy/proxy-go_test.go @@ -7,7 +7,6 @@ import ( "io/ioutil" "net" "net/http" - "net/url" "strconv" "strings" "testing" @@ -337,8 +336,9 @@ func TestBrokerInteractions(t *testing.T) { const sampleAnswer = `{"type":"answer","sdp":` + sampleSDP + `}` Convey("Proxy connections to broker", t, func() { - broker := new(SignalingServer) - broker.url, _ = url.Parse("localhost") + broker, err := newSignalingServer("localhost", false) + So(err, ShouldEqual, nil) + tokens = newTokens(0) //Mock peerConnection config = webrtc.Configuration{ @@ -469,17 +469,6 @@ func TestUtilityFuncs(t *testing.T) { So(err, ShouldEqual, io.ErrClosedPipe) }) }) - Convey("Tokens", t, func() { - tokens = make(chan bool, 2) - for i := uint(0); i < 2; i++ { - tokens <- true - } - So(len(tokens), ShouldEqual, 2) - getToken() - So(len(tokens), ShouldEqual, 1) - retToken() - So(len(tokens), ShouldEqual, 2) - }) Convey("SessionID Generation", t, func() { sid1 := genSessionID() sid2 := genSessionID() diff --git a/proxy/snowflake.go b/proxy/snowflake.go index 86ae0b2..f7eacf8 100644 --- a/proxy/snowflake.go +++ b/proxy/snowflake.go @@ -55,7 +55,7 @@ const ( ) var ( - tokens chan bool + tokens *tokens_t config webrtc.Configuration client http.Client ) @@ -171,14 +171,6 @@ func (c *webRTCConn) SetWriteDeadline(t time.Time) error { return fmt.Errorf("SetWriteDeadline not implemented") } -func getToken() { - <-tokens -} - -func retToken() { - tokens <- true -} - func genSessionID() string { buf := make([]byte, sessionIDLength) _, err := rand.Read(buf) @@ -204,6 +196,21 @@ type SignalingServer struct { keepLocalAddresses bool } +func newSignalingServer(rawURL string, keepLocalAddresses bool) (*SignalingServer, error) { + var err error + s := new(SignalingServer) + s.keepLocalAddresses = keepLocalAddresses + s.url, err = url.Parse(rawURL) + if err != nil { + return nil, fmt.Errorf("invalid broker url: %s", err) + } + + s.transport = http.DefaultTransport.(*http.Transport) + s.transport.(*http.Transport).ResponseHeaderTimeout = 30 * time.Second + + return s, nil +} + func (s *SignalingServer) Post(path string, payload io.Reader) ([]byte, error) { req, err := http.NewRequest("POST", path, payload) @@ -238,7 +245,8 @@ func (s *SignalingServer) pollOffer(sid string) *webrtc.SessionDescription { timeOfNextPoll = now } - body, err := messages.EncodePollRequest(sid, "standalone", currentNATType) + numClients := int((tokens.count() / 8) * 8) // Round down to 8 + body, err := messages.EncodePollRequest(sid, "standalone", currentNATType, numClients) if err != nil { log.Printf("Error encoding poll message: %s", err.Error()) return nil @@ -323,7 +331,7 @@ func CopyLoop(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser) { // RemoteAddr). https://bugs.torproject.org/18628#comment:8 func datachannelHandler(conn *webRTCConn, remoteAddr net.Addr) { defer conn.Close() - defer retToken() + defer tokens.ret() u, err := url.Parse(relayURL) if err != nil { @@ -494,14 +502,14 @@ func runSession(sid string) { offer := broker.pollOffer(sid) if offer == nil { log.Printf("bad offer from broker") - retToken() + tokens.ret() return } dataChan := make(chan struct{}) pc, err := makePeerConnectionFromOffer(offer, config, dataChan, datachannelHandler) if err != nil { log.Printf("error making WebRTC connection: %s", err) - retToken() + tokens.ret() return } err = broker.sendAnswer(sid, pc) @@ -510,7 +518,7 @@ func runSession(sid string) { if inerr := pc.Close(); inerr != nil { log.Printf("error calling pc.Close: %v", inerr) } - retToken() + tokens.ret() return } // Set a timeout on peerconnection. If the connection state has not @@ -524,7 +532,7 @@ func runSession(sid string) { if err := pc.Close(); err != nil { log.Printf("error calling pc.Close: %v", err) } - retToken() + tokens.ret() } } @@ -536,7 +544,7 @@ func main() { var unsafeLogging bool var keepLocalAddresses bool - flag.UintVar(&capacity, "capacity", 10, "maximum concurrent clients") + flag.UintVar(&capacity, "capacity", 0, "maximum concurrent clients") flag.StringVar(&rawBrokerURL, "broker", defaultBrokerURL, "broker URL") flag.StringVar(&relayURL, "relay", defaultRelayURL, "websocket relay URL") flag.StringVar(&stunURL, "stun", defaultSTUNURL, "stun URL") @@ -565,12 +573,11 @@ func main() { log.Println("starting") var err error - broker = new(SignalingServer) - broker.keepLocalAddresses = keepLocalAddresses - broker.url, err = url.Parse(rawBrokerURL) + broker, err = newSignalingServer(rawBrokerURL, keepLocalAddresses) if err != nil { - log.Fatalf("invalid broker url: %s", err) + log.Fatal(err) } + _, err = url.Parse(stunURL) if err != nil { log.Fatalf("invalid stun url: %s", err) @@ -580,8 +587,6 @@ func main() { log.Fatalf("invalid relay url: %s", err) } - broker.transport = http.DefaultTransport.(*http.Transport) - broker.transport.(*http.Transport).ResponseHeaderTimeout = 15 * time.Second config = webrtc.Configuration{ ICEServers: []webrtc.ICEServer{ { @@ -589,17 +594,14 @@ func main() { }, }, } - tokens = make(chan bool, capacity) - for i := uint(0); i < capacity; i++ { - tokens <- true - } + tokens = newTokens(capacity) // use probetest to determine NAT compatability checkNATType(config, defaultProbeURL) log.Printf("NAT type: %s", currentNATType) for { - getToken() + tokens.get() sessionID := genSessionID() runSession(sessionID) } @@ -607,12 +609,7 @@ func main() { func checkNATType(config webrtc.Configuration, probeURL string) { - var err error - - probe := new(SignalingServer) - probe.transport = http.DefaultTransport.(*http.Transport) - probe.transport.(*http.Transport).ResponseHeaderTimeout = 30 * time.Second - probe.url, err = url.Parse(probeURL) + probe, err := newSignalingServer(probeURL, false) if err != nil { log.Printf("Error parsing url: %s", err.Error()) } diff --git a/proxy/tokens.go b/proxy/tokens.go new file mode 100644 index 0000000..fedb8f7 --- /dev/null +++ b/proxy/tokens.go @@ -0,0 +1,44 @@ +package main + +import ( + "sync/atomic" +) + +type tokens_t struct { + ch chan struct{} + capacity uint + clients int64 +} + +func newTokens(capacity uint) *tokens_t { + var ch chan struct{} + if capacity != 0 { + ch = make(chan struct{}, capacity) + } + + return &tokens_t{ + ch: ch, + capacity: capacity, + clients: 0, + } +} + +func (t *tokens_t) get() { + atomic.AddInt64(&t.clients, 1) + + if t.capacity != 0 { + t.ch <- struct{}{} + } +} + +func (t *tokens_t) ret() { + atomic.AddInt64(&t.clients, -1) + + if t.capacity != 0 { + <-t.ch + } +} + +func (t tokens_t) count() int64 { + return atomic.LoadInt64(&t.clients) +} diff --git a/proxy/tokens_test.go b/proxy/tokens_test.go new file mode 100644 index 0000000..622cc05 --- /dev/null +++ b/proxy/tokens_test.go @@ -0,0 +1,28 @@ +package main + +import ( + "testing" + + . "github.com/smartystreets/goconvey/convey" +) + +func TestTokens(t *testing.T) { + Convey("Tokens", t, func() { + tokens := newTokens(2) + So(tokens.count(), ShouldEqual, 0) + tokens.get() + So(tokens.count(), ShouldEqual, 1) + tokens.ret() + So(tokens.count(), ShouldEqual, 0) + }) + Convey("Tokens capacity 0", t, func() { + tokens := newTokens(0) + So(tokens.count(), ShouldEqual, 0) + for i := 0; i < 20; i++ { + tokens.get() + } + So(tokens.count(), ShouldEqual, 20) + tokens.ret() + So(tokens.count(), ShouldEqual, 19) + }) +}