mirror of
https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake.git
synced 2025-10-14 05:11:19 -04:00
Make the proxy to report the number of clients to the broker
So the assignment of proxies is based on the load. The number of clients is ronded down to 8. Existing proxies that doesn't report the number of clients will be distributed equaly to new proxies until they get 8 clients, that is okish as the existing proxies do have a maximum capacity of 10. Fixes #40048
This commit is contained in:
parent
74bdb85b30
commit
7a1857c42f
9 changed files with 165 additions and 77 deletions
|
@ -124,16 +124,18 @@ type ProxyPoll struct {
|
||||||
id string
|
id string
|
||||||
proxyType string
|
proxyType string
|
||||||
natType string
|
natType string
|
||||||
|
clients int
|
||||||
offerChannel chan *ClientOffer
|
offerChannel chan *ClientOffer
|
||||||
}
|
}
|
||||||
|
|
||||||
// Registers a Snowflake and waits for some Client to send an offer,
|
// Registers a Snowflake and waits for some Client to send an offer,
|
||||||
// as part of the polling logic of the proxy handler.
|
// as part of the polling logic of the proxy handler.
|
||||||
func (ctx *BrokerContext) RequestOffer(id string, proxyType string, natType string) *ClientOffer {
|
func (ctx *BrokerContext) RequestOffer(id string, proxyType string, natType string, clients int) *ClientOffer {
|
||||||
request := new(ProxyPoll)
|
request := new(ProxyPoll)
|
||||||
request.id = id
|
request.id = id
|
||||||
request.proxyType = proxyType
|
request.proxyType = proxyType
|
||||||
request.natType = natType
|
request.natType = natType
|
||||||
|
request.clients = clients
|
||||||
request.offerChannel = make(chan *ClientOffer)
|
request.offerChannel = make(chan *ClientOffer)
|
||||||
ctx.proxyPolls <- request
|
ctx.proxyPolls <- request
|
||||||
// Block until an offer is available, or timeout which sends a nil offer.
|
// Block until an offer is available, or timeout which sends a nil offer.
|
||||||
|
@ -146,7 +148,7 @@ func (ctx *BrokerContext) RequestOffer(id string, proxyType string, natType stri
|
||||||
// client offer or nil on timeout / none are available.
|
// client offer or nil on timeout / none are available.
|
||||||
func (ctx *BrokerContext) Broker() {
|
func (ctx *BrokerContext) Broker() {
|
||||||
for request := range ctx.proxyPolls {
|
for request := range ctx.proxyPolls {
|
||||||
snowflake := ctx.AddSnowflake(request.id, request.proxyType, request.natType)
|
snowflake := ctx.AddSnowflake(request.id, request.proxyType, request.natType, request.clients)
|
||||||
// Wait for a client to avail an offer to the snowflake.
|
// Wait for a client to avail an offer to the snowflake.
|
||||||
go func(request *ProxyPoll) {
|
go func(request *ProxyPoll) {
|
||||||
select {
|
select {
|
||||||
|
@ -174,10 +176,10 @@ func (ctx *BrokerContext) Broker() {
|
||||||
// Create and add a Snowflake to the heap.
|
// Create and add a Snowflake to the heap.
|
||||||
// Required to keep track of proxies between providing them
|
// Required to keep track of proxies between providing them
|
||||||
// with an offer and awaiting their second POST with an answer.
|
// with an offer and awaiting their second POST with an answer.
|
||||||
func (ctx *BrokerContext) AddSnowflake(id string, proxyType string, natType string) *Snowflake {
|
func (ctx *BrokerContext) AddSnowflake(id string, proxyType string, natType string, clients int) *Snowflake {
|
||||||
snowflake := new(Snowflake)
|
snowflake := new(Snowflake)
|
||||||
snowflake.id = id
|
snowflake.id = id
|
||||||
snowflake.clients = 0
|
snowflake.clients = clients
|
||||||
snowflake.proxyType = proxyType
|
snowflake.proxyType = proxyType
|
||||||
snowflake.natType = natType
|
snowflake.natType = natType
|
||||||
snowflake.offerChannel = make(chan *ClientOffer)
|
snowflake.offerChannel = make(chan *ClientOffer)
|
||||||
|
@ -205,7 +207,7 @@ func proxyPolls(ctx *BrokerContext, w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sid, proxyType, natType, err := messages.DecodePollRequest(body)
|
sid, proxyType, natType, clients, err := messages.DecodePollRequest(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
w.WriteHeader(http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
|
@ -222,7 +224,7 @@ func proxyPolls(ctx *BrokerContext, w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for a client to avail an offer to the snowflake, or timeout if nil.
|
// Wait for a client to avail an offer to the snowflake, or timeout if nil.
|
||||||
offer := ctx.RequestOffer(sid, proxyType, natType)
|
offer := ctx.RequestOffer(sid, proxyType, natType, clients)
|
||||||
var b []byte
|
var b []byte
|
||||||
if nil == offer {
|
if nil == offer {
|
||||||
ctx.metrics.lock.Lock()
|
ctx.metrics.lock.Lock()
|
||||||
|
|
|
@ -32,7 +32,7 @@ func TestBroker(t *testing.T) {
|
||||||
Convey("Adds Snowflake", func() {
|
Convey("Adds Snowflake", func() {
|
||||||
So(ctx.snowflakes.Len(), ShouldEqual, 0)
|
So(ctx.snowflakes.Len(), ShouldEqual, 0)
|
||||||
So(len(ctx.idToSnowflake), ShouldEqual, 0)
|
So(len(ctx.idToSnowflake), ShouldEqual, 0)
|
||||||
ctx.AddSnowflake("foo", "", NATUnrestricted)
|
ctx.AddSnowflake("foo", "", NATUnrestricted, 0)
|
||||||
So(ctx.snowflakes.Len(), ShouldEqual, 1)
|
So(ctx.snowflakes.Len(), ShouldEqual, 1)
|
||||||
So(len(ctx.idToSnowflake), ShouldEqual, 1)
|
So(len(ctx.idToSnowflake), ShouldEqual, 1)
|
||||||
})
|
})
|
||||||
|
@ -59,7 +59,7 @@ func TestBroker(t *testing.T) {
|
||||||
Convey("Request an offer from the Snowflake Heap", func() {
|
Convey("Request an offer from the Snowflake Heap", func() {
|
||||||
done := make(chan *ClientOffer)
|
done := make(chan *ClientOffer)
|
||||||
go func() {
|
go func() {
|
||||||
offer := ctx.RequestOffer("test", "", NATUnrestricted)
|
offer := ctx.RequestOffer("test", "", NATUnrestricted, 0)
|
||||||
done <- offer
|
done <- offer
|
||||||
}()
|
}()
|
||||||
request := <-ctx.proxyPolls
|
request := <-ctx.proxyPolls
|
||||||
|
@ -84,7 +84,7 @@ func TestBroker(t *testing.T) {
|
||||||
Convey("with a proxy answer if available.", func() {
|
Convey("with a proxy answer if available.", func() {
|
||||||
done := make(chan bool)
|
done := make(chan bool)
|
||||||
// Prepare a fake proxy to respond with.
|
// Prepare a fake proxy to respond with.
|
||||||
snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted)
|
snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted, 0)
|
||||||
go func() {
|
go func() {
|
||||||
clientOffers(ctx, w, r)
|
clientOffers(ctx, w, r)
|
||||||
done <- true
|
done <- true
|
||||||
|
@ -102,7 +102,7 @@ func TestBroker(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
done := make(chan bool)
|
done := make(chan bool)
|
||||||
snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted)
|
snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted, 0)
|
||||||
go func() {
|
go func() {
|
||||||
clientOffers(ctx, w, r)
|
clientOffers(ctx, w, r)
|
||||||
// Takes a few seconds here...
|
// Takes a few seconds here...
|
||||||
|
@ -132,7 +132,7 @@ func TestBroker(t *testing.T) {
|
||||||
Convey("with a proxy answer if available.", func() {
|
Convey("with a proxy answer if available.", func() {
|
||||||
done := make(chan bool)
|
done := make(chan bool)
|
||||||
// Prepare a fake proxy to respond with.
|
// Prepare a fake proxy to respond with.
|
||||||
snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted)
|
snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted, 0)
|
||||||
go func() {
|
go func() {
|
||||||
clientOffers(ctx, w, r)
|
clientOffers(ctx, w, r)
|
||||||
done <- true
|
done <- true
|
||||||
|
@ -150,7 +150,7 @@ func TestBroker(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
done := make(chan bool)
|
done := make(chan bool)
|
||||||
snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted)
|
snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted, 0)
|
||||||
go func() {
|
go func() {
|
||||||
clientOffers(ctx, w, r)
|
clientOffers(ctx, w, r)
|
||||||
// Takes a few seconds here...
|
// Takes a few seconds here...
|
||||||
|
@ -201,7 +201,7 @@ func TestBroker(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
Convey("Responds to proxy answers...", func() {
|
Convey("Responds to proxy answers...", func() {
|
||||||
s := ctx.AddSnowflake("test", "", NATUnrestricted)
|
s := ctx.AddSnowflake("test", "", NATUnrestricted, 0)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
data := bytes.NewReader([]byte(`{"Version":"1.0","Sid":"test","Answer":"test"}`))
|
data := bytes.NewReader([]byte(`{"Version":"1.0","Sid":"test","Answer":"test"}`))
|
||||||
|
|
||||||
|
@ -314,7 +314,7 @@ func TestBroker(t *testing.T) {
|
||||||
// Manually do the Broker goroutine action here for full control.
|
// Manually do the Broker goroutine action here for full control.
|
||||||
p := <-ctx.proxyPolls
|
p := <-ctx.proxyPolls
|
||||||
So(p.id, ShouldEqual, "ymbcCMto7KHNGYlp")
|
So(p.id, ShouldEqual, "ymbcCMto7KHNGYlp")
|
||||||
s := ctx.AddSnowflake(p.id, "", NATUnrestricted)
|
s := ctx.AddSnowflake(p.id, "", NATUnrestricted, 0)
|
||||||
go func() {
|
go func() {
|
||||||
offer := <-s.offerChannel
|
offer := <-s.offerChannel
|
||||||
p.offerChannel <- offer
|
p.offerChannel <- offer
|
||||||
|
@ -593,7 +593,7 @@ func TestMetrics(t *testing.T) {
|
||||||
So(err, ShouldBeNil)
|
So(err, ShouldBeNil)
|
||||||
|
|
||||||
// Prepare a fake proxy to respond with.
|
// Prepare a fake proxy to respond with.
|
||||||
snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted)
|
snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted, 0)
|
||||||
go func() {
|
go func() {
|
||||||
clientOffers(ctx, w, r)
|
clientOffers(ctx, w, r)
|
||||||
done <- true
|
done <- true
|
||||||
|
|
|
@ -15,6 +15,7 @@ func TestDecodeProxyPollRequest(t *testing.T) {
|
||||||
sid string
|
sid string
|
||||||
proxyType string
|
proxyType string
|
||||||
natType string
|
natType string
|
||||||
|
clients int
|
||||||
data string
|
data string
|
||||||
err error
|
err error
|
||||||
}{
|
}{
|
||||||
|
@ -23,6 +24,7 @@ func TestDecodeProxyPollRequest(t *testing.T) {
|
||||||
"ymbcCMto7KHNGYlp",
|
"ymbcCMto7KHNGYlp",
|
||||||
"",
|
"",
|
||||||
"unknown",
|
"unknown",
|
||||||
|
0,
|
||||||
`{"Sid":"ymbcCMto7KHNGYlp","Version":"1.0"}`,
|
`{"Sid":"ymbcCMto7KHNGYlp","Version":"1.0"}`,
|
||||||
nil,
|
nil,
|
||||||
},
|
},
|
||||||
|
@ -31,6 +33,7 @@ func TestDecodeProxyPollRequest(t *testing.T) {
|
||||||
"ymbcCMto7KHNGYlp",
|
"ymbcCMto7KHNGYlp",
|
||||||
"standalone",
|
"standalone",
|
||||||
"unknown",
|
"unknown",
|
||||||
|
0,
|
||||||
`{"Sid":"ymbcCMto7KHNGYlp","Version":"1.1","Type":"standalone"}`,
|
`{"Sid":"ymbcCMto7KHNGYlp","Version":"1.1","Type":"standalone"}`,
|
||||||
nil,
|
nil,
|
||||||
},
|
},
|
||||||
|
@ -39,14 +42,25 @@ func TestDecodeProxyPollRequest(t *testing.T) {
|
||||||
"ymbcCMto7KHNGYlp",
|
"ymbcCMto7KHNGYlp",
|
||||||
"standalone",
|
"standalone",
|
||||||
"restricted",
|
"restricted",
|
||||||
|
0,
|
||||||
`{"Sid":"ymbcCMto7KHNGYlp","Version":"1.2","Type":"standalone", "NAT":"restricted"}`,
|
`{"Sid":"ymbcCMto7KHNGYlp","Version":"1.2","Type":"standalone", "NAT":"restricted"}`,
|
||||||
nil,
|
nil,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
//Version 1.2 proxy message with clients
|
||||||
|
"ymbcCMto7KHNGYlp",
|
||||||
|
"standalone",
|
||||||
|
"restricted",
|
||||||
|
24,
|
||||||
|
`{"Sid":"ymbcCMto7KHNGYlp","Version":"1.2","Type":"standalone", "NAT":"restricted","Clients":24}`,
|
||||||
|
nil,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
//Version 0.X proxy message:
|
//Version 0.X proxy message:
|
||||||
"",
|
"",
|
||||||
"",
|
"",
|
||||||
"",
|
"",
|
||||||
|
0,
|
||||||
"",
|
"",
|
||||||
&json.SyntaxError{},
|
&json.SyntaxError{},
|
||||||
},
|
},
|
||||||
|
@ -54,6 +68,7 @@ func TestDecodeProxyPollRequest(t *testing.T) {
|
||||||
"",
|
"",
|
||||||
"",
|
"",
|
||||||
"",
|
"",
|
||||||
|
0,
|
||||||
`{"Sid":"ymbcCMto7KHNGYlp"}`,
|
`{"Sid":"ymbcCMto7KHNGYlp"}`,
|
||||||
fmt.Errorf(""),
|
fmt.Errorf(""),
|
||||||
},
|
},
|
||||||
|
@ -61,6 +76,7 @@ func TestDecodeProxyPollRequest(t *testing.T) {
|
||||||
"",
|
"",
|
||||||
"",
|
"",
|
||||||
"",
|
"",
|
||||||
|
0,
|
||||||
"{}",
|
"{}",
|
||||||
fmt.Errorf(""),
|
fmt.Errorf(""),
|
||||||
},
|
},
|
||||||
|
@ -68,6 +84,7 @@ func TestDecodeProxyPollRequest(t *testing.T) {
|
||||||
"",
|
"",
|
||||||
"",
|
"",
|
||||||
"",
|
"",
|
||||||
|
0,
|
||||||
`{"Version":"1.0"}`,
|
`{"Version":"1.0"}`,
|
||||||
fmt.Errorf(""),
|
fmt.Errorf(""),
|
||||||
},
|
},
|
||||||
|
@ -75,14 +92,16 @@ func TestDecodeProxyPollRequest(t *testing.T) {
|
||||||
"",
|
"",
|
||||||
"",
|
"",
|
||||||
"",
|
"",
|
||||||
|
0,
|
||||||
`{"Version":"2.0"}`,
|
`{"Version":"2.0"}`,
|
||||||
fmt.Errorf(""),
|
fmt.Errorf(""),
|
||||||
},
|
},
|
||||||
} {
|
} {
|
||||||
sid, proxyType, natType, err := DecodePollRequest([]byte(test.data))
|
sid, proxyType, natType, clients, err := DecodePollRequest([]byte(test.data))
|
||||||
So(sid, ShouldResemble, test.sid)
|
So(sid, ShouldResemble, test.sid)
|
||||||
So(proxyType, ShouldResemble, test.proxyType)
|
So(proxyType, ShouldResemble, test.proxyType)
|
||||||
So(natType, ShouldResemble, test.natType)
|
So(natType, ShouldResemble, test.natType)
|
||||||
|
So(clients, ShouldEqual, test.clients)
|
||||||
So(err, ShouldHaveSameTypeAs, test.err)
|
So(err, ShouldHaveSameTypeAs, test.err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -91,12 +110,13 @@ func TestDecodeProxyPollRequest(t *testing.T) {
|
||||||
|
|
||||||
func TestEncodeProxyPollRequests(t *testing.T) {
|
func TestEncodeProxyPollRequests(t *testing.T) {
|
||||||
Convey("Context", t, func() {
|
Convey("Context", t, func() {
|
||||||
b, err := EncodePollRequest("ymbcCMto7KHNGYlp", "standalone", "unknown")
|
b, err := EncodePollRequest("ymbcCMto7KHNGYlp", "standalone", "unknown", 16)
|
||||||
So(err, ShouldEqual, nil)
|
So(err, ShouldEqual, nil)
|
||||||
sid, proxyType, natType, err := DecodePollRequest(b)
|
sid, proxyType, natType, clients, err := DecodePollRequest(b)
|
||||||
So(sid, ShouldEqual, "ymbcCMto7KHNGYlp")
|
So(sid, ShouldEqual, "ymbcCMto7KHNGYlp")
|
||||||
So(proxyType, ShouldEqual, "standalone")
|
So(proxyType, ShouldEqual, "standalone")
|
||||||
So(natType, ShouldEqual, "unknown")
|
So(natType, ShouldEqual, "unknown")
|
||||||
|
So(clients, ShouldEqual, 16)
|
||||||
So(err, ShouldEqual, nil)
|
So(err, ShouldEqual, nil)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,8 +17,9 @@ const version = "1.2"
|
||||||
{
|
{
|
||||||
Sid: [generated session id of proxy],
|
Sid: [generated session id of proxy],
|
||||||
Version: 1.2,
|
Version: 1.2,
|
||||||
Type: ["badge"|"webext"|"standalone"]
|
Type: ["badge"|"webext"|"standalone"],
|
||||||
NAT: ["unknown"|"restricted"|"unrestricted"]
|
NAT: ["unknown"|"restricted"|"unrestricted"],
|
||||||
|
Clients: [number of current clients, rounded down to multiples of 8]
|
||||||
}
|
}
|
||||||
|
|
||||||
== ProxyPollResponse ==
|
== ProxyPollResponse ==
|
||||||
|
@ -79,43 +80,48 @@ type ProxyPollRequest struct {
|
||||||
Version string
|
Version string
|
||||||
Type string
|
Type string
|
||||||
NAT string
|
NAT string
|
||||||
|
Clients int
|
||||||
}
|
}
|
||||||
|
|
||||||
func EncodePollRequest(sid string, proxyType string, natType string) ([]byte, error) {
|
func EncodePollRequest(sid string, proxyType string, natType string, clients int) ([]byte, error) {
|
||||||
return json.Marshal(ProxyPollRequest{
|
return json.Marshal(ProxyPollRequest{
|
||||||
Sid: sid,
|
Sid: sid,
|
||||||
Version: version,
|
Version: version,
|
||||||
Type: proxyType,
|
Type: proxyType,
|
||||||
NAT: natType,
|
NAT: natType,
|
||||||
|
Clients: clients,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decodes a poll message from a snowflake proxy and returns the
|
// Decodes a poll message from a snowflake proxy and returns the
|
||||||
// sid and proxy type of the proxy on success and an error if it failed
|
// sid, proxy type, nat type and clients of the proxy on success
|
||||||
func DecodePollRequest(data []byte) (string, string, string, error) {
|
// and an error if it failed
|
||||||
|
func DecodePollRequest(data []byte) (sid string, proxyType string, natType string, clients int, err error) {
|
||||||
var message ProxyPollRequest
|
var message ProxyPollRequest
|
||||||
|
|
||||||
err := json.Unmarshal(data, &message)
|
err = json.Unmarshal(data, &message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", "", err
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
majorVersion := strings.Split(message.Version, ".")[0]
|
majorVersion := strings.Split(message.Version, ".")[0]
|
||||||
if majorVersion != "1" {
|
if majorVersion != "1" {
|
||||||
return "", "", "", fmt.Errorf("using unknown version")
|
err = fmt.Errorf("using unknown version")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Version 1.x requires an Sid
|
// Version 1.x requires an Sid
|
||||||
if message.Sid == "" {
|
if message.Sid == "" {
|
||||||
return "", "", "", fmt.Errorf("no supplied session id")
|
err = fmt.Errorf("no supplied session id")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
natType := message.NAT
|
natType = message.NAT
|
||||||
if natType == "" {
|
if natType == "" {
|
||||||
natType = "unknown"
|
natType = "unknown"
|
||||||
}
|
}
|
||||||
|
|
||||||
return message.Sid, message.Type, natType, nil
|
return message.Sid, message.Type, natType, message.Clients, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type ProxyPollResponse struct {
|
type ProxyPollResponse struct {
|
||||||
|
|
|
@ -141,7 +141,9 @@ POST /proxy HTTP
|
||||||
{
|
{
|
||||||
Sid: [generated session id of proxy],
|
Sid: [generated session id of proxy],
|
||||||
Version: 1.1,
|
Version: 1.1,
|
||||||
Type: ["badge"|"webext"|"standalone"|"mobile"]
|
Type: ["badge"|"webext"|"standalone"|"mobile"],
|
||||||
|
NAT: ["unknown"|"restricted"|"unrestricted"],
|
||||||
|
Clients: [number of current clients, rounded down to multiples of 8]
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,6 @@ import (
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -337,8 +336,9 @@ func TestBrokerInteractions(t *testing.T) {
|
||||||
const sampleAnswer = `{"type":"answer","sdp":` + sampleSDP + `}`
|
const sampleAnswer = `{"type":"answer","sdp":` + sampleSDP + `}`
|
||||||
|
|
||||||
Convey("Proxy connections to broker", t, func() {
|
Convey("Proxy connections to broker", t, func() {
|
||||||
broker := new(SignalingServer)
|
broker, err := newSignalingServer("localhost", false)
|
||||||
broker.url, _ = url.Parse("localhost")
|
So(err, ShouldEqual, nil)
|
||||||
|
tokens = newTokens(0)
|
||||||
|
|
||||||
//Mock peerConnection
|
//Mock peerConnection
|
||||||
config = webrtc.Configuration{
|
config = webrtc.Configuration{
|
||||||
|
@ -469,17 +469,6 @@ func TestUtilityFuncs(t *testing.T) {
|
||||||
So(err, ShouldEqual, io.ErrClosedPipe)
|
So(err, ShouldEqual, io.ErrClosedPipe)
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
Convey("Tokens", t, func() {
|
|
||||||
tokens = make(chan bool, 2)
|
|
||||||
for i := uint(0); i < 2; i++ {
|
|
||||||
tokens <- true
|
|
||||||
}
|
|
||||||
So(len(tokens), ShouldEqual, 2)
|
|
||||||
getToken()
|
|
||||||
So(len(tokens), ShouldEqual, 1)
|
|
||||||
retToken()
|
|
||||||
So(len(tokens), ShouldEqual, 2)
|
|
||||||
})
|
|
||||||
Convey("SessionID Generation", t, func() {
|
Convey("SessionID Generation", t, func() {
|
||||||
sid1 := genSessionID()
|
sid1 := genSessionID()
|
||||||
sid2 := genSessionID()
|
sid2 := genSessionID()
|
||||||
|
|
|
@ -55,7 +55,7 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
tokens chan bool
|
tokens *tokens_t
|
||||||
config webrtc.Configuration
|
config webrtc.Configuration
|
||||||
client http.Client
|
client http.Client
|
||||||
)
|
)
|
||||||
|
@ -171,14 +171,6 @@ func (c *webRTCConn) SetWriteDeadline(t time.Time) error {
|
||||||
return fmt.Errorf("SetWriteDeadline not implemented")
|
return fmt.Errorf("SetWriteDeadline not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func getToken() {
|
|
||||||
<-tokens
|
|
||||||
}
|
|
||||||
|
|
||||||
func retToken() {
|
|
||||||
tokens <- true
|
|
||||||
}
|
|
||||||
|
|
||||||
func genSessionID() string {
|
func genSessionID() string {
|
||||||
buf := make([]byte, sessionIDLength)
|
buf := make([]byte, sessionIDLength)
|
||||||
_, err := rand.Read(buf)
|
_, err := rand.Read(buf)
|
||||||
|
@ -204,6 +196,21 @@ type SignalingServer struct {
|
||||||
keepLocalAddresses bool
|
keepLocalAddresses bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newSignalingServer(rawURL string, keepLocalAddresses bool) (*SignalingServer, error) {
|
||||||
|
var err error
|
||||||
|
s := new(SignalingServer)
|
||||||
|
s.keepLocalAddresses = keepLocalAddresses
|
||||||
|
s.url, err = url.Parse(rawURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid broker url: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.transport = http.DefaultTransport.(*http.Transport)
|
||||||
|
s.transport.(*http.Transport).ResponseHeaderTimeout = 30 * time.Second
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *SignalingServer) Post(path string, payload io.Reader) ([]byte, error) {
|
func (s *SignalingServer) Post(path string, payload io.Reader) ([]byte, error) {
|
||||||
|
|
||||||
req, err := http.NewRequest("POST", path, payload)
|
req, err := http.NewRequest("POST", path, payload)
|
||||||
|
@ -238,7 +245,8 @@ func (s *SignalingServer) pollOffer(sid string) *webrtc.SessionDescription {
|
||||||
timeOfNextPoll = now
|
timeOfNextPoll = now
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := messages.EncodePollRequest(sid, "standalone", currentNATType)
|
numClients := int((tokens.count() / 8) * 8) // Round down to 8
|
||||||
|
body, err := messages.EncodePollRequest(sid, "standalone", currentNATType, numClients)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error encoding poll message: %s", err.Error())
|
log.Printf("Error encoding poll message: %s", err.Error())
|
||||||
return nil
|
return nil
|
||||||
|
@ -323,7 +331,7 @@ func CopyLoop(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser) {
|
||||||
// RemoteAddr). https://bugs.torproject.org/18628#comment:8
|
// RemoteAddr). https://bugs.torproject.org/18628#comment:8
|
||||||
func datachannelHandler(conn *webRTCConn, remoteAddr net.Addr) {
|
func datachannelHandler(conn *webRTCConn, remoteAddr net.Addr) {
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
defer retToken()
|
defer tokens.ret()
|
||||||
|
|
||||||
u, err := url.Parse(relayURL)
|
u, err := url.Parse(relayURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -494,14 +502,14 @@ func runSession(sid string) {
|
||||||
offer := broker.pollOffer(sid)
|
offer := broker.pollOffer(sid)
|
||||||
if offer == nil {
|
if offer == nil {
|
||||||
log.Printf("bad offer from broker")
|
log.Printf("bad offer from broker")
|
||||||
retToken()
|
tokens.ret()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
dataChan := make(chan struct{})
|
dataChan := make(chan struct{})
|
||||||
pc, err := makePeerConnectionFromOffer(offer, config, dataChan, datachannelHandler)
|
pc, err := makePeerConnectionFromOffer(offer, config, dataChan, datachannelHandler)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("error making WebRTC connection: %s", err)
|
log.Printf("error making WebRTC connection: %s", err)
|
||||||
retToken()
|
tokens.ret()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = broker.sendAnswer(sid, pc)
|
err = broker.sendAnswer(sid, pc)
|
||||||
|
@ -510,7 +518,7 @@ func runSession(sid string) {
|
||||||
if inerr := pc.Close(); inerr != nil {
|
if inerr := pc.Close(); inerr != nil {
|
||||||
log.Printf("error calling pc.Close: %v", inerr)
|
log.Printf("error calling pc.Close: %v", inerr)
|
||||||
}
|
}
|
||||||
retToken()
|
tokens.ret()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Set a timeout on peerconnection. If the connection state has not
|
// Set a timeout on peerconnection. If the connection state has not
|
||||||
|
@ -524,7 +532,7 @@ func runSession(sid string) {
|
||||||
if err := pc.Close(); err != nil {
|
if err := pc.Close(); err != nil {
|
||||||
log.Printf("error calling pc.Close: %v", err)
|
log.Printf("error calling pc.Close: %v", err)
|
||||||
}
|
}
|
||||||
retToken()
|
tokens.ret()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -536,7 +544,7 @@ func main() {
|
||||||
var unsafeLogging bool
|
var unsafeLogging bool
|
||||||
var keepLocalAddresses bool
|
var keepLocalAddresses bool
|
||||||
|
|
||||||
flag.UintVar(&capacity, "capacity", 10, "maximum concurrent clients")
|
flag.UintVar(&capacity, "capacity", 0, "maximum concurrent clients")
|
||||||
flag.StringVar(&rawBrokerURL, "broker", defaultBrokerURL, "broker URL")
|
flag.StringVar(&rawBrokerURL, "broker", defaultBrokerURL, "broker URL")
|
||||||
flag.StringVar(&relayURL, "relay", defaultRelayURL, "websocket relay URL")
|
flag.StringVar(&relayURL, "relay", defaultRelayURL, "websocket relay URL")
|
||||||
flag.StringVar(&stunURL, "stun", defaultSTUNURL, "stun URL")
|
flag.StringVar(&stunURL, "stun", defaultSTUNURL, "stun URL")
|
||||||
|
@ -565,12 +573,11 @@ func main() {
|
||||||
log.Println("starting")
|
log.Println("starting")
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
broker = new(SignalingServer)
|
broker, err = newSignalingServer(rawBrokerURL, keepLocalAddresses)
|
||||||
broker.keepLocalAddresses = keepLocalAddresses
|
|
||||||
broker.url, err = url.Parse(rawBrokerURL)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("invalid broker url: %s", err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = url.Parse(stunURL)
|
_, err = url.Parse(stunURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("invalid stun url: %s", err)
|
log.Fatalf("invalid stun url: %s", err)
|
||||||
|
@ -580,8 +587,6 @@ func main() {
|
||||||
log.Fatalf("invalid relay url: %s", err)
|
log.Fatalf("invalid relay url: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
broker.transport = http.DefaultTransport.(*http.Transport)
|
|
||||||
broker.transport.(*http.Transport).ResponseHeaderTimeout = 15 * time.Second
|
|
||||||
config = webrtc.Configuration{
|
config = webrtc.Configuration{
|
||||||
ICEServers: []webrtc.ICEServer{
|
ICEServers: []webrtc.ICEServer{
|
||||||
{
|
{
|
||||||
|
@ -589,17 +594,14 @@ func main() {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
tokens = make(chan bool, capacity)
|
tokens = newTokens(capacity)
|
||||||
for i := uint(0); i < capacity; i++ {
|
|
||||||
tokens <- true
|
|
||||||
}
|
|
||||||
|
|
||||||
// use probetest to determine NAT compatability
|
// use probetest to determine NAT compatability
|
||||||
checkNATType(config, defaultProbeURL)
|
checkNATType(config, defaultProbeURL)
|
||||||
log.Printf("NAT type: %s", currentNATType)
|
log.Printf("NAT type: %s", currentNATType)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
getToken()
|
tokens.get()
|
||||||
sessionID := genSessionID()
|
sessionID := genSessionID()
|
||||||
runSession(sessionID)
|
runSession(sessionID)
|
||||||
}
|
}
|
||||||
|
@ -607,12 +609,7 @@ func main() {
|
||||||
|
|
||||||
func checkNATType(config webrtc.Configuration, probeURL string) {
|
func checkNATType(config webrtc.Configuration, probeURL string) {
|
||||||
|
|
||||||
var err error
|
probe, err := newSignalingServer(probeURL, false)
|
||||||
|
|
||||||
probe := new(SignalingServer)
|
|
||||||
probe.transport = http.DefaultTransport.(*http.Transport)
|
|
||||||
probe.transport.(*http.Transport).ResponseHeaderTimeout = 30 * time.Second
|
|
||||||
probe.url, err = url.Parse(probeURL)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error parsing url: %s", err.Error())
|
log.Printf("Error parsing url: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
44
proxy/tokens.go
Normal file
44
proxy/tokens.go
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type tokens_t struct {
|
||||||
|
ch chan struct{}
|
||||||
|
capacity uint
|
||||||
|
clients int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTokens(capacity uint) *tokens_t {
|
||||||
|
var ch chan struct{}
|
||||||
|
if capacity != 0 {
|
||||||
|
ch = make(chan struct{}, capacity)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tokens_t{
|
||||||
|
ch: ch,
|
||||||
|
capacity: capacity,
|
||||||
|
clients: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokens_t) get() {
|
||||||
|
atomic.AddInt64(&t.clients, 1)
|
||||||
|
|
||||||
|
if t.capacity != 0 {
|
||||||
|
t.ch <- struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *tokens_t) ret() {
|
||||||
|
atomic.AddInt64(&t.clients, -1)
|
||||||
|
|
||||||
|
if t.capacity != 0 {
|
||||||
|
<-t.ch
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t tokens_t) count() int64 {
|
||||||
|
return atomic.LoadInt64(&t.clients)
|
||||||
|
}
|
28
proxy/tokens_test.go
Normal file
28
proxy/tokens_test.go
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/smartystreets/goconvey/convey"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTokens(t *testing.T) {
|
||||||
|
Convey("Tokens", t, func() {
|
||||||
|
tokens := newTokens(2)
|
||||||
|
So(tokens.count(), ShouldEqual, 0)
|
||||||
|
tokens.get()
|
||||||
|
So(tokens.count(), ShouldEqual, 1)
|
||||||
|
tokens.ret()
|
||||||
|
So(tokens.count(), ShouldEqual, 0)
|
||||||
|
})
|
||||||
|
Convey("Tokens capacity 0", t, func() {
|
||||||
|
tokens := newTokens(0)
|
||||||
|
So(tokens.count(), ShouldEqual, 0)
|
||||||
|
for i := 0; i < 20; i++ {
|
||||||
|
tokens.get()
|
||||||
|
}
|
||||||
|
So(tokens.count(), ShouldEqual, 20)
|
||||||
|
tokens.ret()
|
||||||
|
So(tokens.count(), ShouldEqual, 19)
|
||||||
|
})
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue