Have encapsulation.ReadData return an error when the buffer is short.

https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/-/merge_requests/154#note_2919109

Still ignoring the io.ErrShortBuffer at the callers, which retains
current behavior.
This commit is contained in:
David Fifield 2023-11-07 05:49:48 +00:00
parent 001f691b47
commit d99f31d881
4 changed files with 24 additions and 16 deletions

View file

@ -38,10 +38,10 @@ func newEncapsulationPacketConn(
// ReadFrom reads an encapsulated packet from the stream. // ReadFrom reads an encapsulated packet from the stream.
func (c *encapsulationPacketConn) ReadFrom(p []byte) (int, net.Addr, error) { func (c *encapsulationPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
n, err := encapsulation.ReadData(c.ReadWriteCloser, p) n, err := encapsulation.ReadData(c.ReadWriteCloser, p)
if err != nil { if err == io.ErrShortBuffer {
return n, c.remoteAddr, err err = nil
} }
return n, c.remoteAddr, nil return n, c.remoteAddr, err
} }
// WriteTo writes an encapsulated packet to the stream. // WriteTo writes an encapsulated packet to the stream.

View file

@ -53,11 +53,12 @@ var ErrTooLong = errors.New("length prefix is too long")
// ReadData the next available data chunk, skipping over any padding chunks that // ReadData the next available data chunk, skipping over any padding chunks that
// may come first, and copies the data into p. If p is shorter than the length // may come first, and copies the data into p. If p is shorter than the length
// of the data chunk, only the first len(p) bytes are copied into p. The // of the data chunk, only the first len(p) bytes are copied into p, and the
// returned error value is nil if and only if a data chunk was present and was // error return is io.ErrShortBuffer. The returned error value is nil if and
// read in its entirety. The returned error is io.EOF only if r ended before the // only if a data chunk was present and was read in its entirety. The returned
// first byte of a length prefix. If r ended in the middle of a length prefix or // error is io.EOF only if r ended before the first byte of a length prefix. If
// data/padding, the returned error is io.ErrUnexpectedEOF. // r ended in the middle of a length prefix or data/padding, the returned error
// is io.ErrUnexpectedEOF.
func ReadData(r io.Reader, p []byte) (int, error) { func ReadData(r io.Reader, p []byte) (int, error) {
for { for {
var b [1]byte var b [1]byte
@ -89,9 +90,13 @@ func ReadData(r io.Reader, p []byte) (int, error) {
} }
numData, err := io.ReadFull(r, p) numData, err := io.ReadFull(r, p)
if err == nil && numData < n { if err == nil && numData < n {
// Discard the rest of the data, if the caller's // If the caller's buffer was too short, discard
// buffer was too short. // the rest of the data and return
// io.ErrShortBuffer.
_, err = io.CopyN(ioutil.Discard, r, int64(n-numData)) _, err = io.CopyN(ioutil.Discard, r, int64(n-numData))
if err == nil {
err = io.ErrShortBuffer
}
} }
if err == io.EOF { if err == io.EOF {
err = io.ErrUnexpectedEOF err = io.ErrUnexpectedEOF

View file

@ -342,8 +342,8 @@ func TestReadDataTruncate(t *testing.T) {
var p [4]byte var p [4]byte
// First ReadData should return truncated "1234". // First ReadData should return truncated "1234".
n, err := ReadData(&enc, p[:]) n, err := ReadData(&enc, p[:])
if err != nil { if err != io.ErrShortBuffer {
t.Fatalf("got error %v, expected %v", err, nil) t.Fatalf("got error %v, expected %v", err, io.ErrShortBuffer)
} }
if !bytes.Equal(p[:n], []byte("1234")) { if !bytes.Equal(p[:n], []byte("1234")) {
t.Fatalf("got <%x>, expected <%x>", p[:n], []byte("1234")) t.Fatalf("got <%x>, expected <%x>", p[:n], []byte("1234"))
@ -351,8 +351,8 @@ func TestReadDataTruncate(t *testing.T) {
// Second ReadData should return truncated "abcd", not the rest of // Second ReadData should return truncated "abcd", not the rest of
// "12345678". // "12345678".
n, err = ReadData(&enc, p[:]) n, err = ReadData(&enc, p[:])
if err != nil { if err != io.ErrShortBuffer {
t.Fatalf("got error %v, expected %v", err, nil) t.Fatalf("got error %v, expected %v", err, io.ErrShortBuffer)
} }
if !bytes.Equal(p[:n], []byte("abcd")) { if !bytes.Equal(p[:n], []byte("abcd")) {
t.Fatalf("got <%x>, expected <%x>", p[:n], []byte("abcd")) t.Fatalf("got <%x>, expected <%x>", p[:n], []byte("abcd"))
@ -377,8 +377,8 @@ func TestReadDataTruncateFull(t *testing.T) {
}() }()
var p [8]byte var p [8]byte
n, err := ReadData(pr, p[:]) n, err := ReadData(pr, p[:])
if err != nil { if err != io.ErrShortBuffer {
t.Fatalf("got error %v, expected %v", err, nil) t.Fatalf("got error %v, expected %v", err, io.ErrShortBuffer)
} }
// Should not stop after "hello". // Should not stop after "hello".
if !bytes.Equal(p[:n], []byte("hellowor")) { if !bytes.Equal(p[:n], []byte("hellowor")) {

View file

@ -176,6 +176,9 @@ func (handler *httpHandler) turbotunnelMode(conn net.Conn, addr net.Addr) error
var p [2048]byte var p [2048]byte
for { for {
n, err := encapsulation.ReadData(conn, p[:]) n, err := encapsulation.ReadData(conn, p[:])
if err == io.ErrShortBuffer {
err = nil
}
if err != nil { if err != nil {
return return
} }