Simplify proxy NAT checking logic

This commit is contained in:
itchyonion 2024-05-17 14:26:21 -07:00
parent 54495ceb4e
commit 4ed5da7f2f
No known key found for this signature in database
GPG key ID: 4B87B720348500EA
2 changed files with 44 additions and 47 deletions

View file

@ -60,13 +60,13 @@ const (
)
const (
// NATUnknown represents a NAT type which is unknown.
// NATUnknown is set if the proxy cannot connect to probetest.
NATUnknown = "unknown"
// NATRestricted represents a restricted NAT.
// NATRestricted is set if the proxy times out when connecting to a symmetric NAT.
NATRestricted = "restricted"
// NATUnrestricted represents an unrestricted NAT.
// NATUnrestricted is set if the proxy successfully connects to a symmetric NAT.
NATUnrestricted = "unrestricted"
)
@ -99,6 +99,12 @@ func getCurrentNATType() string {
return currentNATType
}
func setCurrentNATType(newType string) {
currentNATTypeAccess.Lock()
defer currentNATTypeAccess.Unlock()
currentNATType = newType
}
var (
tokens *tokens_t
config webrtc.Configuration
@ -694,9 +700,13 @@ func (sf *SnowflakeProxy) Start() error {
}
tokens = newTokens(sf.Capacity)
sf.checkNATType(config, sf.NATProbeURL)
currentNATTypeLoaded := getCurrentNATType()
sf.EventDispatcher.OnNewSnowflakeEvent(&event.EventOnCurrentNATTypeDetermined{CurNATType: currentNATTypeLoaded})
err = sf.checkNATType(config, sf.NATProbeURL)
if err != nil {
// non-fatal error. Log it and continue
log.Printf(err.Error())
setCurrentNATType(NATUnknown)
}
sf.EventDispatcher.OnNewSnowflakeEvent(&event.EventOnCurrentNATTypeDetermined{CurNATType: getCurrentNATType()})
NatRetestTask := task.Periodic{
Interval: sf.NATTypeMeasurementInterval,
@ -704,6 +714,9 @@ func (sf *SnowflakeProxy) Start() error {
sf.checkNATType(config, sf.NATProbeURL)
return nil
},
OnError: func(err error) {
log.Printf("Periodic probetest failed: %s, retaining current NAT type: %s", err.Error(), getCurrentNATType())
},
}
if sf.NATTypeMeasurementInterval != 0 {
@ -735,87 +748,64 @@ func (sf *SnowflakeProxy) Stop() {
// checkNATType use probetest to determine NAT compatability by
// attempting to connect with a known symmetric NAT. If success,
// it is considered "unrestricted". If timeout it is considered "restricted"
func (sf *SnowflakeProxy) checkNATType(config webrtc.Configuration, probeURL string) {
func (sf *SnowflakeProxy) checkNATType(config webrtc.Configuration, probeURL string) error {
probe, err := newSignalingServer(probeURL, false)
if err != nil {
log.Printf("Error parsing url: %s", err.Error())
return fmt.Errorf("Error parsing url: %w", err)
}
dataChan := make(chan struct{})
pc, err := sf.makeNewPeerConnection(config, dataChan)
if err != nil {
log.Printf("error making WebRTC connection: %s", err)
return
return fmt.Errorf("Error making WebRTC connection: %w", err)
}
offer := pc.LocalDescription()
log.Printf("Probetest offer: \n\t%s", strings.ReplaceAll(offer.SDP, "\n", "\n\t"))
sdp, err := util.SerializeSessionDescription(offer)
if err != nil {
log.Printf("Error encoding probe message: %s", err.Error())
return
return fmt.Errorf("Error encoding probe message: %w", err)
}
// send offer
body, err := messages.EncodePollResponse(sdp, true, "")
if err != nil {
log.Printf("Error encoding probe message: %s", err.Error())
return
return fmt.Errorf("Error encoding probe message: %w", err)
}
resp, err := probe.Post(probe.url.String(), bytes.NewBuffer(body))
if err != nil {
log.Printf("error polling probe: %s", err.Error())
return
return fmt.Errorf("Error polling probe: %w", err)
}
sdp, _, err = messages.DecodeAnswerRequest(resp)
if err != nil {
log.Printf("Error reading probe response: %s", err.Error())
return
return fmt.Errorf("Error reading probe response: %w", err)
}
answer, err := util.DeserializeSessionDescription(sdp)
if err != nil {
log.Printf("Error setting answer: %s", err.Error())
return
return fmt.Errorf("Error setting answer: %w", err)
}
err = pc.SetRemoteDescription(*answer)
if err != nil {
log.Printf("Error setting answer: %s", err.Error())
return
return fmt.Errorf("Error setting answer: %w", err)
}
currentNATTypeLoaded := getCurrentNATType()
prevNATType := getCurrentNATType()
currentNATTypeTestResult := NATUnknown
select {
case <-dataChan:
currentNATTypeTestResult = NATUnrestricted
setCurrentNATType(NATUnrestricted)
case <-time.After(dataChannelTimeout):
currentNATTypeTestResult = NATRestricted
setCurrentNATType(NATRestricted)
}
currentNATTypeToStore := NATUnknown
switch currentNATTypeLoaded + "->" + currentNATTypeTestResult {
case NATUnrestricted + "->" + NATUnknown:
currentNATTypeToStore = NATUnrestricted
case NATRestricted + "->" + NATUnknown:
currentNATTypeToStore = NATRestricted
default:
currentNATTypeToStore = currentNATTypeTestResult
}
log.Printf("NAT Type measurement: %v -> %v = %v\n", currentNATTypeLoaded, currentNATTypeTestResult, currentNATTypeToStore)
currentNATTypeAccess.Lock()
currentNATType = currentNATTypeToStore
currentNATTypeAccess.Unlock()
log.Printf("NAT Type measurement: %v -> %v\n", prevNATType, getCurrentNATType())
if err := pc.Close(); err != nil {
log.Printf("error calling pc.Close: %v", err)
}
return nil
}