mirror of
https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake.git
synced 2025-10-13 11:11:30 -04:00
Merge branch 'encapsulation-readdata-buffer'
This commit is contained in:
commit
aa06e7bef3
4 changed files with 132 additions and 60 deletions
|
@ -37,11 +37,11 @@ func newEncapsulationPacketConn(
|
|||
|
||||
// ReadFrom reads an encapsulated packet from the stream.
|
||||
func (c *encapsulationPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
|
||||
data, err := encapsulation.ReadData(c.ReadWriteCloser)
|
||||
if err != nil {
|
||||
return 0, c.remoteAddr, err
|
||||
n, err := encapsulation.ReadData(c.ReadWriteCloser, p)
|
||||
if err == io.ErrShortBuffer {
|
||||
err = nil
|
||||
}
|
||||
return copy(p, data), c.remoteAddr, nil
|
||||
return n, c.remoteAddr, err
|
||||
}
|
||||
|
||||
// WriteTo writes an encapsulated packet to the stream.
|
||||
|
|
|
@ -51,54 +51,64 @@ import (
|
|||
// encode in a 3-byte length prefix.
|
||||
var ErrTooLong = errors.New("length prefix is too long")
|
||||
|
||||
// ReadData returns a new slice with the contents of the next available data
|
||||
// chunk, skipping over any padding chunks that may come first. The returned
|
||||
// error value is nil if and only if a data chunk was present and was read in
|
||||
// its entirety. The returned error is io.EOF only if r ended before the first
|
||||
// byte of a length prefix. If r ended in the middle of a length prefix or
|
||||
// data/padding, the returned error is io.ErrUnexpectedEOF.
|
||||
func ReadData(r io.Reader) ([]byte, error) {
|
||||
// ReadData the next available data chunk, skipping over any padding chunks that
|
||||
// may come first, and copies the data into p. If p is shorter than the length
|
||||
// of the data chunk, only the first len(p) bytes are copied into p, and the
|
||||
// error return is io.ErrShortBuffer. The returned error value is nil if and
|
||||
// only if a data chunk was present and was read in its entirety. The returned
|
||||
// error is io.EOF only if r ended before the first byte of a length prefix. If
|
||||
// r ended in the middle of a length prefix or data/padding, the returned error
|
||||
// is io.ErrUnexpectedEOF.
|
||||
func ReadData(r io.Reader, p []byte) (int, error) {
|
||||
for {
|
||||
var b [1]byte
|
||||
_, err := r.Read(b[:])
|
||||
if err != nil {
|
||||
// This is the only place we may return a real io.EOF.
|
||||
return nil, err
|
||||
return 0, err
|
||||
}
|
||||
isData := (b[0] & 0x80) != 0
|
||||
moreLength := (b[0] & 0x40) != 0
|
||||
n := int(b[0] & 0x3f)
|
||||
for i := 0; moreLength; i++ {
|
||||
if i >= 2 {
|
||||
return nil, ErrTooLong
|
||||
return 0, ErrTooLong
|
||||
}
|
||||
_, err := r.Read(b[:])
|
||||
if err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return 0, err
|
||||
}
|
||||
moreLength = (b[0] & 0x80) != 0
|
||||
n = (n << 7) | int(b[0]&0x7f)
|
||||
}
|
||||
if isData {
|
||||
p := make([]byte, n)
|
||||
_, err := io.ReadFull(r, p)
|
||||
if len(p) > n {
|
||||
p = p[:n]
|
||||
}
|
||||
numData, err := io.ReadFull(r, p)
|
||||
if err == nil && numData < n {
|
||||
// If the caller's buffer was too short, discard
|
||||
// the rest of the data and return
|
||||
// io.ErrShortBuffer.
|
||||
_, err = io.CopyN(ioutil.Discard, r, int64(n-numData))
|
||||
if err == nil {
|
||||
err = io.ErrShortBuffer
|
||||
}
|
||||
}
|
||||
if err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, err
|
||||
} else {
|
||||
return numData, err
|
||||
} else if n > 0 {
|
||||
_, err := io.CopyN(ioutil.Discard, r, int64(n))
|
||||
if err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -54,12 +54,13 @@ func TestRoundtrip(t *testing.T) {
|
|||
t.Fatalf("size %d, returned length was %d, written length was %d",
|
||||
i, n, enc.Len())
|
||||
}
|
||||
inverse, err := ReadData(&enc)
|
||||
inverse := make([]byte, i)
|
||||
n, err = ReadData(&enc, inverse)
|
||||
if err != nil {
|
||||
t.Fatalf("size %d, ReadData returned error %v", i, err)
|
||||
}
|
||||
if !bytes.Equal(inverse, original) {
|
||||
t.Fatalf("size %d, got <%x>, expected <%x>", i, inverse, original)
|
||||
if !bytes.Equal(inverse[:n], original) {
|
||||
t.Fatalf("size %d, got <%x>, expected <%x>", i, inverse[:n], original)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -106,25 +107,26 @@ func TestSkipPadding(t *testing.T) {
|
|||
mustWritePadding(&enc, 10)
|
||||
mustWritePadding(&enc, 10)
|
||||
for i, expected := range data {
|
||||
actual, err := ReadData(&enc)
|
||||
var actual [10]byte
|
||||
n, err := ReadData(&enc, actual[:])
|
||||
if err != nil {
|
||||
t.Fatalf("slice %d, got error %v, expected %v", i, err, nil)
|
||||
}
|
||||
if !bytes.Equal(actual, expected) {
|
||||
t.Fatalf("slice %d, got <%x>, expected <%x>", i, actual, expected)
|
||||
if !bytes.Equal(actual[:n], expected) {
|
||||
t.Fatalf("slice %d, got <%x>, expected <%x>", i, actual[:n], expected)
|
||||
}
|
||||
}
|
||||
p, err := ReadData(&enc)
|
||||
if p != nil || err != io.EOF {
|
||||
t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, io.EOF)
|
||||
n, err := ReadData(&enc, nil)
|
||||
if n != 0 || err != io.EOF {
|
||||
t.Fatalf("got (%v, %v), expected (%v, %v)", n, err, 0, io.EOF)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that EOF before a length prefix returns io.EOF.
|
||||
func TestEOF(t *testing.T) {
|
||||
p, err := ReadData(bytes.NewReader(nil))
|
||||
if p != nil || err != io.EOF {
|
||||
t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, io.EOF)
|
||||
n, err := ReadData(bytes.NewReader(nil), nil)
|
||||
if n != 0 || err != io.EOF {
|
||||
t.Fatalf("got (%v, %v), expected (%v, %v)", n, err, 0, io.EOF)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -149,9 +151,9 @@ func TestUnexpectedEOF(t *testing.T) {
|
|||
{0x41, 0x80, 0x00, 'X'}, // expecting 32767 bytes of padding
|
||||
{0xc1, 0x80, 0x00, 'X'}, // expecting 32767 bytes of data
|
||||
} {
|
||||
p, err := ReadData(bytes.NewReader(test))
|
||||
if p != nil || err != io.ErrUnexpectedEOF {
|
||||
t.Fatalf("<%x> got (<%x>, %v), expected (%v, %v)", test, p, err, nil, io.ErrUnexpectedEOF)
|
||||
n, err := ReadData(bytes.NewReader(test), nil)
|
||||
if n != 0 || err != io.ErrUnexpectedEOF {
|
||||
t.Fatalf("<%x> got (%v, %v), expected (%v, %v)", test, n, err, 0, io.ErrUnexpectedEOF)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -167,12 +169,13 @@ func TestNonMinimalLengthEncoding(t *testing.T) {
|
|||
{[]byte{0xc0, 0x01, 'X'}, []byte("X")},
|
||||
{[]byte{0xc0, 0x80, 0x01, 'X'}, []byte("X")},
|
||||
} {
|
||||
p, err := ReadData(bytes.NewReader(test.enc))
|
||||
var p [10]byte
|
||||
n, err := ReadData(bytes.NewReader(test.enc), p[:])
|
||||
if err != nil {
|
||||
t.Fatalf("<%x> got error %v, expected %v", test.enc, err, nil)
|
||||
}
|
||||
if !bytes.Equal(p, test.expected) {
|
||||
t.Fatalf("<%x> got <%x>, expected <%x>", test.enc, p, test.expected)
|
||||
if !bytes.Equal(p[:n], test.expected) {
|
||||
t.Fatalf("<%x> got <%x>, expected <%x>", test.enc, p[:n], test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -184,27 +187,28 @@ func TestReadLimits(t *testing.T) {
|
|||
maxLength := (0x3f << 14) | (0x7f << 7) | 0x7f
|
||||
data := bytes.Repeat([]byte{'X'}, maxLength)
|
||||
prefix := []byte{0xff, 0xff, 0x7f} // encodes 0xfffff
|
||||
p, err := ReadData(bytes.NewReader(append(prefix, data...)))
|
||||
var p [0xfffff]byte
|
||||
n, err := ReadData(bytes.NewReader(append(prefix, data...)), p[:])
|
||||
if err != nil {
|
||||
t.Fatalf("got error %v, expected %v", err, nil)
|
||||
}
|
||||
if !bytes.Equal(p, data) {
|
||||
if !bytes.Equal(p[:n], data) {
|
||||
t.Fatalf("got %d bytes unequal to %d bytes", len(p), len(data))
|
||||
}
|
||||
// Test a 4-byte prefix.
|
||||
prefix = []byte{0xc0, 0xc0, 0x80, 0x80} // encodes 0x100000
|
||||
data = bytes.Repeat([]byte{'X'}, maxLength+1)
|
||||
p, err = ReadData(bytes.NewReader(append(prefix, data...)))
|
||||
if p != nil || err != ErrTooLong {
|
||||
t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, ErrTooLong)
|
||||
n, err = ReadData(bytes.NewReader(append(prefix, data...)), nil)
|
||||
if n != 0 || err != ErrTooLong {
|
||||
t.Fatalf("got (%v, %v), expected (%v, %v)", n, err, 0, ErrTooLong)
|
||||
}
|
||||
// Test that 4 bytes don't work, even when they encode an integer that
|
||||
// would fix in 3 bytes.
|
||||
prefix = []byte{0xc0, 0x80, 0x80, 0x80} // encodes 0x0
|
||||
data = []byte{}
|
||||
p, err = ReadData(bytes.NewReader(append(prefix, data...)))
|
||||
if p != nil || err != ErrTooLong {
|
||||
t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, ErrTooLong)
|
||||
n, err = ReadData(bytes.NewReader(append(prefix, data...)), nil)
|
||||
if n != 0 || err != ErrTooLong {
|
||||
t.Fatalf("got (%v, %v), expected (%v, %v)", n, err, 0, ErrTooLong)
|
||||
}
|
||||
|
||||
// Do the same tests with padding lengths.
|
||||
|
@ -213,28 +217,28 @@ func TestReadLimits(t *testing.T) {
|
|||
padding := bytes.Repeat([]byte{'X'}, maxLength)
|
||||
enc := bytes.NewBuffer(append(prefix, padding...))
|
||||
mustWriteData(enc, data)
|
||||
p, err = ReadData(enc)
|
||||
n, err = ReadData(enc, p[:])
|
||||
if err != nil {
|
||||
t.Fatalf("got error %v, expected %v", err, nil)
|
||||
}
|
||||
if !bytes.Equal(p, data) {
|
||||
t.Fatalf("got <%x>, expected <%x>", p, data)
|
||||
if !bytes.Equal(p[:n], data) {
|
||||
t.Fatalf("got <%x>, expected <%x>", p[:n], data)
|
||||
}
|
||||
prefix = []byte{0x40, 0xc0, 0x80, 0x80} // encodes 0x100000
|
||||
padding = bytes.Repeat([]byte{'X'}, maxLength+1)
|
||||
enc = bytes.NewBuffer(append(prefix, padding...))
|
||||
mustWriteData(enc, data)
|
||||
p, err = ReadData(enc)
|
||||
if p != nil || err != ErrTooLong {
|
||||
t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, ErrTooLong)
|
||||
n, err = ReadData(enc, nil)
|
||||
if n != 0 || err != ErrTooLong {
|
||||
t.Fatalf("got (%v, %v), expected (%v, %v)", n, err, 0, ErrTooLong)
|
||||
}
|
||||
prefix = []byte{0x40, 0x80, 0x80, 0x80} // encodes 0x0
|
||||
padding = []byte{}
|
||||
enc = bytes.NewBuffer(append(prefix, padding...))
|
||||
mustWriteData(enc, data)
|
||||
p, err = ReadData(enc)
|
||||
if p != nil || err != ErrTooLong {
|
||||
t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, ErrTooLong)
|
||||
n, err = ReadData(enc, nil)
|
||||
if n != 0 || err != ErrTooLong {
|
||||
t.Fatalf("got (%v, %v), expected (%v, %v)", n, err, 0, ErrTooLong)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -329,6 +333,59 @@ func TestMaxDataForSize(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// Test that ReadData truncates the data when the destination slice is too
|
||||
// short.
|
||||
func TestReadDataTruncate(t *testing.T) {
|
||||
var enc bytes.Buffer
|
||||
mustWriteData(&enc, []byte("12345678"))
|
||||
mustWriteData(&enc, []byte("abcdefgh"))
|
||||
var p [4]byte
|
||||
// First ReadData should return truncated "1234".
|
||||
n, err := ReadData(&enc, p[:])
|
||||
if err != io.ErrShortBuffer {
|
||||
t.Fatalf("got error %v, expected %v", err, io.ErrShortBuffer)
|
||||
}
|
||||
if !bytes.Equal(p[:n], []byte("1234")) {
|
||||
t.Fatalf("got <%x>, expected <%x>", p[:n], []byte("1234"))
|
||||
}
|
||||
// Second ReadData should return truncated "abcd", not the rest of
|
||||
// "12345678".
|
||||
n, err = ReadData(&enc, p[:])
|
||||
if err != io.ErrShortBuffer {
|
||||
t.Fatalf("got error %v, expected %v", err, io.ErrShortBuffer)
|
||||
}
|
||||
if !bytes.Equal(p[:n], []byte("abcd")) {
|
||||
t.Fatalf("got <%x>, expected <%x>", p[:n], []byte("abcd"))
|
||||
}
|
||||
// Last ReadData should give io.EOF.
|
||||
n, err = ReadData(&enc, p[:])
|
||||
if err != io.EOF {
|
||||
t.Fatalf("got error %v, expected %v", err, io.EOF)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that even when the result is truncated, ReadData fills the provided
|
||||
// buffer as much as possible (and not stop at the boundary of an internal Read,
|
||||
// say).
|
||||
func TestReadDataTruncateFull(t *testing.T) {
|
||||
pr, pw := io.Pipe()
|
||||
go func() {
|
||||
// Send one data chunk that will be delivered across two Read
|
||||
// calls.
|
||||
pw.Write([]byte{0x8a, 'h', 'e', 'l', 'l', 'o'})
|
||||
pw.Write([]byte{'w', 'o', 'r', 'l', 'd'})
|
||||
}()
|
||||
var p [8]byte
|
||||
n, err := ReadData(pr, p[:])
|
||||
if err != io.ErrShortBuffer {
|
||||
t.Fatalf("got error %v, expected %v", err, io.ErrShortBuffer)
|
||||
}
|
||||
// Should not stop after "hello".
|
||||
if !bytes.Equal(p[:n], []byte("hellowor")) {
|
||||
t.Fatalf("got <%x>, expected <%x>", p[:n], []byte("hellowor"))
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark the ReadData function when reading from a stream of data packets of
|
||||
// different sizes.
|
||||
func BenchmarkReadData(b *testing.B) {
|
||||
|
@ -341,8 +398,9 @@ func BenchmarkReadData(b *testing.B) {
|
|||
}
|
||||
}()
|
||||
|
||||
var p [128]byte
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := ReadData(pr)
|
||||
_, err := ReadData(pr, p[:])
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -173,12 +173,16 @@ func (handler *httpHandler) turbotunnelMode(conn net.Conn, addr net.Addr) error
|
|||
go func() {
|
||||
defer wg.Done()
|
||||
defer close(done) // Signal the write loop to finish
|
||||
var p [2048]byte
|
||||
for {
|
||||
p, err := encapsulation.ReadData(conn)
|
||||
n, err := encapsulation.ReadData(conn, p[:])
|
||||
if err == io.ErrShortBuffer {
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
pconn.QueueIncoming(p, clientID)
|
||||
pconn.QueueIncoming(p[:n], clientID)
|
||||
}
|
||||
}()
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue