mirror of
https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake.git
synced 2025-10-13 20:11:19 -04:00
Import Turbo Tunnel support code.
Copied and slightly modified from
https://gitweb.torproject.org/pluggable-transports/meek.git/log/?h=turbotunnel&id=7eb94209f857fc71c2155907b0462cc587fc76cc
https://github.com/net4people/bbs/issues/21
RedialPacketConn is adapted from clientPacketConn in
c64a61c6da/obfs4proxy/turbotunnel_client.go
https://github.com/net4people/bbs/issues/14#issuecomment-544747519
This commit is contained in:
parent
904af9cb8a
commit
222ab3d85a
7 changed files with 1050 additions and 0 deletions
194
common/encapsulation/encapsulation.go
Normal file
194
common/encapsulation/encapsulation.go
Normal file
|
@ -0,0 +1,194 @@
|
|||
// Package encapsulation implements a way of encoding variable-size chunks of
|
||||
// data and padding into a byte stream.
|
||||
//
|
||||
// Each chunk of data or padding starts with a variable-size length prefix. One
|
||||
// bit ("d") in the first byte of the prefix indicates whether the chunk
|
||||
// represents data or padding (1=data, 0=padding). Another bit ("c" for
|
||||
// "continuation") is the indicates whether there are more bytes in the length
|
||||
// prefix. The remaining 6 bits ("x") encode part of the length value.
|
||||
// dcxxxxxx
|
||||
// If the continuation bit is set, then the next byte is also part of the length
|
||||
// prefix. It lacks the "d" bit, has its own "c" bit, and 7 value-carrying bits
|
||||
// ("y").
|
||||
// cyyyyyyy
|
||||
// The length is decoded by concatenating value-carrying bits, from left to
|
||||
// right, of all value-carrying bits, up to and including the first byte whose
|
||||
// "c" bit is 0. Although in principle this encoding would allow for length
|
||||
// prefixes of any size, length prefixes are arbitrarily limited to 3 bytes and
|
||||
// any attempt to read or write a longer one is an error. These are therefore
|
||||
// the only valid formats:
|
||||
// 00xxxxxx xxxxxx₂ bytes of padding
|
||||
// 10xxxxxx xxxxxx₂ bytes of data
|
||||
// 01xxxxxx 0yyyyyyy xxxxxxyyyyyyy₂ bytes of padding
|
||||
// 11xxxxxx 0yyyyyyy xxxxxxyyyyyyy₂ bytes of data
|
||||
// 01xxxxxx 1yyyyyyy 0zzzzzzz xxxxxxyyyyyyyzzzzzzz₂ bytes of padding
|
||||
// 11xxxxxx 1yyyyyyy 0zzzzzzz xxxxxxyyyyyyyzzzzzzz₂ bytes of data
|
||||
// The maximum encodable length is 11111111111111111111₂ = 0xfffff = 1048575.
|
||||
// There is no requirement to use a length prefix of minimum size; i.e. 00000100
|
||||
// and 01000000 00000100 are both valid encodings of the value 4.
|
||||
//
|
||||
// After the length prefix follow that many bytes of padding or data. There are
|
||||
// no restrictions on the value of bytes comprising padding.
|
||||
//
|
||||
// The idea for this encapsulation is sketched here:
|
||||
// https://github.com/net4people/bbs/issues/9#issuecomment-524095186
|
||||
package encapsulation
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
)
|
||||
|
||||
// ErrTooLong is the error returned when an encoded length prefix is longer than
|
||||
// 3 bytes, or when ReadData receives an input whose length is too large to
|
||||
// encode in a 3-byte length prefix.
|
||||
var ErrTooLong = errors.New("length prefix is too long")
|
||||
|
||||
// ReadData returns a new slice with the contents of the next available data
|
||||
// chunk, skipping over any padding chunks that may come first. The returned
|
||||
// error value is nil if and only if a data chunk was present and was read in
|
||||
// its entirety. The returned error is io.EOF only if r ended before the first
|
||||
// byte of a length prefix. If r ended in the middle of a length prefix or
|
||||
// data/padding, the returned error is io.ErrUnexpectedEOF.
|
||||
func ReadData(r io.Reader) ([]byte, error) {
|
||||
for {
|
||||
var b [1]byte
|
||||
_, err := r.Read(b[:])
|
||||
if err != nil {
|
||||
// This is the only place we may return a real io.EOF.
|
||||
return nil, err
|
||||
}
|
||||
isData := (b[0] & 0x80) != 0
|
||||
moreLength := (b[0] & 0x40) != 0
|
||||
n := int(b[0] & 0x3f)
|
||||
for i := 0; moreLength; i++ {
|
||||
if i >= 2 {
|
||||
return nil, ErrTooLong
|
||||
}
|
||||
_, err := r.Read(b[:])
|
||||
if err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
moreLength = (b[0] & 0x80) != 0
|
||||
n = (n << 7) | int(b[0]&0x7f)
|
||||
}
|
||||
if isData {
|
||||
p := make([]byte, n)
|
||||
_, err := io.ReadFull(r, p)
|
||||
if err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, err
|
||||
} else {
|
||||
_, err := io.CopyN(ioutil.Discard, r, int64(n))
|
||||
if err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// dataPrefixForLength returns a length prefix for the given length, with the
|
||||
// "d" bit set to 1.
|
||||
func dataPrefixForLength(n int) ([]byte, error) {
|
||||
switch {
|
||||
case (n>>0)&0x3f == (n >> 0):
|
||||
return []byte{0x80 | byte((n>>0)&0x3f)}, nil
|
||||
case (n>>7)&0x3f == (n >> 7):
|
||||
return []byte{0xc0 | byte((n>>7)&0x3f), byte((n >> 0) & 0x7f)}, nil
|
||||
case (n>>14)&0x3f == (n >> 14):
|
||||
return []byte{0xc0 | byte((n>>14)&0x3f), 0x80 | byte((n>>7)&0x7f), byte((n >> 0) & 0x7f)}, nil
|
||||
default:
|
||||
return nil, ErrTooLong
|
||||
}
|
||||
}
|
||||
|
||||
// WriteData encodes a data chunk into w. It returns the total number of bytes
|
||||
// written; i.e., including the length prefix. The error is ErrTooLong if the
|
||||
// length of data cannot fit into a length prefix.
|
||||
func WriteData(w io.Writer, data []byte) (int, error) {
|
||||
prefix, err := dataPrefixForLength(len(data))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
total := 0
|
||||
n, err := w.Write(prefix)
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
n, err = w.Write(data)
|
||||
total += n
|
||||
return total, err
|
||||
}
|
||||
|
||||
var paddingBuffer = make([]byte, 1024)
|
||||
|
||||
// WritePadding encodes padding chunks, whose total size (including their own
|
||||
// length prefixes) is n. Returns the total number of bytes written to w, which
|
||||
// will be exactly n unless there was an error. The error cannot be ErrTooLong
|
||||
// because this function will write multiple padding chunks if necessary to
|
||||
// reach the requested size. Panics if n is negative.
|
||||
func WritePadding(w io.Writer, n int) (int, error) {
|
||||
if n < 0 {
|
||||
panic("negative length")
|
||||
}
|
||||
total := 0
|
||||
for n > 0 {
|
||||
p := len(paddingBuffer)
|
||||
if p > n {
|
||||
p = n
|
||||
}
|
||||
n -= p
|
||||
var prefix []byte
|
||||
switch {
|
||||
case ((p-1)>>0)&0x3f == ((p - 1) >> 0):
|
||||
p = p - 1
|
||||
prefix = []byte{byte((p >> 0) & 0x3f)}
|
||||
case ((p-2)>>7)&0x3f == ((p - 2) >> 7):
|
||||
p = p - 2
|
||||
prefix = []byte{0x40 | byte((p>>7)&0x3f), byte((p >> 0) & 0x7f)}
|
||||
case ((p-3)>>14)&0x3f == ((p - 3) >> 14):
|
||||
p = p - 3
|
||||
prefix = []byte{0x40 | byte((p>>14)&0x3f), 0x80 | byte((p>>7)&0x3f), byte((p >> 0) & 0x7f)}
|
||||
}
|
||||
nn, err := w.Write(prefix)
|
||||
total += nn
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
nn, err = w.Write(paddingBuffer[:p])
|
||||
total += nn
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
}
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// MaxDataForSize returns the length of the longest slice that can pe passed to
|
||||
// WriteData, whose total encoded size (including length prefix) is no larger
|
||||
// than n. Call this to find out if a chunk of data will fit into a length
|
||||
// budget. Panics if n == 0.
|
||||
func MaxDataForSize(n int) int {
|
||||
if n == 0 {
|
||||
panic("zero length")
|
||||
}
|
||||
prefix, err := dataPrefixForLength(n)
|
||||
if err == ErrTooLong {
|
||||
return (1 << (6 + 7 + 7)) - 1 - 3
|
||||
} else if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return n - len(prefix)
|
||||
}
|
330
common/encapsulation/encapsulation_test.go
Normal file
330
common/encapsulation/encapsulation_test.go
Normal file
|
@ -0,0 +1,330 @@
|
|||
package encapsulation
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"math/rand"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Return a byte slice with non-trivial contents.
|
||||
func pseudorandomBuffer(n int) []byte {
|
||||
source := rand.NewSource(0)
|
||||
p := make([]byte, n)
|
||||
for i := 0; i < len(p); i++ {
|
||||
p[i] = byte(source.Int63() & 0xff)
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func mustWriteData(w io.Writer, p []byte) int {
|
||||
n, err := WriteData(w, p)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func mustWritePadding(w io.Writer, n int) int {
|
||||
n, err := WritePadding(w, n)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// Test that ReadData(WriteData()) recovers the original data.
|
||||
func TestRoundtrip(t *testing.T) {
|
||||
// Test above and below interesting thresholds.
|
||||
for _, i := range []int{
|
||||
0x00, 0x01,
|
||||
0x3e, 0x3f, 0x40, 0x41,
|
||||
0xfe, 0xff, 0x100, 0x101,
|
||||
0x1ffe, 0x1fff, 0x2000, 0x2001,
|
||||
0xfffe, 0xffff, 0x10000, 0x10001,
|
||||
0xffffe, 0xfffff,
|
||||
} {
|
||||
original := pseudorandomBuffer(i)
|
||||
var enc bytes.Buffer
|
||||
n, err := WriteData(&enc, original)
|
||||
if err != nil {
|
||||
t.Fatalf("size %d, WriteData returned error %v", i, err)
|
||||
}
|
||||
if enc.Len() != n {
|
||||
t.Fatalf("size %d, returned length was %d, written length was %d",
|
||||
i, n, enc.Len())
|
||||
}
|
||||
inverse, err := ReadData(&enc)
|
||||
if err != nil {
|
||||
t.Fatalf("size %d, ReadData returned error %v", i, err)
|
||||
}
|
||||
if !bytes.Equal(inverse, original) {
|
||||
t.Fatalf("size %d, got <%x>, expected <%x>", i, inverse, original)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test that WritePadding writes exactly as much as requested.
|
||||
func TestPaddingLength(t *testing.T) {
|
||||
// Test above and below interesting thresholds. WritePadding also gets
|
||||
// values above 0xfffff, the maximum value of a single length prefix.
|
||||
for _, i := range []int{
|
||||
0x00, 0x01,
|
||||
0x3f, 0x40, 0x41, 0x42,
|
||||
0xff, 0x100, 0x101, 0x102,
|
||||
0x2000, 0x2001, 0x2002, 0x2003,
|
||||
0x10000, 0x10001, 0x10002, 0x10003,
|
||||
0x100001, 0x100002, 0x100003, 0x100004,
|
||||
} {
|
||||
var enc bytes.Buffer
|
||||
n, err := WritePadding(&enc, i)
|
||||
if err != nil {
|
||||
t.Fatalf("size %d, WritePadding returned error %v", i, err)
|
||||
}
|
||||
if n != i {
|
||||
t.Fatalf("requested %d bytes, returned %d", i, n)
|
||||
}
|
||||
if enc.Len() != n {
|
||||
t.Fatalf("requested %d bytes, wrote %d bytes", i, enc.Len())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test that ReadData skips over padding.
|
||||
func TestSkipPadding(t *testing.T) {
|
||||
var data = [][]byte{{}, {}, []byte("hello"), {}, []byte("world")}
|
||||
var enc bytes.Buffer
|
||||
mustWritePadding(&enc, 10)
|
||||
mustWritePadding(&enc, 100)
|
||||
mustWriteData(&enc, data[0])
|
||||
mustWriteData(&enc, data[1])
|
||||
mustWritePadding(&enc, 10)
|
||||
mustWriteData(&enc, data[2])
|
||||
mustWriteData(&enc, data[3])
|
||||
mustWritePadding(&enc, 10)
|
||||
mustWriteData(&enc, data[4])
|
||||
mustWritePadding(&enc, 10)
|
||||
mustWritePadding(&enc, 10)
|
||||
for i, expected := range data {
|
||||
actual, err := ReadData(&enc)
|
||||
if err != nil {
|
||||
t.Fatalf("slice %d, got error %v, expected %v", i, err, nil)
|
||||
}
|
||||
if !bytes.Equal(actual, expected) {
|
||||
t.Fatalf("slice %d, got <%x>, expected <%x>", i, actual, expected)
|
||||
}
|
||||
}
|
||||
p, err := ReadData(&enc)
|
||||
if p != nil || err != io.EOF {
|
||||
t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, io.EOF)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that EOF before a length prefix returns io.EOF.
|
||||
func TestEOF(t *testing.T) {
|
||||
p, err := ReadData(bytes.NewReader(nil))
|
||||
if p != nil || err != io.EOF {
|
||||
t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, io.EOF)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that an EOF while reading a length prefix, or while reading the
|
||||
// subsequent data/padding, returns io.ErrUnexpectedEOF.
|
||||
func TestUnexpectedEOF(t *testing.T) {
|
||||
for _, test := range [][]byte{
|
||||
{0x40}, // expecting a second length byte
|
||||
{0xc0}, // expecting a second length byte
|
||||
{0x41, 0x80}, // expecting a third length byte
|
||||
{0xc1, 0x80}, // expecting a third length byte
|
||||
{0x02}, // expecting 2 bytes of padding
|
||||
{0x82}, // expecting 2 bytes of data
|
||||
{0x02, 'X'}, // expecting 1 byte of padding
|
||||
{0x82, 'X'}, // expecting 1 byte of data
|
||||
{0x41, 0x00}, // expecting 128 bytes of padding
|
||||
{0xc1, 0x00}, // expecting 128 bytes of data
|
||||
{0x41, 0x00, 'X'}, // expecting 127 bytes of padding
|
||||
{0xc1, 0x00, 'X'}, // expecting 127 bytes of data
|
||||
{0x41, 0x80, 0x00}, // expecting 32768 bytes of padding
|
||||
{0xc1, 0x80, 0x00}, // expecting 32768 bytes of data
|
||||
{0x41, 0x80, 0x00, 'X'}, // expecting 32767 bytes of padding
|
||||
{0xc1, 0x80, 0x00, 'X'}, // expecting 32767 bytes of data
|
||||
} {
|
||||
p, err := ReadData(bytes.NewReader(test))
|
||||
if p != nil || err != io.ErrUnexpectedEOF {
|
||||
t.Fatalf("<%x> got (<%x>, %v), expected (%v, %v)", test, p, err, nil, io.ErrUnexpectedEOF)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test that length encodings that are longer than they could be are still
|
||||
// interpreted.
|
||||
func TestNonMinimalLengthEncoding(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
enc []byte
|
||||
expected []byte
|
||||
}{
|
||||
{[]byte{0x81, 'X'}, []byte("X")},
|
||||
{[]byte{0xc0, 0x01, 'X'}, []byte("X")},
|
||||
{[]byte{0xc0, 0x80, 0x01, 'X'}, []byte("X")},
|
||||
} {
|
||||
p, err := ReadData(bytes.NewReader(test.enc))
|
||||
if err != nil {
|
||||
t.Fatalf("<%x> got error %v, expected %v", test.enc, err, nil)
|
||||
}
|
||||
if !bytes.Equal(p, test.expected) {
|
||||
t.Fatalf("<%x> got <%x>, expected <%x>", test.enc, p, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Test that ReadData only reads up to 3 bytes of length prefix.
|
||||
func TestReadLimits(t *testing.T) {
|
||||
// Test the maximum length that's possible with 3 bytes of length
|
||||
// prefix.
|
||||
maxLength := (0x3f << 14) | (0x7f << 7) | 0x7f
|
||||
data := bytes.Repeat([]byte{'X'}, maxLength)
|
||||
prefix := []byte{0xff, 0xff, 0x7f} // encodes 0xfffff
|
||||
p, err := ReadData(bytes.NewReader(append(prefix, data...)))
|
||||
if err != nil {
|
||||
t.Fatalf("got error %v, expected %v", err, nil)
|
||||
}
|
||||
if !bytes.Equal(p, data) {
|
||||
t.Fatalf("got %d bytes unequal to %d bytes", len(p), len(data))
|
||||
}
|
||||
// Test a 4-byte prefix.
|
||||
prefix = []byte{0xc0, 0xc0, 0x80, 0x80} // encodes 0x100000
|
||||
data = bytes.Repeat([]byte{'X'}, maxLength+1)
|
||||
p, err = ReadData(bytes.NewReader(append(prefix, data...)))
|
||||
if p != nil || err != ErrTooLong {
|
||||
t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, ErrTooLong)
|
||||
}
|
||||
// Test that 4 bytes don't work, even when they encode an integer that
|
||||
// would fix in 3 bytes.
|
||||
prefix = []byte{0xc0, 0x80, 0x80, 0x80} // encodes 0x0
|
||||
data = []byte{}
|
||||
p, err = ReadData(bytes.NewReader(append(prefix, data...)))
|
||||
if p != nil || err != ErrTooLong {
|
||||
t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, ErrTooLong)
|
||||
}
|
||||
|
||||
// Do the same tests with padding lengths.
|
||||
data = []byte("hello")
|
||||
prefix = []byte{0x7f, 0xff, 0x7f} // encodes 0xfffff
|
||||
padding := bytes.Repeat([]byte{'X'}, maxLength)
|
||||
enc := bytes.NewBuffer(append(prefix, padding...))
|
||||
mustWriteData(enc, data)
|
||||
p, err = ReadData(enc)
|
||||
if err != nil {
|
||||
t.Fatalf("got error %v, expected %v", err, nil)
|
||||
}
|
||||
if !bytes.Equal(p, data) {
|
||||
t.Fatalf("got <%x>, expected <%x>", p, data)
|
||||
}
|
||||
prefix = []byte{0x40, 0xc0, 0x80, 0x80} // encodes 0x100000
|
||||
padding = bytes.Repeat([]byte{'X'}, maxLength+1)
|
||||
enc = bytes.NewBuffer(append(prefix, padding...))
|
||||
mustWriteData(enc, data)
|
||||
p, err = ReadData(enc)
|
||||
if p != nil || err != ErrTooLong {
|
||||
t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, ErrTooLong)
|
||||
}
|
||||
prefix = []byte{0x40, 0x80, 0x80, 0x80} // encodes 0x0
|
||||
padding = []byte{}
|
||||
enc = bytes.NewBuffer(append(prefix, padding...))
|
||||
mustWriteData(enc, data)
|
||||
p, err = ReadData(enc)
|
||||
if p != nil || err != ErrTooLong {
|
||||
t.Fatalf("got (<%x>, %v), expected (%v, %v)", p, err, nil, ErrTooLong)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that WriteData and WritePadding only accept lengths that can be encoded
|
||||
// in up to 3 bytes of length prefix.
|
||||
func TestWriteLimits(t *testing.T) {
|
||||
maxLength := (0x3f << 14) | (0x7f << 7) | 0x7f
|
||||
var enc bytes.Buffer
|
||||
n, err := WriteData(&enc, bytes.Repeat([]byte{'X'}, maxLength))
|
||||
if n != maxLength+3 || err != nil {
|
||||
t.Fatalf("got (%d, %v), expected (%d, %v)", n, err, maxLength, nil)
|
||||
}
|
||||
enc.Reset()
|
||||
n, err = WriteData(&enc, bytes.Repeat([]byte{'X'}, maxLength+1))
|
||||
if n != 0 || err != ErrTooLong {
|
||||
t.Fatalf("got (%d, %v), expected (%d, %v)", n, err, 0, ErrTooLong)
|
||||
}
|
||||
|
||||
// Padding gets an extra 3 bytes because the prefix is counted as part
|
||||
// of the length.
|
||||
enc.Reset()
|
||||
n, err = WritePadding(&enc, maxLength+3)
|
||||
if n != maxLength+3 || err != nil {
|
||||
t.Fatalf("got (%d, %v), expected (%d, %v)", n, err, maxLength+3, nil)
|
||||
}
|
||||
// Writing a too-long padding is okay because WritePadding will break it
|
||||
// into smaller chunks.
|
||||
enc.Reset()
|
||||
n, err = WritePadding(&enc, maxLength+4)
|
||||
if n != maxLength+4 || err != nil {
|
||||
t.Fatalf("got (%d, %v), expected (%d, %v)", n, err, maxLength+4, nil)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that WritePadding panics when given a negative length.
|
||||
func TestNegativeLength(t *testing.T) {
|
||||
for _, n := range []int{-1, ^0} {
|
||||
var enc bytes.Buffer
|
||||
panicked, nn, err := testNegativeLengthSub(t, &enc, n)
|
||||
if !panicked {
|
||||
t.Fatalf("WritePadding(%d) returned (%d, %v) instead of panicking", n, nn, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Calls WritePadding(w, n) and augments the return value with a flag indicating
|
||||
// whether the call panicked.
|
||||
func testNegativeLengthSub(t *testing.T, w io.Writer, n int) (panicked bool, nn int, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
panicked = true
|
||||
}
|
||||
}()
|
||||
t.Helper()
|
||||
nn, err = WritePadding(w, n)
|
||||
return false, n, err
|
||||
}
|
||||
|
||||
// Test that MaxDataForSize panics when given a 0 length.
|
||||
func TestMaxDataForSizeZero(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Fatal("didn't panic")
|
||||
}
|
||||
}()
|
||||
MaxDataForSize(0)
|
||||
}
|
||||
|
||||
// Test thresholds of available sizes for MaxDataForSize.
|
||||
func TestMaxDataForSize(t *testing.T) {
|
||||
for _, test := range []struct {
|
||||
size int
|
||||
expected int
|
||||
}{
|
||||
{0x01, 0x00},
|
||||
{0x02, 0x01},
|
||||
{0x3f, 0x3e},
|
||||
{0x40, 0x3e},
|
||||
{0x41, 0x3f},
|
||||
{0x1fff, 0x1ffd},
|
||||
{0x2000, 0x1ffd},
|
||||
{0x2001, 0x1ffe},
|
||||
{0xfffff, 0xffffc},
|
||||
{0x100000, 0xffffc},
|
||||
{0x100001, 0xffffc},
|
||||
{0x7fffffff, 0xffffc},
|
||||
} {
|
||||
max := MaxDataForSize(test.size)
|
||||
if max != test.expected {
|
||||
t.Fatalf("size %d, got %d, expected %d", test.size, max, test.expected)
|
||||
}
|
||||
}
|
||||
}
|
28
common/turbotunnel/clientid.go
Normal file
28
common/turbotunnel/clientid.go
Normal file
|
@ -0,0 +1,28 @@
|
|||
package turbotunnel
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
// ClientID is an abstract identifier that binds together all the communications
|
||||
// belonging to a single client session, even though those communications may
|
||||
// arrive from multiple IP addresses or over multiple lower-level connections.
|
||||
// It plays the same role that an (IP address, port number) tuple plays in a
|
||||
// net.UDPConn: it's the return address pertaining to a long-lived abstract
|
||||
// client session. The client attaches its ClientID to each of its
|
||||
// communications, enabling the server to disambiguate requests among its many
|
||||
// clients. ClientID implements the net.Addr interface.
|
||||
type ClientID [8]byte
|
||||
|
||||
func NewClientID() ClientID {
|
||||
var id ClientID
|
||||
_, err := rand.Read(id[:])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
func (id ClientID) Network() string { return "clientid" }
|
||||
func (id ClientID) String() string { return hex.EncodeToString(id[:]) }
|
144
common/turbotunnel/clientmap.go
Normal file
144
common/turbotunnel/clientmap.go
Normal file
|
@ -0,0 +1,144 @@
|
|||
package turbotunnel
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// clientRecord is a record of a recently seen client, with the time it was last
|
||||
// seen and a send queue.
|
||||
type clientRecord struct {
|
||||
Addr net.Addr
|
||||
LastSeen time.Time
|
||||
SendQueue chan []byte
|
||||
}
|
||||
|
||||
// ClientMap manages a mapping of live clients (keyed by address, which will be
|
||||
// a ClientID) to their respective send queues. ClientMap's functions are safe
|
||||
// to call from multiple goroutines.
|
||||
type ClientMap struct {
|
||||
// We use an inner structure to avoid exposing public heap.Interface
|
||||
// functions to users of clientMap.
|
||||
inner clientMapInner
|
||||
// Synchronizes access to inner.
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
// NewClientMap creates a ClientMap that expires clients after a timeout.
|
||||
//
|
||||
// The timeout does not have to be kept in sync with QUIC's internal idle
|
||||
// timeout. If a client is removed from the client map while the QUIC session is
|
||||
// still live, the worst that can happen is a loss of whatever packets were in
|
||||
// the send queue at the time. If QUIC later decides to send more packets to the
|
||||
// same client, we'll instantiate a new send queue, and if the client ever
|
||||
// connects again with the proper client ID, we'll deliver them.
|
||||
func NewClientMap(timeout time.Duration) *ClientMap {
|
||||
m := &ClientMap{
|
||||
inner: clientMapInner{
|
||||
byAge: make([]*clientRecord, 0),
|
||||
byAddr: make(map[net.Addr]int),
|
||||
},
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(timeout / 2)
|
||||
now := time.Now()
|
||||
m.lock.Lock()
|
||||
m.inner.removeExpired(now, timeout)
|
||||
m.lock.Unlock()
|
||||
}
|
||||
}()
|
||||
return m
|
||||
}
|
||||
|
||||
// SendQueue returns the send queue corresponding to addr, creating it if
|
||||
// necessary.
|
||||
func (m *ClientMap) SendQueue(addr net.Addr) chan []byte {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
return m.inner.SendQueue(addr, time.Now())
|
||||
}
|
||||
|
||||
// clientMapInner is the inner type of ClientMap, implementing heap.Interface.
|
||||
// byAge is the backing store, a heap ordered by LastSeen time, to facilitate
|
||||
// expiring old client records. byAddr is a map from addresses (i.e., ClientIDs)
|
||||
// to heap indices, to allow looking up by address. Unlike ClientMap,
|
||||
// clientMapInner requires external synchonization.
|
||||
type clientMapInner struct {
|
||||
byAge []*clientRecord
|
||||
byAddr map[net.Addr]int
|
||||
}
|
||||
|
||||
// removeExpired removes all client records whose LastSeen timestamp is more
|
||||
// than timeout in the past.
|
||||
func (inner *clientMapInner) removeExpired(now time.Time, timeout time.Duration) {
|
||||
for len(inner.byAge) > 0 && now.Sub(inner.byAge[0].LastSeen) >= timeout {
|
||||
heap.Pop(inner)
|
||||
}
|
||||
}
|
||||
|
||||
// SendQueue finds the existing client record corresponding to addr, or creates
|
||||
// a new one if none exists yet. It updates the client record's LastSeen time
|
||||
// and returns its SendQueue.
|
||||
func (inner *clientMapInner) SendQueue(addr net.Addr, now time.Time) chan []byte {
|
||||
var record *clientRecord
|
||||
i, ok := inner.byAddr[addr]
|
||||
if ok {
|
||||
// Found one, update its LastSeen.
|
||||
record = inner.byAge[i]
|
||||
record.LastSeen = now
|
||||
heap.Fix(inner, i)
|
||||
} else {
|
||||
// Not found, create a new one.
|
||||
record = &clientRecord{
|
||||
Addr: addr,
|
||||
LastSeen: now,
|
||||
SendQueue: make(chan []byte, queueSize),
|
||||
}
|
||||
heap.Push(inner, record)
|
||||
}
|
||||
return record.SendQueue
|
||||
}
|
||||
|
||||
// heap.Interface for clientMapInner.
|
||||
|
||||
func (inner *clientMapInner) Len() int {
|
||||
if len(inner.byAge) != len(inner.byAddr) {
|
||||
panic("inconsistent clientMap")
|
||||
}
|
||||
return len(inner.byAge)
|
||||
}
|
||||
|
||||
func (inner *clientMapInner) Less(i, j int) bool {
|
||||
return inner.byAge[i].LastSeen.Before(inner.byAge[j].LastSeen)
|
||||
}
|
||||
|
||||
func (inner *clientMapInner) Swap(i, j int) {
|
||||
inner.byAge[i], inner.byAge[j] = inner.byAge[j], inner.byAge[i]
|
||||
inner.byAddr[inner.byAge[i].Addr] = i
|
||||
inner.byAddr[inner.byAge[j].Addr] = j
|
||||
}
|
||||
|
||||
func (inner *clientMapInner) Push(x interface{}) {
|
||||
record := x.(*clientRecord)
|
||||
if _, ok := inner.byAddr[record.Addr]; ok {
|
||||
panic("duplicate address in clientMap")
|
||||
}
|
||||
// Insert into byAddr map.
|
||||
inner.byAddr[record.Addr] = len(inner.byAge)
|
||||
// Insert into byAge slice.
|
||||
inner.byAge = append(inner.byAge, record)
|
||||
}
|
||||
|
||||
func (inner *clientMapInner) Pop() interface{} {
|
||||
n := len(inner.byAddr)
|
||||
// Remove from byAge slice.
|
||||
record := inner.byAge[n-1]
|
||||
inner.byAge[n-1] = nil
|
||||
inner.byAge = inner.byAge[:n-1]
|
||||
// Remove from byAddr map.
|
||||
delete(inner.byAddr, record.Addr)
|
||||
return record
|
||||
}
|
13
common/turbotunnel/consts.go
Normal file
13
common/turbotunnel/consts.go
Normal file
|
@ -0,0 +1,13 @@
|
|||
// Package turbotunnel provides support for overlaying a virtual net.PacketConn
|
||||
// on some other network carrier.
|
||||
//
|
||||
// https://github.com/net4people/bbs/issues/9
|
||||
package turbotunnel
|
||||
|
||||
import "errors"
|
||||
|
||||
// The size of receive and send queues.
|
||||
const queueSize = 32
|
||||
|
||||
var errClosedPacketConn = errors.New("operation on closed connection")
|
||||
var errNotImplemented = errors.New("not implemented")
|
137
common/turbotunnel/queuepacketconn.go
Normal file
137
common/turbotunnel/queuepacketconn.go
Normal file
|
@ -0,0 +1,137 @@
|
|||
package turbotunnel
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// taggedPacket is a combination of a []byte and a net.Addr, encapsulating the
|
||||
// return type of PacketConn.ReadFrom.
|
||||
type taggedPacket struct {
|
||||
P []byte
|
||||
Addr net.Addr
|
||||
}
|
||||
|
||||
// QueuePacketConn implements net.PacketConn by storing queues of packets. There
|
||||
// is one incoming queue (where packets are additionally tagged by the source
|
||||
// address of the client that sent them). There are many outgoing queues, one
|
||||
// for each client address that has been recently seen. The QueueIncoming method
|
||||
// inserts a packet into the incoming queue, to eventually be returned by
|
||||
// ReadFrom. WriteTo inserts a packet into an address-specific outgoing queue,
|
||||
// which can later by accessed through the OutgoingQueue method.
|
||||
type QueuePacketConn struct {
|
||||
clients *ClientMap
|
||||
localAddr net.Addr
|
||||
recvQueue chan taggedPacket
|
||||
closeOnce sync.Once
|
||||
closed chan struct{}
|
||||
// What error to return when the QueuePacketConn is closed.
|
||||
err atomic.Value
|
||||
}
|
||||
|
||||
// NewQueuePacketConn makes a new QueuePacketConn, set to track recent clients
|
||||
// for at least a duration of timeout.
|
||||
func NewQueuePacketConn(localAddr net.Addr, timeout time.Duration) *QueuePacketConn {
|
||||
return &QueuePacketConn{
|
||||
clients: NewClientMap(timeout),
|
||||
localAddr: localAddr,
|
||||
recvQueue: make(chan taggedPacket, queueSize),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// QueueIncoming queues and incoming packet and its source address, to be
|
||||
// returned in a future call to ReadFrom.
|
||||
func (c *QueuePacketConn) QueueIncoming(p []byte, addr net.Addr) {
|
||||
select {
|
||||
case <-c.closed:
|
||||
// If we're closed, silently drop it.
|
||||
return
|
||||
default:
|
||||
}
|
||||
// Copy the slice so that the caller may reuse it.
|
||||
buf := make([]byte, len(p))
|
||||
copy(buf, p)
|
||||
select {
|
||||
case c.recvQueue <- taggedPacket{buf, addr}:
|
||||
default:
|
||||
// Drop the incoming packet if the receive queue is full.
|
||||
}
|
||||
}
|
||||
|
||||
// OutgoingQueue returns the queue of outgoing packets corresponding to addr,
|
||||
// creating it if necessary. The contents of the queue will be packets that are
|
||||
// written to the address in question using WriteTo.
|
||||
func (c *QueuePacketConn) OutgoingQueue(addr net.Addr) <-chan []byte {
|
||||
return c.clients.SendQueue(addr)
|
||||
}
|
||||
|
||||
// ReadFrom returns a packet and address previously stored by QueueIncoming.
|
||||
func (c *QueuePacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
|
||||
select {
|
||||
case <-c.closed:
|
||||
return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case <-c.closed:
|
||||
return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
|
||||
case packet := <-c.recvQueue:
|
||||
return copy(p, packet.P), packet.Addr, nil
|
||||
}
|
||||
}
|
||||
|
||||
// WriteTo queues an outgoing packet for the given address. The queue can later
|
||||
// be retrieved using the OutgoingQueue method.
|
||||
func (c *QueuePacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
|
||||
select {
|
||||
case <-c.closed:
|
||||
return 0, &net.OpError{Op: "write", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
|
||||
default:
|
||||
}
|
||||
// Copy the slice so that the caller may reuse it.
|
||||
buf := make([]byte, len(p))
|
||||
copy(buf, p)
|
||||
select {
|
||||
case c.clients.SendQueue(addr) <- buf:
|
||||
return len(buf), nil
|
||||
default:
|
||||
// Drop the outgoing packet if the send queue is full.
|
||||
return len(buf), nil
|
||||
}
|
||||
}
|
||||
|
||||
// closeWithError unblocks pending operations and makes future operations fail
|
||||
// with the given error. If err is nil, it becomes errClosedPacketConn.
|
||||
func (c *QueuePacketConn) closeWithError(err error) error {
|
||||
var newlyClosed bool
|
||||
c.closeOnce.Do(func() {
|
||||
newlyClosed = true
|
||||
// Store the error to be returned by future PacketConn
|
||||
// operations.
|
||||
if err == nil {
|
||||
err = errClosedPacketConn
|
||||
}
|
||||
c.err.Store(err)
|
||||
close(c.closed)
|
||||
})
|
||||
if !newlyClosed {
|
||||
return &net.OpError{Op: "close", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close unblocks pending operations and makes future operations fail with a
|
||||
// "closed connection" error.
|
||||
func (c *QueuePacketConn) Close() error {
|
||||
return c.closeWithError(nil)
|
||||
}
|
||||
|
||||
// LocalAddr returns the localAddr value that was passed to NewQueuePacketConn.
|
||||
func (c *QueuePacketConn) LocalAddr() net.Addr { return c.localAddr }
|
||||
|
||||
func (c *QueuePacketConn) SetDeadline(t time.Time) error { return errNotImplemented }
|
||||
func (c *QueuePacketConn) SetReadDeadline(t time.Time) error { return errNotImplemented }
|
||||
func (c *QueuePacketConn) SetWriteDeadline(t time.Time) error { return errNotImplemented }
|
204
common/turbotunnel/redialpacketconn.go
Normal file
204
common/turbotunnel/redialpacketconn.go
Normal file
|
@ -0,0 +1,204 @@
|
|||
package turbotunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RedialPacketConn implements a long-lived net.PacketConn atop a sequence of
|
||||
// other, transient net.PacketConns. RedialPacketConn creates a new
|
||||
// net.PacketConn by calling a provided dialContext function. Whenever the
|
||||
// net.PacketConn experiences a ReadFrom or WriteTo error, RedialPacketConn
|
||||
// calls the dialContext function again and starts sending and receiving packets
|
||||
// on the new net.PacketConn. RedialPacketConn's own ReadFrom and WriteTo
|
||||
// methods return an error only when the dialContext function returns an error.
|
||||
//
|
||||
// RedialPacketConn uses static local and remote addresses that are independent
|
||||
// of those of any dialed net.PacketConn.
|
||||
type RedialPacketConn struct {
|
||||
localAddr net.Addr
|
||||
remoteAddr net.Addr
|
||||
dialContext func(context.Context) (net.PacketConn, error)
|
||||
recvQueue chan []byte
|
||||
sendQueue chan []byte
|
||||
closed chan struct{}
|
||||
closeOnce sync.Once
|
||||
// The first dial error, which causes the clientPacketConn to be
|
||||
// closed and is returned from future read/write operations. Compare to
|
||||
// the rerr and werr in io.Pipe.
|
||||
err atomic.Value
|
||||
}
|
||||
|
||||
// NewQueuePacketConn makes a new RedialPacketConn, with the given static local
|
||||
// and remote addresses, and dialContext function.
|
||||
func NewRedialPacketConn(
|
||||
localAddr, remoteAddr net.Addr,
|
||||
dialContext func(context.Context) (net.PacketConn, error),
|
||||
) *RedialPacketConn {
|
||||
c := &RedialPacketConn{
|
||||
localAddr: localAddr,
|
||||
remoteAddr: remoteAddr,
|
||||
dialContext: dialContext,
|
||||
recvQueue: make(chan []byte, queueSize),
|
||||
sendQueue: make(chan []byte, queueSize),
|
||||
closed: make(chan struct{}),
|
||||
err: atomic.Value{},
|
||||
}
|
||||
go c.dialLoop()
|
||||
return c
|
||||
}
|
||||
|
||||
// dialLoop repeatedly calls c.dialContext and passes the resulting
|
||||
// net.PacketConn to c.exchange. It returns only when c is closed or dialContext
|
||||
// returns an error.
|
||||
func (c *RedialPacketConn) dialLoop() {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
for {
|
||||
select {
|
||||
case <-c.closed:
|
||||
cancel()
|
||||
return
|
||||
default:
|
||||
}
|
||||
conn, err := c.dialContext(ctx)
|
||||
if err != nil {
|
||||
c.closeWithError(err)
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
c.exchange(conn)
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// exchange calls ReadFrom on the given net.PacketConn and places the resulting
|
||||
// packets in the receive queue, and takes packets from the send queue and calls
|
||||
// WriteTo on them, making the current net.PacketConn active.
|
||||
func (c *RedialPacketConn) exchange(conn net.PacketConn) {
|
||||
readErrCh := make(chan error)
|
||||
writeErrCh := make(chan error)
|
||||
|
||||
go func() {
|
||||
defer close(readErrCh)
|
||||
for {
|
||||
select {
|
||||
case <-c.closed:
|
||||
return
|
||||
case <-writeErrCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
var buf [1500]byte
|
||||
n, _, err := conn.ReadFrom(buf[:])
|
||||
if err != nil {
|
||||
readErrCh <- err
|
||||
return
|
||||
}
|
||||
p := make([]byte, n)
|
||||
copy(p, buf[:])
|
||||
select {
|
||||
case c.recvQueue <- p:
|
||||
default: // OK to drop packets.
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer close(writeErrCh)
|
||||
for {
|
||||
select {
|
||||
case <-c.closed:
|
||||
return
|
||||
case <-readErrCh:
|
||||
return
|
||||
case p := <-c.sendQueue:
|
||||
_, err := conn.WriteTo(p, c.remoteAddr)
|
||||
if err != nil {
|
||||
writeErrCh <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-readErrCh:
|
||||
case <-writeErrCh:
|
||||
}
|
||||
}
|
||||
|
||||
// ReadFrom reads a packet from the currently active net.PacketConn. The
|
||||
// packet's original remote address is replaced with the RedialPacketConn's own
|
||||
// remote address.
|
||||
func (c *RedialPacketConn) ReadFrom(p []byte) (int, net.Addr, error) {
|
||||
select {
|
||||
case <-c.closed:
|
||||
return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Addr: c.remoteAddr, Err: c.err.Load().(error)}
|
||||
default:
|
||||
}
|
||||
select {
|
||||
case <-c.closed:
|
||||
return 0, nil, &net.OpError{Op: "read", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Addr: c.remoteAddr, Err: c.err.Load().(error)}
|
||||
case buf := <-c.recvQueue:
|
||||
return copy(p, buf), c.remoteAddr, nil
|
||||
}
|
||||
}
|
||||
|
||||
// WriteTo writes a packet to the currently active net.PacketConn. The addr
|
||||
// argument is ignored and instead replaced with the RedialPacketConn's own
|
||||
// remote address.
|
||||
func (c *RedialPacketConn) WriteTo(p []byte, addr net.Addr) (int, error) {
|
||||
// addr is ignored.
|
||||
select {
|
||||
case <-c.closed:
|
||||
return 0, &net.OpError{Op: "write", Net: c.LocalAddr().Network(), Source: c.LocalAddr(), Addr: c.remoteAddr, Err: c.err.Load().(error)}
|
||||
default:
|
||||
}
|
||||
buf := make([]byte, len(p))
|
||||
copy(buf, p)
|
||||
select {
|
||||
case c.sendQueue <- buf:
|
||||
return len(buf), nil
|
||||
default:
|
||||
// Drop the outgoing packet if the send queue is full.
|
||||
return len(buf), nil
|
||||
}
|
||||
}
|
||||
|
||||
// closeWithError unblocks pending operations and makes future operations fail
|
||||
// with the given error. If err is nil, it becomes errClosedPacketConn.
|
||||
func (c *RedialPacketConn) closeWithError(err error) error {
|
||||
var once bool
|
||||
c.closeOnce.Do(func() {
|
||||
// Store the error to be returned by future read/write
|
||||
// operations.
|
||||
if err == nil {
|
||||
err = errors.New("operation on closed connection")
|
||||
}
|
||||
c.err.Store(err)
|
||||
close(c.closed)
|
||||
once = true
|
||||
})
|
||||
if !once {
|
||||
return &net.OpError{Op: "close", Net: c.LocalAddr().Network(), Addr: c.LocalAddr(), Err: c.err.Load().(error)}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close unblocks pending operations and makes future operations fail with a
|
||||
// "closed connection" error.
|
||||
func (c *RedialPacketConn) Close() error {
|
||||
return c.closeWithError(nil)
|
||||
}
|
||||
|
||||
// LocalAddr returns the localAddr value that was passed to NewRedialPacketConn.
|
||||
func (c *RedialPacketConn) LocalAddr() net.Addr { return c.localAddr }
|
||||
|
||||
func (c *RedialPacketConn) SetDeadline(t time.Time) error { return errNotImplemented }
|
||||
func (c *RedialPacketConn) SetReadDeadline(t time.Time) error { return errNotImplemented }
|
||||
func (c *RedialPacketConn) SetWriteDeadline(t time.Time) error { return errNotImplemented }
|
Loading…
Add table
Add a link
Reference in a new issue