Simplify proxy NAT checking logic

This commit is contained in:
itchyonion 2024-05-17 14:26:21 -07:00
parent 54495ceb4e
commit 4ed5da7f2f
No known key found for this signature in database
GPG key ID: 4B87B720348500EA
2 changed files with 44 additions and 47 deletions

View file

@ -36,6 +36,8 @@ type Periodic struct {
Interval time.Duration Interval time.Duration
// Execute is the task function // Execute is the task function
Execute func() error Execute func() error
// OnError handles the error of the task
OnError func(error)
access sync.Mutex access sync.Mutex
timer *time.Timer timer *time.Timer
@ -55,10 +57,15 @@ func (t *Periodic) checkedExecute() error {
} }
if err := t.Execute(); err != nil { if err := t.Execute(); err != nil {
t.access.Lock() if t.OnError != nil {
t.running = false t.OnError(err)
t.access.Unlock() } else {
return err // default error handling is to shut down the task
t.access.Lock()
t.running = false
t.access.Unlock()
return err
}
} }
t.access.Lock() t.access.Lock()

View file

@ -60,13 +60,13 @@ const (
) )
const ( const (
// NATUnknown represents a NAT type which is unknown. // NATUnknown is set if the proxy cannot connect to probetest.
NATUnknown = "unknown" NATUnknown = "unknown"
// NATRestricted represents a restricted NAT. // NATRestricted is set if the proxy times out when connecting to a symmetric NAT.
NATRestricted = "restricted" NATRestricted = "restricted"
// NATUnrestricted represents an unrestricted NAT. // NATUnrestricted is set if the proxy successfully connects to a symmetric NAT.
NATUnrestricted = "unrestricted" NATUnrestricted = "unrestricted"
) )
@ -99,6 +99,12 @@ func getCurrentNATType() string {
return currentNATType return currentNATType
} }
func setCurrentNATType(newType string) {
currentNATTypeAccess.Lock()
defer currentNATTypeAccess.Unlock()
currentNATType = newType
}
var ( var (
tokens *tokens_t tokens *tokens_t
config webrtc.Configuration config webrtc.Configuration
@ -694,9 +700,13 @@ func (sf *SnowflakeProxy) Start() error {
} }
tokens = newTokens(sf.Capacity) tokens = newTokens(sf.Capacity)
sf.checkNATType(config, sf.NATProbeURL) err = sf.checkNATType(config, sf.NATProbeURL)
currentNATTypeLoaded := getCurrentNATType() if err != nil {
sf.EventDispatcher.OnNewSnowflakeEvent(&event.EventOnCurrentNATTypeDetermined{CurNATType: currentNATTypeLoaded}) // non-fatal error. Log it and continue
log.Printf(err.Error())
setCurrentNATType(NATUnknown)
}
sf.EventDispatcher.OnNewSnowflakeEvent(&event.EventOnCurrentNATTypeDetermined{CurNATType: getCurrentNATType()})
NatRetestTask := task.Periodic{ NatRetestTask := task.Periodic{
Interval: sf.NATTypeMeasurementInterval, Interval: sf.NATTypeMeasurementInterval,
@ -704,6 +714,9 @@ func (sf *SnowflakeProxy) Start() error {
sf.checkNATType(config, sf.NATProbeURL) sf.checkNATType(config, sf.NATProbeURL)
return nil return nil
}, },
OnError: func(err error) {
log.Printf("Periodic probetest failed: %s, retaining current NAT type: %s", err.Error(), getCurrentNATType())
},
} }
if sf.NATTypeMeasurementInterval != 0 { if sf.NATTypeMeasurementInterval != 0 {
@ -735,87 +748,64 @@ func (sf *SnowflakeProxy) Stop() {
// checkNATType use probetest to determine NAT compatability by // checkNATType use probetest to determine NAT compatability by
// attempting to connect with a known symmetric NAT. If success, // attempting to connect with a known symmetric NAT. If success,
// it is considered "unrestricted". If timeout it is considered "restricted" // it is considered "unrestricted". If timeout it is considered "restricted"
func (sf *SnowflakeProxy) checkNATType(config webrtc.Configuration, probeURL string) { func (sf *SnowflakeProxy) checkNATType(config webrtc.Configuration, probeURL string) error {
probe, err := newSignalingServer(probeURL, false) probe, err := newSignalingServer(probeURL, false)
if err != nil { if err != nil {
log.Printf("Error parsing url: %s", err.Error()) return fmt.Errorf("Error parsing url: %w", err)
} }
dataChan := make(chan struct{}) dataChan := make(chan struct{})
pc, err := sf.makeNewPeerConnection(config, dataChan) pc, err := sf.makeNewPeerConnection(config, dataChan)
if err != nil { if err != nil {
log.Printf("error making WebRTC connection: %s", err) return fmt.Errorf("Error making WebRTC connection: %w", err)
return
} }
offer := pc.LocalDescription() offer := pc.LocalDescription()
log.Printf("Probetest offer: \n\t%s", strings.ReplaceAll(offer.SDP, "\n", "\n\t")) log.Printf("Probetest offer: \n\t%s", strings.ReplaceAll(offer.SDP, "\n", "\n\t"))
sdp, err := util.SerializeSessionDescription(offer) sdp, err := util.SerializeSessionDescription(offer)
if err != nil { if err != nil {
log.Printf("Error encoding probe message: %s", err.Error()) return fmt.Errorf("Error encoding probe message: %w", err)
return
} }
// send offer // send offer
body, err := messages.EncodePollResponse(sdp, true, "") body, err := messages.EncodePollResponse(sdp, true, "")
if err != nil { if err != nil {
log.Printf("Error encoding probe message: %s", err.Error()) return fmt.Errorf("Error encoding probe message: %w", err)
return
} }
resp, err := probe.Post(probe.url.String(), bytes.NewBuffer(body)) resp, err := probe.Post(probe.url.String(), bytes.NewBuffer(body))
if err != nil { if err != nil {
log.Printf("error polling probe: %s", err.Error()) return fmt.Errorf("Error polling probe: %w", err)
return
} }
sdp, _, err = messages.DecodeAnswerRequest(resp) sdp, _, err = messages.DecodeAnswerRequest(resp)
if err != nil { if err != nil {
log.Printf("Error reading probe response: %s", err.Error()) return fmt.Errorf("Error reading probe response: %w", err)
return
} }
answer, err := util.DeserializeSessionDescription(sdp) answer, err := util.DeserializeSessionDescription(sdp)
if err != nil { if err != nil {
log.Printf("Error setting answer: %s", err.Error()) return fmt.Errorf("Error setting answer: %w", err)
return
} }
err = pc.SetRemoteDescription(*answer) err = pc.SetRemoteDescription(*answer)
if err != nil { if err != nil {
log.Printf("Error setting answer: %s", err.Error()) return fmt.Errorf("Error setting answer: %w", err)
return
} }
currentNATTypeLoaded := getCurrentNATType() prevNATType := getCurrentNATType()
currentNATTypeTestResult := NATUnknown
select { select {
case <-dataChan: case <-dataChan:
currentNATTypeTestResult = NATUnrestricted setCurrentNATType(NATUnrestricted)
case <-time.After(dataChannelTimeout): case <-time.After(dataChannelTimeout):
currentNATTypeTestResult = NATRestricted setCurrentNATType(NATRestricted)
} }
currentNATTypeToStore := NATUnknown log.Printf("NAT Type measurement: %v -> %v\n", prevNATType, getCurrentNATType())
switch currentNATTypeLoaded + "->" + currentNATTypeTestResult {
case NATUnrestricted + "->" + NATUnknown:
currentNATTypeToStore = NATUnrestricted
case NATRestricted + "->" + NATUnknown:
currentNATTypeToStore = NATRestricted
default:
currentNATTypeToStore = currentNATTypeTestResult
}
log.Printf("NAT Type measurement: %v -> %v = %v\n", currentNATTypeLoaded, currentNATTypeTestResult, currentNATTypeToStore)
currentNATTypeAccess.Lock()
currentNATType = currentNATTypeToStore
currentNATTypeAccess.Unlock()
if err := pc.Close(); err != nil { if err := pc.Close(); err != nil {
log.Printf("error calling pc.Close: %v", err) log.Printf("error calling pc.Close: %v", err)
} }
return nil
} }