diff --git a/go.mod b/go.mod index df9fa59..8d1562a 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/aws/aws-sdk-go-v2/config v1.30.2 github.com/aws/aws-sdk-go-v2/credentials v1.18.5 github.com/aws/aws-sdk-go-v2/service/sqs v1.38.5 + github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0 github.com/golang/mock v1.6.0 github.com/gorilla/websocket v1.5.3 github.com/miekg/dns v1.1.65 diff --git a/go.sum b/go.sum index d4b9b0b..319a39d 100644 --- a/go.sum +++ b/go.sum @@ -36,6 +36,8 @@ github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0 h1:pRcxfaAlK0vR6nOeQs7eAEvjJzdGXl8+KaBlcvpQTyQ= +github.com/cloudflare/backoff v0.0.0-20240920015135-e46b80a3a7d0/go.mod h1:rzgs2ZOiguV6/NpiDgADjRLPNyZlApIWxKpkT+X8SdY= github.com/cloudflare/circl v1.3.7 h1:qlCDlTPz2n9fu58M0Nh1J/JzcFpfgkFHHX3O35r5vcU= github.com/cloudflare/circl v1.3.7/go.mod h1:sRTcRWXGLrKw6yIGJ+l7amYJFfAXbZG0kBSc8r4zxgA= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= diff --git a/proxy/lib/snowflake.go b/proxy/lib/snowflake.go index bcdfbda..268f9d6 100644 --- a/proxy/lib/snowflake.go +++ b/proxy/lib/snowflake.go @@ -46,6 +46,8 @@ import ( "github.com/pion/transport/v3/stdnet" "github.com/pion/webrtc/v4" + "github.com/cloudflare/backoff" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/constants" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/event" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages" @@ -163,6 +165,7 @@ type SnowflakeProxy struct { NATProbeURL string // NATTypeMeasurementInterval is time before NAT type is retested NATTypeMeasurementInterval time.Duration + NATMeasurementRetry bool // ProxyType is the type reported to the broker, if not provided it "standalone" will be used ProxyType string EventDispatcher event.SnowflakeEventDispatcher @@ -822,7 +825,14 @@ func (sf *SnowflakeProxy) Start() error { } tokens = newTokens(sf.Capacity) - err = sf.checkNATType(config, sf.NATProbeURL) + for { + err = sf.checkNATType(config, sf.NATProbeURL) + if getCurrentNATType() == NATUnrestricted || !sf.NATMeasurementRetry { + break + } + + <-time.After(5 * time.Second) + } if err != nil { // non-fatal error. Log it and continue log.Printf(err.Error()) @@ -833,7 +843,21 @@ func (sf *SnowflakeProxy) Start() error { NatRetestTask := task.Periodic{ Interval: sf.NATTypeMeasurementInterval, Execute: func() error { - return sf.checkNATType(config, sf.NATProbeURL) + var err error + + b := backoff.New(5*time.Minute, sf.NATTypeMeasurementInterval) + for { + err = sf.checkNATType(config, sf.NATProbeURL) + sf.EventDispatcher.OnNewSnowflakeEvent(event.EventOnCurrentNATTypeDetermined{CurNATType: getCurrentNATType()}) + + if getCurrentNATType() == NATUnrestricted || !sf.NATMeasurementRetry { + break + } + + <-time.After(b.Duration()) + } + + return err }, // Not setting OnError would shut down the periodic task on error by default. OnError: func(err error) { diff --git a/proxy/main.go b/proxy/main.go index 360703e..4a16249 100644 --- a/proxy/main.go +++ b/proxy/main.go @@ -38,6 +38,7 @@ func main() { allowNonTLSRelay := flag.Bool("allow-non-tls-relay", false, "allow this proxy to pass client's data to the relay in an unencrypted form.\nThis is only useful if the relay doesn't support encryption, e.g. for testing / development purposes.") NATTypeMeasurementInterval := flag.Duration("nat-retest-interval", time.Hour*24, "the time interval between NAT type is retests (see \"nat-probe-server\"). 0s disables retest. Valid time units are \"s\", \"m\", \"h\".") + natRetry := flag.Bool("nat-retry", false, "Retry NAT measurement when not unrestricted") summaryInterval := flag.Duration("summary-interval", time.Hour, "the time interval between summary log outputs, 0s disables summaries. Valid time units are \"s\", \"m\", \"h\".") disableStatsLogger := flag.Bool("disable-stats-logger", false, "disable the exposing mechanism for stats using logs") @@ -114,6 +115,7 @@ func main() { EphemeralMaxPort: ephemeralPortsRange[1], NATTypeMeasurementInterval: *NATTypeMeasurementInterval, + NATMeasurementRetry: *natRetry, EventDispatcher: eventLogger, RelayDomainNamePattern: *allowedRelayHostNamePattern,