Add unit tests for SQS rendezvous in client

Co-authored-by: Michael Pu <michael.pu@uwaterloo.ca>
This commit is contained in:
Anthony Chang 2024-01-12 23:23:33 -05:00 committed by Cecylia Bocovich
parent f3b062ddb2
commit 32e864b71d
No known key found for this signature in database
GPG key ID: 009DE379FD9B7B90
3 changed files with 137 additions and 38 deletions

View file

@ -23,6 +23,8 @@ type sqsRendezvous struct {
sqsClientID string
sqsClient sqsclient.SQSClient
sqsURL *url.URL
timeout time.Duration
numRetries int
}
func newSQSRendezvous(sqsQueue string, sqsAccessKeyId string, sqsSecretKey string, transport http.RoundTripper) (*sqsRendezvous, error) {
@ -66,6 +68,8 @@ func newSQSRendezvous(sqsQueue string, sqsAccessKeyId string, sqsSecretKey strin
sqsClientID: clientID,
sqsClient: client,
sqsURL: sqsURL,
timeout: time.Second,
numRetries: 5,
}, nil
}
@ -86,11 +90,10 @@ func (r *sqsRendezvous) Exchange(encPollReq []byte) ([]byte, error) {
return nil, err
}
time.Sleep(time.Second) // wait for client queue to be created by the broker
time.Sleep(r.timeout) // wait for client queue to be created by the broker
numRetries := 5
var responseQueueURL *string
for i := 0; i < numRetries; i++ {
for i := 0; i < r.numRetries; i++ {
// The SQS queue corresponding to the client where the SDP Answer will be placed
// may not be created yet. We will retry up to 5 times before we error out.
var res *sqs.GetQueueUrlOutput
@ -99,8 +102,8 @@ func (r *sqsRendezvous) Exchange(encPollReq []byte) ([]byte, error) {
})
if err != nil {
log.Println(err)
log.Printf("Attempt %d of %d to retrieve URL of response SQS queue failed.\n", i+1, numRetries)
time.Sleep(time.Second)
log.Printf("Attempt %d of %d to retrieve URL of response SQS queue failed.\n", i+1, r.numRetries)
time.Sleep(r.timeout)
} else {
responseQueueURL = res.QueueUrl
break
@ -111,7 +114,7 @@ func (r *sqsRendezvous) Exchange(encPollReq []byte) ([]byte, error) {
}
var answer string
for i := 0; i < numRetries; i++ {
for i := 0; i < r.numRetries; i++ {
// Waiting for SDP Answer from proxy to be placed in SQS queue.
// We will retry upt to 5 times before we error out.
res, err := r.sqsClient.ReceiveMessage(context.TODO(), &sqs.ReceiveMessageInput{
@ -123,9 +126,9 @@ func (r *sqsRendezvous) Exchange(encPollReq []byte) ([]byte, error) {
return nil, err
}
if len(res.Messages) == 0 {
log.Printf("Attempt %d of %d to receive message from response SQS queue failed. No message found in queue.\n", i+1, numRetries)
log.Printf("Attempt %d of %d to receive message from response SQS queue failed. No message found in queue.\n", i+1, r.numRetries)
delay := float64(i)/2.0 + 1
time.Sleep(time.Duration(delay*1000) * time.Millisecond)
time.Sleep(time.Duration(delay*1000) * (r.timeout / 1000))
} else {
answer = *res.Messages[0].Body
break

View file

@ -7,12 +7,18 @@ import (
"io"
"io/ioutil"
"net/http"
"net/url"
"testing"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/sqs"
"github.com/aws/aws-sdk-go-v2/service/sqs/types"
"github.com/golang/mock/gomock"
. "github.com/smartystreets/goconvey/convey"
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/amp"
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages"
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/nat"
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/sqsclient"
)
// mockTransport's RoundTrip method returns a response with a fake status and
@ -271,3 +277,123 @@ func TestAMPCacheRendezvous(t *testing.T) {
})
})
}
func TestSQSRendezvous(t *testing.T) {
Convey("SQS Rendezvous", t, func() {
Convey("Construct SQS queue rendezvous", func() {
transport := &mockTransport{http.StatusOK, []byte{}}
rend, err := newSQSRendezvous("https://sqs.us-east-1.amazonaws.com", "some-access-key-id", "some-secret-key", transport)
So(err, ShouldBeNil)
So(rend.sqsClientID, ShouldNotBeNil)
So(rend.sqsClient, ShouldNotBeNil)
So(rend.sqsURL, ShouldNotBeNil)
So(rend.sqsURL.String(), ShouldResemble, "https://sqs.us-east-1.amazonaws.com")
})
ctrl := gomock.NewController(t)
mockSqsClient := sqsclient.NewMockSQSClient(ctrl)
responseQueueURL := "https://sqs.us-east-1.amazonaws.com/testing"
sqsClientID := "test123"
sqsUrl, _ := url.Parse("https://sqs.us-east-1.amazonaws.com/broker")
fakeEncPollResp := makeEncPollResp(
`{"answer": "{\"type\":\"answer\",\"sdp\":\"fake\"}" }`,
"",
)
sqsRendezvous := sqsRendezvous{
transport: &mockTransport{http.StatusOK, []byte{}},
sqsClientID: sqsClientID,
sqsClient: mockSqsClient,
sqsURL: sqsUrl,
timeout: 0,
numRetries: 5,
}
Convey("sqsRendezvous.Exchange responds with answer", func() {
mockSqsClient.EXPECT().SendMessage(gomock.Any(), &sqs.SendMessageInput{
MessageAttributes: map[string]types.MessageAttributeValue{
"ClientID": {
DataType: aws.String("String"),
StringValue: aws.String(sqsClientID),
},
},
MessageBody: aws.String(string(fakeEncPollResp)),
QueueUrl: aws.String(sqsUrl.String()),
})
mockSqsClient.EXPECT().GetQueueUrl(gomock.Any(), &sqs.GetQueueUrlInput{
QueueName: aws.String("snowflake-client-" + sqsClientID),
}).Return(&sqs.GetQueueUrlOutput{
QueueUrl: aws.String(responseQueueURL),
}, nil)
mockSqsClient.EXPECT().ReceiveMessage(gomock.Any(), gomock.Eq(&sqs.ReceiveMessageInput{
QueueUrl: &responseQueueURL,
MaxNumberOfMessages: 1,
WaitTimeSeconds: 20,
})).Return(&sqs.ReceiveMessageOutput{
Messages: []types.Message{{Body: aws.String("answer")}},
}, nil)
answer, err := sqsRendezvous.Exchange(fakeEncPollResp)
So(answer, ShouldEqual, []byte("answer"))
So(err, ShouldBeNil)
})
Convey("sqsRendezvous.Exchange cannot get queue url", func() {
mockSqsClient.EXPECT().SendMessage(gomock.Any(), &sqs.SendMessageInput{
MessageAttributes: map[string]types.MessageAttributeValue{
"ClientID": {
DataType: aws.String("String"),
StringValue: aws.String(sqsClientID),
},
},
MessageBody: aws.String(string(fakeEncPollResp)),
QueueUrl: aws.String(sqsUrl.String()),
})
for i := 0; i < sqsRendezvous.numRetries; i++ {
mockSqsClient.EXPECT().GetQueueUrl(gomock.Any(), &sqs.GetQueueUrlInput{
QueueName: aws.String("snowflake-client-" + sqsClientID),
}).Return(nil, errors.New("test error"))
}
answer, err := sqsRendezvous.Exchange(fakeEncPollResp)
So(answer, ShouldBeNil)
So(err, ShouldNotBeNil)
So(err, ShouldEqual, errors.New("test error"))
})
Convey("sqsRendezvous.Exchange does not receive answer", func() {
mockSqsClient.EXPECT().SendMessage(gomock.Any(), &sqs.SendMessageInput{
MessageAttributes: map[string]types.MessageAttributeValue{
"ClientID": {
DataType: aws.String("String"),
StringValue: aws.String(sqsClientID),
},
},
MessageBody: aws.String(string(fakeEncPollResp)),
QueueUrl: aws.String(sqsUrl.String()),
})
mockSqsClient.EXPECT().GetQueueUrl(gomock.Any(), &sqs.GetQueueUrlInput{
QueueName: aws.String("snowflake-client-" + sqsClientID),
}).Return(&sqs.GetQueueUrlOutput{
QueueUrl: aws.String(responseQueueURL),
}, nil)
for i := 0; i < sqsRendezvous.numRetries; i++ {
mockSqsClient.EXPECT().ReceiveMessage(gomock.Any(), gomock.Eq(&sqs.ReceiveMessageInput{
QueueUrl: &responseQueueURL,
MaxNumberOfMessages: 1,
WaitTimeSeconds: 20,
})).Return(&sqs.ReceiveMessageOutput{
Messages: []types.Message{},
}, nil)
}
answer, err := sqsRendezvous.Exchange(fakeEncPollResp)
So(answer, ShouldEqual, []byte{})
So(err, ShouldBeNil)
})
})
}

View file

@ -1,30 +0,0 @@
package snowflake_client
import (
"context"
"testing"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/sqs"
"github.com/golang/mock/gomock"
. "github.com/smartystreets/goconvey/convey"
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/sqsclient"
)
func TestExample(t *testing.T) {
Convey("Test Example 1", t, func() {
ctrl := gomock.NewController(t)
mockSqsClient := sqsclient.NewMockSQSClient(ctrl)
mockSqsClient.EXPECT().GetQueueUrl(gomock.Any(), gomock.Any()).Return(&sqs.GetQueueUrlOutput{
QueueUrl: aws.String("https://wwww.google.com"),
}, nil)
output, err := mockSqsClient.GetQueueUrl(context.TODO(), &sqs.GetQueueUrlInput{
QueueName: aws.String("testing"),
})
ShouldBeNil(err)
ShouldEqual(output, sqs.GetQueueUrlOutput{
QueueUrl: aws.String("https://wwww.google.com"),
})
})
}