Merge remote-tracking branch 'gitlab/main'

This commit is contained in:
meskio 2023-04-20 16:37:52 +02:00
commit f723cf52e8
No known key found for this signature in database
GPG key ID: 52B8F5AC97A2DA86
5 changed files with 127 additions and 28 deletions

View file

@ -27,23 +27,29 @@ type QueuePacketConn struct {
recvQueue chan taggedPacket recvQueue chan taggedPacket
closeOnce sync.Once closeOnce sync.Once
closed chan struct{} closed chan struct{}
mtu int
// Pool of reusable mtu-sized buffers.
bufPool sync.Pool
// What error to return when the QueuePacketConn is closed. // What error to return when the QueuePacketConn is closed.
err atomic.Value err atomic.Value
} }
// NewQueuePacketConn makes a new QueuePacketConn, set to track recent clients // NewQueuePacketConn makes a new QueuePacketConn, set to track recent clients
// for at least a duration of timeout. // for at least a duration of timeout. The maximum packet size is mtu.
func NewQueuePacketConn(localAddr net.Addr, timeout time.Duration) *QueuePacketConn { func NewQueuePacketConn(localAddr net.Addr, timeout time.Duration, mtu int) *QueuePacketConn {
return &QueuePacketConn{ return &QueuePacketConn{
clients: NewClientMap(timeout), clients: NewClientMap(timeout),
localAddr: localAddr, localAddr: localAddr,
recvQueue: make(chan taggedPacket, queueSize), recvQueue: make(chan taggedPacket, queueSize),
closed: make(chan struct{}), closed: make(chan struct{}),
mtu: mtu,
bufPool: sync.Pool{New: func() interface{} { return make([]byte, mtu) }},
} }
} }
// QueueIncoming queues and incoming packet and its source address, to be // QueueIncoming queues an incoming packet and its source address, to be
// returned in a future call to ReadFrom. // returned in a future call to ReadFrom. If p is longer than the MTU, only its
// first MTU bytes will be used.
func (c *QueuePacketConn) QueueIncoming(p []byte, addr net.Addr) { func (c *QueuePacketConn) QueueIncoming(p []byte, addr net.Addr) {
select { select {
case <-c.closed: case <-c.closed:
@ -52,12 +58,18 @@ func (c *QueuePacketConn) QueueIncoming(p []byte, addr net.Addr) {
default: default:
} }
// Copy the slice so that the caller may reuse it. // Copy the slice so that the caller may reuse it.
buf := make([]byte, len(p)) buf := c.bufPool.Get().([]byte)
if len(p) < cap(buf) {
buf = buf[:len(p)]
} else {
buf = buf[:cap(buf)]
}
copy(buf, p) copy(buf, p)
select { select {
case c.recvQueue <- taggedPacket{buf, addr}: case c.recvQueue <- taggedPacket{buf, addr}:
default: default:
// Drop the incoming packet if the receive queue is full. // Drop the incoming packet if the receive queue is full.
c.Restore(buf)
} }
} }
@ -68,6 +80,16 @@ func (c *QueuePacketConn) OutgoingQueue(addr net.Addr) <-chan []byte {
return c.clients.SendQueue(addr) return c.clients.SendQueue(addr)
} }
// Restore adds a slice to the internal pool of packet buffers. Typically you
// will call this with a slice from the OutgoingQueue channel once you are done
// using it. (It is not an error to fail to do so, it will just result in more
// allocations.)
func (c *QueuePacketConn) Restore(p []byte) {
if cap(p) >= c.mtu {
c.bufPool.Put(p)
}
}
// ReadFrom returns a packet and address previously stored by QueueIncoming. // ReadFrom returns a packet and address previously stored by QueueIncoming.
func (c *QueuePacketConn) ReadFrom(p []byte) (int, net.Addr, error) { func (c *QueuePacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
select { select {
@ -79,12 +101,15 @@ func (c *QueuePacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
case <-c.closed: case <-c.closed:
return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)} return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
case packet := <-c.recvQueue: case packet := <-c.recvQueue:
return copy(p, packet.P), packet.Addr, nil n := copy(p, packet.P)
c.Restore(packet.P)
return n, packet.Addr, nil
} }
} }
// WriteTo queues an outgoing packet for the given address. The queue can later // WriteTo queues an outgoing packet for the given address. The queue can later
// be retrieved using the OutgoingQueue method. // be retrieved using the OutgoingQueue method. If p is longer than the MTU,
// only its first MTU bytes will be used.
func (c *QueuePacketConn) WriteTo(p []byte, addr net.Addr) (int, error) { func (c *QueuePacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
select { select {
case <-c.closed: case <-c.closed:
@ -92,14 +117,20 @@ func (c *QueuePacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
default: default:
} }
// Copy the slice so that the caller may reuse it. // Copy the slice so that the caller may reuse it.
buf := make([]byte, len(p)) buf := c.bufPool.Get().([]byte)
if len(p) < cap(buf) {
buf = buf[:len(p)]
} else {
buf = buf[:cap(buf)]
}
copy(buf, p) copy(buf, p)
select { select {
case c.clients.SendQueue(addr) <- buf: case c.clients.SendQueue(addr) <- buf:
return len(buf), nil return len(buf), nil
default: default:
// Drop the outgoing packet if the send queue is full. // Drop the outgoing packet if the send queue is full.
return len(buf), nil c.Restore(buf)
return len(p), nil
} }
} }

View file

@ -23,36 +23,96 @@ func (i intAddr) String() string { return fmt.Sprintf("%d", i) }
// Run with -benchmem to see memory allocations. // Run with -benchmem to see memory allocations.
func BenchmarkQueueIncoming(b *testing.B) { func BenchmarkQueueIncoming(b *testing.B) {
conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour) conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, 500)
defer conn.Close() defer conn.Close()
b.ResetTimer() b.ResetTimer()
s := 500 var p [500]byte
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
// Use a variable for the length to stop the compiler from conn.QueueIncoming(p[:], emptyAddr{})
// optimizing out the allocation.
p := make([]byte, s)
conn.QueueIncoming(p, emptyAddr{})
} }
b.StopTimer() b.StopTimer()
} }
// BenchmarkWriteTo benchmarks the QueuePacketConn.WriteTo function. // BenchmarkWriteTo benchmarks the QueuePacketConn.WriteTo function.
func BenchmarkWriteTo(b *testing.B) { func BenchmarkWriteTo(b *testing.B) {
conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour) conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, 500)
defer conn.Close() defer conn.Close()
b.ResetTimer() b.ResetTimer()
s := 500 var p [500]byte
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
// Use a variable for the length to stop the compiler from conn.WriteTo(p[:], emptyAddr{})
// optimizing out the allocation.
p := make([]byte, s)
conn.WriteTo(p, emptyAddr{})
} }
b.StopTimer() b.StopTimer()
} }
// TestQueueIncomingOversize tests that QueueIncoming truncates packets that are
// larger than the MTU.
func TestQueueIncomingOversize(t *testing.T) {
const payload = "abcdefghijklmnopqrstuvwxyz"
conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, len(payload)-1)
defer conn.Close()
conn.QueueIncoming([]byte(payload), emptyAddr{})
var p [500]byte
n, _, err := conn.ReadFrom(p[:])
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(p[:n], []byte(payload[:len(payload)-1])) {
t.Fatalf("payload was %+q, expected %+q", p[:n], payload[:len(payload)-1])
}
}
// TestWriteToOversize tests that WriteTo truncates packets that are larger than
// the MTU.
func TestWriteToOversize(t *testing.T) {
const payload = "abcdefghijklmnopqrstuvwxyz"
conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, len(payload)-1)
defer conn.Close()
conn.WriteTo([]byte(payload), emptyAddr{})
p := <-conn.OutgoingQueue(emptyAddr{})
if !bytes.Equal(p, []byte(payload[:len(payload)-1])) {
t.Fatalf("payload was %+q, expected %+q", p, payload[:len(payload)-1])
}
}
// TestRestoreMTU tests that Restore ignores any inputs that are not at least
// MTU-sized.
func TestRestoreMTU(t *testing.T) {
const mtu = 500
const payload = "hello"
conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, mtu)
defer conn.Close()
conn.Restore(make([]byte, mtu-1))
// This WriteTo may use the short slice we just gave to Restore.
conn.WriteTo([]byte(payload), emptyAddr{})
// Read the queued slice and ensure its capacity is at least the MTU.
p := <-conn.OutgoingQueue(emptyAddr{})
if cap(p) != mtu {
t.Fatalf("cap was %v, expected %v", cap(p), mtu)
}
// Check the payload while we're at it.
if !bytes.Equal(p, []byte(payload)) {
t.Fatalf("payload was %+q, expected %+q", p, payload)
}
}
// TestRestoreCap tests that Restore can use slices whose cap is at least the
// MTU, even if the len is shorter.
func TestRestoreCap(t *testing.T) {
const mtu = 500
const payload = "hello"
conn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, mtu)
defer conn.Close()
conn.Restore(make([]byte, 0, mtu))
conn.WriteTo([]byte(payload), emptyAddr{})
p := <-conn.OutgoingQueue(emptyAddr{})
if !bytes.Equal(p, []byte(payload)) {
t.Fatalf("payload was %+q, expected %+q", p, payload)
}
}
// DiscardPacketConn is a net.PacketConn whose ReadFrom method block forever and // DiscardPacketConn is a net.PacketConn whose ReadFrom method block forever and
// whose WriteTo method discards whatever it is called with. // whose WriteTo method discards whatever it is called with.
type DiscardPacketConn struct{} type DiscardPacketConn struct{}
@ -105,10 +165,11 @@ func TestQueuePacketConnWriteToKCP(t *testing.T) {
defer readyClose.Do(func() { close(ready) }) defer readyClose.Do(func() { close(ready) })
pconn := DiscardPacketConn{} pconn := DiscardPacketConn{}
defer pconn.Close() defer pconn.Close()
loop:
for { for {
select { select {
case <-done: case <-done:
break break loop
default: default:
} }
// Create a new UDPSession, send once, then discard the // Create a new UDPSession, send once, then discard the
@ -127,7 +188,7 @@ func TestQueuePacketConnWriteToKCP(t *testing.T) {
} }
}() }()
pconn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour) pconn := NewQueuePacketConn(emptyAddr{}, 1*time.Hour, 500)
defer pconn.Close() defer pconn.Close()
addr1 := intAddr(1) addr1 := intAddr(1)
outgoing := pconn.OutgoingQueue(addr1) outgoing := pconn.OutgoingQueue(addr1)

