diff --git a/server/lib/snowflake.go b/server/lib/snowflake.go index 469ed91..5f995ac 100644 --- a/server/lib/snowflake.go +++ b/server/lib/snowflake.go @@ -232,7 +232,7 @@ func (l *SnowflakeListener) acceptStreams(conn *kcp.UDPSession) error { } return err } - l.queueConn(&SnowflakeClientConn{Conn: stream, address: addr}) + l.queueConn(&SnowflakeClientConn{stream: stream, address: addr}) } } @@ -280,15 +280,35 @@ func (l *SnowflakeListener) queueConn(conn net.Conn) error { } } -// SnowflakeClientConn is a wrapper for the underlying turbotunnel -// conn. We need to reference our client address map to determine the -// remote address +// SnowflakeClientConn is a wrapper for the underlying turbotunnel conn +// (smux.Stream). It implements the net.Conn and io.WriterTo interfaces. The +// RemoteAddr method is overridden to refer to a real IP address, looked up from +// the client address map, rather than an abstract client ID. type SnowflakeClientConn struct { - net.Conn + stream *smux.Stream address net.Addr } -// RemoteAddr returns the mapped client address of the Snowflake connection +// Forward net.Conn methods, other than RemoteAddr, to the inner stream. +func (conn *SnowflakeClientConn) Read(b []byte) (int, error) { return conn.stream.Read(b) } +func (conn *SnowflakeClientConn) Write(b []byte) (int, error) { return conn.stream.Write(b) } +func (conn *SnowflakeClientConn) Close() error { return conn.stream.Close() } +func (conn *SnowflakeClientConn) LocalAddr() net.Addr { return conn.stream.LocalAddr() } +func (conn *SnowflakeClientConn) SetDeadline(t time.Time) error { return conn.stream.SetDeadline(t) } +func (conn *SnowflakeClientConn) SetReadDeadline(t time.Time) error { + return conn.stream.SetReadDeadline(t) +} +func (conn *SnowflakeClientConn) SetWriteDeadline(t time.Time) error { + return conn.stream.SetWriteDeadline(t) +} + +// RemoteAddr returns the mapped client address of the Snowflake connection. func (conn *SnowflakeClientConn) RemoteAddr() net.Addr { return conn.address } + +// WriteTo implements the io.WriterTo interface by passing the call to the +// underlying smux.Stream. +func (conn *SnowflakeClientConn) WriteTo(w io.Writer) (int64, error) { + return conn.stream.WriteTo(w) +} diff --git a/server/server.go b/server/server.go index ed1e876..484e37a 100644 --- a/server/server.go +++ b/server/server.go @@ -57,7 +57,7 @@ func proxy(local *net.TCPConn, conn net.Conn) { wg.Done() }() go func() { - if _, err := io.Copy(local, conn); err != nil && !errors.Is(err, io.ErrClosedPipe) { + if _, err := io.Copy(local, conn); err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrClosedPipe) { log.Printf("error copying WebSocket to ORPort %v", err) } local.CloseWrite()