mirror of
https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake.git
synced 2025-10-13 20:11:19 -04:00
Add unit tests for SQS rendezvous in client
Co-authored-by: Michael Pu <michael.pu@uwaterloo.ca>
This commit is contained in:
parent
f3b062ddb2
commit
32e864b71d
3 changed files with 137 additions and 38 deletions
|
@ -23,6 +23,8 @@ type sqsRendezvous struct {
|
||||||
sqsClientID string
|
sqsClientID string
|
||||||
sqsClient sqsclient.SQSClient
|
sqsClient sqsclient.SQSClient
|
||||||
sqsURL *url.URL
|
sqsURL *url.URL
|
||||||
|
timeout time.Duration
|
||||||
|
numRetries int
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSQSRendezvous(sqsQueue string, sqsAccessKeyId string, sqsSecretKey string, transport http.RoundTripper) (*sqsRendezvous, error) {
|
func newSQSRendezvous(sqsQueue string, sqsAccessKeyId string, sqsSecretKey string, transport http.RoundTripper) (*sqsRendezvous, error) {
|
||||||
|
@ -66,6 +68,8 @@ func newSQSRendezvous(sqsQueue string, sqsAccessKeyId string, sqsSecretKey strin
|
||||||
sqsClientID: clientID,
|
sqsClientID: clientID,
|
||||||
sqsClient: client,
|
sqsClient: client,
|
||||||
sqsURL: sqsURL,
|
sqsURL: sqsURL,
|
||||||
|
timeout: time.Second,
|
||||||
|
numRetries: 5,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -86,11 +90,10 @@ func (r *sqsRendezvous) Exchange(encPollReq []byte) ([]byte, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
time.Sleep(time.Second) // wait for client queue to be created by the broker
|
time.Sleep(r.timeout) // wait for client queue to be created by the broker
|
||||||
|
|
||||||
numRetries := 5
|
|
||||||
var responseQueueURL *string
|
var responseQueueURL *string
|
||||||
for i := 0; i < numRetries; i++ {
|
for i := 0; i < r.numRetries; i++ {
|
||||||
// The SQS queue corresponding to the client where the SDP Answer will be placed
|
// The SQS queue corresponding to the client where the SDP Answer will be placed
|
||||||
// may not be created yet. We will retry up to 5 times before we error out.
|
// may not be created yet. We will retry up to 5 times before we error out.
|
||||||
var res *sqs.GetQueueUrlOutput
|
var res *sqs.GetQueueUrlOutput
|
||||||
|
@ -99,8 +102,8 @@ func (r *sqsRendezvous) Exchange(encPollReq []byte) ([]byte, error) {
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
log.Printf("Attempt %d of %d to retrieve URL of response SQS queue failed.\n", i+1, numRetries)
|
log.Printf("Attempt %d of %d to retrieve URL of response SQS queue failed.\n", i+1, r.numRetries)
|
||||||
time.Sleep(time.Second)
|
time.Sleep(r.timeout)
|
||||||
} else {
|
} else {
|
||||||
responseQueueURL = res.QueueUrl
|
responseQueueURL = res.QueueUrl
|
||||||
break
|
break
|
||||||
|
@ -111,7 +114,7 @@ func (r *sqsRendezvous) Exchange(encPollReq []byte) ([]byte, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var answer string
|
var answer string
|
||||||
for i := 0; i < numRetries; i++ {
|
for i := 0; i < r.numRetries; i++ {
|
||||||
// Waiting for SDP Answer from proxy to be placed in SQS queue.
|
// Waiting for SDP Answer from proxy to be placed in SQS queue.
|
||||||
// We will retry upt to 5 times before we error out.
|
// We will retry upt to 5 times before we error out.
|
||||||
res, err := r.sqsClient.ReceiveMessage(context.TODO(), &sqs.ReceiveMessageInput{
|
res, err := r.sqsClient.ReceiveMessage(context.TODO(), &sqs.ReceiveMessageInput{
|
||||||
|
@ -123,9 +126,9 @@ func (r *sqsRendezvous) Exchange(encPollReq []byte) ([]byte, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if len(res.Messages) == 0 {
|
if len(res.Messages) == 0 {
|
||||||
log.Printf("Attempt %d of %d to receive message from response SQS queue failed. No message found in queue.\n", i+1, numRetries)
|
log.Printf("Attempt %d of %d to receive message from response SQS queue failed. No message found in queue.\n", i+1, r.numRetries)
|
||||||
delay := float64(i)/2.0 + 1
|
delay := float64(i)/2.0 + 1
|
||||||
time.Sleep(time.Duration(delay*1000) * time.Millisecond)
|
time.Sleep(time.Duration(delay*1000) * (r.timeout / 1000))
|
||||||
} else {
|
} else {
|
||||||
answer = *res.Messages[0].Body
|
answer = *res.Messages[0].Body
|
||||||
break
|
break
|
||||||
|
|
|
@ -7,12 +7,18 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go-v2/aws"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/service/sqs"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/service/sqs/types"
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
. "github.com/smartystreets/goconvey/convey"
|
. "github.com/smartystreets/goconvey/convey"
|
||||||
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/amp"
|
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/amp"
|
||||||
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages"
|
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages"
|
||||||
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/nat"
|
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/nat"
|
||||||
|
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/sqsclient"
|
||||||
)
|
)
|
||||||
|
|
||||||
// mockTransport's RoundTrip method returns a response with a fake status and
|
// mockTransport's RoundTrip method returns a response with a fake status and
|
||||||
|
@ -271,3 +277,123 @@ func TestAMPCacheRendezvous(t *testing.T) {
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSQSRendezvous(t *testing.T) {
|
||||||
|
Convey("SQS Rendezvous", t, func() {
|
||||||
|
|
||||||
|
Convey("Construct SQS queue rendezvous", func() {
|
||||||
|
transport := &mockTransport{http.StatusOK, []byte{}}
|
||||||
|
rend, err := newSQSRendezvous("https://sqs.us-east-1.amazonaws.com", "some-access-key-id", "some-secret-key", transport)
|
||||||
|
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
So(rend.sqsClientID, ShouldNotBeNil)
|
||||||
|
So(rend.sqsClient, ShouldNotBeNil)
|
||||||
|
So(rend.sqsURL, ShouldNotBeNil)
|
||||||
|
So(rend.sqsURL.String(), ShouldResemble, "https://sqs.us-east-1.amazonaws.com")
|
||||||
|
})
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
mockSqsClient := sqsclient.NewMockSQSClient(ctrl)
|
||||||
|
responseQueueURL := "https://sqs.us-east-1.amazonaws.com/testing"
|
||||||
|
sqsClientID := "test123"
|
||||||
|
sqsUrl, _ := url.Parse("https://sqs.us-east-1.amazonaws.com/broker")
|
||||||
|
fakeEncPollResp := makeEncPollResp(
|
||||||
|
`{"answer": "{\"type\":\"answer\",\"sdp\":\"fake\"}" }`,
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
sqsRendezvous := sqsRendezvous{
|
||||||
|
transport: &mockTransport{http.StatusOK, []byte{}},
|
||||||
|
sqsClientID: sqsClientID,
|
||||||
|
sqsClient: mockSqsClient,
|
||||||
|
sqsURL: sqsUrl,
|
||||||
|
timeout: 0,
|
||||||
|
numRetries: 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
Convey("sqsRendezvous.Exchange responds with answer", func() {
|
||||||
|
mockSqsClient.EXPECT().SendMessage(gomock.Any(), &sqs.SendMessageInput{
|
||||||
|
MessageAttributes: map[string]types.MessageAttributeValue{
|
||||||
|
"ClientID": {
|
||||||
|
DataType: aws.String("String"),
|
||||||
|
StringValue: aws.String(sqsClientID),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MessageBody: aws.String(string(fakeEncPollResp)),
|
||||||
|
QueueUrl: aws.String(sqsUrl.String()),
|
||||||
|
})
|
||||||
|
mockSqsClient.EXPECT().GetQueueUrl(gomock.Any(), &sqs.GetQueueUrlInput{
|
||||||
|
QueueName: aws.String("snowflake-client-" + sqsClientID),
|
||||||
|
}).Return(&sqs.GetQueueUrlOutput{
|
||||||
|
QueueUrl: aws.String(responseQueueURL),
|
||||||
|
}, nil)
|
||||||
|
mockSqsClient.EXPECT().ReceiveMessage(gomock.Any(), gomock.Eq(&sqs.ReceiveMessageInput{
|
||||||
|
QueueUrl: &responseQueueURL,
|
||||||
|
MaxNumberOfMessages: 1,
|
||||||
|
WaitTimeSeconds: 20,
|
||||||
|
})).Return(&sqs.ReceiveMessageOutput{
|
||||||
|
Messages: []types.Message{{Body: aws.String("answer")}},
|
||||||
|
}, nil)
|
||||||
|
|
||||||
|
answer, err := sqsRendezvous.Exchange(fakeEncPollResp)
|
||||||
|
|
||||||
|
So(answer, ShouldEqual, []byte("answer"))
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
})
|
||||||
|
|
||||||
|
Convey("sqsRendezvous.Exchange cannot get queue url", func() {
|
||||||
|
mockSqsClient.EXPECT().SendMessage(gomock.Any(), &sqs.SendMessageInput{
|
||||||
|
MessageAttributes: map[string]types.MessageAttributeValue{
|
||||||
|
"ClientID": {
|
||||||
|
DataType: aws.String("String"),
|
||||||
|
StringValue: aws.String(sqsClientID),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MessageBody: aws.String(string(fakeEncPollResp)),
|
||||||
|
QueueUrl: aws.String(sqsUrl.String()),
|
||||||
|
})
|
||||||
|
for i := 0; i < sqsRendezvous.numRetries; i++ {
|
||||||
|
mockSqsClient.EXPECT().GetQueueUrl(gomock.Any(), &sqs.GetQueueUrlInput{
|
||||||
|
QueueName: aws.String("snowflake-client-" + sqsClientID),
|
||||||
|
}).Return(nil, errors.New("test error"))
|
||||||
|
}
|
||||||
|
|
||||||
|
answer, err := sqsRendezvous.Exchange(fakeEncPollResp)
|
||||||
|
|
||||||
|
So(answer, ShouldBeNil)
|
||||||
|
So(err, ShouldNotBeNil)
|
||||||
|
So(err, ShouldEqual, errors.New("test error"))
|
||||||
|
})
|
||||||
|
|
||||||
|
Convey("sqsRendezvous.Exchange does not receive answer", func() {
|
||||||
|
mockSqsClient.EXPECT().SendMessage(gomock.Any(), &sqs.SendMessageInput{
|
||||||
|
MessageAttributes: map[string]types.MessageAttributeValue{
|
||||||
|
"ClientID": {
|
||||||
|
DataType: aws.String("String"),
|
||||||
|
StringValue: aws.String(sqsClientID),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
MessageBody: aws.String(string(fakeEncPollResp)),
|
||||||
|
QueueUrl: aws.String(sqsUrl.String()),
|
||||||
|
})
|
||||||
|
mockSqsClient.EXPECT().GetQueueUrl(gomock.Any(), &sqs.GetQueueUrlInput{
|
||||||
|
QueueName: aws.String("snowflake-client-" + sqsClientID),
|
||||||
|
}).Return(&sqs.GetQueueUrlOutput{
|
||||||
|
QueueUrl: aws.String(responseQueueURL),
|
||||||
|
}, nil)
|
||||||
|
for i := 0; i < sqsRendezvous.numRetries; i++ {
|
||||||
|
mockSqsClient.EXPECT().ReceiveMessage(gomock.Any(), gomock.Eq(&sqs.ReceiveMessageInput{
|
||||||
|
QueueUrl: &responseQueueURL,
|
||||||
|
MaxNumberOfMessages: 1,
|
||||||
|
WaitTimeSeconds: 20,
|
||||||
|
})).Return(&sqs.ReceiveMessageOutput{
|
||||||
|
Messages: []types.Message{},
|
||||||
|
}, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
answer, err := sqsRendezvous.Exchange(fakeEncPollResp)
|
||||||
|
|
||||||
|
So(answer, ShouldEqual, []byte{})
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -1,30 +0,0 @@
|
||||||
package snowflake_client
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go-v2/aws"
|
|
||||||
"github.com/aws/aws-sdk-go-v2/service/sqs"
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
. "github.com/smartystreets/goconvey/convey"
|
|
||||||
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/sqsclient"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestExample(t *testing.T) {
|
|
||||||
Convey("Test Example 1", t, func() {
|
|
||||||
ctrl := gomock.NewController(t)
|
|
||||||
mockSqsClient := sqsclient.NewMockSQSClient(ctrl)
|
|
||||||
mockSqsClient.EXPECT().GetQueueUrl(gomock.Any(), gomock.Any()).Return(&sqs.GetQueueUrlOutput{
|
|
||||||
QueueUrl: aws.String("https://wwww.google.com"),
|
|
||||||
}, nil)
|
|
||||||
|
|
||||||
output, err := mockSqsClient.GetQueueUrl(context.TODO(), &sqs.GetQueueUrlInput{
|
|
||||||
QueueName: aws.String("testing"),
|
|
||||||
})
|
|
||||||
ShouldBeNil(err)
|
|
||||||
ShouldEqual(output, sqs.GetQueueUrlOutput{
|
|
||||||
QueueUrl: aws.String("https://wwww.google.com"),
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
Loading…
Add table
Add a link
Reference in a new issue