mirror of
https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake.git
synced 2025-10-13 20:11:19 -04:00
Use a sync.Pool to reuse packet buffers in QueuePacketConn.
This is meant to reduce overall allocations. See past discussion at https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/-/issues/40260#note_2885524 ff.
This commit is contained in:
parent
97c930013b
commit
c097d5f3bc
4 changed files with 116 additions and 14 deletions
|
@ -27,23 +27,29 @@ type QueuePacketConn struct {
|
|||
recvQueue chan taggedPacket
|
||||
closeOnce sync.Once
|
||||
closed chan struct{}
|
||||
mtu int
|
||||
// Pool of reusable mtu-sized buffers.
|
||||
bufPool sync.Pool
|
||||
// What error to return when the QueuePacketConn is closed.
|
||||
err atomic.Value
|
||||
}
|
||||
|
||||
// NewQueuePacketConn makes a new QueuePacketConn, set to track recent clients
|
||||
// for at least a duration of timeout.
|
||||
func NewQueuePacketConn(localAddr net.Addr, timeout time.Duration) *QueuePacketConn {
|
||||
// for at least a duration of timeout. The maximum packet size is mtu.
|
||||
func NewQueuePacketConn(localAddr net.Addr, timeout time.Duration, mtu int) *QueuePacketConn {
|
||||
return &QueuePacketConn{
|
||||
clients: NewClientMap(timeout),
|
||||
localAddr: localAddr,
|
||||
recvQueue: make(chan taggedPacket, queueSize),
|
||||
closed: make(chan struct{}),
|
||||
mtu: mtu,
|
||||
bufPool: sync.Pool{New: func() interface{} { return make([]byte, mtu) }},
|
||||
}
|
||||
}
|
||||
|
||||
// QueueIncoming queues an incoming packet and its source address, to be
|
||||
// returned in a future call to ReadFrom.
|
||||
// returned in a future call to ReadFrom. If p is longer than the MTU, only its
|
||||
// first MTU bytes will be used.
|
||||
func (c *QueuePacketConn) QueueIncoming(p []byte, addr net.Addr) {
|
||||
select {
|
||||
case <-c.closed:
|
||||
|
@ -52,12 +58,18 @@ func (c *QueuePacketConn) QueueIncoming(p []byte, addr net.Addr) {
|
|||
default:
|
||||
}
|
||||
// Copy the slice so that the caller may reuse it.
|
||||
buf := make([]byte, len(p))
|
||||
buf := c.bufPool.Get().([]byte)
|
||||
if len(p) < cap(buf) {
|
||||
buf = buf[:len(p)]
|
||||
} else {
|
||||
buf = buf[:cap(buf)]
|
||||
}
|
||||
copy(buf, p)
|
||||
select {
|
||||
case c.recvQueue <- taggedPacket{buf, addr}:
|
||||
default:
|
||||
// Drop the incoming packet if the receive queue is full.
|
||||
c.Restore(buf)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -68,6 +80,16 @@ func (c *QueuePacketConn) OutgoingQueue(addr net.Addr) <-chan []byte {
|
|||
return c.clients.SendQueue(addr)
|
||||
}
|
||||
|
||||
// Restore adds a slice to the internal pool of packet buffers. Typically you
|
||||
// will call this with a slice from the OutgoingQueue channel once you are done
|
||||
// using it. (It is not an error to fail to do so, it will just result in more
|
||||
// allocations.)
|
||||
func (c *QueuePacketConn) Restore(p []byte) {
|
||||
if cap(p) >= c.mtu {
|
||||
c.bufPool.Put(p)
|
||||
}
|
||||
}
|
||||
|
||||
// ReadFrom returns a packet and address previously stored by QueueIncoming.
|
||||
func (c *QueuePacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
|
||||
select {
|
||||
|
@ -79,12 +101,15 @@ func (c *QueuePacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
|
|||
case <-c.closed:
|
||||
return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
|
||||
case packet := <-c.recvQueue:
|
||||
return copy(p, packet.P), packet.Addr, nil
|
||||
n := copy(p, packet.P)
|
||||
c.Restore(packet.P)
|
||||
return n, packet.Addr, nil
|
||||
}
|
||||
}
|
||||
|
||||
// WriteTo queues an outgoing packet for the given address. The queue can later
|
||||
// be retrieved using the OutgoingQueue method.
|
||||
// be retrieved using the OutgoingQueue method. If p is longer than the MTU,
|
||||
// only its first MTU bytes will be used.
|
||||
func (c *QueuePacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
|
||||
select {
|
||||
case <-c.closed:
|
||||
|
@ -92,14 +117,20 @@ func (c *QueuePacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
|
|||
default:
|
||||
}
|
||||
// Copy the slice so that the caller may reuse it.
|
||||
buf := make([]byte, len(p))
|
||||
buf := c.bufPool.Get().([]byte)
|
||||
if len(p) < cap(buf) {
|
||||
buf = buf[:len(p)]
|
||||
} else {
|
||||
buf = buf[:cap(buf)]
|
||||
}
|
||||
copy(buf, p)
|
||||
select {
|
||||
case c.clients.SendQueue(addr) <- buf:
|
||||
return len(buf), nil
|
||||
default:
|
||||
// Drop the outgoing packet if the send queue is full.
|
||||
return len(buf), nil
|
||||
c.Restore(buf)
|
||||
return len(p), nil
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue