mirror of
https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake.git
synced 2025-10-14 05:11:19 -04:00
Add unit tests for SQS rendezvous in broker
Co-authored-by: Michael Pu <michael.pu@uwaterloo.ca>
This commit is contained in:
parent
32e864b71d
commit
9b90b77d69
4 changed files with 438 additions and 93 deletions
|
@ -23,6 +23,8 @@ import (
|
||||||
|
|
||||||
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/bridgefingerprint"
|
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/bridgefingerprint"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go-v2/config"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/service/sqs"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/namematcher"
|
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/namematcher"
|
||||||
|
@ -283,8 +285,14 @@ func main() {
|
||||||
|
|
||||||
// Run SQS Handler to continuously poll and process messages from SQS
|
// Run SQS Handler to continuously poll and process messages from SQS
|
||||||
if brokerSQSQueueName != "" && brokerSQSQueueRegion != "" {
|
if brokerSQSQueueName != "" && brokerSQSQueueRegion != "" {
|
||||||
|
log.Printf("Loading SQSHandler using SQS Queue %s in region %s\n", brokerSQSQueueName, brokerSQSQueueRegion)
|
||||||
sqsHandlerContext := context.Background()
|
sqsHandlerContext := context.Background()
|
||||||
sqsHandler, err := newSQSHandler(sqsHandlerContext, brokerSQSQueueName, brokerSQSQueueRegion, i)
|
cfg, err := config.LoadDefaultConfig(sqsHandlerContext, config.WithRegion(brokerSQSQueueRegion))
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
client := sqs.NewFromConfig(cfg)
|
||||||
|
sqsHandler, err := newSQSHandler(sqsHandlerContext, client, brokerSQSQueueName, brokerSQSQueueRegion, i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,7 +8,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/aws/aws-sdk-go-v2/aws"
|
"github.com/aws/aws-sdk-go-v2/aws"
|
||||||
"github.com/aws/aws-sdk-go-v2/config"
|
|
||||||
"github.com/aws/aws-sdk-go-v2/service/sqs"
|
"github.com/aws/aws-sdk-go-v2/service/sqs"
|
||||||
"github.com/aws/aws-sdk-go-v2/service/sqs/types"
|
"github.com/aws/aws-sdk-go-v2/service/sqs/types"
|
||||||
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages"
|
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages"
|
||||||
|
@ -16,7 +15,6 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
cleanupInterval = time.Second * 30
|
|
||||||
cleanupThreshold = -2 * time.Minute
|
cleanupThreshold = -2 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -24,11 +22,17 @@ type sqsHandler struct {
|
||||||
SQSClient sqsclient.SQSClient
|
SQSClient sqsclient.SQSClient
|
||||||
SQSQueueURL *string
|
SQSQueueURL *string
|
||||||
IPC *IPC
|
IPC *IPC
|
||||||
|
cleanupInterval time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *sqsHandler) pollMessages(context context.Context, chn chan<- *types.Message) {
|
func (r *sqsHandler) pollMessages(ctx context.Context, chn chan<- *types.Message) {
|
||||||
for {
|
for {
|
||||||
res, err := r.SQSClient.ReceiveMessage(context, &sqs.ReceiveMessageInput{
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
// if context is cancelled
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
res, err := r.SQSClient.ReceiveMessage(ctx, &sqs.ReceiveMessageInput{
|
||||||
QueueUrl: r.SQSQueueURL,
|
QueueUrl: r.SQSQueueURL,
|
||||||
MaxNumberOfMessages: 10,
|
MaxNumberOfMessages: 10,
|
||||||
WaitTimeSeconds: 15,
|
WaitTimeSeconds: 15,
|
||||||
|
@ -39,6 +43,7 @@ func (r *sqsHandler) pollMessages(context context.Context, chn chan<- *types.Mes
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("SQSHandler: encountered error while polling for messages: %v\n", err)
|
log.Printf("SQSHandler: encountered error while polling for messages: %v\n", err)
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, message := range res.Messages {
|
for _, message := range res.Messages {
|
||||||
|
@ -46,14 +51,20 @@ func (r *sqsHandler) pollMessages(context context.Context, chn chan<- *types.Mes
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *sqsHandler) cleanupClientQueues(context context.Context) {
|
func (r *sqsHandler) cleanupClientQueues(ctx context.Context) {
|
||||||
for range time.Tick(cleanupInterval) {
|
for range time.NewTicker(r.cleanupInterval).C {
|
||||||
// Runs at fixed intervals to clean up any client queues that were last changed more than 2 minutes ago
|
// Runs at fixed intervals to clean up any client queues that were last changed more than 2 minutes ago
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
// if context is cancelled
|
||||||
|
return
|
||||||
|
default:
|
||||||
queueURLsList := []string{}
|
queueURLsList := []string{}
|
||||||
var nextToken *string
|
var nextToken *string
|
||||||
for {
|
for {
|
||||||
res, err := r.SQSClient.ListQueues(context, &sqs.ListQueuesInput{
|
res, err := r.SQSClient.ListQueues(ctx, &sqs.ListQueuesInput{
|
||||||
QueueNamePrefix: aws.String("snowflake-client-"),
|
QueueNamePrefix: aws.String("snowflake-client-"),
|
||||||
MaxResults: aws.Int32(1000),
|
MaxResults: aws.Int32(1000),
|
||||||
NextToken: nextToken,
|
NextToken: nextToken,
|
||||||
|
@ -75,7 +86,7 @@ func (r *sqsHandler) cleanupClientQueues(context context.Context) {
|
||||||
if !strings.Contains(queueURL, "snowflake-client-") {
|
if !strings.Contains(queueURL, "snowflake-client-") {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
res, err := r.SQSClient.GetQueueAttributes(context, &sqs.GetQueueAttributesInput{
|
res, err := r.SQSClient.GetQueueAttributes(ctx, &sqs.GetQueueAttributesInput{
|
||||||
QueueUrl: aws.String(queueURL),
|
QueueUrl: aws.String(queueURL),
|
||||||
AttributeNames: []types.QueueAttributeName{types.QueueAttributeNameLastModifiedTimestamp},
|
AttributeNames: []types.QueueAttributeName{types.QueueAttributeNameLastModifiedTimestamp},
|
||||||
})
|
})
|
||||||
|
@ -93,7 +104,7 @@ func (r *sqsHandler) cleanupClientQueues(context context.Context) {
|
||||||
}
|
}
|
||||||
lastModified := time.Unix(lastModifiedInt64, 0)
|
lastModified := time.Unix(lastModifiedInt64, 0)
|
||||||
if lastModified.Before(cleanupCutoff) {
|
if lastModified.Before(cleanupCutoff) {
|
||||||
_, err := r.SQSClient.DeleteQueue(context, &sqs.DeleteQueueInput{
|
_, err := r.SQSClient.DeleteQueue(ctx, &sqs.DeleteQueueInput{
|
||||||
QueueUrl: aws.String(queueURL),
|
QueueUrl: aws.String(queueURL),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -108,6 +119,7 @@ func (r *sqsHandler) cleanupClientQueues(context context.Context) {
|
||||||
log.Printf("SQSHandler: finished running iteration of client queue cleanup. found and deleted %d client queues.\n", numDeleted)
|
log.Printf("SQSHandler: finished running iteration of client queue cleanup. found and deleted %d client queues.\n", numDeleted)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *sqsHandler) handleMessage(context context.Context, message *types.Message) {
|
func (r *sqsHandler) handleMessage(context context.Context, message *types.Message) {
|
||||||
var encPollReq []byte
|
var encPollReq []byte
|
||||||
|
@ -123,10 +135,11 @@ func (r *sqsHandler) handleMessage(context context.Context, message *types.Messa
|
||||||
res, err := r.SQSClient.CreateQueue(context, &sqs.CreateQueueInput{
|
res, err := r.SQSClient.CreateQueue(context, &sqs.CreateQueueInput{
|
||||||
QueueName: aws.String("snowflake-client-" + *clientID),
|
QueueName: aws.String("snowflake-client-" + *clientID),
|
||||||
})
|
})
|
||||||
answerSQSURL := res.QueueUrl
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("SQSHandler: error encountered when creating answer queue for client %s: %v\n", *clientID, err)
|
log.Printf("SQSHandler: error encountered when creating answer queue for client %s: %v\n", *clientID, err)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
answerSQSURL := res.QueueUrl
|
||||||
|
|
||||||
encPollReq = []byte(*message.Body)
|
encPollReq = []byte(*message.Body)
|
||||||
arg := messages.Arg{
|
arg := messages.Arg{
|
||||||
|
@ -153,15 +166,7 @@ func (r *sqsHandler) deleteMessage(context context.Context, message *types.Messa
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSQSHandler(context context.Context, sqsQueueName string, region string, i *IPC) (*sqsHandler, error) {
|
func newSQSHandler(context context.Context, client sqsclient.SQSClient, sqsQueueName string, region string, i *IPC) (*sqsHandler, error) {
|
||||||
log.Printf("Loading SQSHandler using SQS Queue %s in region %s\n", sqsQueueName, region)
|
|
||||||
cfg, err := config.LoadDefaultConfig(context, config.WithRegion(region))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
client := sqs.NewFromConfig(cfg)
|
|
||||||
|
|
||||||
// Creates the queue if a queue with the same name doesn't exist. If a queue with the same name and attributes
|
// Creates the queue if a queue with the same name doesn't exist. If a queue with the same name and attributes
|
||||||
// already exists, then nothing will happen. If a queue with the same name, but different attributes exists, then
|
// already exists, then nothing will happen. If a queue with the same name, but different attributes exists, then
|
||||||
// an error will be returned
|
// an error will be returned
|
||||||
|
@ -180,17 +185,24 @@ func newSQSHandler(context context.Context, sqsQueueName string, region string,
|
||||||
SQSClient: client,
|
SQSClient: client,
|
||||||
SQSQueueURL: res.QueueUrl,
|
SQSQueueURL: res.QueueUrl,
|
||||||
IPC: i,
|
IPC: i,
|
||||||
|
cleanupInterval: time.Second * 30,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *sqsHandler) PollAndHandleMessages(context context.Context) {
|
func (r *sqsHandler) PollAndHandleMessages(ctx context.Context) {
|
||||||
log.Println("SQSHandler: Starting to poll for messages at: " + *r.SQSQueueURL)
|
log.Println("SQSHandler: Starting to poll for messages at: " + *r.SQSQueueURL)
|
||||||
messagesChn := make(chan *types.Message, 2)
|
messagesChn := make(chan *types.Message, 2)
|
||||||
go r.pollMessages(context, messagesChn)
|
go r.pollMessages(ctx, messagesChn)
|
||||||
go r.cleanupClientQueues(context)
|
go r.cleanupClientQueues(ctx)
|
||||||
|
|
||||||
for message := range messagesChn {
|
for message := range messagesChn {
|
||||||
r.handleMessage(context, message)
|
select {
|
||||||
r.deleteMessage(context, message)
|
case <-ctx.Done():
|
||||||
|
// if context is cancelled
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
r.handleMessage(ctx, message)
|
||||||
|
r.deleteMessage(ctx, message)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
327
broker/sqs_test.go
Normal file
327
broker/sqs_test.go
Normal file
|
@ -0,0 +1,327 @@
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"log"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go-v2/aws"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/service/sqs"
|
||||||
|
"github.com/aws/aws-sdk-go-v2/service/sqs/types"
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
. "github.com/smartystreets/goconvey/convey"
|
||||||
|
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/sqsclient"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSQS(t *testing.T) {
|
||||||
|
|
||||||
|
Convey("Context", t, func() {
|
||||||
|
ctx := NewBrokerContext(NullLogger())
|
||||||
|
i := &IPC{ctx}
|
||||||
|
|
||||||
|
var logBuffer bytes.Buffer
|
||||||
|
log.SetOutput(&logBuffer)
|
||||||
|
|
||||||
|
Convey("Responds to SQS client offers...", func() {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
mockSQSClient := sqsclient.NewMockSQSClient(ctrl)
|
||||||
|
|
||||||
|
brokerSQSQueueName := "example-name"
|
||||||
|
responseQueueURL := aws.String("https://sqs.us-east-1.amazonaws.com/testing")
|
||||||
|
|
||||||
|
runSQSHandler := func(sqsHandlerContext context.Context) {
|
||||||
|
mockSQSClient.EXPECT().CreateQueue(sqsHandlerContext, &sqs.CreateQueueInput{
|
||||||
|
QueueName: aws.String(brokerSQSQueueName),
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"MessageRetentionPeriod": strconv.FormatInt(int64((5 * time.Minute).Seconds()), 10),
|
||||||
|
},
|
||||||
|
}).Return(&sqs.CreateQueueOutput{
|
||||||
|
QueueUrl: responseQueueURL,
|
||||||
|
}, nil).Times(1)
|
||||||
|
sqsHandler, err := newSQSHandler(sqsHandlerContext, mockSQSClient, brokerSQSQueueName, "example-region", i)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
go sqsHandler.PollAndHandleMessages(sqsHandlerContext)
|
||||||
|
}
|
||||||
|
|
||||||
|
messageBody := aws.String("1.0\n{\"offer\": \"fake\", \"nat\": \"unknown\"}")
|
||||||
|
receiptHandle := "fake-receipt-handle"
|
||||||
|
sqsReceiveMessageInput := sqs.ReceiveMessageInput{
|
||||||
|
QueueUrl: responseQueueURL,
|
||||||
|
MaxNumberOfMessages: 10,
|
||||||
|
WaitTimeSeconds: 15,
|
||||||
|
MessageAttributeNames: []string{
|
||||||
|
string(types.QueueAttributeNameAll),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
sqsDeleteMessageInput := sqs.DeleteMessageInput{
|
||||||
|
QueueUrl: responseQueueURL,
|
||||||
|
ReceiptHandle: &receiptHandle,
|
||||||
|
}
|
||||||
|
|
||||||
|
Convey("by ignoring it if no client id specified", func(c C) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background())
|
||||||
|
defer sqsCancelFunc()
|
||||||
|
defer wg.Wait()
|
||||||
|
mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).MinTimes(1).DoAndReturn(
|
||||||
|
func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) {
|
||||||
|
return &sqs.ReceiveMessageOutput{
|
||||||
|
Messages: []types.Message{
|
||||||
|
{
|
||||||
|
Body: messageBody,
|
||||||
|
ReceiptHandle: &receiptHandle,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).Times(1).Do(
|
||||||
|
func(ctx context.Context, input *sqs.DeleteMessageInput, optFns ...func(*sqs.Options)) {
|
||||||
|
defer wg.Done()
|
||||||
|
c.So(logBuffer.String(), ShouldContainSubstring, "SQSHandler: got SDP offer in SQS message with no client ID. ignoring this message.")
|
||||||
|
mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).AnyTimes()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
runSQSHandler(sqsHandlerContext)
|
||||||
|
})
|
||||||
|
|
||||||
|
Convey("by doing nothing if an error occurs upon receipt of the message", func(c C) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
|
||||||
|
sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background())
|
||||||
|
defer sqsCancelFunc()
|
||||||
|
defer wg.Wait()
|
||||||
|
|
||||||
|
numTimes := 0
|
||||||
|
// When ReceiveMessage is called for the first time, the error has not had a chance to be logged yet.
|
||||||
|
// Therefore, we opt to wait for the second call because we are guaranteed that the error was logged
|
||||||
|
// by then.
|
||||||
|
mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).MinTimes(2).DoAndReturn(
|
||||||
|
func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) {
|
||||||
|
numTimes += 1
|
||||||
|
if numTimes <= 2 {
|
||||||
|
wg.Done()
|
||||||
|
if numTimes == 2 {
|
||||||
|
c.So(logBuffer.String(), ShouldContainSubstring, "SQSHandler: encountered error while polling for messages: error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, errors.New("error")
|
||||||
|
},
|
||||||
|
)
|
||||||
|
runSQSHandler(sqsHandlerContext)
|
||||||
|
})
|
||||||
|
|
||||||
|
Convey("by attempting to create a new sqs queue...", func() {
|
||||||
|
clientId := "fake-id"
|
||||||
|
sqsCreateQueueInput := sqs.CreateQueueInput{
|
||||||
|
QueueName: aws.String("snowflake-client-fake-id"),
|
||||||
|
}
|
||||||
|
|
||||||
|
expectReceiveMessageReturnsValidMessage := func(sqsHandlerContext context.Context) {
|
||||||
|
mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).AnyTimes().DoAndReturn(
|
||||||
|
func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) {
|
||||||
|
return &sqs.ReceiveMessageOutput{
|
||||||
|
Messages: []types.Message{
|
||||||
|
{
|
||||||
|
Body: messageBody,
|
||||||
|
MessageAttributes: map[string]types.MessageAttributeValue{
|
||||||
|
"ClientID": {StringValue: &clientId},
|
||||||
|
},
|
||||||
|
ReceiptHandle: &receiptHandle,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
Convey("and does not attempt to send a message via SQS if queue creation fails.", func(c C) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
|
||||||
|
sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background())
|
||||||
|
defer sqsCancelFunc()
|
||||||
|
defer wg.Wait()
|
||||||
|
|
||||||
|
expectReceiveMessageReturnsValidMessage(sqsHandlerContext)
|
||||||
|
mockSQSClient.EXPECT().CreateQueue(sqsHandlerContext, &sqsCreateQueueInput).Return(nil, errors.New("error")).AnyTimes()
|
||||||
|
numTimes := 0
|
||||||
|
mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).MinTimes(2).Do(
|
||||||
|
func(ctx context.Context, input *sqs.DeleteMessageInput, optFns ...func(*sqs.Options)) {
|
||||||
|
numTimes += 1
|
||||||
|
if numTimes <= 2 {
|
||||||
|
wg.Done()
|
||||||
|
if numTimes == 2 {
|
||||||
|
c.So(logBuffer.String(), ShouldContainSubstring, "SQSHandler: error encountered when creating answer queue for client fake-id: error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
runSQSHandler(sqsHandlerContext)
|
||||||
|
})
|
||||||
|
|
||||||
|
Convey("and responds with a proxy answer if available.", func(c C) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
|
||||||
|
sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background())
|
||||||
|
defer sqsCancelFunc()
|
||||||
|
defer wg.Wait()
|
||||||
|
|
||||||
|
expectReceiveMessageReturnsValidMessage(sqsHandlerContext)
|
||||||
|
mockSQSClient.EXPECT().CreateQueue(sqsHandlerContext, &sqsCreateQueueInput).Return(&sqs.CreateQueueOutput{
|
||||||
|
QueueUrl: responseQueueURL,
|
||||||
|
}, nil).AnyTimes()
|
||||||
|
mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).AnyTimes()
|
||||||
|
numTimes := 0
|
||||||
|
mockSQSClient.EXPECT().SendMessage(sqsHandlerContext, gomock.Any()).MinTimes(1).DoAndReturn(
|
||||||
|
func(ctx context.Context, input *sqs.SendMessageInput, optFns ...func(*sqs.Options)) (*sqs.SendMessageOutput, error) {
|
||||||
|
numTimes += 1
|
||||||
|
if numTimes == 1 {
|
||||||
|
c.So(input.MessageBody, ShouldEqual, aws.String("{\"answer\":\"fake answer\"}"))
|
||||||
|
wg.Done()
|
||||||
|
}
|
||||||
|
return &sqs.SendMessageOutput{}, nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
runSQSHandler(sqsHandlerContext)
|
||||||
|
|
||||||
|
snowflake := ctx.AddSnowflake("fake", "", NATUnrestricted, 0)
|
||||||
|
|
||||||
|
offer := <-snowflake.offerChannel
|
||||||
|
So(offer.sdp, ShouldResemble, []byte("fake"))
|
||||||
|
|
||||||
|
snowflake.answerChannel <- "fake answer"
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
Convey("Cleans up SQS client queues...", func() {
|
||||||
|
brokerSQSQueueName := "example-name"
|
||||||
|
responseQueueURL := aws.String("https://sqs.us-east-1.amazonaws.com/testing")
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
mockSQSClient := sqsclient.NewMockSQSClient(ctrl)
|
||||||
|
|
||||||
|
runSQSHandler := func(sqsHandlerContext context.Context) {
|
||||||
|
|
||||||
|
mockSQSClient.EXPECT().CreateQueue(sqsHandlerContext, &sqs.CreateQueueInput{
|
||||||
|
QueueName: aws.String(brokerSQSQueueName),
|
||||||
|
Attributes: map[string]string{
|
||||||
|
"MessageRetentionPeriod": strconv.FormatInt(int64((5 * time.Minute).Seconds()), 10),
|
||||||
|
},
|
||||||
|
}).Return(&sqs.CreateQueueOutput{
|
||||||
|
QueueUrl: responseQueueURL,
|
||||||
|
}, nil).Times(1)
|
||||||
|
|
||||||
|
mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, gomock.Any()).AnyTimes().Return(
|
||||||
|
&sqs.ReceiveMessageOutput{
|
||||||
|
Messages: []types.Message{},
|
||||||
|
}, nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
sqsHandler, err := newSQSHandler(sqsHandlerContext, mockSQSClient, brokerSQSQueueName, "example-region", i)
|
||||||
|
So(err, ShouldBeNil)
|
||||||
|
// Set the cleanup interval to 1 ns so we can immediately test the cleanup logic
|
||||||
|
sqsHandler.cleanupInterval = time.Nanosecond
|
||||||
|
|
||||||
|
go sqsHandler.PollAndHandleMessages(sqsHandlerContext)
|
||||||
|
}
|
||||||
|
|
||||||
|
Convey("does nothing if there are no open queues.", func() {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background())
|
||||||
|
defer wg.Wait()
|
||||||
|
|
||||||
|
mockSQSClient.EXPECT().ListQueues(sqsHandlerContext, &sqs.ListQueuesInput{
|
||||||
|
QueueNamePrefix: aws.String("snowflake-client-"),
|
||||||
|
MaxResults: aws.Int32(1000),
|
||||||
|
NextToken: nil,
|
||||||
|
}).DoAndReturn(func(ctx context.Context, input *sqs.ListQueuesInput, optFns ...func(*sqs.Options)) (*sqs.ListQueuesOutput, error) {
|
||||||
|
wg.Done()
|
||||||
|
// Cancel the handler context since we are only interested in testing one iteration of the cleanup
|
||||||
|
sqsCancelFunc()
|
||||||
|
return &sqs.ListQueuesOutput{
|
||||||
|
QueueUrls: []string{},
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
|
||||||
|
runSQSHandler(sqsHandlerContext)
|
||||||
|
})
|
||||||
|
|
||||||
|
Convey("deletes open queue when there is one open queue.", func(c C) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
clientQueueUrl1 := "https://sqs.us-east-1.amazonaws.com/snowflake-client-1"
|
||||||
|
clientQueueUrl2 := "https://sqs.us-east-1.amazonaws.com/snowflake-client-2"
|
||||||
|
|
||||||
|
gomock.InOrder(
|
||||||
|
mockSQSClient.EXPECT().ListQueues(sqsHandlerContext, &sqs.ListQueuesInput{
|
||||||
|
QueueNamePrefix: aws.String("snowflake-client-"),
|
||||||
|
MaxResults: aws.Int32(1000),
|
||||||
|
NextToken: nil,
|
||||||
|
}).Times(1).Return(&sqs.ListQueuesOutput{
|
||||||
|
QueueUrls: []string{
|
||||||
|
clientQueueUrl1,
|
||||||
|
clientQueueUrl2,
|
||||||
|
},
|
||||||
|
}, nil),
|
||||||
|
mockSQSClient.EXPECT().ListQueues(sqsHandlerContext, &sqs.ListQueuesInput{
|
||||||
|
QueueNamePrefix: aws.String("snowflake-client-"),
|
||||||
|
MaxResults: aws.Int32(1000),
|
||||||
|
NextToken: nil,
|
||||||
|
}).Times(1).DoAndReturn(func(ctx context.Context, input *sqs.ListQueuesInput, optFns ...func(*sqs.Options)) (*sqs.ListQueuesOutput, error) {
|
||||||
|
// Executed on second iteration of cleanupClientQueues loop. This means that one full iteration has completed and we can verify the results of that iteration
|
||||||
|
wg.Done()
|
||||||
|
sqsCancelFunc()
|
||||||
|
c.So(logBuffer.String(), ShouldContainSubstring, "SQSHandler: finished running iteration of client queue cleanup. found and deleted 2 client queues.")
|
||||||
|
return &sqs.ListQueuesOutput{
|
||||||
|
QueueUrls: []string{},
|
||||||
|
}, nil
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
|
||||||
|
gomock.InOrder(
|
||||||
|
mockSQSClient.EXPECT().GetQueueAttributes(sqsHandlerContext, &sqs.GetQueueAttributesInput{
|
||||||
|
QueueUrl: aws.String(clientQueueUrl1),
|
||||||
|
AttributeNames: []types.QueueAttributeName{types.QueueAttributeNameLastModifiedTimestamp},
|
||||||
|
}).Times(1).Return(&sqs.GetQueueAttributesOutput{
|
||||||
|
Attributes: map[string]string{
|
||||||
|
string(types.QueueAttributeNameLastModifiedTimestamp): "0",
|
||||||
|
}}, nil),
|
||||||
|
|
||||||
|
mockSQSClient.EXPECT().GetQueueAttributes(sqsHandlerContext, &sqs.GetQueueAttributesInput{
|
||||||
|
QueueUrl: aws.String(clientQueueUrl2),
|
||||||
|
AttributeNames: []types.QueueAttributeName{types.QueueAttributeNameLastModifiedTimestamp},
|
||||||
|
}).Times(1).Return(&sqs.GetQueueAttributesOutput{
|
||||||
|
Attributes: map[string]string{
|
||||||
|
string(types.QueueAttributeNameLastModifiedTimestamp): "0",
|
||||||
|
}}, nil),
|
||||||
|
)
|
||||||
|
|
||||||
|
gomock.InOrder(
|
||||||
|
mockSQSClient.EXPECT().DeleteQueue(sqsHandlerContext, &sqs.DeleteQueueInput{
|
||||||
|
QueueUrl: aws.String(clientQueueUrl1),
|
||||||
|
}).Return(&sqs.DeleteQueueOutput{}, nil),
|
||||||
|
mockSQSClient.EXPECT().DeleteQueue(sqsHandlerContext, &sqs.DeleteQueueInput{
|
||||||
|
QueueUrl: aws.String(clientQueueUrl2),
|
||||||
|
}).Return(&sqs.DeleteQueueOutput{}, nil),
|
||||||
|
)
|
||||||
|
|
||||||
|
runSQSHandler(sqsHandlerContext)
|
||||||
|
wg.Wait()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
|
@ -16,8 +16,6 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/proxy"
|
|
||||||
|
|
||||||
pt "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib"
|
pt "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/goptlib"
|
||||||
|
|
||||||
sf "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/client/lib"
|
sf "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/client/lib"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue