diff --git a/broker/sqs_test.go b/broker/sqs_test.go index 7c70390..ab2e761 100644 --- a/broker/sqs_test.go +++ b/broker/sqs_test.go @@ -7,6 +7,7 @@ import ( "log" "strconv" "sync" + "sync/atomic" "testing" "time" @@ -25,9 +26,6 @@ func TestSQS(t *testing.T) { ipcCtx := NewBrokerContext(log.New(buf, "", 0), "", "") i := &IPC{ipcCtx} - var logBuffer bytes.Buffer - log.SetOutput(&logBuffer) - Convey("Responds to SQS client offers...", func() { ctrl := gomock.NewController(t) mockSQSClient := sqsclient.NewMockSQSClient(ctrl) @@ -65,12 +63,7 @@ func TestSQS(t *testing.T) { } Convey("by ignoring it if no client id specified", func(c C) { - var wg sync.WaitGroup - wg.Add(1) - sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background()) - defer sqsCancelFunc() - defer wg.Wait() mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).MinTimes(1).DoAndReturn( func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) { return &sqs.ReceiveMessageOutput{ @@ -83,41 +76,32 @@ func TestSQS(t *testing.T) { }, nil }, ) - mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).Times(1).Do( + mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).MinTimes(1).Do( func(ctx context.Context, input *sqs.DeleteMessageInput, optFns ...func(*sqs.Options)) { - defer wg.Done() - c.So(logBuffer.String(), ShouldContainSubstring, "SQSHandler: got SDP offer in SQS message with no client ID. ignoring this message.") - mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).AnyTimes() + sqsCancelFunc() }, ) + // We expect no queues to be created + mockSQSClient.EXPECT().CreateQueue(gomock.Any(), gomock.Any()).Times(0) runSQSHandler(sqsHandlerContext) + <-sqsHandlerContext.Done() }) Convey("by doing nothing if an error occurs upon receipt of the message", func(c C) { - var wg sync.WaitGroup - wg.Add(2) sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background()) - defer sqsCancelFunc() - defer wg.Wait() - numTimes := 0 - // When ReceiveMessage is called for the first time, the error has not had a chance to be logged yet. - // Therefore, we opt to wait for the second call because we are guaranteed that the error was logged - // by then. - mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).MinTimes(2).DoAndReturn( + mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).MinTimes(1).DoAndReturn( func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) { - numTimes += 1 - if numTimes <= 2 { - wg.Done() - if numTimes == 2 { - c.So(logBuffer.String(), ShouldContainSubstring, "SQSHandler: encountered error while polling for messages: error") - } - } + sqsCancelFunc() return nil, errors.New("error") }, ) + // We expect no queues to be created or deleted + mockSQSClient.EXPECT().CreateQueue(gomock.Any(), gomock.Any()).Times(0) + mockSQSClient.EXPECT().DeleteMessage(gomock.Any(), gomock.Any()).Times(0) runSQSHandler(sqsHandlerContext) + <-sqsHandlerContext.Done() }) Convey("by attempting to create a new sqs queue...", func() { @@ -125,68 +109,53 @@ func TestSQS(t *testing.T) { sqsCreateQueueInput := sqs.CreateQueueInput{ QueueName: aws.String("snowflake-client-fake-id"), } + validMessage := &sqs.ReceiveMessageOutput{ + Messages: []types.Message{ + { + Body: messageBody, + MessageAttributes: map[string]types.MessageAttributeValue{ + "ClientID": {StringValue: &clientId}, + }, + ReceiptHandle: &receiptHandle, + }, + }, + } + Convey("and does not attempt to send a message via SQS if queue creation fails.", func(c C) { + sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background()) - expectReceiveMessageReturnsValidMessage := func(sqsHandlerContext context.Context) { mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).AnyTimes().DoAndReturn( func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) { - return &sqs.ReceiveMessageOutput{ - Messages: []types.Message{ - { - Body: messageBody, - MessageAttributes: map[string]types.MessageAttributeValue{ - "ClientID": {StringValue: &clientId}, - }, - ReceiptHandle: &receiptHandle, - }, - }, - }, nil - }, - ) - } - - Convey("and does not attempt to send a message via SQS if queue creation fails.", func(c C) { - var wg sync.WaitGroup - wg.Add(2) - - sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background()) - defer sqsCancelFunc() - defer wg.Wait() - - expectReceiveMessageReturnsValidMessage(sqsHandlerContext) + sqsCancelFunc() + return validMessage, nil + }) mockSQSClient.EXPECT().CreateQueue(sqsHandlerContext, &sqsCreateQueueInput).Return(nil, errors.New("error")).AnyTimes() - numTimes := 0 - mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).MinTimes(2).Do( - func(ctx context.Context, input *sqs.DeleteMessageInput, optFns ...func(*sqs.Options)) { - numTimes += 1 - if numTimes <= 2 { - wg.Done() - if numTimes == 2 { - c.So(logBuffer.String(), ShouldContainSubstring, "SQSHandler: error encountered when creating answer queue for client fake-id: error") - } - } - }, - ) + mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).AnyTimes() runSQSHandler(sqsHandlerContext) + <-sqsHandlerContext.Done() }) Convey("and responds with a proxy answer if available.", func(c C) { - var wg sync.WaitGroup - wg.Add(1) - sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background()) - defer sqsCancelFunc() - defer wg.Wait() - expectReceiveMessageReturnsValidMessage(sqsHandlerContext) + mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).AnyTimes().DoAndReturn( + func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) { + + go func(c C) { + snowflake := ipcCtx.AddSnowflake("fake", "", NATUnrestricted, 0) + <-snowflake.offerChannel + snowflake.answerChannel <- "fake answer" + }(c) + return validMessage, nil + }) mockSQSClient.EXPECT().CreateQueue(sqsHandlerContext, &sqsCreateQueueInput).Return(&sqs.CreateQueueOutput{ QueueUrl: responseQueueURL, }, nil).AnyTimes() - mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).AnyTimes() - numTimes := 0 + mockSQSClient.EXPECT().DeleteMessage(gomock.Any(), gomock.Any()).AnyTimes() + var numTimes atomic.Uint32 mockSQSClient.EXPECT().SendMessage(sqsHandlerContext, gomock.Any()).MinTimes(1).DoAndReturn( func(ctx context.Context, input *sqs.SendMessageInput, optFns ...func(*sqs.Options)) (*sqs.SendMessageOutput, error) { - numTimes += 1 - if numTimes == 1 { + n := numTimes.Add(1) + if n == 1 { c.So(input.MessageBody, ShouldEqual, aws.String("{\"answer\":\"fake answer\"}")) // Ensure that match is correctly recorded in metrics ipcCtx.metrics.printMetrics() @@ -201,19 +170,14 @@ client-ampcache-ips client-sqs-count 8 client-sqs-ips ??=8 `) - wg.Done() + sqsCancelFunc() } return &sqs.SendMessageOutput{}, nil }, ) runSQSHandler(sqsHandlerContext) - snowflake := ipcCtx.AddSnowflake("fake", "", NATUnrestricted, 0) - - offer := <-snowflake.offerChannel - So(offer.sdp, ShouldResemble, []byte("fake")) - - snowflake.answerChannel <- "fake answer" + <-sqsHandlerContext.Done() }) }) }) @@ -299,7 +263,6 @@ client-sqs-ips ??=8 // Executed on second iteration of cleanupClientQueues loop. This means that one full iteration has completed and we can verify the results of that iteration wg.Done() sqsCancelFunc() - c.So(logBuffer.String(), ShouldContainSubstring, "SQSHandler: finished running iteration of client queue cleanup. found and deleted 2 client queues.") return &sqs.ListQueuesOutput{ QueueUrls: []string{}, }, nil