diff --git a/proxy/snowflake.go b/proxy/snowflake.go index 276ebed..f0fa2c0 100644 --- a/proxy/snowflake.go +++ b/proxy/snowflake.go @@ -19,7 +19,6 @@ import ( "time" "git.torproject.org/pluggable-transports/snowflake.git/common/messages" - "git.torproject.org/pluggable-transports/snowflake.git/common/nat" "git.torproject.org/pluggable-transports/snowflake.git/common/safelog" "git.torproject.org/pluggable-transports/snowflake.git/common/util" "git.torproject.org/pluggable-transports/snowflake.git/common/websocketconn" @@ -29,6 +28,7 @@ import ( ) const defaultBrokerURL = "https://snowflake-broker.bamsoftware.com/" +const defaultProbeURL = "https://snowflake-broker.torproject.net:8443/probe" const defaultRelayURL = "wss://snowflake.bamsoftware.com/" const defaultSTUNURL = "stun:stun.stunprotocol.org:3478" const pollInterval = 5 * time.Second @@ -427,6 +427,48 @@ func makePeerConnectionFromOffer(sdp *webrtc.SessionDescription, return pc, nil } +// Create a new PeerConnection. Blocks until the gathering of ICE +// candidates is complete and the answer is available in LocalDescription. +func makeNewPeerConnection(config webrtc.Configuration, + dataChan chan struct{}) (*webrtc.PeerConnection, error) { + + pc, err := webrtc.NewPeerConnection(config) + if err != nil { + return nil, fmt.Errorf("accept: NewPeerConnection: %s", err) + } + + offer, err := pc.CreateOffer(nil) + // TODO: Potentially timeout and retry if ICE isn't working. + if err != nil { + log.Println("Failed to prepare offer", err) + pc.Close() + return nil, err + } + log.Println("WebRTC: Created offer") + err = pc.SetLocalDescription(offer) + if err != nil { + log.Println("Failed to prepare offer", err) + pc.Close() + return nil, err + } + log.Println("WebRTC: Set local description") + + dc, err := pc.CreateDataChannel("test", &webrtc.DataChannelInit{}) + if err != nil { + log.Printf("CreateDataChannel ERROR: %s", err) + return nil, err + } + dc.OnOpen(func() { + log.Println("WebRTC: DataChannel.OnOpen") + close(dataChan) + }) + dc.OnClose(func() { + log.Println("WebRTC: DataChannel.OnClose") + dc.Close() + }) + return pc, nil +} + func runSession(sid string) { offer := broker.pollOffer(sid) if offer == nil { @@ -531,8 +573,8 @@ func main() { tokens <- true } - // determine NAT type before polling - updateNATType(config.ICEServers) + // use probetest to determine NAT compatability + checkNATType(config, defaultProbeURL) log.Printf("NAT type: %s", currentNATType) for { @@ -542,24 +584,69 @@ func main() { } } -// use provided STUN server(s) to determine NAT type -func updateNATType(servers []webrtc.ICEServer) { +func checkNATType(config webrtc.Configuration, probeURL string) { - var restrictedNAT bool var err error - for _, server := range servers { - addr := strings.TrimPrefix(server.URLs[0], "stun:") - restrictedNAT, err = nat.CheckIfRestrictedNAT(addr) - if err == nil { - if restrictedNAT { - currentNATType = NATRestricted - } else { - currentNATType = NATUnrestricted - } - break - } - } + + probe := new(SignalingServer) + probe.transport = http.DefaultTransport.(*http.Transport) + probe.transport.(*http.Transport).ResponseHeaderTimeout = 30 * time.Second + probe.url, err = url.Parse(probeURL) if err != nil { - currentNATType = NATUnknown + log.Printf("Error parsing url: %s", err.Error()) } + + // create offer + dataChan := make(chan struct{}) + pc, err := makeNewPeerConnection(config, dataChan) + if err != nil { + log.Printf("error making WebRTC connection: %s", err) + return + } + + offer := pc.LocalDescription() + sdp, err := util.SerializeSessionDescription(offer) + if err != nil { + log.Printf("Error encoding probe message: %s", err.Error()) + return + } + + // send offer + body, err := messages.EncodePollResponse(sdp, true, "") + if err != nil { + log.Printf("Error encoding probe message: %s", err.Error()) + return + } + resp, err := probe.Post(probe.url.String(), bytes.NewBuffer(body)) + if err != nil { + log.Printf("error polling probe: %s", err.Error()) + return + } + + sdp, _, err = messages.DecodeAnswerRequest(resp) + if err != nil { + log.Printf("Error reading probe response: %s", err.Error()) + return + } + answer, err := util.DeserializeSessionDescription(sdp) + if err != nil { + log.Printf("Error setting answer: %s", err.Error()) + return + } + err = pc.SetRemoteDescription(*answer) + if err != nil { + log.Printf("Error setting answer: %s", err.Error()) + return + } + + select { + case <-dataChan: + currentNATType = NATUnrestricted + case <-time.After(dataChannelTimeout): + currentNATType = NATRestricted + } + if err := pc.Close(); err != nil { + log.Printf("error calling pc.Close: %v", err) + } + }