View file

@ -69,10 +69,10 @@ type httpHandler struct {
// newHTTPHandler creates a new http.Handler that exchanges encapsulated packets // newHTTPHandler creates a new http.Handler that exchanges encapsulated packets
// over incoming WebSocket connections. // over incoming WebSocket connections.
func newHTTPHandler(localAddr net.Addr, numInstances int) *httpHandler { func newHTTPHandler(localAddr net.Addr, numInstances int, mtu int) *httpHandler {
pconns := make([]*turbotunnel.QueuePacketConn, 0, numInstances) pconns := make([]*turbotunnel.QueuePacketConn, 0, numInstances)
for i := 0; i < numInstances; i++ { for i := 0; i < numInstances; i++ {
pconns = append(pconns, turbotunnel.NewQueuePacketConn(localAddr, clientMapTimeout)) pconns = append(pconns, turbotunnel.NewQueuePacketConn(localAddr, clientMapTimeout, mtu))
} }
clientIDLookupKey := make([]byte, 16) clientIDLookupKey := make([]byte, 16)
@ -200,6 +200,7 @@ func (handler *httpHandler) turbotunnelMode(conn net.Conn, addr net.Addr) error
return return
} }
_, err := encapsulation.WriteData(bw, p) _, err := encapsulation.WriteData(bw, p)
pconn.Restore(p)
if err == nil { if err == nil {
err = bw.Flush() err = bw.Flush()
} }

View file

@ -79,7 +79,11 @@ func (t *Transport) Listen(addr net.Addr, numKCPInstances int) (*SnowflakeListen
ln: make([]*kcp.Listener, 0, numKCPInstances), ln: make([]*kcp.Listener, 0, numKCPInstances),
} }
handler := newHTTPHandler(addr, numKCPInstances) // kcp-go doesn't provide an accessor for the current MTU setting (and
// anyway we could not create a kcp.Listener without creating a
// net.PacketConn for it first), so assume the default kcp.IKCP_MTU_DEF
// (1400 bytes) and don't increase it elsewhere.
handler := newHTTPHandler(addr, numKCPInstances, kcp.IKCP_MTU_DEF)
server := &http.Server{ server := &http.Server{
Addr: addr.String(), Addr: addr.String(),
Handler: handler, Handler: handler,
@ -125,13 +129,15 @@ func (t *Transport) Listen(addr net.Addr, numKCPInstances int) (*SnowflakeListen
errChan <- err errChan <- err
} }
}() }()
select { select {
case err = <-errChan: case err = <-errChan:
break break
case <-time.After(listenAndServeErrorTimeout): case <-time.After(listenAndServeErrorTimeout):
break break
} }
if err != nil {
return nil, err
}
listener.server = server listener.server = server

View file

@ -1,6 +1,6 @@
package main package main
// This code handled periodic statistics logging. // This code handles periodic statistics logging.
// //
// The only thing it keeps track of is how many connections had the client_ip // The only thing it keeps track of is how many connections had the client_ip
// parameter. Write true to statsChannel to record a connection with client_ip; // parameter. Write true to statsChannel to record a connection with client_ip;