Add synchronization around destroying DataChannels and PeerConnections

From https://trac.torproject.org/projects/tor/ticket/21312#comment:33
This commit is contained in:
Arlo Breault 2018-03-14 13:35:39 -04:00
parent 40bf7664d4
commit 1114acbcb4
3 changed files with 54 additions and 16 deletions

View file

@ -5,6 +5,7 @@ import (
"errors" "errors"
"io" "io"
"log" "log"
"sync"
"time" "time"
"github.com/dchest/uniuri" "github.com/dchest/uniuri"
@ -35,6 +36,9 @@ type WebRTCPeer struct {
closed bool closed bool
lock sync.Mutex // Synchronization for DataChannel destruction
once sync.Once // Synchronization for PeerConnection destruction
BytesLogger BytesLogger
} }
@ -69,6 +73,8 @@ func (c *WebRTCPeer) Read(b []byte) (int, error) {
// Writes bytes out to remote WebRTC. // Writes bytes out to remote WebRTC.
// As part of |io.ReadWriter| // As part of |io.ReadWriter|
func (c *WebRTCPeer) Write(b []byte) (int, error) { func (c *WebRTCPeer) Write(b []byte) (int, error) {
c.lock.Lock()
defer c.lock.Unlock()
c.BytesLogger.AddOutbound(len(b)) c.BytesLogger.AddOutbound(len(b))
// TODO: Buffering could be improved / separated out of WebRTCPeer. // TODO: Buffering could be improved / separated out of WebRTCPeer.
if nil == c.transport { if nil == c.transport {
@ -82,14 +88,12 @@ func (c *WebRTCPeer) Write(b []byte) (int, error) {
// As part of |Snowflake| // As part of |Snowflake|
func (c *WebRTCPeer) Close() error { func (c *WebRTCPeer) Close() error {
if c.closed { // Skip if already closed. c.once.Do(func() {
return nil c.closed = true
} c.cleanup()
// Mark for deletion. c.Reset()
c.closed = true log.Printf("WebRTC: Closing")
c.cleanup() })
c.Reset()
log.Printf("WebRTC: Closing")
return nil return nil
} }
@ -194,6 +198,8 @@ func (c *WebRTCPeer) preparePeerConnection() error {
// Create a WebRTC DataChannel locally. // Create a WebRTC DataChannel locally.
func (c *WebRTCPeer) establishDataChannel() error { func (c *WebRTCPeer) establishDataChannel() error {
c.lock.Lock()
defer c.lock.Unlock()
if c.transport != nil { if c.transport != nil {
panic("Unexpected datachannel already exists!") panic("Unexpected datachannel already exists!")
} }
@ -206,6 +212,8 @@ func (c *WebRTCPeer) establishDataChannel() error {
return err return err
} }
dc.OnOpen = func() { dc.OnOpen = func() {
c.lock.Lock()
defer c.lock.Unlock()
log.Println("WebRTC: DataChannel.OnOpen") log.Println("WebRTC: DataChannel.OnOpen")
if nil != c.transport { if nil != c.transport {
panic("WebRTC: transport already exists.") panic("WebRTC: transport already exists.")
@ -220,10 +228,12 @@ func (c *WebRTCPeer) establishDataChannel() error {
c.transport = dc c.transport = dc
} }
dc.OnClose = func() { dc.OnClose = func() {
c.lock.Lock()
// Future writes will go to the buffer until a new DataChannel is available. // Future writes will go to the buffer until a new DataChannel is available.
if nil == c.transport { if nil == c.transport {
// Closed locally, as part of a reset. // Closed locally, as part of a reset.
log.Println("WebRTC: DataChannel.OnClose [locally]") log.Println("WebRTC: DataChannel.OnClose [locally]")
c.lock.Unlock()
return return
} }
// Closed remotely, need to reset everything. // Closed remotely, need to reset everything.
@ -231,6 +241,9 @@ func (c *WebRTCPeer) establishDataChannel() error {
log.Println("WebRTC: DataChannel.OnClose [remotely]") log.Println("WebRTC: DataChannel.OnClose [remotely]")
c.transport = nil c.transport = nil
c.pc.DeleteDataChannel(dc) c.pc.DeleteDataChannel(dc)
// Unlock before Close'ing, since it calls cleanup and asks for the
// lock to check if the transport needs to be be deleted.
c.lock.Unlock()
c.Close() c.Close()
} }
dc.OnMessage = func(msg []byte) { dc.OnMessage = func(msg []byte) {
@ -321,16 +334,23 @@ func (c *WebRTCPeer) cleanup() {
c.writePipe.Close() c.writePipe.Close()
c.writePipe = nil c.writePipe = nil
} }
c.lock.Lock()
if nil != c.transport { if nil != c.transport {
log.Printf("WebRTC: closing DataChannel") log.Printf("WebRTC: closing DataChannel")
dataChannel := c.transport dataChannel := c.transport
// Setting transport to nil *before* dc Close indicates to OnClose that // Setting transport to nil *before* dc Close indicates to OnClose that
// this was locally triggered. // this was locally triggered.
c.transport = nil c.transport = nil
// Release the lock before calling DeleteDataChannel (which in turn
// calls Close on the dataChannel), but after nil'ing out the transport,
// since otherwise we'll end up in the onClose handler in a deadlock.
c.lock.Unlock()
if c.pc == nil { if c.pc == nil {
panic("DataChannel w/o PeerConnection, not good.") panic("DataChannel w/o PeerConnection, not good.")
} }
c.pc.DeleteDataChannel(dataChannel.(*webrtc.DataChannel)) c.pc.DeleteDataChannel(dataChannel.(*webrtc.DataChannel))
} else {
c.lock.Unlock()
} }
if nil != c.pc { if nil != c.pc {
log.Printf("WebRTC: closing PeerConnection") log.Printf("WebRTC: closing PeerConnection")

View file

@ -62,6 +62,9 @@ type webRTCConn struct {
dc *webrtc.DataChannel dc *webrtc.DataChannel
pc *webrtc.PeerConnection pc *webrtc.PeerConnection
pr *io.PipeReader pr *io.PipeReader
lock sync.Mutex // Synchronization for DataChannel destruction
once sync.Once // Synchronization for PeerConnection destruction
} }
func (c *webRTCConn) Read(b []byte) (int, error) { func (c *webRTCConn) Read(b []byte) (int, error) {
@ -69,6 +72,8 @@ func (c *webRTCConn) Read(b []byte) (int, error) {
} }
func (c *webRTCConn) Write(b []byte) (int, error) { func (c *webRTCConn) Write(b []byte) (int, error) {
c.lock.Lock()
defer c.lock.Unlock()
// log.Printf("webrtc Write %d %+q", len(b), string(b)) // log.Printf("webrtc Write %d %+q", len(b), string(b))
log.Printf("Write %d bytes --> WebRTC", len(b)) log.Printf("Write %d bytes --> WebRTC", len(b))
if c.dc != nil { if c.dc != nil {
@ -77,8 +82,11 @@ func (c *webRTCConn) Write(b []byte) (int, error) {
return len(b), nil return len(b), nil
} }
func (c *webRTCConn) Close() error { func (c *webRTCConn) Close() (err error) {
return c.pc.Destroy() c.once.Do(func() {
err = c.pc.Destroy()
})
return
} }
func (c *webRTCConn) LocalAddr() net.Addr { func (c *webRTCConn) LocalAddr() net.Addr {
@ -255,17 +263,18 @@ func makePeerConnectionFromOffer(sdp *webrtc.SessionDescription, config *webrtc.
log.Println("OnDataChannel") log.Println("OnDataChannel")
pr, pw := io.Pipe() pr, pw := io.Pipe()
conn := &webRTCConn{pc: pc, dc: dc, pr: pr} conn := &webRTCConn{pc: pc, dc: dc, pr: pr}
dc.OnOpen = func() { dc.OnOpen = func() {
log.Println("OnOpen channel") log.Println("OnOpen channel")
} }
dc.OnClose = func() { dc.OnClose = func() {
conn.lock.Lock()
defer conn.lock.Unlock()
log.Println("OnClose channel") log.Println("OnClose channel")
pw.Close()
conn.dc = nil conn.dc = nil
pc.DeleteDataChannel(dc) pc.DeleteDataChannel(dc)
pw.Close()
} }
dc.OnMessage = func(msg []byte) { dc.OnMessage = func(msg []byte) {
log.Printf("OnMessage <--- %d bytes", len(msg)) log.Printf("OnMessage <--- %d bytes", len(msg))

View file

@ -43,6 +43,9 @@ type webRTCConn struct {
dc *webrtc.DataChannel dc *webrtc.DataChannel
pc *webrtc.PeerConnection pc *webrtc.PeerConnection
pr *io.PipeReader pr *io.PipeReader
lock sync.Mutex // Synchronization for DataChannel destruction
once sync.Once // Synchronization for PeerConnection destruction
} }
func (c *webRTCConn) Read(b []byte) (int, error) { func (c *webRTCConn) Read(b []byte) (int, error) {
@ -50,6 +53,8 @@ func (c *webRTCConn) Read(b []byte) (int, error) {
} }
func (c *webRTCConn) Write(b []byte) (int, error) { func (c *webRTCConn) Write(b []byte) (int, error) {
c.lock.Lock()
defer c.lock.Unlock()
// log.Printf("webrtc Write %d %+q", len(b), string(b)) // log.Printf("webrtc Write %d %+q", len(b), string(b))
log.Printf("Write %d bytes --> WebRTC", len(b)) log.Printf("Write %d bytes --> WebRTC", len(b))
if c.dc != nil { if c.dc != nil {
@ -58,8 +63,11 @@ func (c *webRTCConn) Write(b []byte) (int, error) {
return len(b), nil return len(b), nil
} }
func (c *webRTCConn) Close() error { func (c *webRTCConn) Close() (err error) {
return c.pc.Destroy() c.once.Do(func() {
err = c.pc.Destroy()
})
return
} }
func (c *webRTCConn) LocalAddr() net.Addr { func (c *webRTCConn) LocalAddr() net.Addr {
@ -122,17 +130,18 @@ func makePeerConnectionFromOffer(sdp *webrtc.SessionDescription, config *webrtc.
log.Println("OnDataChannel") log.Println("OnDataChannel")
pr, pw := io.Pipe() pr, pw := io.Pipe()
conn := &webRTCConn{pc: pc, dc: dc, pr: pr} conn := &webRTCConn{pc: pc, dc: dc, pr: pr}
dc.OnOpen = func() { dc.OnOpen = func() {
log.Println("OnOpen channel") log.Println("OnOpen channel")
} }
dc.OnClose = func() { dc.OnClose = func() {
conn.lock.Lock()
defer conn.lock.Unlock()
log.Println("OnClose channel") log.Println("OnClose channel")
pw.Close()
conn.dc = nil conn.dc = nil
pc.DeleteDataChannel(dc) pc.DeleteDataChannel(dc)
pw.Close()
} }
dc.OnMessage = func(msg []byte) { dc.OnMessage = func(msg []byte) {
log.Printf("OnMessage <--- %d bytes", len(msg)) log.Printf("OnMessage <--- %d bytes", len(msg))