mirror of
https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake.git
synced 2025-10-13 11:11:30 -04:00
Simplify proxy NAT checking logic
This commit is contained in:
parent
54495ceb4e
commit
4ed5da7f2f
2 changed files with 44 additions and 47 deletions
|
@ -60,13 +60,13 @@ const (
|
|||
)
|
||||
|
||||
const (
|
||||
// NATUnknown represents a NAT type which is unknown.
|
||||
// NATUnknown is set if the proxy cannot connect to probetest.
|
||||
NATUnknown = "unknown"
|
||||
|
||||
// NATRestricted represents a restricted NAT.
|
||||
// NATRestricted is set if the proxy times out when connecting to a symmetric NAT.
|
||||
NATRestricted = "restricted"
|
||||
|
||||
// NATUnrestricted represents an unrestricted NAT.
|
||||
// NATUnrestricted is set if the proxy successfully connects to a symmetric NAT.
|
||||
NATUnrestricted = "unrestricted"
|
||||
)
|
||||
|
||||
|
@ -99,6 +99,12 @@ func getCurrentNATType() string {
|
|||
return currentNATType
|
||||
}
|
||||
|
||||
func setCurrentNATType(newType string) {
|
||||
currentNATTypeAccess.Lock()
|
||||
defer currentNATTypeAccess.Unlock()
|
||||
currentNATType = newType
|
||||
}
|
||||
|
||||
var (
|
||||
tokens *tokens_t
|
||||
config webrtc.Configuration
|
||||
|
@ -694,9 +700,13 @@ func (sf *SnowflakeProxy) Start() error {
|
|||
}
|
||||
tokens = newTokens(sf.Capacity)
|
||||
|
||||
sf.checkNATType(config, sf.NATProbeURL)
|
||||
currentNATTypeLoaded := getCurrentNATType()
|
||||
sf.EventDispatcher.OnNewSnowflakeEvent(&event.EventOnCurrentNATTypeDetermined{CurNATType: currentNATTypeLoaded})
|
||||
err = sf.checkNATType(config, sf.NATProbeURL)
|
||||
if err != nil {
|
||||
// non-fatal error. Log it and continue
|
||||
log.Printf(err.Error())
|
||||
setCurrentNATType(NATUnknown)
|
||||
}
|
||||
sf.EventDispatcher.OnNewSnowflakeEvent(&event.EventOnCurrentNATTypeDetermined{CurNATType: getCurrentNATType()})
|
||||
|
||||
NatRetestTask := task.Periodic{
|
||||
Interval: sf.NATTypeMeasurementInterval,
|
||||
|
@ -704,6 +714,9 @@ func (sf *SnowflakeProxy) Start() error {
|
|||
sf.checkNATType(config, sf.NATProbeURL)
|
||||
return nil
|
||||
},
|
||||
OnError: func(err error) {
|
||||
log.Printf("Periodic probetest failed: %s, retaining current NAT type: %s", err.Error(), getCurrentNATType())
|
||||
},
|
||||
}
|
||||
|
||||
if sf.NATTypeMeasurementInterval != 0 {
|
||||
|
@ -735,87 +748,64 @@ func (sf *SnowflakeProxy) Stop() {
|
|||
// checkNATType use probetest to determine NAT compatability by
|
||||
// attempting to connect with a known symmetric NAT. If success,
|
||||
// it is considered "unrestricted". If timeout it is considered "restricted"
|
||||
func (sf *SnowflakeProxy) checkNATType(config webrtc.Configuration, probeURL string) {
|
||||
func (sf *SnowflakeProxy) checkNATType(config webrtc.Configuration, probeURL string) error {
|
||||
probe, err := newSignalingServer(probeURL, false)
|
||||
if err != nil {
|
||||
log.Printf("Error parsing url: %s", err.Error())
|
||||
return fmt.Errorf("Error parsing url: %w", err)
|
||||
}
|
||||
|
||||
dataChan := make(chan struct{})
|
||||
pc, err := sf.makeNewPeerConnection(config, dataChan)
|
||||
if err != nil {
|
||||
log.Printf("error making WebRTC connection: %s", err)
|
||||
return
|
||||
return fmt.Errorf("Error making WebRTC connection: %w", err)
|
||||
}
|
||||
|
||||
offer := pc.LocalDescription()
|
||||
log.Printf("Probetest offer: \n\t%s", strings.ReplaceAll(offer.SDP, "\n", "\n\t"))
|
||||
sdp, err := util.SerializeSessionDescription(offer)
|
||||
if err != nil {
|
||||
log.Printf("Error encoding probe message: %s", err.Error())
|
||||
return
|
||||
return fmt.Errorf("Error encoding probe message: %w", err)
|
||||
}
|
||||
|
||||
// send offer
|
||||
body, err := messages.EncodePollResponse(sdp, true, "")
|
||||
if err != nil {
|
||||
log.Printf("Error encoding probe message: %s", err.Error())
|
||||
return
|
||||
return fmt.Errorf("Error encoding probe message: %w", err)
|
||||
}
|
||||
|
||||
resp, err := probe.Post(probe.url.String(), bytes.NewBuffer(body))
|
||||
if err != nil {
|
||||
log.Printf("error polling probe: %s", err.Error())
|
||||
return
|
||||
return fmt.Errorf("Error polling probe: %w", err)
|
||||
}
|
||||
|
||||
sdp, _, err = messages.DecodeAnswerRequest(resp)
|
||||
if err != nil {
|
||||
log.Printf("Error reading probe response: %s", err.Error())
|
||||
return
|
||||
return fmt.Errorf("Error reading probe response: %w", err)
|
||||
}
|
||||
|
||||
answer, err := util.DeserializeSessionDescription(sdp)
|
||||
if err != nil {
|
||||
log.Printf("Error setting answer: %s", err.Error())
|
||||
return
|
||||
return fmt.Errorf("Error setting answer: %w", err)
|
||||
}
|
||||
|
||||
err = pc.SetRemoteDescription(*answer)
|
||||
if err != nil {
|
||||
log.Printf("Error setting answer: %s", err.Error())
|
||||
return
|
||||
return fmt.Errorf("Error setting answer: %w", err)
|
||||
}
|
||||
|
||||
currentNATTypeLoaded := getCurrentNATType()
|
||||
prevNATType := getCurrentNATType()
|
||||
|
||||
currentNATTypeTestResult := NATUnknown
|
||||
select {
|
||||
case <-dataChan:
|
||||
currentNATTypeTestResult = NATUnrestricted
|
||||
setCurrentNATType(NATUnrestricted)
|
||||
case <-time.After(dataChannelTimeout):
|
||||
currentNATTypeTestResult = NATRestricted
|
||||
setCurrentNATType(NATRestricted)
|
||||
}
|
||||
|
||||
currentNATTypeToStore := NATUnknown
|
||||
switch currentNATTypeLoaded + "->" + currentNATTypeTestResult {
|
||||
case NATUnrestricted + "->" + NATUnknown:
|
||||
currentNATTypeToStore = NATUnrestricted
|
||||
|
||||
case NATRestricted + "->" + NATUnknown:
|
||||
currentNATTypeToStore = NATRestricted
|
||||
|
||||
default:
|
||||
currentNATTypeToStore = currentNATTypeTestResult
|
||||
}
|
||||
|
||||
log.Printf("NAT Type measurement: %v -> %v = %v\n", currentNATTypeLoaded, currentNATTypeTestResult, currentNATTypeToStore)
|
||||
|
||||
currentNATTypeAccess.Lock()
|
||||
currentNATType = currentNATTypeToStore
|
||||
currentNATTypeAccess.Unlock()
|
||||
log.Printf("NAT Type measurement: %v -> %v\n", prevNATType, getCurrentNATType())
|
||||
|
||||
if err := pc.Close(); err != nil {
|
||||
log.Printf("error calling pc.Close: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue