diff --git a/client/lib/webrtc.go b/client/lib/webrtc.go index 397ba85..5851e8e 100644 --- a/client/lib/webrtc.go +++ b/client/lib/webrtc.go @@ -17,6 +17,7 @@ import ( "github.com/pion/webrtc/v4" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/event" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/media" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/proxy" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/util" ) @@ -43,6 +44,7 @@ type WebRTCPeer struct { bytesLogger bytesLogger eventsLogger event.SnowflakeEventReceiver proxy *url.URL + mediaChannel *media.MediaChannel } // Deprecated: Use NewWebRTCPeerWithNatPolicyAndEventsAndProxy Instead. @@ -106,6 +108,7 @@ func NewWebRTCPeerWithNatPolicyAndEventsAndProxy( connection.eventsLogger = eventsLogger connection.proxy = proxy + connection.mediaChannel = media.NewMediaChannel() err := connection.connect(config, broker, natPolicy) if err != nil { @@ -296,6 +299,7 @@ func (c *WebRTCPeer) preparePeerConnection( log.Printf("NewPeerConnection ERROR: %s", err) return err } + ordered := true dataChannelOptions := &webrtc.DataChannelInit{ Ordered: &ordered, @@ -340,6 +344,11 @@ func (c *WebRTCPeer) preparePeerConnection( c.open = make(chan struct{}) log.Println("WebRTC: DataChannel created") + err = c.mediaChannel.Start(c.pc) + if err != nil { + log.Printf("Failed to setup media channel: %v", err) + } + offer, err := c.pc.CreateOffer(nil) // TODO: Potentially timeout and retry if ICE isn't working. if err != nil { @@ -367,6 +376,10 @@ func (c *WebRTCPeer) preparePeerConnection( // cleanup closes all channels and transports func (c *WebRTCPeer) cleanup() { + // Stop media channel + if c.mediaChannel != nil { + c.mediaChannel.Stop() + } // Close this side of the SOCKS pipe. if c.writePipe != nil { // c.writePipe can be nil in tests. c.writePipe.Close() diff --git a/common/media/channel.go b/common/media/channel.go new file mode 100644 index 0000000..7f68615 --- /dev/null +++ b/common/media/channel.go @@ -0,0 +1,145 @@ +package media + +import ( + "crypto/rand" + "log" + "math/big" + "time" + + "github.com/pion/interceptor" + "github.com/pion/webrtc/v4" + "github.com/pion/webrtc/v4/pkg/media" +) + +func randomInt(min, max int) int { + nBig, err := rand.Int(rand.Reader, big.NewInt(int64(max-min))) + if err != nil { + panic(err) + } + return int(nBig.Int64()) + min +} + +type RTPReader interface { + Read(b []byte) (n int, a interceptor.Attributes, err error) +} + +// MediaChannel handles media track simulation for WebRTC connections +type MediaChannel struct { + stopCh chan struct{} +} + +// NewMediaChannel creates a new media channel +func NewMediaChannel() *MediaChannel { + return &MediaChannel{ + stopCh: make(chan struct{}), + } +} + +// StartVideoTrack starts video track simulation on the given peer connection +func (mc *MediaChannel) StartVideoTrack(pc *webrtc.PeerConnection) error { + videoTrack, err := webrtc.NewTrackLocalStaticSample( + webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeAV1}, "video", "pion", + ) + if err != nil { + log.Printf("webrtc.NewTrackLocalStaticSample ERROR: %s", err) + return err + } + + rtpSender, err := pc.AddTrack(videoTrack) + if err != nil { + log.Printf("webrtc.AddTrack ERROR: %s", err) + return err + } + + go mc.handleRTCP(rtpSender) + go mc.simulateVideoFrames(videoTrack) + + log.Println("WebRTC: Media track opened") + return nil +} + +// Stop stops the media simulation +func (mc *MediaChannel) Stop() { + close(mc.stopCh) +} + +func (mc *MediaChannel) handleRTCP(reader RTPReader) { + rtcpBuf := make([]byte, 1500) + for { + select { + case <-mc.stopCh: + return + default: + if _, _, err := reader.Read(rtcpBuf); err != nil { + return + } + } + } +} + +func (mc *MediaChannel) simulateVideoFrames(track *webrtc.TrackLocalStaticSample) { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + + for { + select { + case <-mc.stopCh: + return + case <-ticker.C: + // Add jitter to simulate "realistic" media patterns + jitterDelay := time.Duration(randomInt(0, 200)) * time.Millisecond + time.Sleep(jitterDelay) + + // Vary packet sizes for specific frames types + var bufSize int + frameType := randomInt(1, 100) + switch { + case frameType <= 5: // I-frames: 5% chance, larger + bufSize = randomInt(8000, 15000) + case frameType <= 35: // P-frames: 30% chance, medium + bufSize = randomInt(2000, 5000) + default: // B-frames: 65% chance, smaller + bufSize = randomInt(500, 2000) + } + + buf := make([]byte, bufSize) + + // Add some timing variation + frameDuration := time.Duration(randomInt(900, 1100)) * time.Millisecond + + err := track.WriteSample(media.Sample{Data: buf, Duration: frameDuration}) + if err != nil { + log.Printf("webrtc.WriteSample ERROR: %s", err) + } + + // Simulate some burst of smaller packets + if randomInt(1, 10) == 1 { // 10% chance + burstCount := randomInt(2, 5) + for i := 0; i < burstCount; i++ { + smallBuf := make([]byte, randomInt(100, 400)) + time.Sleep(time.Duration(randomInt(10, 50)) * time.Millisecond) + + frameDuration = time.Duration(randomInt(16, 33)) * time.Millisecond + + err = track.WriteSample(media.Sample{Data: smallBuf, Duration: frameDuration}) + if err != nil { + log.Printf("webrtc.WriteSample burst ERROR: %s", err) + break + } + } + } + } + } +} + +// Start sets up duplex media handling (both incoming and outgoing tracks) +func (mc *MediaChannel) Start(pc *webrtc.PeerConnection) error { + // Set up handler for incoming tracks + pc.OnTrack(func(remote *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { + log.Printf("Media Track received: streamId(%s) id(%s) rid(%s)", remote.StreamID(), remote.ID(), remote.RID()) + go mc.handleRTCP(receiver) + }) + + // Set up outgoing media track + return mc.StartVideoTrack(pc) +} diff --git a/common/media/channel_test.go b/common/media/channel_test.go new file mode 100644 index 0000000..b1f6353 --- /dev/null +++ b/common/media/channel_test.go @@ -0,0 +1,20 @@ +package media + +import ( + "testing" +) + +func TestMediaChannelStop(t *testing.T) { + mc := NewMediaChannel() + + // This should not panic + mc.Stop() + + // Verify that the stop channel is closed + select { + case <-mc.stopCh: + // Channel is closed, which is expected + default: + t.Fatal("MediaChannel stopCh should be closed after Stop()") + } +} diff --git a/proxy/lib/snowflake.go b/proxy/lib/snowflake.go index bcdfbda..19945e2 100644 --- a/proxy/lib/snowflake.go +++ b/proxy/lib/snowflake.go @@ -48,6 +48,7 @@ import ( "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/constants" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/event" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/media" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/namematcher" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/task" @@ -451,6 +452,14 @@ func (sf *SnowflakeProxy) makePeerConnectionFromOffer( return nil, fmt.Errorf("accept: NewPeerConnection: %s", err) } + + // Start duplex media handling (both incoming and outgoing tracks) + mediaChannel := media.NewMediaChannel() + err = mediaChannel.Start(pc) + if err != nil { + log.Printf("Failed to setup proxy media channel: %v", err) + } + pc.OnDataChannel(func(dc *webrtc.DataChannel) { log.Printf("New Data Channel %s-%d\n", dc.Label(), dc.ID()) close(dataChan) @@ -499,6 +508,9 @@ func (sf *SnowflakeProxy) makePeerConnectionFromOffer( } sf.EventDispatcher.OnNewSnowflakeEvent(event.EventOnProxyConnectionOver{Country: country}) + // Clean up media channel + mediaChannel.Stop() + conn.dc = nil dc.Close() pw.Close() @@ -825,7 +837,7 @@ func (sf *SnowflakeProxy) Start() error { err = sf.checkNATType(config, sf.NATProbeURL) if err != nil { // non-fatal error. Log it and continue - log.Printf(err.Error()) + log.Printf("%s", err.Error()) setCurrentNATType(NATUnknown) } sf.EventDispatcher.OnNewSnowflakeEvent(event.EventOnCurrentNATTypeDetermined{CurNATType: getCurrentNATType()})