Implement DataChannel flow control

This commit is contained in:
Vort 2023-03-27 19:02:10 +03:00 committed by Shelikhoo
parent f8eb86f24d
commit ea01c92cf1
No known key found for this signature in database
GPG key ID: C4D5E79D22B25316
2 changed files with 30 additions and 1 deletions

View file

@ -83,6 +83,8 @@ const (
sessionIDLength = 16 sessionIDLength = 16
) )
const bufferedAmountLowThreshold uint64 = 256 * 1024 // 256 KB
var broker *SignalingServer var broker *SignalingServer
var currentNATTypeAccess = &sync.RWMutex{} var currentNATTypeAccess = &sync.RWMutex{}
@ -408,6 +410,15 @@ func (sf *SnowflakeProxy) makePeerConnectionFromOffer(sdp *webrtc.SessionDescrip
pr, pw := io.Pipe() pr, pw := io.Pipe()
conn := newWebRTCConn(pc, dc, pr, sf.EventDispatcher) conn := newWebRTCConn(pc, dc, pr, sf.EventDispatcher)
dc.SetBufferedAmountLowThreshold(bufferedAmountLowThreshold)
dc.OnBufferedAmountLow(func() {
select {
case conn.sendMoreCh <- struct{}{}:
default:
}
})
dc.OnOpen(func() { dc.OnOpen(func() {
log.Printf("Data Channel %s-%d open\n", dc.Label(), dc.ID()) log.Printf("Data Channel %s-%d open\n", dc.Label(), dc.ID())

View file

@ -16,6 +16,8 @@ import (
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/event" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/event"
) )
const maxBufferedAmount uint64 = 512 * 1024 // 512 KB
var remoteIPPatterns = []*regexp.Regexp{ var remoteIPPatterns = []*regexp.Regexp{
/* IPv4 */ /* IPv4 */
regexp.MustCompile(`(?m)^c=IN IP4 ([\d.]+)(?:(?:\/\d+)?\/\d+)?(:? |\r?\n)`), regexp.MustCompile(`(?m)^c=IN IP4 ([\d.]+)(?:(?:\/\d+)?\/\d+)?(:? |\r?\n)`),
@ -31,18 +33,23 @@ type webRTCConn struct {
lock sync.Mutex // Synchronization for DataChannel destruction lock sync.Mutex // Synchronization for DataChannel destruction
once sync.Once // Synchronization for PeerConnection destruction once sync.Once // Synchronization for PeerConnection destruction
isClosing bool
bytesLogger bytesLogger bytesLogger bytesLogger
eventLogger event.SnowflakeEventReceiver eventLogger event.SnowflakeEventReceiver
inactivityTimeout time.Duration inactivityTimeout time.Duration
activity chan struct{} activity chan struct{}
sendMoreCh chan struct{}
cancelTimeoutLoop context.CancelFunc cancelTimeoutLoop context.CancelFunc
} }
func newWebRTCConn(pc *webrtc.PeerConnection, dc *webrtc.DataChannel, pr *io.PipeReader, eventLogger event.SnowflakeEventReceiver) *webRTCConn { func newWebRTCConn(pc *webrtc.PeerConnection, dc *webrtc.DataChannel, pr *io.PipeReader, eventLogger event.SnowflakeEventReceiver) *webRTCConn {
conn := &webRTCConn{pc: pc, dc: dc, pr: pr, eventLogger: eventLogger} conn := &webRTCConn{pc: pc, dc: dc, pr: pr, eventLogger: eventLogger}
conn.isClosing = false
conn.bytesLogger = newBytesSyncLogger() conn.bytesLogger = newBytesSyncLogger()
conn.activity = make(chan struct{}, 100) conn.activity = make(chan struct{}, 100)
conn.sendMoreCh = make(chan struct{}, 1)
conn.inactivityTimeout = 30 * time.Second conn.inactivityTimeout = 30 * time.Second
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
conn.cancelTimeoutLoop = cancel conn.cancelTimeoutLoop = cancel
@ -76,16 +83,27 @@ func (c *webRTCConn) Read(b []byte) (int, error) {
func (c *webRTCConn) Write(b []byte) (int, error) { func (c *webRTCConn) Write(b []byte) (int, error) {
c.bytesLogger.AddInbound(int64(len(b))) c.bytesLogger.AddInbound(int64(len(b)))
c.activity <- struct{}{} select {
case c.activity <- struct{}{}:
default:
}
c.lock.Lock() c.lock.Lock()
defer c.lock.Unlock() defer c.lock.Unlock()
if c.dc != nil { if c.dc != nil {
c.dc.Send(b) c.dc.Send(b)
if !c.isClosing && c.dc.BufferedAmount()+uint64(len(b)) > maxBufferedAmount {
<-c.sendMoreCh
}
} }
return len(b), nil return len(b), nil
} }
func (c *webRTCConn) Close() (err error) { func (c *webRTCConn) Close() (err error) {
c.isClosing = true
select {
case c.sendMoreCh <- struct{}{}:
default:
}
c.once.Do(func() { c.once.Do(func() {
c.cancelTimeoutLoop() c.cancelTimeoutLoop()
err = c.pc.Close() err = c.pc.Close()