mirror of
https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake.git
synced 2025-10-13 20:11:19 -04:00
Use a sync.Pool to reuse packet buffers in QueuePacketConn.
This is meant to reduce overall allocations. See past discussion at https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/-/issues/40260#note_2885524 ff.
This commit is contained in:
parent
97c930013b
commit
c097d5f3bc
4 changed files with 116 additions and 14 deletions
|
@ -27,23 +27,29 @@ type QueuePacketConn struct {
|
||||||
recvQueue chan taggedPacket
|
recvQueue chan taggedPacket
|
||||||
closeOnce sync.Once
|
closeOnce sync.Once
|
||||||
closed chan struct{}
|
closed chan struct{}
|
||||||
|
mtu int
|
||||||
|
// Pool of reusable mtu-sized buffers.
|
||||||
|
bufPool sync.Pool
|
||||||
// What error to return when the QueuePacketConn is closed.
|
// What error to return when the QueuePacketConn is closed.
|
||||||
err atomic.Value
|
err atomic.Value
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewQueuePacketConn makes a new QueuePacketConn, set to track recent clients
|
// NewQueuePacketConn makes a new QueuePacketConn, set to track recent clients
|
||||||
// for at least a duration of timeout.
|
// for at least a duration of timeout. The maximum packet size is mtu.
|
||||||
func NewQueuePacketConn(localAddr net.Addr, timeout time.Duration) *QueuePacketConn {
|
func NewQueuePacketConn(localAddr net.Addr, timeout time.Duration, mtu int) *QueuePacketConn {
|
||||||
return &QueuePacketConn{
|
return &QueuePacketConn{
|
||||||
clients: NewClientMap(timeout),
|
clients: NewClientMap(timeout),
|
||||||
localAddr: localAddr,
|
localAddr: localAddr,
|
||||||
recvQueue: make(chan taggedPacket, queueSize),
|
recvQueue: make(chan taggedPacket, queueSize),
|
||||||
closed: make(chan struct{}),
|
closed: make(chan struct{}),
|
||||||
|
mtu: mtu,
|
||||||
|
bufPool: sync.Pool{New: func() interface{} { return make([]byte, mtu) }},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueueIncoming queues an incoming packet and its source address, to be
|
// QueueIncoming queues an incoming packet and its source address, to be
|
||||||
// returned in a future call to ReadFrom.
|
// returned in a future call to ReadFrom. If p is longer than the MTU, only its
|
||||||
|
// first MTU bytes will be used.
|
||||||
func (c *QueuePacketConn) QueueIncoming(p []byte, addr net.Addr) {
|
func (c *QueuePacketConn) QueueIncoming(p []byte, addr net.Addr) {
|
||||||
select {
|
select {
|
||||||
case <-c.closed:
|
case <-c.closed:
|
||||||
|
@ -52,12 +58,18 @@ func (c *QueuePacketConn) QueueIncoming(p []byte, addr net.Addr) {
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
// Copy the slice so that the caller may reuse it.
|
// Copy the slice so that the caller may reuse it.
|
||||||
buf := make([]byte, len(p))
|
buf := c.bufPool.Get().([]byte)
|
||||||
|
if len(p) < cap(buf) {
|
||||||
|
buf = buf[:len(p)]
|
||||||
|
} else {
|
||||||
|
buf = buf[:cap(buf)]
|
||||||
|
}
|
||||||
copy(buf, p)
|
copy(buf, p)
|
||||||
select {
|
select {
|
||||||
case c.recvQueue <- taggedPacket{buf, addr}:
|
case c.recvQueue <- taggedPacket{buf, addr}:
|
||||||
default:
|
default:
|
||||||
// Drop the incoming packet if the receive queue is full.
|
// Drop the incoming packet if the receive queue is full.
|
||||||
|
c.Restore(buf)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -68,6 +80,16 @@ func (c *QueuePacketConn) OutgoingQueue(addr net.Addr) <-chan []byte {
|
||||||
return c.clients.SendQueue(addr)
|
return c.clients.SendQueue(addr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Restore adds a slice to the internal pool of packet buffers. Typically you
|
||||||
|
// will call this with a slice from the OutgoingQueue channel once you are done
|
||||||
|
// using it. (It is not an error to fail to do so, it will just result in more
|
||||||
|
// allocations.)
|
||||||
|
func (c *QueuePacketConn) Restore(p []byte) {
|
||||||
|
if cap(p) >= c.mtu {
|
||||||
|
c.bufPool.Put(p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ReadFrom returns a packet and address previously stored by QueueIncoming.
|
// ReadFrom returns a packet and address previously stored by QueueIncoming.
|
||||||
func (c *QueuePacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
|
func (c *QueuePacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
|
||||||
select {
|
select {
|
||||||
|
@ -79,12 +101,15 @@ func (c *QueuePacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
|
||||||
case <-c.closed:
|
case <-c.closed:
|
||||||
return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
|
return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
|
||||||
case packet := <-c.recvQueue:
|
case packet := <-c.recvQueue:
|
||||||
return copy(p, packet.P), packet.Addr, nil
|
n := copy(p, packet.P)
|
||||||
|
c.Restore(packet.P)
|
||||||
|
return n, packet.Addr, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteTo queues an outgoing packet for the given address. The queue can later
|
// WriteTo queues an outgoing packet for the given address. The queue can later
|
||||||
// be retrieved using the OutgoingQueue method.
|
// be retrieved using the OutgoingQueue method. If p is longer than the MTU,
|
||||||
|
// only its first MTU bytes will be used.
|
||||||
func (c *QueuePacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
|
func (c *QueuePacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
|
||||||
select {
|
select {
|
||||||
case <-c.closed:
|
case <-c.closed:
|
||||||
|
@ -92,14 +117,20 @@ func (c *QueuePacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
// Copy the slice so that the caller may reuse it.
|
// Copy the slice so that the caller may reuse it.
|
||||||
buf := make([]byte, len(p))
|
buf := c.bufPool.Get().([]byte)
|
||||||
|
if len(p) < cap(buf) {
|
||||||
|
buf = buf[:len(p)]
|
||||||
|
} else {
|
||||||
|
buf = buf[:cap(buf)]
|
||||||
|
}
|
||||||
copy(buf, p)
|
copy(buf, p)
|
||||||
select {
|
select {
|
||||||
case c.clients.SendQueue(addr) <- buf:
|
case c.clients.SendQueue(addr) <- buf:
|
||||||
return len(buf), nil
|
return len(buf), nil
|
||||||
default:
|
default:
|
||||||
// Drop the outgoing packet if the send queue is full.
|
// Drop the outgoing packet if the send queue is full.
|
||||||
return len(buf), nil
|
c.Restore(buf)
|
||||||
|
return len(p), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -23,7 +23,7 @@ func (i intAddr) String() string { return fmt.Sprintf("%d", i) }
|
||||||
|
|
||||||
// Run with -benchmem to see memory allocations.
|
// Run with -benchmem to see memory allocations.
|
||||||
func BenchmarkQueueIncoming(b *testing.B) {
|
func BenchmarkQueueIncoming(b *testing.B) {
|
||||||
conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour)
|
conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, 500)
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
@ -36,7 +36,7 @@ func BenchmarkQueueIncoming(b *testing.B) {
|
||||||
|
|
||||||
// BenchmarkWriteTo benchmarks the QueuePacketConn.WriteTo function.
|
// BenchmarkWriteTo benchmarks the QueuePacketConn.WriteTo function.
|
||||||
func BenchmarkWriteTo(b *testing.B) {
|
func BenchmarkWriteTo(b *testing.B) {
|
||||||
conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour)
|
conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, 500)
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
@ -47,6 +47,72 @@ func BenchmarkWriteTo(b *testing.B) {
|
||||||
b.StopTimer()
|
b.StopTimer()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestQueueIncomingOversize tests that QueueIncoming truncates packets that are
|
||||||
|
// larger than the MTU.
|
||||||
|
func TestQueueIncomingOversize(t *testing.T) {
|
||||||
|
const payload = "abcdefghijklmnopqrstuvwxyz"
|
||||||
|
conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, len(payload)-1)
|
||||||
|
defer conn.Close()
|
||||||
|
conn.QueueIncoming([]byte(payload), emptyAddr{})
|
||||||
|
var p [500]byte
|
||||||
|
n, _, err := conn.ReadFrom(p[:])
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(p[:n], []byte(payload[:len(payload)-1])) {
|
||||||
|
t.Fatalf("payload was %+q, expected %+q", p[:n], payload[:len(payload)-1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestWriteToOversize tests that WriteTo truncates packets that are larger than
|
||||||
|
// the MTU.
|
||||||
|
func TestWriteToOversize(t *testing.T) {
|
||||||
|
const payload = "abcdefghijklmnopqrstuvwxyz"
|
||||||
|
conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, len(payload)-1)
|
||||||
|
defer conn.Close()
|
||||||
|
conn.WriteTo([]byte(payload), emptyAddr{})
|
||||||
|
p := <-conn.OutgoingQueue(emptyAddr{})
|
||||||
|
if !bytes.Equal(p, []byte(payload[:len(payload)-1])) {
|
||||||
|
t.Fatalf("payload was %+q, expected %+q", p, payload[:len(payload)-1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRestoreMTU tests that Restore ignores any inputs that are not at least
|
||||||
|
// MTU-sized.
|
||||||
|
func TestRestoreMTU(t *testing.T) {
|
||||||
|
const mtu = 500
|
||||||
|
const payload = "hello"
|
||||||
|
conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, mtu)
|
||||||
|
defer conn.Close()
|
||||||
|
conn.Restore(make([]byte, mtu-1))
|
||||||
|
// This WriteTo may use the short slice we just gave to Restore.
|
||||||
|
conn.WriteTo([]byte(payload), emptyAddr{})
|
||||||
|
// Read the queued slice and ensure its capacity is at least the MTU.
|
||||||
|
p := <-conn.OutgoingQueue(emptyAddr{})
|
||||||
|
if cap(p) != mtu {
|
||||||
|
t.Fatalf("cap was %v, expected %v", cap(p), mtu)
|
||||||
|
}
|
||||||
|
// Check the payload while we're at it.
|
||||||
|
if !bytes.Equal(p, []byte(payload)) {
|
||||||
|
t.Fatalf("payload was %+q, expected %+q", p, payload)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRestoreCap tests that Restore can use slices whose cap is at least the
|
||||||
|
// MTU, even if the len is shorter.
|
||||||
|
func TestRestoreCap(t *testing.T) {
|
||||||
|
const mtu = 500
|
||||||
|
const payload = "hello"
|
||||||
|
conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, mtu)
|
||||||
|
defer conn.Close()
|
||||||
|
conn.Restore(make([]byte, 0, mtu))
|
||||||
|
conn.WriteTo([]byte(payload), emptyAddr{})
|
||||||
|
p := <-conn.OutgoingQueue(emptyAddr{})
|
||||||
|
if !bytes.Equal(p, []byte(payload)) {
|
||||||
|
t.Fatalf("payload was %+q, expected %+q", p, payload)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// DiscardPacketConn is a net.PacketConn whose ReadFrom method block forever and
|
// DiscardPacketConn is a net.PacketConn whose ReadFrom method block forever and
|
||||||
// whose WriteTo method discards whatever it is called with.
|
// whose WriteTo method discards whatever it is called with.
|
||||||
type DiscardPacketConn struct{}
|
type DiscardPacketConn struct{}
|
||||||
|
@ -122,7 +188,7 @@ func TestQueuePacketConnWriteToKCP(t *testing.T) {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
pconn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour)
|
pconn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, 500)
|
||||||
defer pconn.Close()
|
defer pconn.Close()
|
||||||
addr1 := intAddr(1)
|
addr1 := intAddr(1)
|
||||||
outgoing := pconn.OutgoingQueue(addr1)
|
outgoing := pconn.OutgoingQueue(addr1)
|
||||||
|
|
|
@ -69,10 +69,10 @@ type httpHandler struct {
|
||||||
|
|
||||||
// newHTTPHandler creates a new http.Handler that exchanges encapsulated packets
|
// newHTTPHandler creates a new http.Handler that exchanges encapsulated packets
|
||||||
// over incoming WebSocket connections.
|
// over incoming WebSocket connections.
|
||||||
func newHTTPHandler(localAddr net.Addr, numInstances int) *httpHandler {
|
func newHTTPHandler(localAddr net.Addr, numInstances int, mtu int) *httpHandler {
|
||||||
pconns := make([]*turbotunnel.QueuePacketConn, 0, numInstances)
|
pconns := make([]*turbotunnel.QueuePacketConn, 0, numInstances)
|
||||||
for i := 0; i < numInstances; i++ {
|
for i := 0; i < numInstances; i++ {
|
||||||
pconns = append(pconns, turbotunnel.NewQueuePacketConn(localAddr, clientMapTimeout))
|
pconns = append(pconns, turbotunnel.NewQueuePacketConn(localAddr, clientMapTimeout, mtu))
|
||||||
}
|
}
|
||||||
|
|
||||||
clientIDLookupKey := make([]byte, 16)
|
clientIDLookupKey := make([]byte, 16)
|
||||||
|
@ -200,6 +200,7 @@ func (handler *httpHandler) turbotunnelMode(conn net.Conn, addr net.Addr) error
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_, err := encapsulation.WriteData(bw, p)
|
_, err := encapsulation.WriteData(bw, p)
|
||||||
|
pconn.Restore(p)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
err = bw.Flush()
|
err = bw.Flush()
|
||||||
}
|
}
|
||||||
|
|
|
@ -79,7 +79,11 @@ func (t *Transport) Listen(addr net.Addr, numKCPInstances int) (*SnowflakeListen
|
||||||
ln: make([]*kcp.Listener, 0, numKCPInstances),
|
ln: make([]*kcp.Listener, 0, numKCPInstances),
|
||||||
}
|
}
|
||||||
|
|
||||||
handler := newHTTPHandler(addr, numKCPInstances)
|
// kcp-go doesn't provide an accessor for the current MTU setting (and
|
||||||
|
// anyway we could not create a kcp.Listener without creating a
|
||||||
|
// net.PacketConn for it first), so assume the default kcp.IKCP_MTU_DEF
|
||||||
|
// (1400 bytes) and don't increase it elsewhere.
|
||||||
|
handler := newHTTPHandler(addr, numKCPInstances, kcp.IKCP_MTU_DEF)
|
||||||
server := &http.Server{
|
server := &http.Server{
|
||||||
Addr: addr.String(),
|
Addr: addr.String(),
|
||||||
Handler: handler,
|
Handler: handler,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue