diff --git a/proxy/lib/snowflake.go b/proxy/lib/snowflake.go index e39fcfb..7d237de 100644 --- a/proxy/lib/snowflake.go +++ b/proxy/lib/snowflake.go @@ -73,6 +73,10 @@ const readLimit = 100000 //Maximum number of bytes to be read from an HTTP reque var broker *SignalingServer +var currentNATTypeAccess = &sync.RWMutex{} + +// currentNATType describes local network environment. +// Obtain currentNATTypeAccess before access. var currentNATType = NATUnknown const ( @@ -183,7 +187,10 @@ func (s *SignalingServer) pollOffer(sid string, shutdown chan struct{}) *webrtc. return nil default: numClients := int((tokens.count() / 8) * 8) // Round down to 8 - body, err := messages.EncodePollRequest(sid, "standalone", currentNATType, numClients) + currentNATTypeAccess.RLock() + currentNATTypeLoaded := currentNATType + currentNATTypeAccess.RUnlock() + body, err := messages.EncodePollRequest(sid, "standalone", currentNATTypeLoaded, numClients) if err != nil { log.Printf("Error encoding poll message: %s", err.Error()) return nil @@ -530,7 +537,12 @@ func (sf *SnowflakeProxy) Start() error { // use probetest to determine NAT compatability sf.checkNATType(config, sf.NATProbeURL) - log.Printf("NAT type: %s", currentNATType) + + currentNATTypeAccess.RLock() + currentNATTypeLoaded := currentNATType + currentNATTypeAccess.RUnlock() + + log.Printf("NAT type: %s", currentNATTypeLoaded) ticker := time.NewTicker(pollInterval) defer ticker.Stop() @@ -604,12 +616,54 @@ func (sf *SnowflakeProxy) checkNATType(config webrtc.Configuration, probeURL str return } + currentNATTypeAccess.RLock() + currentNATTypeLoaded := currentNATType + currentNATTypeAccess.RUnlock() + + currentNATTypeTestResult := NATUnknown select { case <-dataChan: - currentNATType = NATUnrestricted + currentNATTypeTestResult = NATUnrestricted case <-time.After(dataChannelTimeout): - currentNATType = NATRestricted + currentNATTypeTestResult = NATRestricted } + + currentNATTypeToStore := NATUnknown + switch currentNATTypeLoaded + "->" + currentNATTypeTestResult { + case NATUnknown + "->" + NATUnknown: + currentNATTypeToStore = NATUnknown + + case NATUnknown + "->" + NATUnrestricted: + currentNATTypeToStore = NATUnrestricted + + case NATUnknown + "->" + NATRestricted: + currentNATTypeToStore = NATRestricted + + case NATUnrestricted + "->" + NATUnknown: + currentNATTypeToStore = NATUnrestricted + + case NATUnrestricted + "->" + NATUnrestricted: + currentNATTypeToStore = NATUnrestricted + + case NATUnrestricted + "->" + NATRestricted: + currentNATTypeToStore = NATRestricted + + case NATRestricted + "->" + NATUnknown: + currentNATTypeToStore = NATRestricted + + case NATRestricted + "->" + NATUnrestricted: + currentNATTypeToStore = NATUnrestricted + + case NATRestricted + "->" + NATRestricted: + currentNATTypeToStore = NATRestricted + } + + log.Printf("NAT Type measurement: %v -> %v = %v\n", currentNATTypeLoaded, currentNATTypeTestResult, currentNATTypeToStore) + + currentNATTypeAccess.Lock() + currentNATType = currentNATTypeToStore + currentNATTypeAccess.Unlock() + if err := pc.Close(); err != nil { log.Printf("error calling pc.Close: %v", err) }