diff --git a/client/lib/rendezvous_test.go b/client/lib/rendezvous_test.go index d2460d3..89b437d 100644 --- a/client/lib/rendezvous_test.go +++ b/client/lib/rendezvous_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net/http" + "net/http/httptest" "net/url" "testing" @@ -13,11 +14,13 @@ import ( "github.com/aws/aws-sdk-go-v2/service/sqs" "github.com/aws/aws-sdk-go-v2/service/sqs/types" "github.com/golang/mock/gomock" + "github.com/pion/webrtc/v3" . "github.com/smartystreets/goconvey/convey" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/amp" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/nat" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/sqsclient" + "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/util" ) // mockTransport's RoundTrip method returns a response with a fake status and @@ -386,3 +389,51 @@ func TestSQSRendezvous(t *testing.T) { }) }) } + +func TestBrokerChannel(t *testing.T) { + Convey("Requests a proxy and handles response", t, func() { + answerSdp := &webrtc.SessionDescription{ + Type: webrtc.SDPTypeAnswer, + SDP: "test", + } + answerSdpStr, _ := util.SerializeSessionDescription(answerSdp) + serverResponse, _ := (&messages.ClientPollResponse{ + Answer: answerSdpStr, + }).EncodePollResponse() + + offerSdp := &webrtc.SessionDescription{ + Type: webrtc.SDPTypeOffer, + SDP: "test", + } + + requestBodyChan := make(chan []byte) + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + go func() { + requestBodyChan <- body + }() + w.Write(serverResponse) + })) + defer mockServer.Close() + + brokerChannel, err := newBrokerChannelFromConfig(ClientConfig{ + BrokerURL: mockServer.URL, + BridgeFingerprint: "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA", + }) + So(err, ShouldBeNil) + brokerChannel.SetNATType(nat.NATRestricted) + + answerSdpReturned, err := brokerChannel.Negotiate(offerSdp) + So(err, ShouldBeNil) + So(answerSdpReturned, ShouldEqual, answerSdp) + + body := <-requestBodyChan + pollReq, err := messages.DecodeClientPollRequest(body) + So(err, ShouldBeNil) + So(pollReq.Fingerprint, ShouldEqual, "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") + So(pollReq.NAT, ShouldEqual, nat.NATRestricted) + requestSdp, err := util.DeserializeSessionDescription(pollReq.Offer) + So(err, ShouldBeNil) + So(requestSdp, ShouldEqual, offerSdp) + }) +}