Turn the proxy code into a library

Allow other go programs to easily import the snowflake proxy library and
start/stop a snowflake proxy.
This commit is contained in:
idk 2021-10-25 22:51:40 -04:00 committed by Cecylia Bocovich
parent 54ab79384f
commit 50e4f4fd61
7 changed files with 184 additions and 99 deletions

View file

@ -1,4 +1,4 @@
package main package snowflake
import ( import (
"bytes" "bytes"
@ -365,7 +365,7 @@ func TestBrokerInteractions(t *testing.T) {
b, b,
} }
sdp := broker.pollOffer(sampleOffer) sdp := broker.pollOffer(sampleOffer, nil)
expectedSDP, _ := strconv.Unquote(sampleSDP) expectedSDP, _ := strconv.Unquote(sampleSDP)
So(sdp.SDP, ShouldResemble, expectedSDP) So(sdp.SDP, ShouldResemble, expectedSDP)
}) })
@ -379,7 +379,7 @@ func TestBrokerInteractions(t *testing.T) {
b, b,
} }
sdp := broker.pollOffer(sampleOffer) sdp := broker.pollOffer(sampleOffer, nil)
So(sdp, ShouldBeNil) So(sdp, ShouldBeNil)
}) })
Convey("sends answer to broker", func() { Convey("sends answer to broker", func() {
@ -478,7 +478,7 @@ func TestUtilityFuncs(t *testing.T) {
Convey("CopyLoop", t, func() { Convey("CopyLoop", t, func() {
c1, s1 := net.Pipe() c1, s1 := net.Pipe()
c2, s2 := net.Pipe() c2, s2 := net.Pipe()
go CopyLoop(s1, s2) go copyLoop(s1, s2, nil)
go func() { go func() {
bytes := []byte("Hello!") bytes := []byte("Hello!")
c1.Write(bytes) c1.Write(bytes)

View file

@ -1,10 +1,9 @@
package main package snowflake
import ( import (
"bytes" "bytes"
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"flag"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -12,27 +11,44 @@ import (
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"os"
"strings" "strings"
"sync" "sync"
"time" "time"
"git.torproject.org/pluggable-transports/snowflake.git/common/messages" "git.torproject.org/pluggable-transports/snowflake.git/common/messages"
"git.torproject.org/pluggable-transports/snowflake.git/common/safelog"
"git.torproject.org/pluggable-transports/snowflake.git/common/util" "git.torproject.org/pluggable-transports/snowflake.git/common/util"
"git.torproject.org/pluggable-transports/snowflake.git/common/websocketconn" "git.torproject.org/pluggable-transports/snowflake.git/common/websocketconn"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/pion/webrtc/v3" "github.com/pion/webrtc/v3"
) )
const defaultBrokerURL = "https://snowflake-broker.torproject.net/" // DefaultBrokerURL is the bamsoftware.com broker, https://snowflake-broker.bamsoftware.com
const defaultProbeURL = "https://snowflake-broker.torproject.net:8443/probe" // Changing this will change the default broker. The recommended way of changing
const defaultRelayURL = "wss://snowflake.torproject.net/" // the broker that gets used is by passing an argument to Main.
const defaultSTUNURL = "stun:stun.stunprotocol.org:3478" const DefaultBrokerURL = "https://snowflake-broker.bamsoftware.com/"
// DefaultProbeURL is the torproject.org ProbeURL, https://snowflake-broker.torproject.net:8443/probe
// Changing this will change the default Probe URL. The recommended way of changing
// the probe that gets used is by passing an argument to Main.
const DefaultProbeURL = "https://snowflake-broker.torproject.net:8443/probe"
// DefaultRelayURL is the bamsoftware.com Websocket Relay, wss://snowflake.bamsoftware.com/
// Changing this will change the default Relay URL. The recommended way of changing
// the relay that gets used is by passing an argument to Main.
const DefaultRelayURL = "wss://snowflake.bamsoftware.com/"
// DefaultSTUNURL is a stunprotocol.org STUN URL. stun:stun.stunprotocol.org:3478
// Changing this will change the default STUN URL. The recommended way of changing
// the STUN Server that gets used is by passing an argument to Main.
const DefaultSTUNURL = "stun:stun.stunprotocol.org:3478"
const pollInterval = 5 * time.Second const pollInterval = 5 * time.Second
const ( const (
NATUnknown = "unknown" // NATUnknown represents a NAT type which is unknown.
NATRestricted = "restricted" NATUnknown = "unknown"
// NATRestricted represents a restricted NAT.
NATRestricted = "restricted"
// NATUnrestricted represents an unrestricted NAT.
NATUnrestricted = "unrestricted" NATUnrestricted = "unrestricted"
) )
@ -43,7 +59,6 @@ const dataChannelTimeout = 20 * time.Second
const readLimit = 100000 //Maximum number of bytes to be read from an HTTP request const readLimit = 100000 //Maximum number of bytes to be read from an HTTP request
var broker *SignalingServer var broker *SignalingServer
var relayURL string
var currentNATType = NATUnknown var currentNATType = NATUnknown
@ -57,6 +72,18 @@ var (
client http.Client client http.Client
) )
// SnowflakeProxy is a structure which is used to configure an embedded
// Snowflake in another Go application.
type SnowflakeProxy struct {
Capacity uint
StunURL string
RawBrokerURL string
KeepLocalAddresses bool
RelayURL string
LogOutput io.Writer
shutdown chan struct{}
}
// Checks whether an IP address is a remote address for the client // Checks whether an IP address is a remote address for the client
func isRemoteAddress(ip net.IP) bool { func isRemoteAddress(ip net.IP) bool {
return !(util.IsLocal(ip) || ip.IsUnspecified() || ip.IsLoopback()) return !(util.IsLocal(ip) || ip.IsUnspecified() || ip.IsLoopback())
@ -81,6 +108,7 @@ func limitedRead(r io.Reader, limit int64) ([]byte, error) {
return p, err return p, err
} }
// SignalingServer keeps track of the SignalingServer in use by the Snowflake
type SignalingServer struct { type SignalingServer struct {
url *url.URL url *url.URL
transport http.RoundTripper transport http.RoundTripper
@ -102,6 +130,7 @@ func newSignalingServer(rawURL string, keepLocalAddresses bool) (*SignalingServe
return s, nil return s, nil
} }
// Post sends a POST request to the SignalingServer
func (s *SignalingServer) Post(path string, payload io.Reader) ([]byte, error) { func (s *SignalingServer) Post(path string, payload io.Reader) ([]byte, error) {
req, err := http.NewRequest("POST", path, payload) req, err := http.NewRequest("POST", path, payload)
@ -121,7 +150,7 @@ func (s *SignalingServer) Post(path string, payload io.Reader) ([]byte, error) {
return limitedRead(resp.Body, readLimit) return limitedRead(resp.Body, readLimit)
} }
func (s *SignalingServer) pollOffer(sid string) *webrtc.SessionDescription { func (s *SignalingServer) pollOffer(sid string, shutdown chan struct{}) *webrtc.SessionDescription {
brokerPath := s.url.ResolveReference(&url.URL{Path: "proxy"}) brokerPath := s.url.ResolveReference(&url.URL{Path: "proxy"})
ticker := time.NewTicker(pollInterval) ticker := time.NewTicker(pollInterval)
@ -129,31 +158,36 @@ func (s *SignalingServer) pollOffer(sid string) *webrtc.SessionDescription {
// Run the loop once before hitting the ticker // Run the loop once before hitting the ticker
for ; true; <-ticker.C { for ; true; <-ticker.C {
numClients := int((tokens.count() / 8) * 8) // Round down to 8 select {
body, err := messages.EncodePollRequest(sid, "standalone", currentNATType, numClients) case <-shutdown:
if err != nil {
log.Printf("Error encoding poll message: %s", err.Error())
return nil return nil
} default:
resp, err := s.Post(brokerPath.String(), bytes.NewBuffer(body)) numClients := int((tokens.count() / 8) * 8) // Round down to 8
if err != nil { body, err := messages.EncodePollRequest(sid, "standalone", currentNATType, numClients)
log.Printf("error polling broker: %s", err.Error())
}
offer, _, err := messages.DecodePollResponse(resp)
if err != nil {
log.Printf("Error reading broker response: %s", err.Error())
log.Printf("body: %s", resp)
return nil
}
if offer != "" {
offer, err := util.DeserializeSessionDescription(offer)
if err != nil { if err != nil {
log.Printf("Error processing session description: %s", err.Error()) log.Printf("Error encoding poll message: %s", err.Error())
return nil return nil
} }
return offer resp, err := s.Post(brokerPath.String(), bytes.NewBuffer(body))
if err != nil {
log.Printf("error polling broker: %s", err.Error())
}
offer, _, err := messages.DecodePollResponse(resp)
if err != nil {
log.Printf("Error reading broker response: %s", err.Error())
log.Printf("body: %s", resp)
return nil
}
if offer != "" {
offer, err := util.DeserializeSessionDescription(offer)
if err != nil {
log.Printf("Error processing session description: %s", err.Error())
return nil
}
return offer
}
} }
} }
return nil return nil
@ -192,33 +226,41 @@ func (s *SignalingServer) sendAnswer(sid string, pc *webrtc.PeerConnection) erro
return nil return nil
} }
func CopyLoop(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser) { func copyLoop(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser, shutdown chan struct{}) {
var wg sync.WaitGroup var once sync.Once
defer c2.Close()
defer c1.Close()
done := make(chan struct{})
copyer := func(dst io.ReadWriteCloser, src io.ReadWriteCloser) { copyer := func(dst io.ReadWriteCloser, src io.ReadWriteCloser) {
defer wg.Done()
// Ignore io.ErrClosedPipe because it is likely caused by the // Ignore io.ErrClosedPipe because it is likely caused by the
// termination of copyer in the other direction. // termination of copyer in the other direction.
if _, err := io.Copy(dst, src); err != nil && err != io.ErrClosedPipe { if _, err := io.Copy(dst, src); err != nil && err != io.ErrClosedPipe {
log.Printf("io.Copy inside CopyLoop generated an error: %v", err) log.Printf("io.Copy inside CopyLoop generated an error: %v", err)
} }
dst.Close() once.Do(func() {
src.Close() close(done)
})
} }
wg.Add(2)
go copyer(c1, c2) go copyer(c1, c2)
go copyer(c2, c1) go copyer(c2, c1)
wg.Wait()
select {
case <-done:
case <-shutdown:
}
log.Println("copy loop ended")
} }
// We pass conn.RemoteAddr() as an additional parameter, rather than calling // We pass conn.RemoteAddr() as an additional parameter, rather than calling
// conn.RemoteAddr() inside this function, as a workaround for a hang that // conn.RemoteAddr() inside this function, as a workaround for a hang that
// otherwise occurs inside of conn.pc.RemoteDescription() (called by // otherwise occurs inside of conn.pc.RemoteDescription() (called by
// RemoteAddr). https://bugs.torproject.org/18628#comment:8 // RemoteAddr). https://bugs.torproject.org/18628#comment:8
func datachannelHandler(conn *webRTCConn, remoteAddr net.Addr) { func (sf *SnowflakeProxy) datachannelHandler(conn *webRTCConn, remoteAddr net.Addr) {
defer conn.Close() defer conn.Close()
defer tokens.ret() defer tokens.ret()
u, err := url.Parse(relayURL) u, err := url.Parse(sf.RelayURL)
if err != nil { if err != nil {
log.Fatalf("invalid relay url: %s", err) log.Fatalf("invalid relay url: %s", err)
} }
@ -241,7 +283,7 @@ func datachannelHandler(conn *webRTCConn, remoteAddr net.Addr) {
wsConn := websocketconn.New(ws) wsConn := websocketconn.New(ws)
log.Printf("connected to relay") log.Printf("connected to relay")
defer wsConn.Close() defer wsConn.Close()
CopyLoop(conn, wsConn) copyLoop(conn, wsConn, sf.shutdown)
log.Printf("datachannelHandler ends") log.Printf("datachannelHandler ends")
} }
@ -249,7 +291,7 @@ func datachannelHandler(conn *webRTCConn, remoteAddr net.Addr) {
// candidates is complete and the answer is available in LocalDescription. // candidates is complete and the answer is available in LocalDescription.
// Installs an OnDataChannel callback that creates a webRTCConn and passes it to // Installs an OnDataChannel callback that creates a webRTCConn and passes it to
// datachannelHandler. // datachannelHandler.
func makePeerConnectionFromOffer(sdp *webrtc.SessionDescription, func (sf *SnowflakeProxy) makePeerConnectionFromOffer(sdp *webrtc.SessionDescription,
config webrtc.Configuration, config webrtc.Configuration,
dataChan chan struct{}, dataChan chan struct{},
handler func(conn *webRTCConn, remoteAddr net.Addr)) (*webrtc.PeerConnection, error) { handler func(conn *webRTCConn, remoteAddr net.Addr)) (*webrtc.PeerConnection, error) {
@ -333,7 +375,7 @@ func makePeerConnectionFromOffer(sdp *webrtc.SessionDescription,
// Create a new PeerConnection. Blocks until the gathering of ICE // Create a new PeerConnection. Blocks until the gathering of ICE
// candidates is complete and the answer is available in LocalDescription. // candidates is complete and the answer is available in LocalDescription.
func makeNewPeerConnection(config webrtc.Configuration, func (sf *SnowflakeProxy) makeNewPeerConnection(config webrtc.Configuration,
dataChan chan struct{}) (*webrtc.PeerConnection, error) { dataChan chan struct{}) (*webrtc.PeerConnection, error) {
pc, err := webrtc.NewPeerConnection(config) pc, err := webrtc.NewPeerConnection(config)
@ -383,15 +425,15 @@ func makeNewPeerConnection(config webrtc.Configuration,
return pc, nil return pc, nil
} }
func runSession(sid string) { func (sf *SnowflakeProxy) runSession(sid string) {
offer := broker.pollOffer(sid) offer := broker.pollOffer(sid, sf.shutdown)
if offer == nil { if offer == nil {
log.Printf("bad offer from broker") log.Printf("bad offer from broker")
tokens.ret() tokens.ret()
return return
} }
dataChan := make(chan struct{}) dataChan := make(chan struct{})
pc, err := makePeerConnectionFromOffer(offer, config, dataChan, datachannelHandler) pc, err := sf.makePeerConnectionFromOffer(offer, config, dataChan, sf.datachannelHandler)
if err != nil { if err != nil {
log.Printf("error making WebRTC connection: %s", err) log.Printf("error making WebRTC connection: %s", err)
tokens.ret() tokens.ret()
@ -421,53 +463,28 @@ func runSession(sid string) {
} }
} }
func main() { // Start configures and starts a Snowflake, fully formed and special. In the
var capacity uint // case of an empty map, defaults are configured automatically and can be
var stunURL string // found in the GoDoc and in main.go
var logFilename string func (sf *SnowflakeProxy) Start() {
var rawBrokerURL string
var unsafeLogging bool
var keepLocalAddresses bool
flag.UintVar(&capacity, "capacity", 0, "maximum concurrent clients") sf.shutdown = make(chan struct{})
flag.StringVar(&rawBrokerURL, "broker", defaultBrokerURL, "broker URL")
flag.StringVar(&relayURL, "relay", defaultRelayURL, "websocket relay URL")
flag.StringVar(&stunURL, "stun", defaultSTUNURL, "stun URL")
flag.StringVar(&logFilename, "log", "", "log filename")
flag.BoolVar(&unsafeLogging, "unsafe-logging", false, "prevent logs from being scrubbed")
flag.BoolVar(&keepLocalAddresses, "keep-local-addresses", false, "keep local LAN address ICE candidates")
flag.Parse()
var logOutput io.Writer = os.Stderr
log.SetFlags(log.LstdFlags | log.LUTC) log.SetFlags(log.LstdFlags | log.LUTC)
if logFilename != "" {
f, err := os.OpenFile(logFilename, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
if err != nil {
log.Fatal(err)
}
defer f.Close()
logOutput = io.MultiWriter(os.Stderr, f)
}
if unsafeLogging {
log.SetOutput(logOutput)
} else {
// We want to send the log output through our scrubber first
log.SetOutput(&safelog.LogScrubber{Output: logOutput})
}
log.Println("starting") log.Println("starting")
var err error var err error
broker, err = newSignalingServer(rawBrokerURL, keepLocalAddresses) broker, err = newSignalingServer(sf.RawBrokerURL, sf.KeepLocalAddresses)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
_, err = url.Parse(stunURL) _, err = url.Parse(sf.StunURL)
if err != nil { if err != nil {
log.Fatalf("invalid stun url: %s", err) log.Fatalf("invalid stun url: %s", err)
} }
_, err = url.Parse(relayURL) _, err = url.Parse(sf.RelayURL)
if err != nil { if err != nil {
log.Fatalf("invalid relay url: %s", err) log.Fatalf("invalid relay url: %s", err)
} }
@ -475,27 +492,37 @@ func main() {
config = webrtc.Configuration{ config = webrtc.Configuration{
ICEServers: []webrtc.ICEServer{ ICEServers: []webrtc.ICEServer{
{ {
URLs: []string{stunURL}, URLs: []string{sf.StunURL},
}, },
}, },
} }
tokens = newTokens(capacity) tokens = newTokens(sf.Capacity)
// use probetest to determine NAT compatability // use probetest to determine NAT compatability
checkNATType(config, defaultProbeURL) sf.checkNATType(config, DefaultProbeURL)
log.Printf("NAT type: %s", currentNATType) log.Printf("NAT type: %s", currentNATType)
ticker := time.NewTicker(pollInterval) ticker := time.NewTicker(pollInterval)
defer ticker.Stop() defer ticker.Stop()
for ; true; <-ticker.C { for ; true; <-ticker.C {
tokens.get() select {
sessionID := genSessionID() case <-sf.shutdown:
runSession(sessionID) return
default:
tokens.get()
sessionID := genSessionID()
sf.runSession(sessionID)
}
} }
} }
func checkNATType(config webrtc.Configuration, probeURL string) { // Stop calls close on the sf.shutdown channel shutting down the Snowflake.
func (sf *SnowflakeProxy) Stop() {
close(sf.shutdown)
}
func (sf *SnowflakeProxy) checkNATType(config webrtc.Configuration, probeURL string) {
probe, err := newSignalingServer(probeURL, false) probe, err := newSignalingServer(probeURL, false)
if err != nil { if err != nil {
@ -504,7 +531,7 @@ func checkNATType(config webrtc.Configuration, probeURL string) {
// create offer // create offer
dataChan := make(chan struct{}) dataChan := make(chan struct{})
pc, err := makeNewPeerConnection(config, dataChan) pc, err := sf.makeNewPeerConnection(config, dataChan)
if err != nil { if err != nil {
log.Printf("error making WebRTC connection: %s", err) log.Printf("error making WebRTC connection: %s", err)
return return

View file

@ -1,4 +1,4 @@
package main package snowflake
import ( import (
"sync/atomic" "sync/atomic"

View file

@ -1,4 +1,4 @@
package main package snowflake
import ( import (
"testing" "testing"

View file

@ -1,21 +1,28 @@
package main package snowflake
import ( import (
"fmt" "fmt"
"time" "time"
) )
// BytesLogger is an interface which is used to allow logging the throughput
// of the Snowflake. A default BytesLogger(BytesNullLogger) does nothing.
type BytesLogger interface { type BytesLogger interface {
AddOutbound(int) AddOutbound(int)
AddInbound(int) AddInbound(int)
ThroughputSummary() string ThroughputSummary() string
} }
// Default BytesLogger does nothing. // BytesNullLogger Default BytesLogger does nothing.
type BytesNullLogger struct{} type BytesNullLogger struct{}
func (b BytesNullLogger) AddOutbound(amount int) {} // AddOutbound in BytesNullLogger does nothing
func (b BytesNullLogger) AddInbound(amount int) {} func (b BytesNullLogger) AddOutbound(amount int) {}
// AddInbound in BytesNullLogger does nothing
func (b BytesNullLogger) AddInbound(amount int) {}
// ThroughputSummary in BytesNullLogger does nothing
func (b BytesNullLogger) ThroughputSummary() string { return "" } func (b BytesNullLogger) ThroughputSummary() string { return "" }
// BytesSyncLogger uses channels to safely log from multiple sources with output // BytesSyncLogger uses channels to safely log from multiple sources with output
@ -50,14 +57,17 @@ func (b *BytesSyncLogger) log() {
} }
} }
// AddOutbound add a number of bytes to the outbound total reported by the logger
func (b *BytesSyncLogger) AddOutbound(amount int) { func (b *BytesSyncLogger) AddOutbound(amount int) {
b.outboundChan <- amount b.outboundChan <- amount
} }
// AddInbound add a number of bytes to the inbound total reported by the logger
func (b *BytesSyncLogger) AddInbound(amount int) { func (b *BytesSyncLogger) AddInbound(amount int) {
b.inboundChan <- amount b.inboundChan <- amount
} }
// ThroughputSummary view a formatted summary of the throughput totals
func (b *BytesSyncLogger) ThroughputSummary() string { func (b *BytesSyncLogger) ThroughputSummary() string {
var inUnit, outUnit string var inUnit, outUnit string
units := []string{"B", "KB", "MB", "GB"} units := []string{"B", "KB", "MB", "GB"}

View file

@ -1,4 +1,4 @@
package main package snowflake
import ( import (
"fmt" "fmt"

48
proxy/main.go Normal file
View file

@ -0,0 +1,48 @@
package main
import (
"flag"
"io"
"log"
"os"
"git.torproject.org/pluggable-transports/snowflake.git/common/safelog"
"git.torproject.org/pluggable-transports/snowflake.git/proxy/lib"
)
func main() {
capacity := flag.Int("capacity", 10, "maximum concurrent clients")
stunURL := flag.String("stun", snowflake.DefaultSTUNURL, "broker URL")
logFilename := flag.String("log", "", "log filename")
rawBrokerURL := flag.String("broker", snowflake.DefaultBrokerURL, "broker URL")
unsafeLogging := flag.Bool("unsafe-logging", false, "prevent logs from being scrubbed")
keepLocalAddresses := flag.Bool("keep-local-addresses", false, "keep local LAN address ICE candidates")
relayURL := flag.String("relay", snowflake.DefaultRelayURL, "websocket relay URL")
flag.Parse()
sf := snowflake.SnowflakeProxy{
Capacity: uint(*capacity),
StunURL: *stunURL,
RawBrokerURL: *rawBrokerURL,
KeepLocalAddresses: *keepLocalAddresses,
RelayURL: *relayURL,
LogOutput: os.Stderr,
}
if *logFilename != "" {
f, err := os.OpenFile(*logFilename, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0600)
if err != nil {
log.Fatal(err)
}
defer f.Close()
sf.LogOutput = io.MultiWriter(os.Stderr, f)
}
if *unsafeLogging {
log.SetOutput(sf.LogOutput)
} else {
log.SetOutput(&safelog.LogScrubber{Output: sf.LogOutput})
}
sf.Start()
}