diff --git a/client/lib/util.go b/client/lib/util.go index 6dbbf2f..536fa17 100644 --- a/client/lib/util.go +++ b/client/lib/util.go @@ -1,9 +1,7 @@ package snowflake_client import ( - "crypto/rand" "log" - "math/big" "time" ) @@ -71,12 +69,3 @@ func (b *bytesSyncLogger) addOutbound(amount int64) { func (b *bytesSyncLogger) addInbound(amount int64) { b.inboundChan <- amount } - -func randomInt(min, max int) int { - nBig, err := rand.Int(rand.Reader, big.NewInt(int64(max-min))) - if err != nil { - panic(err) - } - - return int(nBig.Int64()) + min -} diff --git a/client/lib/webrtc.go b/client/lib/webrtc.go index 845b4a8..5851e8e 100644 --- a/client/lib/webrtc.go +++ b/client/lib/webrtc.go @@ -15,9 +15,9 @@ import ( "github.com/pion/transport/v3" "github.com/pion/transport/v3/stdnet" "github.com/pion/webrtc/v4" - "github.com/pion/webrtc/v4/pkg/media" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/event" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/media" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/proxy" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/util" ) @@ -44,6 +44,7 @@ type WebRTCPeer struct { bytesLogger bytesLogger eventsLogger event.SnowflakeEventReceiver proxy *url.URL + mediaChannel *media.MediaChannel } // Deprecated: Use NewWebRTCPeerWithNatPolicyAndEventsAndProxy Instead. @@ -107,6 +108,7 @@ func NewWebRTCPeerWithNatPolicyAndEventsAndProxy( connection.eventsLogger = eventsLogger connection.proxy = proxy + connection.mediaChannel = media.NewMediaChannel() err := connection.connect(config, broker, natPolicy) if err != nil { @@ -342,7 +344,10 @@ func (c *WebRTCPeer) preparePeerConnection( c.open = make(chan struct{}) log.Println("WebRTC: DataChannel created") - c.openMediaTrack() + err = c.mediaChannel.Start(c.pc) + if err != nil { + log.Printf("Failed to setup media channel: %v", err) + } offer, err := c.pc.CreateOffer(nil) // TODO: Potentially timeout and retry if ICE isn't working. @@ -369,85 +374,12 @@ func (c *WebRTCPeer) preparePeerConnection( return nil } -func (c *WebRTCPeer) openMediaTrack() { - videoTrack, err := webrtc.NewTrackLocalStaticSample( - webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeAV1}, "video", "pion", - ) - if err != nil { - log.Printf("webrtc.NewTrackLocalStaticSample ERROR: %s", err) - return - } - - rtpSender, err := c.pc.AddTrack(videoTrack) - if err != nil { - log.Printf("webrtc.AddTrack ERROR: %s", err) - return - } - - go func() { - rtcpBuf := make([]byte, 1500) - for { - if _, _, rtcpErr := rtpSender.Read(rtcpBuf); rtcpErr != nil { - return - } - } - }() - - go func() { - ticker := time.NewTicker(time.Second) - defer ticker.Stop() - - for ; true; <-ticker.C { - // Add jitter to simulate "realistic" media patterns - jitterDelay := time.Duration(randomInt(0, 200)) * time.Millisecond - time.Sleep(jitterDelay) - - // Vary packet sizes for specific frames types - var bufSize int - frameType := randomInt(1, 100) - switch { - case frameType <= 5: // I-frames: 5% chance, larger - bufSize = randomInt(8000, 15000) - case frameType <= 35: // P-frames: 30% chance, medium - bufSize = randomInt(2000, 5000) - default: // B-frames: 65% chance, smaller - bufSize = randomInt(500, 2000) - } - - buf := make([]byte, bufSize) - - // Add some timing variation - frameDuration := time.Duration(randomInt(900, 1100)) * time.Millisecond - - err = videoTrack.WriteSample(media.Sample{Data: buf, Duration: frameDuration}) - if err != nil { - log.Printf("webrtc.WriteSample ERROR: %s", err) - } - - // Simulate some burst of smaller packets - if randomInt(1, 10) == 1 { // 10% chance - burstCount := randomInt(2, 5) - for i := 0; i < burstCount; i++ { - smallBuf := make([]byte, randomInt(100, 400)) - time.Sleep(time.Duration(randomInt(10, 50)) * time.Millisecond) - - frameDuration = time.Duration(randomInt(16, 33)) * time.Millisecond - - err = videoTrack.WriteSample(media.Sample{Data: smallBuf, Duration: frameDuration}) - if err != nil { - log.Printf("webrtc.WriteSample burst ERROR: %s", err) - break - } - } - } - } - }() - - log.Println("WebRTC: Media track opened") -} - // cleanup closes all channels and transports func (c *WebRTCPeer) cleanup() { + // Stop media channel + if c.mediaChannel != nil { + c.mediaChannel.Stop() + } // Close this side of the SOCKS pipe. if c.writePipe != nil { // c.writePipe can be nil in tests. c.writePipe.Close() diff --git a/common/media/channel.go b/common/media/channel.go new file mode 100644 index 0000000..7f68615 --- /dev/null +++ b/common/media/channel.go @@ -0,0 +1,145 @@ +package media + +import ( + "crypto/rand" + "log" + "math/big" + "time" + + "github.com/pion/interceptor" + "github.com/pion/webrtc/v4" + "github.com/pion/webrtc/v4/pkg/media" +) + +func randomInt(min, max int) int { + nBig, err := rand.Int(rand.Reader, big.NewInt(int64(max-min))) + if err != nil { + panic(err) + } + return int(nBig.Int64()) + min +} + +type RTPReader interface { + Read(b []byte) (n int, a interceptor.Attributes, err error) +} + +// MediaChannel handles media track simulation for WebRTC connections +type MediaChannel struct { + stopCh chan struct{} +} + +// NewMediaChannel creates a new media channel +func NewMediaChannel() *MediaChannel { + return &MediaChannel{ + stopCh: make(chan struct{}), + } +} + +// StartVideoTrack starts video track simulation on the given peer connection +func (mc *MediaChannel) StartVideoTrack(pc *webrtc.PeerConnection) error { + videoTrack, err := webrtc.NewTrackLocalStaticSample( + webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeAV1}, "video", "pion", + ) + if err != nil { + log.Printf("webrtc.NewTrackLocalStaticSample ERROR: %s", err) + return err + } + + rtpSender, err := pc.AddTrack(videoTrack) + if err != nil { + log.Printf("webrtc.AddTrack ERROR: %s", err) + return err + } + + go mc.handleRTCP(rtpSender) + go mc.simulateVideoFrames(videoTrack) + + log.Println("WebRTC: Media track opened") + return nil +} + +// Stop stops the media simulation +func (mc *MediaChannel) Stop() { + close(mc.stopCh) +} + +func (mc *MediaChannel) handleRTCP(reader RTPReader) { + rtcpBuf := make([]byte, 1500) + for { + select { + case <-mc.stopCh: + return + default: + if _, _, err := reader.Read(rtcpBuf); err != nil { + return + } + } + } +} + +func (mc *MediaChannel) simulateVideoFrames(track *webrtc.TrackLocalStaticSample) { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + + for { + select { + case <-mc.stopCh: + return + case <-ticker.C: + // Add jitter to simulate "realistic" media patterns + jitterDelay := time.Duration(randomInt(0, 200)) * time.Millisecond + time.Sleep(jitterDelay) + + // Vary packet sizes for specific frames types + var bufSize int + frameType := randomInt(1, 100) + switch { + case frameType <= 5: // I-frames: 5% chance, larger + bufSize = randomInt(8000, 15000) + case frameType <= 35: // P-frames: 30% chance, medium + bufSize = randomInt(2000, 5000) + default: // B-frames: 65% chance, smaller + bufSize = randomInt(500, 2000) + } + + buf := make([]byte, bufSize) + + // Add some timing variation + frameDuration := time.Duration(randomInt(900, 1100)) * time.Millisecond + + err := track.WriteSample(media.Sample{Data: buf, Duration: frameDuration}) + if err != nil { + log.Printf("webrtc.WriteSample ERROR: %s", err) + } + + // Simulate some burst of smaller packets + if randomInt(1, 10) == 1 { // 10% chance + burstCount := randomInt(2, 5) + for i := 0; i < burstCount; i++ { + smallBuf := make([]byte, randomInt(100, 400)) + time.Sleep(time.Duration(randomInt(10, 50)) * time.Millisecond) + + frameDuration = time.Duration(randomInt(16, 33)) * time.Millisecond + + err = track.WriteSample(media.Sample{Data: smallBuf, Duration: frameDuration}) + if err != nil { + log.Printf("webrtc.WriteSample burst ERROR: %s", err) + break + } + } + } + } + } +} + +// Start sets up duplex media handling (both incoming and outgoing tracks) +func (mc *MediaChannel) Start(pc *webrtc.PeerConnection) error { + // Set up handler for incoming tracks + pc.OnTrack(func(remote *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { + log.Printf("Media Track received: streamId(%s) id(%s) rid(%s)", remote.StreamID(), remote.ID(), remote.RID()) + go mc.handleRTCP(receiver) + }) + + // Set up outgoing media track + return mc.StartVideoTrack(pc) +} diff --git a/common/media/channel_test.go b/common/media/channel_test.go new file mode 100644 index 0000000..b1f6353 --- /dev/null +++ b/common/media/channel_test.go @@ -0,0 +1,20 @@ +package media + +import ( + "testing" +) + +func TestMediaChannelStop(t *testing.T) { + mc := NewMediaChannel() + + // This should not panic + mc.Stop() + + // Verify that the stop channel is closed + select { + case <-mc.stopCh: + // Channel is closed, which is expected + default: + t.Fatal("MediaChannel stopCh should be closed after Stop()") + } +} diff --git a/proxy/lib/snowflake.go b/proxy/lib/snowflake.go index 4d6881c..19945e2 100644 --- a/proxy/lib/snowflake.go +++ b/proxy/lib/snowflake.go @@ -48,6 +48,7 @@ import ( "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/constants" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/event" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/media" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/namematcher" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/task" @@ -451,18 +452,13 @@ func (sf *SnowflakeProxy) makePeerConnectionFromOffer( return nil, fmt.Errorf("accept: NewPeerConnection: %s", err) } - pc.OnTrack(func(remote *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { - log.Printf("Track has started streamId(%s) id(%s) rid(%s) \n", remote.StreamID(), remote.ID(), remote.RID()) - for { - rtcpBuf := make([]byte, 1500) - for { - if _, _, err := receiver.Read(rtcpBuf); err != nil { - return - } - } - } - }) + // Start duplex media handling (both incoming and outgoing tracks) + mediaChannel := media.NewMediaChannel() + err = mediaChannel.Start(pc) + if err != nil { + log.Printf("Failed to setup proxy media channel: %v", err) + } pc.OnDataChannel(func(dc *webrtc.DataChannel) { log.Printf("New Data Channel %s-%d\n", dc.Label(), dc.ID()) @@ -512,6 +508,9 @@ func (sf *SnowflakeProxy) makePeerConnectionFromOffer( } sf.EventDispatcher.OnNewSnowflakeEvent(event.EventOnProxyConnectionOver{Country: country}) + // Clean up media channel + mediaChannel.Stop() + conn.dc = nil dc.Close() pw.Close() @@ -838,7 +837,7 @@ func (sf *SnowflakeProxy) Start() error { err = sf.checkNATType(config, sf.NATProbeURL) if err != nil { // non-fatal error. Log it and continue - log.Printf(err.Error()) + log.Printf("%s", err.Error()) setCurrentNATType(NATUnknown) } sf.EventDispatcher.OnNewSnowflakeEvent(event.EventOnCurrentNATTypeDetermined{CurNATType: getCurrentNATType()})