Store net.Addr in clientIDAddrMap

This fixes a stats collection bug where we were converting client
addresses between a string and net.Addr using the clientAddr function
multiple times, resulting in an empty string for all addresses.
This commit is contained in:
Cecylia Bocovich 2021-06-19 11:16:38 -04:00
parent aefabe683f
commit 6634f2bec9
4 changed files with 55 additions and 44 deletions

View file

@ -138,7 +138,7 @@ func turbotunnelMode(conn net.Conn, addr net.Addr, pconn *turbotunnel.QueuePacke
// recent WebSocket connection that has had to do with a session, at the
// time the session is established, is the IP address that should be
// credited for the entire KCP session.
clientIDAddrMap.Set(clientID, addr.String())
clientIDAddrMap.Set(clientID, addr)
var wg sync.WaitGroup
wg.Add(2)

View file

@ -181,7 +181,7 @@ func (l *SnowflakeListener) acceptStreams(conn *kcp.UDPSession) error {
}
return err
}
l.QueueConn(&SnowflakeClientConn{Conn: stream, address: clientAddr(addr)})
l.QueueConn(&SnowflakeClientConn{Conn: stream, address: addr})
}
}

View file

@ -1,12 +1,13 @@
package lib
import (
"net"
"sync"
"git.torproject.org/pluggable-transports/snowflake.git/common/turbotunnel"
)
// clientIDMap is a fixed-capacity mapping from ClientIDs to address strings.
// clientIDMap is a fixed-capacity mapping from ClientIDs to a net.Addr.
// Adding a new entry using the Set method causes the oldest existing entry to
// be forgotten.
//
@ -23,7 +24,7 @@ type clientIDMap struct {
// entries is a circular buffer of (ClientID, addr) pairs.
entries []struct {
clientID turbotunnel.ClientID
addr string
addr net.Addr
}
// oldest is the index of the oldest member of the entries buffer, the
// one that will be overwritten at the next call to Set.
@ -38,7 +39,7 @@ func newClientIDMap(capacity int) *clientIDMap {
return &clientIDMap{
entries: make([]struct {
clientID turbotunnel.ClientID
addr string
addr net.Addr
}, capacity),
oldest: 0,
current: make(map[turbotunnel.ClientID]int),
@ -48,7 +49,7 @@ func newClientIDMap(capacity int) *clientIDMap {
// Set adds a mapping from clientID to addr, replacing any previous mapping for
// clientID. It may also cause the clientIDMap to forget at most one other
// mapping, the oldest one.
func (m *clientIDMap) Set(clientID turbotunnel.ClientID, addr string) {
func (m *clientIDMap) Set(clientID turbotunnel.ClientID, addr net.Addr) {
m.lock.Lock()
defer m.lock.Unlock()
if len(m.entries) == 0 {
@ -73,13 +74,13 @@ func (m *clientIDMap) Set(clientID turbotunnel.ClientID, addr string) {
// Get returns a previously stored mapping. The second return value indicates
// whether clientID was actually present in the map. If it is false, then the
// returned address string will be "".
func (m *clientIDMap) Get(clientID turbotunnel.ClientID) (string, bool) {
// returned address will be nil.
func (m *clientIDMap) Get(clientID turbotunnel.ClientID) (net.Addr, bool) {
m.lock.Lock()
defer m.lock.Unlock()
if i, ok := m.current[clientID]; ok {
return m.entries[i].addr, true
} else {
return "", false
return nil, false
}
}

View file

@ -2,6 +2,7 @@ package lib
import (
"encoding/binary"
"net"
"testing"
"git.torproject.org/pluggable-transports/snowflake.git/common/turbotunnel"
@ -19,7 +20,7 @@ func TestClientIDMap(t *testing.T) {
expectGet := func(m *clientIDMap, clientID turbotunnel.ClientID, expectedAddr string, expectedOK bool) {
t.Helper()
addr, ok := m.Get(clientID)
if addr != expectedAddr || ok != expectedOK {
if (ok && addr.String() != expectedAddr) || ok != expectedOK {
t.Errorf("expected (%+q, %v), got (%+q, %v)", expectedAddr, expectedOK, addr, ok)
}
}
@ -32,6 +33,15 @@ func TestClientIDMap(t *testing.T) {
}
}
// Convert a string to a net.Addr
ip := func(addr string) net.Addr {
ret, err := net.ResolveIPAddr("ip", addr)
if err != nil {
t.Errorf("received error: %s", err.Error())
}
return ret
}
// Zero-capacity map can't remember anything.
{
m := newClientIDMap(0)
@ -39,12 +49,12 @@ func TestClientIDMap(t *testing.T) {
expectGet(m, id(0), "", false)
expectGet(m, id(1234), "", false)
m.Set(id(0), "A")
m.Set(id(0), ip("1.1.1.1"))
expectSize(m, 0)
expectGet(m, id(0), "", false)
expectGet(m, id(1234), "", false)
m.Set(id(1234), "A")
m.Set(id(1234), ip("1.1.1.1"))
expectSize(m, 0)
expectGet(m, id(0), "", false)
expectGet(m, id(1234), "", false)
@ -56,60 +66,60 @@ func TestClientIDMap(t *testing.T) {
expectGet(m, id(0), "", false)
expectGet(m, id(1), "", false)
m.Set(id(0), "A")
m.Set(id(0), ip("1.1.1.1"))
expectSize(m, 1)
expectGet(m, id(0), "A", true)
expectGet(m, id(0), "1.1.1.1", true)
expectGet(m, id(1), "", false)
m.Set(id(1), "B") // forgets the (0, "A") entry
m.Set(id(1), ip("1.1.1.2")) // forgets the (0, "1.1.1.1") entry
expectSize(m, 1)
expectGet(m, id(0), "", false)
expectGet(m, id(1), "B", true)
expectGet(m, id(1), "1.1.1.2", true)
m.Set(id(1), "C") // forgets the (1, "B") entry
m.Set(id(1), ip("1.1.1.3")) // forgets the (1, "1.1.1.2") entry
expectSize(m, 1)
expectGet(m, id(0), "", false)
expectGet(m, id(1), "C", true)
expectGet(m, id(1), "1.1.1.3", true)
}
{
m := newClientIDMap(5)
m.Set(id(0), "A")
m.Set(id(1), "B")
m.Set(id(2), "C")
m.Set(id(0), "D") // shadows the (0, "D") entry
m.Set(id(3), "E")
m.Set(id(0), ip("1.1.1.1"))
m.Set(id(1), ip("1.1.1.2"))
m.Set(id(2), ip("1.1.1.3"))
m.Set(id(0), ip("1.1.1.4")) // shadows the (0, "1.1.1.1") entry
m.Set(id(3), ip("1.1.1.5"))
expectSize(m, 4)
expectGet(m, id(0), "D", true)
expectGet(m, id(1), "B", true)
expectGet(m, id(2), "C", true)
expectGet(m, id(3), "E", true)
expectGet(m, id(0), "1.1.1.4", true)
expectGet(m, id(1), "1.1.1.2", true)
expectGet(m, id(2), "1.1.1.3", true)
expectGet(m, id(3), "1.1.1.5", true)
expectGet(m, id(4), "", false)
m.Set(id(4), "F") // forgets the (0, "A") entry but should preserve (0, "D")
m.Set(id(4), ip("1.1.1.6")) // forgets the (0, "1.1.1.1") entry but should preserve (0, "1.1.1.4")
expectSize(m, 5)
expectGet(m, id(0), "D", true)
expectGet(m, id(1), "B", true)
expectGet(m, id(2), "C", true)
expectGet(m, id(3), "E", true)
expectGet(m, id(4), "F", true)
expectGet(m, id(0), "1.1.1.4", true)
expectGet(m, id(1), "1.1.1.2", true)
expectGet(m, id(2), "1.1.1.3", true)
expectGet(m, id(3), "1.1.1.5", true)
expectGet(m, id(4), "1.1.1.6", true)
m.Set(id(5), "G") // forgets the (1, "B") entry
m.Set(id(0), "H") // forgets the (2, "C") entry and shadows (0, "D")
m.Set(id(5), ip("1.1.1.7")) // forgets the (1, "1.1.1.2") entry
m.Set(id(0), ip("1.1.1.8")) // forgets the (2, "1.1.1.3") entry and shadows (0, "1.1.1.4")
expectSize(m, 4)
expectGet(m, id(0), "H", true)
expectGet(m, id(0), "1.1.1.8", true)
expectGet(m, id(1), "", false)
expectGet(m, id(2), "", false)
expectGet(m, id(3), "E", true)
expectGet(m, id(4), "F", true)
expectGet(m, id(5), "G", true)
expectGet(m, id(3), "1.1.1.5", true)
expectGet(m, id(4), "1.1.1.6", true)
expectGet(m, id(5), "1.1.1.7", true)
m.Set(id(0), "I") // forgets the (0, "D") entry and shadows (0, "H")
m.Set(id(0), "J") // forgets the (3, "E") entry and shadows (0, "I")
m.Set(id(0), "K") // forgets the (4, "F") entry and shadows (0, "J")
m.Set(id(0), "L") // forgets the (5, "G") entry and shadows (0, "K")
m.Set(id(0), ip("1.1.1.9")) // forgets the (0, "1.1.1.4") entry and shadows (0, "1.1.1.8")
m.Set(id(0), ip("1.1.1.10")) // forgets the (3, "1.1.1.5") entry and shadows (0, "1.1.1.9")
m.Set(id(0), ip("1.1.1.11")) // forgets the (4, "1.1.1.6") entry and shadows (0, "1.1.1.10")
m.Set(id(0), ip("1.1.1.12")) // forgets the (5, "1.1.1.7") entry and shadows (0, "1.1.1.11")
expectSize(m, 1)
expectGet(m, id(0), "L", true)
expectGet(m, id(0), "1.1.1.12", true)
expectGet(m, id(1), "", false)
expectGet(m, id(2), "", false)
expectGet(m, id(3), "", false)