pull copyLoop out of goroutine, better pop and reset

This commit is contained in:
Serene Han 2016-06-14 17:07:21 -07:00
parent a71c98c0ae
commit 2bf0e5457e
5 changed files with 70 additions and 35 deletions

View file

@ -65,9 +65,9 @@ func (f FakeSocksConn) Grant(addr *net.TCPAddr) error { return nil }
type FakePeers struct{ toRelease *webRTCConn } type FakePeers struct{ toRelease *webRTCConn }
func (f FakePeers) Collect() error { return nil } func (f FakePeers) Collect() (Snowflake, error) { return &webRTCConn{}, nil }
func (f FakePeers) Pop() Snowflake { return nil } func (f FakePeers) Pop() Snowflake { return nil }
func (f FakePeers) Melted() <-chan struct{} { return nil } func (f FakePeers) Melted() <-chan struct{} { return nil }
func TestSnowflakeClient(t *testing.T) { func TestSnowflakeClient(t *testing.T) {
@ -81,16 +81,16 @@ func TestSnowflakeClient(t *testing.T) {
Convey("Collecting a Snowflake requires a Tongue.", func() { Convey("Collecting a Snowflake requires a Tongue.", func() {
p := NewPeers(1) p := NewPeers(1)
err := p.Collect() _, err := p.Collect()
So(err, ShouldNotBeNil) So(err, ShouldNotBeNil)
So(p.Count(), ShouldEqual, 0) So(p.Count(), ShouldEqual, 0)
// Set the dialer so that collection is possible. // Set the dialer so that collection is possible.
p.Tongue = FakeDialer{} p.Tongue = FakeDialer{}
err = p.Collect() _, err = p.Collect()
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(p.Count(), ShouldEqual, 1) So(p.Count(), ShouldEqual, 1)
// S // S
err = p.Collect() _, err = p.Collect()
}) })
Convey("Collection continues until capacity.", func() { Convey("Collection continues until capacity.", func() {
@ -100,13 +100,13 @@ func TestSnowflakeClient(t *testing.T) {
// Fill up to capacity. // Fill up to capacity.
for i := 0; i < c; i++ { for i := 0; i < c; i++ {
fmt.Println("Adding snowflake ", i) fmt.Println("Adding snowflake ", i)
err := p.Collect() _, err := p.Collect()
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(p.Count(), ShouldEqual, i+1) So(p.Count(), ShouldEqual, i+1)
} }
// But adding another gives an error. // But adding another gives an error.
So(p.Count(), ShouldEqual, c) So(p.Count(), ShouldEqual, c)
err := p.Collect() _, err := p.Collect()
So(err, ShouldNotBeNil) So(err, ShouldNotBeNil)
So(p.Count(), ShouldEqual, c) So(p.Count(), ShouldEqual, c)
@ -116,7 +116,7 @@ func TestSnowflakeClient(t *testing.T) {
So(s, ShouldNotBeNil) So(s, ShouldNotBeNil)
So(p.Count(), ShouldEqual, c-1) So(p.Count(), ShouldEqual, c-1)
err = p.Collect() _, err = p.Collect()
So(err, ShouldBeNil) So(err, ShouldBeNil)
So(p.Count(), ShouldEqual, c) So(p.Count(), ShouldEqual, c)
}) })
@ -149,6 +149,26 @@ func TestSnowflakeClient(t *testing.T) {
So(p.Count(), ShouldEqual, 0) So(p.Count(), ShouldEqual, 0)
}) })
Convey("Pop skips over closed peers.", func() {
p := NewPeers(4)
p.Tongue = FakeDialer{}
wc1, _ := p.Collect()
wc2, _ := p.Collect()
wc3, _ := p.Collect()
So(wc1, ShouldNotBeNil)
So(wc2, ShouldNotBeNil)
So(wc3, ShouldNotBeNil)
wc1.Close()
r := p.Pop()
So(p.Count(), ShouldEqual, 2)
So(r, ShouldEqual, wc2)
wc4, _ := p.Collect()
wc2.Close()
wc3.Close()
r = p.Pop()
So(r, ShouldEqual, wc4)
})
}) })
Convey("Snowflake", t, func() { Convey("Snowflake", t, func() {

View file

@ -17,10 +17,9 @@ type Resetter interface {
// Interface for a single remote WebRTC peer. // Interface for a single remote WebRTC peer.
// In the Client context, "Snowflake" refers to the remote browser proxy. // In the Client context, "Snowflake" refers to the remote browser proxy.
type Snowflake interface { type Snowflake interface {
io.ReadWriter io.ReadWriteCloser
Resetter Resetter
Connector Connector
Close() error
} }
// Interface for catching Snowflakes. (aka the remote dialer) // Interface for catching Snowflakes. (aka the remote dialer)
@ -34,7 +33,7 @@ type SnowflakeCollector interface {
// Add a Snowflake to the collection. // Add a Snowflake to the collection.
// Implementation should decide how to connect and maintain the webRTCConn. // Implementation should decide how to connect and maintain the webRTCConn.
Collect() error Collect() (Snowflake, error)
// Remove and return the most available Snowflake from the collection. // Remove and return the most available Snowflake from the collection.
Pop() Snowflake Pop() Snowflake
@ -52,6 +51,6 @@ type SocksConnector interface {
// Interface for the Snowflake's transport. (Typically just webrtc.DataChannel) // Interface for the Snowflake's transport. (Typically just webrtc.DataChannel)
type SnowflakeDataChannel interface { type SnowflakeDataChannel interface {
io.Closer
Send([]byte) Send([]byte)
Close() error
} }

View file

@ -40,34 +40,43 @@ func NewPeers(max int) *Peers {
} }
// As part of |SnowflakeCollector| interface. // As part of |SnowflakeCollector| interface.
func (p *Peers) Collect() error { func (p *Peers) Collect() (Snowflake, error) {
cnt := p.Count() cnt := p.Count()
s := fmt.Sprintf("Currently at [%d/%d]", cnt, p.capacity) s := fmt.Sprintf("Currently at [%d/%d]", cnt, p.capacity)
if cnt >= p.capacity { if cnt >= p.capacity {
s := fmt.Sprintf("At capacity [%d/%d]", cnt, p.capacity) s := fmt.Sprintf("At capacity [%d/%d]", cnt, p.capacity)
return errors.New(s) return nil, errors.New(s)
} }
log.Println("WebRTC: Collecting a new Snowflake.", s) log.Println("WebRTC: Collecting a new Snowflake.", s)
// Engage the Snowflake Catching interface, which must be available. // Engage the Snowflake Catching interface, which must be available.
if nil == p.Tongue { if nil == p.Tongue {
return errors.New("Missing Tongue to catch Snowflakes with.") return nil, errors.New("Missing Tongue to catch Snowflakes with.")
} }
// BUG: some broker conflict here.
connection, err := p.Tongue.Catch() connection, err := p.Tongue.Catch()
if nil == connection || nil != err { if nil != err {
return err return nil, err
} }
// Track new valid Snowflake in internal collection and pass along. // Track new valid Snowflake in internal collection and pass along.
p.activePeers.PushBack(connection) p.activePeers.PushBack(connection)
p.snowflakeChan <- connection p.snowflakeChan <- connection
return nil return connection, nil
} }
// As part of |SnowflakeCollector| interface. // As part of |SnowflakeCollector| interface.
func (p *Peers) Pop() Snowflake { func (p *Peers) Pop() Snowflake {
// Blocks until an available snowflake appears. // Blocks until an available, valid snowflake appears.
snowflake, ok := <-p.snowflakeChan var snowflake Snowflake
if !ok { var ok bool
return nil for nil == snowflake {
snowflake, ok = <-p.snowflakeChan
conn := snowflake.(*webRTCConn)
if !ok {
return nil
}
if conn.closed {
snowflake = nil
}
} }
// Set to use the same rate-limited traffic logger to keep consistency. // Set to use the same rate-limited traffic logger to keep consistency.
snowflake.(*webRTCConn).BytesLogger = p.BytesLogger snowflake.(*webRTCConn).BytesLogger = p.BytesLogger
@ -105,7 +114,6 @@ func (p *Peers) End() {
p.melt <- struct{}{} p.melt <- struct{}{}
cnt := p.Count() cnt := p.Count()
for e := p.activePeers.Front(); e != nil; { for e := p.activePeers.Front(); e != nil; {
log.Println(e, e.Value)
next := e.Next() next := e.Next()
conn := e.Value.(*webRTCConn) conn := e.Value.(*webRTCConn)
conn.Close() conn.Close()

View file

@ -31,7 +31,7 @@ var handlerChan = make(chan int)
func ConnectLoop(snowflakes SnowflakeCollector) { func ConnectLoop(snowflakes SnowflakeCollector) {
for { for {
// Check if ending is necessary. // Check if ending is necessary.
err := snowflakes.Collect() _, err := snowflakes.Collect()
if nil != err { if nil != err {
log.Println("WebRTC:", err, log.Println("WebRTC:", err,
" Retrying in", ReconnectTimeout, "seconds...") " Retrying in", ReconnectTimeout, "seconds...")
@ -51,6 +51,7 @@ func socksAcceptLoop(ln *pt.SocksListener, snowflakes SnowflakeCollector) error
defer ln.Close() defer ln.Close()
log.Println("Started SOCKS listener.") log.Println("Started SOCKS listener.")
for { for {
log.Println("SOCKS listening...")
conn, err := ln.AcceptSocks() conn, err := ln.AcceptSocks()
log.Println("SOCKS accepted: ", conn.Req) log.Println("SOCKS accepted: ", conn.Req)
if err != nil { if err != nil {
@ -81,20 +82,22 @@ func handler(socks SocksConnector, snowflakes SnowflakeCollector) error {
return errors.New("handler: Received invalid Snowflake") return errors.New("handler: Received invalid Snowflake")
} }
defer socks.Close() defer socks.Close()
log.Println("---- Snowflake assigned ----") log.Println("---- Handler: snowflake assigned ----")
err := socks.Grant(&net.TCPAddr{IP: net.IPv4zero, Port: 0}) err := socks.Grant(&net.TCPAddr{IP: net.IPv4zero, Port: 0})
if err != nil { if err != nil {
return err return err
} }
// Begin exchanging data. go func() {
// BUG(serene): There's a leak here when multiplexed. // When WebRTC resets, close the SOCKS connection, which ends
go copyLoop(socks, snowflake) // the copyLoop below and induces new handler.
snowflake.WaitForReset()
socks.Close()
}()
// When WebRTC resets, close the SOCKS connection, which induces new handler. // Begin exchanging data.
// TODO: Double check this / fix it. copyLoop(socks, snowflake)
snowflake.WaitForReset() log.Println("---- Handler: closed ---")
log.Println("---- Closed ---")
return nil return nil
} }

View file

@ -31,7 +31,7 @@ type webRTCConn struct {
BytesLogger BytesLogger
} }
// Read bytes from remote WebRTC. // Read bytes from local SOCKS.
// As part of |io.ReadWriter| // As part of |io.ReadWriter|
func (c *webRTCConn) Read(b []byte) (int, error) { func (c *webRTCConn) Read(b []byte) (int, error) {
return c.recvPipe.Read(b) return c.recvPipe.Read(b)
@ -62,11 +62,11 @@ func (c *webRTCConn) Close() error {
// As part of |Resetter| // As part of |Resetter|
func (c *webRTCConn) Reset() { func (c *webRTCConn) Reset() {
c.Close()
go func() { go func() {
c.reset <- struct{}{} c.reset <- struct{}{}
log.Println("WebRTC resetting...") log.Println("WebRTC resetting...")
}() }()
c.Close()
} }
// As part of |Resetter| // As part of |Resetter|
@ -282,6 +282,11 @@ func (c *webRTCConn) cleanup() {
if nil != c.errorChannel { if nil != c.errorChannel {
close(c.errorChannel) close(c.errorChannel)
} }
// Close this side of the SOCKS pipe.
if nil != c.writePipe {
c.writePipe.Close()
c.writePipe = nil
}
if nil != c.transport { if nil != c.transport {
log.Printf("WebRTC: closing DataChannel") log.Printf("WebRTC: closing DataChannel")
dataChannel := c.transport dataChannel := c.transport