mirror of
https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake.git
synced 2025-10-13 20:11:19 -04:00
refactor: move media channel handling to common/ and setup duplex channel
a duplex media channel should be more realistic, you generally both send and receive media when doing video call and stuff
This commit is contained in:
parent
ff738ff045
commit
01bde142d4
5 changed files with 187 additions and 102 deletions
|
@ -1,9 +1,7 @@
|
||||||
package snowflake_client
|
package snowflake_client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
|
||||||
"log"
|
"log"
|
||||||
"math/big"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -71,12 +69,3 @@ func (b *bytesSyncLogger) addOutbound(amount int64) {
|
||||||
func (b *bytesSyncLogger) addInbound(amount int64) {
|
func (b *bytesSyncLogger) addInbound(amount int64) {
|
||||||
b.inboundChan <- amount
|
b.inboundChan <- amount
|
||||||
}
|
}
|
||||||
|
|
||||||
func randomInt(min, max int) int {
|
|
||||||
nBig, err := rand.Int(rand.Reader, big.NewInt(int64(max-min)))
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return int(nBig.Int64()) + min
|
|
||||||
}
|
|
||||||
|
|
|
@ -15,9 +15,9 @@ import (
|
||||||
"github.com/pion/transport/v3"
|
"github.com/pion/transport/v3"
|
||||||
"github.com/pion/transport/v3/stdnet"
|
"github.com/pion/transport/v3/stdnet"
|
||||||
"github.com/pion/webrtc/v4"
|
"github.com/pion/webrtc/v4"
|
||||||
"github.com/pion/webrtc/v4/pkg/media"
|
|
||||||
|
|
||||||
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/event"
|
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/event"
|
||||||
|
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/media"
|
||||||
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/proxy"
|
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/proxy"
|
||||||
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/util"
|
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/util"
|
||||||
)
|
)
|
||||||
|
@ -44,6 +44,7 @@ type WebRTCPeer struct {
|
||||||
bytesLogger bytesLogger
|
bytesLogger bytesLogger
|
||||||
eventsLogger event.SnowflakeEventReceiver
|
eventsLogger event.SnowflakeEventReceiver
|
||||||
proxy *url.URL
|
proxy *url.URL
|
||||||
|
mediaChannel *media.MediaChannel
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: Use NewWebRTCPeerWithNatPolicyAndEventsAndProxy Instead.
|
// Deprecated: Use NewWebRTCPeerWithNatPolicyAndEventsAndProxy Instead.
|
||||||
|
@ -107,6 +108,7 @@ func NewWebRTCPeerWithNatPolicyAndEventsAndProxy(
|
||||||
|
|
||||||
connection.eventsLogger = eventsLogger
|
connection.eventsLogger = eventsLogger
|
||||||
connection.proxy = proxy
|
connection.proxy = proxy
|
||||||
|
connection.mediaChannel = media.NewMediaChannel()
|
||||||
|
|
||||||
err := connection.connect(config, broker, natPolicy)
|
err := connection.connect(config, broker, natPolicy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -342,7 +344,10 @@ func (c *WebRTCPeer) preparePeerConnection(
|
||||||
c.open = make(chan struct{})
|
c.open = make(chan struct{})
|
||||||
log.Println("WebRTC: DataChannel created")
|
log.Println("WebRTC: DataChannel created")
|
||||||
|
|
||||||
c.openMediaTrack()
|
err = c.mediaChannel.Start(c.pc)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to setup media channel: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
offer, err := c.pc.CreateOffer(nil)
|
offer, err := c.pc.CreateOffer(nil)
|
||||||
// TODO: Potentially timeout and retry if ICE isn't working.
|
// TODO: Potentially timeout and retry if ICE isn't working.
|
||||||
|
@ -369,85 +374,12 @@ func (c *WebRTCPeer) preparePeerConnection(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WebRTCPeer) openMediaTrack() {
|
|
||||||
videoTrack, err := webrtc.NewTrackLocalStaticSample(
|
|
||||||
webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeAV1}, "video", "pion",
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("webrtc.NewTrackLocalStaticSample ERROR: %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rtpSender, err := c.pc.AddTrack(videoTrack)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("webrtc.AddTrack ERROR: %s", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
rtcpBuf := make([]byte, 1500)
|
|
||||||
for {
|
|
||||||
if _, _, rtcpErr := rtpSender.Read(rtcpBuf); rtcpErr != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
ticker := time.NewTicker(time.Second)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for ; true; <-ticker.C {
|
|
||||||
// Add jitter to simulate "realistic" media patterns
|
|
||||||
jitterDelay := time.Duration(randomInt(0, 200)) * time.Millisecond
|
|
||||||
time.Sleep(jitterDelay)
|
|
||||||
|
|
||||||
// Vary packet sizes for specific frames types
|
|
||||||
var bufSize int
|
|
||||||
frameType := randomInt(1, 100)
|
|
||||||
switch {
|
|
||||||
case frameType <= 5: // I-frames: 5% chance, larger
|
|
||||||
bufSize = randomInt(8000, 15000)
|
|
||||||
case frameType <= 35: // P-frames: 30% chance, medium
|
|
||||||
bufSize = randomInt(2000, 5000)
|
|
||||||
default: // B-frames: 65% chance, smaller
|
|
||||||
bufSize = randomInt(500, 2000)
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := make([]byte, bufSize)
|
|
||||||
|
|
||||||
// Add some timing variation
|
|
||||||
frameDuration := time.Duration(randomInt(900, 1100)) * time.Millisecond
|
|
||||||
|
|
||||||
err = videoTrack.WriteSample(media.Sample{Data: buf, Duration: frameDuration})
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("webrtc.WriteSample ERROR: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simulate some burst of smaller packets
|
|
||||||
if randomInt(1, 10) == 1 { // 10% chance
|
|
||||||
burstCount := randomInt(2, 5)
|
|
||||||
for i := 0; i < burstCount; i++ {
|
|
||||||
smallBuf := make([]byte, randomInt(100, 400))
|
|
||||||
time.Sleep(time.Duration(randomInt(10, 50)) * time.Millisecond)
|
|
||||||
|
|
||||||
frameDuration = time.Duration(randomInt(16, 33)) * time.Millisecond
|
|
||||||
|
|
||||||
err = videoTrack.WriteSample(media.Sample{Data: smallBuf, Duration: frameDuration})
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("webrtc.WriteSample burst ERROR: %s", err)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
log.Println("WebRTC: Media track opened")
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanup closes all channels and transports
|
// cleanup closes all channels and transports
|
||||||
func (c *WebRTCPeer) cleanup() {
|
func (c *WebRTCPeer) cleanup() {
|
||||||
|
// Stop media channel
|
||||||
|
if c.mediaChannel != nil {
|
||||||
|
c.mediaChannel.Stop()
|
||||||
|
}
|
||||||
// Close this side of the SOCKS pipe.
|
// Close this side of the SOCKS pipe.
|
||||||
if c.writePipe != nil { // c.writePipe can be nil in tests.
|
if c.writePipe != nil { // c.writePipe can be nil in tests.
|
||||||
c.writePipe.Close()
|
c.writePipe.Close()
|
||||||
|
|
145
common/media/channel.go
Normal file
145
common/media/channel.go
Normal file
|
@ -0,0 +1,145 @@
|
||||||
|
package media
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"log"
|
||||||
|
"math/big"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pion/interceptor"
|
||||||
|
"github.com/pion/webrtc/v4"
|
||||||
|
"github.com/pion/webrtc/v4/pkg/media"
|
||||||
|
)
|
||||||
|
|
||||||
|
func randomInt(min, max int) int {
|
||||||
|
nBig, err := rand.Int(rand.Reader, big.NewInt(int64(max-min)))
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return int(nBig.Int64()) + min
|
||||||
|
}
|
||||||
|
|
||||||
|
type RTPReader interface {
|
||||||
|
Read(b []byte) (n int, a interceptor.Attributes, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MediaChannel handles media track simulation for WebRTC connections
|
||||||
|
type MediaChannel struct {
|
||||||
|
stopCh chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMediaChannel creates a new media channel
|
||||||
|
func NewMediaChannel() *MediaChannel {
|
||||||
|
return &MediaChannel{
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StartVideoTrack starts video track simulation on the given peer connection
|
||||||
|
func (mc *MediaChannel) StartVideoTrack(pc *webrtc.PeerConnection) error {
|
||||||
|
videoTrack, err := webrtc.NewTrackLocalStaticSample(
|
||||||
|
webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeAV1}, "video", "pion",
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("webrtc.NewTrackLocalStaticSample ERROR: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
rtpSender, err := pc.AddTrack(videoTrack)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("webrtc.AddTrack ERROR: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
go mc.handleRTCP(rtpSender)
|
||||||
|
go mc.simulateVideoFrames(videoTrack)
|
||||||
|
|
||||||
|
log.Println("WebRTC: Media track opened")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the media simulation
|
||||||
|
func (mc *MediaChannel) Stop() {
|
||||||
|
close(mc.stopCh)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *MediaChannel) handleRTCP(reader RTPReader) {
|
||||||
|
rtcpBuf := make([]byte, 1500)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-mc.stopCh:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
if _, _, err := reader.Read(rtcpBuf); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *MediaChannel) simulateVideoFrames(track *webrtc.TrackLocalStaticSample) {
|
||||||
|
ticker := time.NewTicker(time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-mc.stopCh:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
// Add jitter to simulate "realistic" media patterns
|
||||||
|
jitterDelay := time.Duration(randomInt(0, 200)) * time.Millisecond
|
||||||
|
time.Sleep(jitterDelay)
|
||||||
|
|
||||||
|
// Vary packet sizes for specific frames types
|
||||||
|
var bufSize int
|
||||||
|
frameType := randomInt(1, 100)
|
||||||
|
switch {
|
||||||
|
case frameType <= 5: // I-frames: 5% chance, larger
|
||||||
|
bufSize = randomInt(8000, 15000)
|
||||||
|
case frameType <= 35: // P-frames: 30% chance, medium
|
||||||
|
bufSize = randomInt(2000, 5000)
|
||||||
|
default: // B-frames: 65% chance, smaller
|
||||||
|
bufSize = randomInt(500, 2000)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, bufSize)
|
||||||
|
|
||||||
|
// Add some timing variation
|
||||||
|
frameDuration := time.Duration(randomInt(900, 1100)) * time.Millisecond
|
||||||
|
|
||||||
|
err := track.WriteSample(media.Sample{Data: buf, Duration: frameDuration})
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("webrtc.WriteSample ERROR: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate some burst of smaller packets
|
||||||
|
if randomInt(1, 10) == 1 { // 10% chance
|
||||||
|
burstCount := randomInt(2, 5)
|
||||||
|
for i := 0; i < burstCount; i++ {
|
||||||
|
smallBuf := make([]byte, randomInt(100, 400))
|
||||||
|
time.Sleep(time.Duration(randomInt(10, 50)) * time.Millisecond)
|
||||||
|
|
||||||
|
frameDuration = time.Duration(randomInt(16, 33)) * time.Millisecond
|
||||||
|
|
||||||
|
err = track.WriteSample(media.Sample{Data: smallBuf, Duration: frameDuration})
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("webrtc.WriteSample burst ERROR: %s", err)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start sets up duplex media handling (both incoming and outgoing tracks)
|
||||||
|
func (mc *MediaChannel) Start(pc *webrtc.PeerConnection) error {
|
||||||
|
// Set up handler for incoming tracks
|
||||||
|
pc.OnTrack(func(remote *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
|
||||||
|
log.Printf("Media Track received: streamId(%s) id(%s) rid(%s)", remote.StreamID(), remote.ID(), remote.RID())
|
||||||
|
go mc.handleRTCP(receiver)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Set up outgoing media track
|
||||||
|
return mc.StartVideoTrack(pc)
|
||||||
|
}
|
20
common/media/channel_test.go
Normal file
20
common/media/channel_test.go
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
package media
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMediaChannelStop(t *testing.T) {
|
||||||
|
mc := NewMediaChannel()
|
||||||
|
|
||||||
|
// This should not panic
|
||||||
|
mc.Stop()
|
||||||
|
|
||||||
|
// Verify that the stop channel is closed
|
||||||
|
select {
|
||||||
|
case <-mc.stopCh:
|
||||||
|
// Channel is closed, which is expected
|
||||||
|
default:
|
||||||
|
t.Fatal("MediaChannel stopCh should be closed after Stop()")
|
||||||
|
}
|
||||||
|
}
|
|
@ -48,6 +48,7 @@ import (
|
||||||
|
|
||||||
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/constants"
|
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/constants"
|
||||||
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/event"
|
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/event"
|
||||||
|
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/media"
|
||||||
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages"
|
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages"
|
||||||
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/namematcher"
|
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/namematcher"
|
||||||
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/task"
|
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/task"
|
||||||
|
@ -451,18 +452,13 @@ func (sf *SnowflakeProxy) makePeerConnectionFromOffer(
|
||||||
return nil, fmt.Errorf("accept: NewPeerConnection: %s", err)
|
return nil, fmt.Errorf("accept: NewPeerConnection: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
pc.OnTrack(func(remote *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) {
|
|
||||||
log.Printf("Track has started streamId(%s) id(%s) rid(%s) \n", remote.StreamID(), remote.ID(), remote.RID())
|
|
||||||
|
|
||||||
for {
|
// Start duplex media handling (both incoming and outgoing tracks)
|
||||||
rtcpBuf := make([]byte, 1500)
|
mediaChannel := media.NewMediaChannel()
|
||||||
for {
|
err = mediaChannel.Start(pc)
|
||||||
if _, _, err := receiver.Read(rtcpBuf); err != nil {
|
if err != nil {
|
||||||
return
|
log.Printf("Failed to setup proxy media channel: %v", err)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
pc.OnDataChannel(func(dc *webrtc.DataChannel) {
|
pc.OnDataChannel(func(dc *webrtc.DataChannel) {
|
||||||
log.Printf("New Data Channel %s-%d\n", dc.Label(), dc.ID())
|
log.Printf("New Data Channel %s-%d\n", dc.Label(), dc.ID())
|
||||||
|
@ -512,6 +508,9 @@ func (sf *SnowflakeProxy) makePeerConnectionFromOffer(
|
||||||
}
|
}
|
||||||
sf.EventDispatcher.OnNewSnowflakeEvent(event.EventOnProxyConnectionOver{Country: country})
|
sf.EventDispatcher.OnNewSnowflakeEvent(event.EventOnProxyConnectionOver{Country: country})
|
||||||
|
|
||||||
|
// Clean up media channel
|
||||||
|
mediaChannel.Stop()
|
||||||
|
|
||||||
conn.dc = nil
|
conn.dc = nil
|
||||||
dc.Close()
|
dc.Close()
|
||||||
pw.Close()
|
pw.Close()
|
||||||
|
@ -838,7 +837,7 @@ func (sf *SnowflakeProxy) Start() error {
|
||||||
err = sf.checkNATType(config, sf.NATProbeURL)
|
err = sf.checkNATType(config, sf.NATProbeURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// non-fatal error. Log it and continue
|
// non-fatal error. Log it and continue
|
||||||
log.Printf(err.Error())
|
log.Printf("%s", err.Error())
|
||||||
setCurrentNATType(NATUnknown)
|
setCurrentNATType(NATUnknown)
|
||||||
}
|
}
|
||||||
sf.EventDispatcher.OnNewSnowflakeEvent(event.EventOnCurrentNATTypeDetermined{CurNATType: getCurrentNATType()})
|
sf.EventDispatcher.OnNewSnowflakeEvent(event.EventOnCurrentNATTypeDetermined{CurNATType: getCurrentNATType()})
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue