Add RelayURL support in proxy

This commit is contained in:
Shelikhoo 2022-04-07 21:32:55 +01:00
parent 613ceaf970
commit 863a8296e8
No known key found for this signature in database
GPG key ID: C4D5E79D22B25316
2 changed files with 33 additions and 14 deletions

View file

@ -365,7 +365,7 @@ func TestBrokerInteractions(t *testing.T) {
b, b,
} }
sdp := broker.pollOffer(sampleOffer, DefaultProxyType, nil) sdp, _ := broker.pollOffer(sampleOffer, DefaultProxyType, "", nil)
expectedSDP, _ := strconv.Unquote(sampleSDP) expectedSDP, _ := strconv.Unquote(sampleSDP)
So(sdp.SDP, ShouldResemble, expectedSDP) So(sdp.SDP, ShouldResemble, expectedSDP)
}) })
@ -379,7 +379,7 @@ func TestBrokerInteractions(t *testing.T) {
b, b,
} }
sdp := broker.pollOffer(sampleOffer, DefaultProxyType, nil) sdp, _ := broker.pollOffer(sampleOffer, DefaultProxyType, "", nil)
So(sdp, ShouldBeNil) So(sdp, ShouldBeNil)
}) })
Convey("sends answer to broker", func() { Convey("sends answer to broker", func() {

View file

@ -112,6 +112,12 @@ type SnowflakeProxy struct {
KeepLocalAddresses bool KeepLocalAddresses bool
// RelayURL is the URL of the Snowflake server that all traffic will be relayed to // RelayURL is the URL of the Snowflake server that all traffic will be relayed to
RelayURL string RelayURL string
// RelayDomainNamePattern is the pattern specify allowed domain name for relay
// If the pattern starts with ^ then an exact match is required.
// The rest of pattern is the suffix of domain name.
// There is no look ahead assertion when matching domain name suffix,
// thus the string prepend the suffix does not need to be empty or ends with a dot.
RelayDomainNamePattern string
// NATProbeURL is the URL of the probe service we use for NAT checks // NATProbeURL is the URL of the probe service we use for NAT checks
NATProbeURL string NATProbeURL string
// NATTypeMeasurementInterval is time before NAT type is retested // NATTypeMeasurementInterval is time before NAT type is retested
@ -188,7 +194,7 @@ func (s *SignalingServer) Post(path string, payload io.Reader) ([]byte, error) {
return limitedRead(resp.Body, readLimit) return limitedRead(resp.Body, readLimit)
} }
func (s *SignalingServer) pollOffer(sid string, proxyType string, shutdown chan struct{}) *webrtc.SessionDescription { func (s *SignalingServer) pollOffer(sid string, proxyType string, acceptedRelayPattern string, shutdown chan struct{}) (*webrtc.SessionDescription, string) {
brokerPath := s.url.ResolveReference(&url.URL{Path: "proxy"}) brokerPath := s.url.ResolveReference(&url.URL{Path: "proxy"})
ticker := time.NewTicker(pollInterval) ticker := time.NewTicker(pollInterval)
@ -198,38 +204,38 @@ func (s *SignalingServer) pollOffer(sid string, proxyType string, shutdown chan
for ; true; <-ticker.C { for ; true; <-ticker.C {
select { select {
case <-shutdown: case <-shutdown:
return nil return nil, ""
default: default:
numClients := int((tokens.count() / 8) * 8) // Round down to 8 numClients := int((tokens.count() / 8) * 8) // Round down to 8
currentNATTypeLoaded := getCurrentNATType() currentNATTypeLoaded := getCurrentNATType()
body, err := messages.EncodeProxyPollRequest(sid, proxyType, currentNATTypeLoaded, numClients) body, err := messages.EncodeProxyPollRequest(sid, proxyType, currentNATTypeLoaded, numClients)
if err != nil { if err != nil {
log.Printf("Error encoding poll message: %s", err.Error()) log.Printf("Error encoding poll message: %s", err.Error())
return nil return nil, ""
} }
resp, err := s.Post(brokerPath.String(), bytes.NewBuffer(body)) resp, err := s.Post(brokerPath.String(), bytes.NewBuffer(body))
if err != nil { if err != nil {
log.Printf("error polling broker: %s", err.Error()) log.Printf("error polling broker: %s", err.Error())
} }
offer, _, err := messages.DecodePollResponse(resp) offer, _, relayURL, err := messages.DecodePollResponseWithRelayURL(resp)
if err != nil { if err != nil {
log.Printf("Error reading broker response: %s", err.Error()) log.Printf("Error reading broker response: %s", err.Error())
log.Printf("body: %s", resp) log.Printf("body: %s", resp)
return nil return nil, ""
} }
if offer != "" { if offer != "" {
offer, err := util.DeserializeSessionDescription(offer) offer, err := util.DeserializeSessionDescription(offer)
if err != nil { if err != nil {
log.Printf("Error processing session description: %s", err.Error()) log.Printf("Error processing session description: %s", err.Error())
return nil return nil, ""
} }
return offer return offer, relayURL
} }
} }
} }
return nil return nil, ""
} }
func (s *SignalingServer) sendAnswer(sid string, pc *webrtc.PeerConnection) error { func (s *SignalingServer) sendAnswer(sid string, pc *webrtc.PeerConnection) error {
@ -295,11 +301,14 @@ func copyLoop(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser, shutdown chan struct
// conn.RemoteAddr() inside this function, as a workaround for a hang that // conn.RemoteAddr() inside this function, as a workaround for a hang that
// otherwise occurs inside of conn.pc.RemoteDescription() (called by // otherwise occurs inside of conn.pc.RemoteDescription() (called by
// RemoteAddr). https://bugs.torproject.org/18628#comment:8 // RemoteAddr). https://bugs.torproject.org/18628#comment:8
func (sf *SnowflakeProxy) datachannelHandler(conn *webRTCConn, remoteAddr net.Addr) { func (sf *SnowflakeProxy) datachannelHandler(conn *webRTCConn, remoteAddr net.Addr, relayURL string) {
defer conn.Close() defer conn.Close()
defer tokens.ret() defer tokens.ret()
u, err := url.Parse(sf.RelayURL) if relayURL == "" {
relayURL = sf.RelayURL
}
u, err := url.Parse(relayURL)
if err != nil { if err != nil {
log.Fatalf("invalid relay url: %s", err) log.Fatalf("invalid relay url: %s", err)
} }
@ -326,6 +335,15 @@ func (sf *SnowflakeProxy) datachannelHandler(conn *webRTCConn, remoteAddr net.Ad
log.Printf("datachannelHandler ends") log.Printf("datachannelHandler ends")
} }
type dataChannelHandlerWithRelayURL struct {
RelayURL string
sf *SnowflakeProxy
}
func (d dataChannelHandlerWithRelayURL) datachannelHandler(conn *webRTCConn, remoteAddr net.Addr) {
d.sf.datachannelHandler(conn, remoteAddr, d.RelayURL)
}
// Create a PeerConnection from an SDP offer. Blocks until the gathering of ICE // Create a PeerConnection from an SDP offer. Blocks until the gathering of ICE
// candidates is complete and the answer is available in LocalDescription. // candidates is complete and the answer is available in LocalDescription.
// Installs an OnDataChannel callback that creates a webRTCConn and passes it to // Installs an OnDataChannel callback that creates a webRTCConn and passes it to
@ -470,14 +488,15 @@ func (sf *SnowflakeProxy) makeNewPeerConnection(config webrtc.Configuration,
} }
func (sf *SnowflakeProxy) runSession(sid string) { func (sf *SnowflakeProxy) runSession(sid string) {
offer := broker.pollOffer(sid, sf.ProxyType, sf.shutdown) offer, relayURL := broker.pollOffer(sid, sf.ProxyType, sf.RelayDomainNamePattern, sf.shutdown)
if offer == nil { if offer == nil {
log.Printf("bad offer from broker") log.Printf("bad offer from broker")
tokens.ret() tokens.ret()
return return
} }
dataChan := make(chan struct{}) dataChan := make(chan struct{})
pc, err := sf.makePeerConnectionFromOffer(offer, config, dataChan, sf.datachannelHandler) dataChannelAdaptor := dataChannelHandlerWithRelayURL{RelayURL: relayURL, sf: sf}
pc, err := sf.makePeerConnectionFromOffer(offer, config, dataChan, dataChannelAdaptor.datachannelHandler)
if err != nil { if err != nil {
log.Printf("error making WebRTC connection: %s", err) log.Printf("error making WebRTC connection: %s", err)
tokens.ret() tokens.ret()