mirror of
https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake.git
synced 2025-10-13 11:11:30 -04:00
Simplify proxy NAT checking logic
This commit is contained in:
parent
54495ceb4e
commit
4ed5da7f2f
2 changed files with 44 additions and 47 deletions
|
@ -36,6 +36,8 @@ type Periodic struct {
|
||||||
Interval time.Duration
|
Interval time.Duration
|
||||||
// Execute is the task function
|
// Execute is the task function
|
||||||
Execute func() error
|
Execute func() error
|
||||||
|
// OnError handles the error of the task
|
||||||
|
OnError func(error)
|
||||||
|
|
||||||
access sync.Mutex
|
access sync.Mutex
|
||||||
timer *time.Timer
|
timer *time.Timer
|
||||||
|
@ -55,10 +57,15 @@ func (t *Periodic) checkedExecute() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := t.Execute(); err != nil {
|
if err := t.Execute(); err != nil {
|
||||||
t.access.Lock()
|
if t.OnError != nil {
|
||||||
t.running = false
|
t.OnError(err)
|
||||||
t.access.Unlock()
|
} else {
|
||||||
return err
|
// default error handling is to shut down the task
|
||||||
|
t.access.Lock()
|
||||||
|
t.running = false
|
||||||
|
t.access.Unlock()
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
t.access.Lock()
|
t.access.Lock()
|
||||||
|
|
|
@ -60,13 +60,13 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// NATUnknown represents a NAT type which is unknown.
|
// NATUnknown is set if the proxy cannot connect to probetest.
|
||||||
NATUnknown = "unknown"
|
NATUnknown = "unknown"
|
||||||
|
|
||||||
// NATRestricted represents a restricted NAT.
|
// NATRestricted is set if the proxy times out when connecting to a symmetric NAT.
|
||||||
NATRestricted = "restricted"
|
NATRestricted = "restricted"
|
||||||
|
|
||||||
// NATUnrestricted represents an unrestricted NAT.
|
// NATUnrestricted is set if the proxy successfully connects to a symmetric NAT.
|
||||||
NATUnrestricted = "unrestricted"
|
NATUnrestricted = "unrestricted"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -99,6 +99,12 @@ func getCurrentNATType() string {
|
||||||
return currentNATType
|
return currentNATType
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setCurrentNATType(newType string) {
|
||||||
|
currentNATTypeAccess.Lock()
|
||||||
|
defer currentNATTypeAccess.Unlock()
|
||||||
|
currentNATType = newType
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
tokens *tokens_t
|
tokens *tokens_t
|
||||||
config webrtc.Configuration
|
config webrtc.Configuration
|
||||||
|
@ -694,9 +700,13 @@ func (sf *SnowflakeProxy) Start() error {
|
||||||
}
|
}
|
||||||
tokens = newTokens(sf.Capacity)
|
tokens = newTokens(sf.Capacity)
|
||||||
|
|
||||||
sf.checkNATType(config, sf.NATProbeURL)
|
err = sf.checkNATType(config, sf.NATProbeURL)
|
||||||
currentNATTypeLoaded := getCurrentNATType()
|
if err != nil {
|
||||||
sf.EventDispatcher.OnNewSnowflakeEvent(&event.EventOnCurrentNATTypeDetermined{CurNATType: currentNATTypeLoaded})
|
// non-fatal error. Log it and continue
|
||||||
|
log.Printf(err.Error())
|
||||||
|
setCurrentNATType(NATUnknown)
|
||||||
|
}
|
||||||
|
sf.EventDispatcher.OnNewSnowflakeEvent(&event.EventOnCurrentNATTypeDetermined{CurNATType: getCurrentNATType()})
|
||||||
|
|
||||||
NatRetestTask := task.Periodic{
|
NatRetestTask := task.Periodic{
|
||||||
Interval: sf.NATTypeMeasurementInterval,
|
Interval: sf.NATTypeMeasurementInterval,
|
||||||
|
@ -704,6 +714,9 @@ func (sf *SnowflakeProxy) Start() error {
|
||||||
sf.checkNATType(config, sf.NATProbeURL)
|
sf.checkNATType(config, sf.NATProbeURL)
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
|
OnError: func(err error) {
|
||||||
|
log.Printf("Periodic probetest failed: %s, retaining current NAT type: %s", err.Error(), getCurrentNATType())
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if sf.NATTypeMeasurementInterval != 0 {
|
if sf.NATTypeMeasurementInterval != 0 {
|
||||||
|
@ -735,87 +748,64 @@ func (sf *SnowflakeProxy) Stop() {
|
||||||
// checkNATType use probetest to determine NAT compatability by
|
// checkNATType use probetest to determine NAT compatability by
|
||||||
// attempting to connect with a known symmetric NAT. If success,
|
// attempting to connect with a known symmetric NAT. If success,
|
||||||
// it is considered "unrestricted". If timeout it is considered "restricted"
|
// it is considered "unrestricted". If timeout it is considered "restricted"
|
||||||
func (sf *SnowflakeProxy) checkNATType(config webrtc.Configuration, probeURL string) {
|
func (sf *SnowflakeProxy) checkNATType(config webrtc.Configuration, probeURL string) error {
|
||||||
probe, err := newSignalingServer(probeURL, false)
|
probe, err := newSignalingServer(probeURL, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error parsing url: %s", err.Error())
|
return fmt.Errorf("Error parsing url: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
dataChan := make(chan struct{})
|
dataChan := make(chan struct{})
|
||||||
pc, err := sf.makeNewPeerConnection(config, dataChan)
|
pc, err := sf.makeNewPeerConnection(config, dataChan)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("error making WebRTC connection: %s", err)
|
return fmt.Errorf("Error making WebRTC connection: %w", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
offer := pc.LocalDescription()
|
offer := pc.LocalDescription()
|
||||||
log.Printf("Probetest offer: \n\t%s", strings.ReplaceAll(offer.SDP, "\n", "\n\t"))
|
log.Printf("Probetest offer: \n\t%s", strings.ReplaceAll(offer.SDP, "\n", "\n\t"))
|
||||||
sdp, err := util.SerializeSessionDescription(offer)
|
sdp, err := util.SerializeSessionDescription(offer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error encoding probe message: %s", err.Error())
|
return fmt.Errorf("Error encoding probe message: %w", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// send offer
|
// send offer
|
||||||
body, err := messages.EncodePollResponse(sdp, true, "")
|
body, err := messages.EncodePollResponse(sdp, true, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error encoding probe message: %s", err.Error())
|
return fmt.Errorf("Error encoding probe message: %w", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := probe.Post(probe.url.String(), bytes.NewBuffer(body))
|
resp, err := probe.Post(probe.url.String(), bytes.NewBuffer(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("error polling probe: %s", err.Error())
|
return fmt.Errorf("Error polling probe: %w", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
sdp, _, err = messages.DecodeAnswerRequest(resp)
|
sdp, _, err = messages.DecodeAnswerRequest(resp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error reading probe response: %s", err.Error())
|
return fmt.Errorf("Error reading probe response: %w", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
answer, err := util.DeserializeSessionDescription(sdp)
|
answer, err := util.DeserializeSessionDescription(sdp)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error setting answer: %s", err.Error())
|
return fmt.Errorf("Error setting answer: %w", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = pc.SetRemoteDescription(*answer)
|
err = pc.SetRemoteDescription(*answer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error setting answer: %s", err.Error())
|
return fmt.Errorf("Error setting answer: %w", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
currentNATTypeLoaded := getCurrentNATType()
|
prevNATType := getCurrentNATType()
|
||||||
|
|
||||||
currentNATTypeTestResult := NATUnknown
|
|
||||||
select {
|
select {
|
||||||
case <-dataChan:
|
case <-dataChan:
|
||||||
currentNATTypeTestResult = NATUnrestricted
|
setCurrentNATType(NATUnrestricted)
|
||||||
case <-time.After(dataChannelTimeout):
|
case <-time.After(dataChannelTimeout):
|
||||||
currentNATTypeTestResult = NATRestricted
|
setCurrentNATType(NATRestricted)
|
||||||
}
|
}
|
||||||
|
|
||||||
currentNATTypeToStore := NATUnknown
|
log.Printf("NAT Type measurement: %v -> %v\n", prevNATType, getCurrentNATType())
|
||||||
switch currentNATTypeLoaded + "->" + currentNATTypeTestResult {
|
|
||||||
case NATUnrestricted + "->" + NATUnknown:
|
|
||||||
currentNATTypeToStore = NATUnrestricted
|
|
||||||
|
|
||||||
case NATRestricted + "->" + NATUnknown:
|
|
||||||
currentNATTypeToStore = NATRestricted
|
|
||||||
|
|
||||||
default:
|
|
||||||
currentNATTypeToStore = currentNATTypeTestResult
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("NAT Type measurement: %v -> %v = %v\n", currentNATTypeLoaded, currentNATTypeTestResult, currentNATTypeToStore)
|
|
||||||
|
|
||||||
currentNATTypeAccess.Lock()
|
|
||||||
currentNATType = currentNATTypeToStore
|
|
||||||
currentNATTypeAccess.Unlock()
|
|
||||||
|
|
||||||
if err := pc.Close(); err != nil {
|
if err := pc.Close(); err != nil {
|
||||||
log.Printf("error calling pc.Close: %v", err)
|
log.Printf("error calling pc.Close: %v", err)
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue