mirror of
https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake.git
synced 2025-10-13 20:11:19 -04:00
Make the proxy to report the number of clients to the broker
So the assignment of proxies is based on the load. The number of clients is ronded down to 8. Existing proxies that doesn't report the number of clients will be distributed equaly to new proxies until they get 8 clients, that is okish as the existing proxies do have a maximum capacity of 10. Fixes #40048
This commit is contained in:
parent
74bdb85b30
commit
7a1857c42f
9 changed files with 165 additions and 77 deletions
|
@ -7,7 +7,6 @@ import (
|
|||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
@ -337,8 +336,9 @@ func TestBrokerInteractions(t *testing.T) {
|
|||
const sampleAnswer = `{"type":"answer","sdp":` + sampleSDP + `}`
|
||||
|
||||
Convey("Proxy connections to broker", t, func() {
|
||||
broker := new(SignalingServer)
|
||||
broker.url, _ = url.Parse("localhost")
|
||||
broker, err := newSignalingServer("localhost", false)
|
||||
So(err, ShouldEqual, nil)
|
||||
tokens = newTokens(0)
|
||||
|
||||
//Mock peerConnection
|
||||
config = webrtc.Configuration{
|
||||
|
@ -469,17 +469,6 @@ func TestUtilityFuncs(t *testing.T) {
|
|||
So(err, ShouldEqual, io.ErrClosedPipe)
|
||||
})
|
||||
})
|
||||
Convey("Tokens", t, func() {
|
||||
tokens = make(chan bool, 2)
|
||||
for i := uint(0); i < 2; i++ {
|
||||
tokens <- true
|
||||
}
|
||||
So(len(tokens), ShouldEqual, 2)
|
||||
getToken()
|
||||
So(len(tokens), ShouldEqual, 1)
|
||||
retToken()
|
||||
So(len(tokens), ShouldEqual, 2)
|
||||
})
|
||||
Convey("SessionID Generation", t, func() {
|
||||
sid1 := genSessionID()
|
||||
sid2 := genSessionID()
|
||||
|
|
|
@ -55,7 +55,7 @@ const (
|
|||
)
|
||||
|
||||
var (
|
||||
tokens chan bool
|
||||
tokens *tokens_t
|
||||
config webrtc.Configuration
|
||||
client http.Client
|
||||
)
|
||||
|
@ -171,14 +171,6 @@ func (c *webRTCConn) SetWriteDeadline(t time.Time) error {
|
|||
return fmt.Errorf("SetWriteDeadline not implemented")
|
||||
}
|
||||
|
||||
func getToken() {
|
||||
<-tokens
|
||||
}
|
||||
|
||||
func retToken() {
|
||||
tokens <- true
|
||||
}
|
||||
|
||||
func genSessionID() string {
|
||||
buf := make([]byte, sessionIDLength)
|
||||
_, err := rand.Read(buf)
|
||||
|
@ -204,6 +196,21 @@ type SignalingServer struct {
|
|||
keepLocalAddresses bool
|
||||
}
|
||||
|
||||
func newSignalingServer(rawURL string, keepLocalAddresses bool) (*SignalingServer, error) {
|
||||
var err error
|
||||
s := new(SignalingServer)
|
||||
s.keepLocalAddresses = keepLocalAddresses
|
||||
s.url, err = url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid broker url: %s", err)
|
||||
}
|
||||
|
||||
s.transport = http.DefaultTransport.(*http.Transport)
|
||||
s.transport.(*http.Transport).ResponseHeaderTimeout = 30 * time.Second
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *SignalingServer) Post(path string, payload io.Reader) ([]byte, error) {
|
||||
|
||||
req, err := http.NewRequest("POST", path, payload)
|
||||
|
@ -238,7 +245,8 @@ func (s *SignalingServer) pollOffer(sid string) *webrtc.SessionDescription {
|
|||
timeOfNextPoll = now
|
||||
}
|
||||
|
||||
body, err := messages.EncodePollRequest(sid, "standalone", currentNATType)
|
||||
numClients := int((tokens.count() / 8) * 8) // Round down to 8
|
||||
body, err := messages.EncodePollRequest(sid, "standalone", currentNATType, numClients)
|
||||
if err != nil {
|
||||
log.Printf("Error encoding poll message: %s", err.Error())
|
||||
return nil
|
||||
|
@ -323,7 +331,7 @@ func CopyLoop(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser) {
|
|||
// RemoteAddr). https://bugs.torproject.org/18628#comment:8
|
||||
func datachannelHandler(conn *webRTCConn, remoteAddr net.Addr) {
|
||||
defer conn.Close()
|
||||
defer retToken()
|
||||
defer tokens.ret()
|
||||
|
||||
u, err := url.Parse(relayURL)
|
||||
if err != nil {
|
||||
|
@ -494,14 +502,14 @@ func runSession(sid string) {
|
|||
offer := broker.pollOffer(sid)
|
||||
if offer == nil {
|
||||
log.Printf("bad offer from broker")
|
||||
retToken()
|
||||
tokens.ret()
|
||||
return
|
||||
}
|
||||
dataChan := make(chan struct{})
|
||||
pc, err := makePeerConnectionFromOffer(offer, config, dataChan, datachannelHandler)
|
||||
if err != nil {
|
||||
log.Printf("error making WebRTC connection: %s", err)
|
||||
retToken()
|
||||
tokens.ret()
|
||||
return
|
||||
}
|
||||
err = broker.sendAnswer(sid, pc)
|
||||
|
@ -510,7 +518,7 @@ func runSession(sid string) {
|
|||
if inerr := pc.Close(); inerr != nil {
|
||||
log.Printf("error calling pc.Close: %v", inerr)
|
||||
}
|
||||
retToken()
|
||||
tokens.ret()
|
||||
return
|
||||
}
|
||||
// Set a timeout on peerconnection. If the connection state has not
|
||||
|
@ -524,7 +532,7 @@ func runSession(sid string) {
|
|||
if err := pc.Close(); err != nil {
|
||||
log.Printf("error calling pc.Close: %v", err)
|
||||
}
|
||||
retToken()
|
||||
tokens.ret()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -536,7 +544,7 @@ func main() {
|
|||
var unsafeLogging bool
|
||||
var keepLocalAddresses bool
|
||||
|
||||
flag.UintVar(&capacity, "capacity", 10, "maximum concurrent clients")
|
||||
flag.UintVar(&capacity, "capacity", 0, "maximum concurrent clients")
|
||||
flag.StringVar(&rawBrokerURL, "broker", defaultBrokerURL, "broker URL")
|
||||
flag.StringVar(&relayURL, "relay", defaultRelayURL, "websocket relay URL")
|
||||
flag.StringVar(&stunURL, "stun", defaultSTUNURL, "stun URL")
|
||||
|
@ -565,12 +573,11 @@ func main() {
|
|||
log.Println("starting")
|
||||
|
||||
var err error
|
||||
broker = new(SignalingServer)
|
||||
broker.keepLocalAddresses = keepLocalAddresses
|
||||
broker.url, err = url.Parse(rawBrokerURL)
|
||||
broker, err = newSignalingServer(rawBrokerURL, keepLocalAddresses)
|
||||
if err != nil {
|
||||
log.Fatalf("invalid broker url: %s", err)
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = url.Parse(stunURL)
|
||||
if err != nil {
|
||||
log.Fatalf("invalid stun url: %s", err)
|
||||
|
@ -580,8 +587,6 @@ func main() {
|
|||
log.Fatalf("invalid relay url: %s", err)
|
||||
}
|
||||
|
||||
broker.transport = http.DefaultTransport.(*http.Transport)
|
||||
broker.transport.(*http.Transport).ResponseHeaderTimeout = 15 * time.Second
|
||||
config = webrtc.Configuration{
|
||||
ICEServers: []webrtc.ICEServer{
|
||||
{
|
||||
|
@ -589,17 +594,14 @@ func main() {
|
|||
},
|
||||
},
|
||||
}
|
||||
tokens = make(chan bool, capacity)
|
||||
for i := uint(0); i < capacity; i++ {
|
||||
tokens <- true
|
||||
}
|
||||
tokens = newTokens(capacity)
|
||||
|
||||
// use probetest to determine NAT compatability
|
||||
checkNATType(config, defaultProbeURL)
|
||||
log.Printf("NAT type: %s", currentNATType)
|
||||
|
||||
for {
|
||||
getToken()
|
||||
tokens.get()
|
||||
sessionID := genSessionID()
|
||||
runSession(sessionID)
|
||||
}
|
||||
|
@ -607,12 +609,7 @@ func main() {
|
|||
|
||||
func checkNATType(config webrtc.Configuration, probeURL string) {
|
||||
|
||||
var err error
|
||||
|
||||
probe := new(SignalingServer)
|
||||
probe.transport = http.DefaultTransport.(*http.Transport)
|
||||
probe.transport.(*http.Transport).ResponseHeaderTimeout = 30 * time.Second
|
||||
probe.url, err = url.Parse(probeURL)
|
||||
probe, err := newSignalingServer(probeURL, false)
|
||||
if err != nil {
|
||||
log.Printf("Error parsing url: %s", err.Error())
|
||||
}
|
||||
|
|
44
proxy/tokens.go
Normal file
44
proxy/tokens.go
Normal file
|
@ -0,0 +1,44 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type tokens_t struct {
|
||||
ch chan struct{}
|
||||
capacity uint
|
||||
clients int64
|
||||
}
|
||||
|
||||
func newTokens(capacity uint) *tokens_t {
|
||||
var ch chan struct{}
|
||||
if capacity != 0 {
|
||||
ch = make(chan struct{}, capacity)
|
||||
}
|
||||
|
||||
return &tokens_t{
|
||||
ch: ch,
|
||||
capacity: capacity,
|
||||
clients: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tokens_t) get() {
|
||||
atomic.AddInt64(&t.clients, 1)
|
||||
|
||||
if t.capacity != 0 {
|
||||
t.ch <- struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *tokens_t) ret() {
|
||||
atomic.AddInt64(&t.clients, -1)
|
||||
|
||||
if t.capacity != 0 {
|
||||
<-t.ch
|
||||
}
|
||||
}
|
||||
|
||||
func (t tokens_t) count() int64 {
|
||||
return atomic.LoadInt64(&t.clients)
|
||||
}
|
28
proxy/tokens_test.go
Normal file
28
proxy/tokens_test.go
Normal file
|
@ -0,0 +1,28 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
. "github.com/smartystreets/goconvey/convey"
|
||||
)
|
||||
|
||||
func TestTokens(t *testing.T) {
|
||||
Convey("Tokens", t, func() {
|
||||
tokens := newTokens(2)
|
||||
So(tokens.count(), ShouldEqual, 0)
|
||||
tokens.get()
|
||||
So(tokens.count(), ShouldEqual, 1)
|
||||
tokens.ret()
|
||||
So(tokens.count(), ShouldEqual, 0)
|
||||
})
|
||||
Convey("Tokens capacity 0", t, func() {
|
||||
tokens := newTokens(0)
|
||||
So(tokens.count(), ShouldEqual, 0)
|
||||
for i := 0; i < 20; i++ {
|
||||
tokens.get()
|
||||
}
|
||||
So(tokens.count(), ShouldEqual, 20)
|
||||
tokens.ret()
|
||||
So(tokens.count(), ShouldEqual, 19)
|
||||
})
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue