Implement NAT discover for go standalone proxies

This commit is contained in:
Cecylia Bocovich 2020-06-16 17:10:56 -04:00
parent bf924445e3
commit f6cf9a453b
4 changed files with 77 additions and 17 deletions

View file

@ -170,7 +170,7 @@ func proxyPolls(ctx *BrokerContext, w http.ResponseWriter, r *http.Request) {
return return
} }
sid, proxyType, err := messages.DecodePollRequest(body) sid, proxyType, _, err := messages.DecodePollRequest(body)
if err != nil { if err != nil {
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
return return

View file

@ -9,15 +9,16 @@ import (
"strings" "strings"
) )
const version = "1.1" const version = "1.2"
/* Version 1.1 specification: /* Version 1.2 specification:
== ProxyPollRequest == == ProxyPollRequest ==
{ {
Sid: [generated session id of proxy], Sid: [generated session id of proxy],
Version: 1.1, Version: 1.2,
Type: ["badge"|"webext"|"standalone"] Type: ["badge"|"webext"|"standalone"]
NAT: ["unknown"|"restricted"|"unrestricted"]
} }
== ProxyPollResponse == == ProxyPollResponse ==
@ -44,7 +45,7 @@ HTTP 400 BadRequest
== ProxyAnswerRequest == == ProxyAnswerRequest ==
{ {
Sid: [generated session id of proxy], Sid: [generated session id of proxy],
Version: 1.1, Version: 1.2,
Answer: Answer:
{ {
type: answer, type: answer,
@ -76,37 +77,44 @@ type ProxyPollRequest struct {
Sid string Sid string
Version string Version string
Type string Type string
NAT string
} }
func EncodePollRequest(sid string, proxyType string) ([]byte, error) { func EncodePollRequest(sid string, proxyType string, natType string) ([]byte, error) {
return json.Marshal(ProxyPollRequest{ return json.Marshal(ProxyPollRequest{
Sid: sid, Sid: sid,
Version: version, Version: version,
Type: proxyType, Type: proxyType,
NAT: natType,
}) })
} }
// Decodes a poll message from a snowflake proxy and returns the // Decodes a poll message from a snowflake proxy and returns the
// sid and proxy type of the proxy on success and an error if it failed // sid and proxy type of the proxy on success and an error if it failed
func DecodePollRequest(data []byte) (string, string, error) { func DecodePollRequest(data []byte) (string, string, string, error) {
var message ProxyPollRequest var message ProxyPollRequest
err := json.Unmarshal(data, &message) err := json.Unmarshal(data, &message)
if err != nil { if err != nil {
return "", "", err return "", "", "", err
} }
majorVersion := strings.Split(message.Version, ".")[0] majorVersion := strings.Split(message.Version, ".")[0]
if majorVersion != "1" { if majorVersion != "1" {
return "", "", fmt.Errorf("using unknown version") return "", "", "", fmt.Errorf("using unknown version")
} }
// Version 1.x requires an Sid // Version 1.x requires an Sid
if message.Sid == "" { if message.Sid == "" {
return "", "", fmt.Errorf("no supplied session id") return "", "", "", fmt.Errorf("no supplied session id")
} }
return message.Sid, message.Type, nil natType := message.NAT
if natType == "" {
natType = "unknown"
}
return message.Sid, message.Type, natType, nil
} }
type ProxyPollResponse struct { type ProxyPollResponse struct {
@ -159,7 +167,7 @@ type ProxyAnswerRequest struct {
func EncodeAnswerRequest(answer string, sid string) ([]byte, error) { func EncodeAnswerRequest(answer string, sid string) ([]byte, error) {
return json.Marshal(ProxyAnswerRequest{ return json.Marshal(ProxyAnswerRequest{
Version: "1.1", Version: version,
Sid: sid, Sid: sid,
Answer: answer, Answer: answer,
}) })

View file

@ -13,6 +13,7 @@ func TestDecodeProxyPollRequest(t *testing.T) {
for _, test := range []struct { for _, test := range []struct {
sid string sid string
proxyType string proxyType string
natType string
data string data string
err error err error
}{ }{
@ -20,6 +21,7 @@ func TestDecodeProxyPollRequest(t *testing.T) {
//Version 1.0 proxy message //Version 1.0 proxy message
"ymbcCMto7KHNGYlp", "ymbcCMto7KHNGYlp",
"", "",
"unknown",
`{"Sid":"ymbcCMto7KHNGYlp","Version":"1.0"}`, `{"Sid":"ymbcCMto7KHNGYlp","Version":"1.0"}`,
nil, nil,
}, },
@ -27,44 +29,59 @@ func TestDecodeProxyPollRequest(t *testing.T) {
//Version 1.1 proxy message //Version 1.1 proxy message
"ymbcCMto7KHNGYlp", "ymbcCMto7KHNGYlp",
"standalone", "standalone",
"unknown",
`{"Sid":"ymbcCMto7KHNGYlp","Version":"1.1","Type":"standalone"}`, `{"Sid":"ymbcCMto7KHNGYlp","Version":"1.1","Type":"standalone"}`,
nil, nil,
}, },
{
//Version 1.2 proxy message
"ymbcCMto7KHNGYlp",
"standalone",
"restricted",
`{"Sid":"ymbcCMto7KHNGYlp","Version":"1.2","Type":"standalone", "NAT":"restricted"}`,
nil,
},
{ {
//Version 0.X proxy message: //Version 0.X proxy message:
"", "",
"", "",
"ymbcCMto7KHNGYlp", "",
"",
&json.SyntaxError{}, &json.SyntaxError{},
}, },
{ {
"",
"", "",
"", "",
`{"Sid":"ymbcCMto7KHNGYlp"}`, `{"Sid":"ymbcCMto7KHNGYlp"}`,
fmt.Errorf(""), fmt.Errorf(""),
}, },
{ {
"",
"", "",
"", "",
"{}", "{}",
fmt.Errorf(""), fmt.Errorf(""),
}, },
{ {
"",
"", "",
"", "",
`{"Version":"1.0"}`, `{"Version":"1.0"}`,
fmt.Errorf(""), fmt.Errorf(""),
}, },
{ {
"",
"", "",
"", "",
`{"Version":"2.0"}`, `{"Version":"2.0"}`,
fmt.Errorf(""), fmt.Errorf(""),
}, },
} { } {
sid, proxyType, err := DecodePollRequest([]byte(test.data)) sid, proxyType, natType, err := DecodePollRequest([]byte(test.data))
So(sid, ShouldResemble, test.sid) So(sid, ShouldResemble, test.sid)
So(proxyType, ShouldResemble, test.proxyType) So(proxyType, ShouldResemble, test.proxyType)
So(natType, ShouldResemble, test.natType)
So(err, ShouldHaveSameTypeAs, test.err) So(err, ShouldHaveSameTypeAs, test.err)
} }
@ -73,11 +90,12 @@ func TestDecodeProxyPollRequest(t *testing.T) {
func TestEncodeProxyPollRequests(t *testing.T) { func TestEncodeProxyPollRequests(t *testing.T) {
Convey("Context", t, func() { Convey("Context", t, func() {
b, err := EncodePollRequest("ymbcCMto7KHNGYlp", "standalone") b, err := EncodePollRequest("ymbcCMto7KHNGYlp", "standalone", "unknown")
So(err, ShouldEqual, nil) So(err, ShouldEqual, nil)
sid, proxyType, err := DecodePollRequest(b) sid, proxyType, natType, err := DecodePollRequest(b)
So(sid, ShouldEqual, "ymbcCMto7KHNGYlp") So(sid, ShouldEqual, "ymbcCMto7KHNGYlp")
So(proxyType, ShouldEqual, "standalone") So(proxyType, ShouldEqual, "standalone")
So(natType, ShouldEqual, "unknown")
So(err, ShouldEqual, nil) So(err, ShouldEqual, nil)
}) })
} }

View file

@ -19,6 +19,7 @@ import (
"time" "time"
"git.torproject.org/pluggable-transports/snowflake.git/common/messages" "git.torproject.org/pluggable-transports/snowflake.git/common/messages"
"git.torproject.org/pluggable-transports/snowflake.git/common/nat"
"git.torproject.org/pluggable-transports/snowflake.git/common/safelog" "git.torproject.org/pluggable-transports/snowflake.git/common/safelog"
"git.torproject.org/pluggable-transports/snowflake.git/common/util" "git.torproject.org/pluggable-transports/snowflake.git/common/util"
"git.torproject.org/pluggable-transports/snowflake.git/common/websocketconn" "git.torproject.org/pluggable-transports/snowflake.git/common/websocketconn"
@ -30,6 +31,11 @@ const defaultBrokerURL = "https://snowflake-broker.bamsoftware.com/"
const defaultRelayURL = "wss://snowflake.bamsoftware.com/" const defaultRelayURL = "wss://snowflake.bamsoftware.com/"
const defaultSTUNURL = "stun:stun.l.google.com:19302" const defaultSTUNURL = "stun:stun.l.google.com:19302"
const pollInterval = 5 * time.Second const pollInterval = 5 * time.Second
const (
NATUnknown = "unknown"
NATRestricted = "restricted"
NATUnrestricted = "unrestricted"
)
//amount of time after sending an SDP answer before the proxy assumes the //amount of time after sending an SDP answer before the proxy assumes the
//client is not going to connect //client is not going to connect
@ -40,6 +46,8 @@ const readLimit = 100000 //Maximum number of bytes to be read from an HTTP reque
var broker *Broker var broker *Broker
var relayURL string var relayURL string
var currentNATType = NATUnknown
const ( const (
sessionIDLength = 16 sessionIDLength = 16
) )
@ -174,7 +182,7 @@ func (b *Broker) pollOffer(sid string) *webrtc.SessionDescription {
timeOfNextPoll = now timeOfNextPoll = now
} }
body, err := messages.EncodePollRequest(sid, "standalone") body, err := messages.EncodePollRequest(sid, "standalone", currentNATType)
if err != nil { if err != nil {
log.Printf("Error encoding poll message: %s", err.Error()) log.Printf("Error encoding poll message: %s", err.Error())
return nil return nil
@ -485,9 +493,35 @@ func main() {
tokens <- true tokens <- true
} }
// determine NAT type before polling
updateNATType(config.ICEServers)
log.Printf("NAT type: %s", currentNATType)
for { for {
getToken() getToken()
sessionID := genSessionID() sessionID := genSessionID()
runSession(sessionID) runSession(sessionID)
} }
} }
// use provided STUN server(s) to determine NAT type
func updateNATType(servers []webrtc.ICEServer) {
var restrictedNAT bool
var err error
for _, server := range servers {
addr := strings.TrimPrefix(server.URLs[0], "stun:")
restrictedNAT, err = nat.CheckIfRestrictedNAT(addr)
if err == nil {
if restrictedNAT {
currentNATType = NATRestricted
} else {
currentNATType = NATUnrestricted
}
break
}
}
if err != nil {
currentNATType = NATUnknown
}
}