From 1180d11a66ed5a6768113287cf22fb83094e7979 Mon Sep 17 00:00:00 2001 From: Cecylia Bocovich Date: Wed, 19 Feb 2025 17:51:37 -0500 Subject: [PATCH] Remove data races from sqs tests Our SQS tests were not concurrency safe and we hadn't noticed until now because we were processing incoming SQS queue messages sequentially rather than in parallel. This fix removes the log output checks, which were prone to error anyway, and relies instead on gomock's expected function calls and strategic use of the context cancel function for each test. --- broker/sqs_test.go | 129 ++++++++++++++++----------------------------- 1 file changed, 46 insertions(+), 83 deletions(-) diff --git a/broker/sqs_test.go b/broker/sqs_test.go index 7c70390..ab2e761 100644 --- a/broker/sqs_test.go +++ b/broker/sqs_test.go @@ -7,6 +7,7 @@ import ( "log" "strconv" "sync" + "sync/atomic" "testing" "time" @@ -25,9 +26,6 @@ func TestSQS(t *testing.T) { ipcCtx := NewBrokerContext(log.New(buf, "", 0), "", "") i := &IPC{ipcCtx} - var logBuffer bytes.Buffer - log.SetOutput(&logBuffer) - Convey("Responds to SQS client offers...", func() { ctrl := gomock.NewController(t) mockSQSClient := sqsclient.NewMockSQSClient(ctrl) @@ -65,12 +63,7 @@ func TestSQS(t *testing.T) { } Convey("by ignoring it if no client id specified", func(c C) { - var wg sync.WaitGroup - wg.Add(1) - sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background()) - defer sqsCancelFunc() - defer wg.Wait() mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).MinTimes(1).DoAndReturn( func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) { return &sqs.ReceiveMessageOutput{ @@ -83,41 +76,32 @@ func TestSQS(t *testing.T) { }, nil }, ) - mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).Times(1).Do( + mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).MinTimes(1).Do( func(ctx context.Context, input *sqs.DeleteMessageInput, optFns ...func(*sqs.Options)) { - defer wg.Done() - c.So(logBuffer.String(), ShouldContainSubstring, "SQSHandler: got SDP offer in SQS message with no client ID. ignoring this message.") - mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).AnyTimes() + sqsCancelFunc() }, ) + // We expect no queues to be created + mockSQSClient.EXPECT().CreateQueue(gomock.Any(), gomock.Any()).Times(0) runSQSHandler(sqsHandlerContext) + <-sqsHandlerContext.Done() }) Convey("by doing nothing if an error occurs upon receipt of the message", func(c C) { - var wg sync.WaitGroup - wg.Add(2) sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background()) - defer sqsCancelFunc() - defer wg.Wait() - numTimes := 0 - // When ReceiveMessage is called for the first time, the error has not had a chance to be logged yet. - // Therefore, we opt to wait for the second call because we are guaranteed that the error was logged - // by then. - mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).MinTimes(2).DoAndReturn( + mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).MinTimes(1).DoAndReturn( func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) { - numTimes += 1 - if numTimes <= 2 { - wg.Done() - if numTimes == 2 { - c.So(logBuffer.String(), ShouldContainSubstring, "SQSHandler: encountered error while polling for messages: error") - } - } + sqsCancelFunc() return nil, errors.New("error") }, ) + // We expect no queues to be created or deleted + mockSQSClient.EXPECT().CreateQueue(gomock.Any(), gomock.Any()).Times(0) + mockSQSClient.EXPECT().DeleteMessage(gomock.Any(), gomock.Any()).Times(0) runSQSHandler(sqsHandlerContext) + <-sqsHandlerContext.Done() }) Convey("by attempting to create a new sqs queue...", func() { @@ -125,68 +109,53 @@ func TestSQS(t *testing.T) { sqsCreateQueueInput := sqs.CreateQueueInput{ QueueName: aws.String("snowflake-client-fake-id"), } + validMessage := &sqs.ReceiveMessageOutput{ + Messages: []types.Message{ + { + Body: messageBody, + MessageAttributes: map[string]types.MessageAttributeValue{ + "ClientID": {StringValue: &clientId}, + }, + ReceiptHandle: &receiptHandle, + }, + }, + } + Convey("and does not attempt to send a message via SQS if queue creation fails.", func(c C) { + sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background()) - expectReceiveMessageReturnsValidMessage := func(sqsHandlerContext context.Context) { mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).AnyTimes().DoAndReturn( func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) { - return &sqs.ReceiveMessageOutput{ - Messages: []types.Message{ - { - Body: messageBody, - MessageAttributes: map[string]types.MessageAttributeValue{ - "ClientID": {StringValue: &clientId}, - }, - ReceiptHandle: &receiptHandle, - }, - }, - }, nil - }, - ) - } - - Convey("and does not attempt to send a message via SQS if queue creation fails.", func(c C) { - var wg sync.WaitGroup - wg.Add(2) - - sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background()) - defer sqsCancelFunc() - defer wg.Wait() - - expectReceiveMessageReturnsValidMessage(sqsHandlerContext) + sqsCancelFunc() + return validMessage, nil + }) mockSQSClient.EXPECT().CreateQueue(sqsHandlerContext, &sqsCreateQueueInput).Return(nil, errors.New("error")).AnyTimes() - numTimes := 0 - mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).MinTimes(2).Do( - func(ctx context.Context, input *sqs.DeleteMessageInput, optFns ...func(*sqs.Options)) { - numTimes += 1 - if numTimes <= 2 { - wg.Done() - if numTimes == 2 { - c.So(logBuffer.String(), ShouldContainSubstring, "SQSHandler: error encountered when creating answer queue for client fake-id: error") - } - } - }, - ) + mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).AnyTimes() runSQSHandler(sqsHandlerContext) + <-sqsHandlerContext.Done() }) Convey("and responds with a proxy answer if available.", func(c C) { - var wg sync.WaitGroup - wg.Add(1) - sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background()) - defer sqsCancelFunc() - defer wg.Wait() - expectReceiveMessageReturnsValidMessage(sqsHandlerContext) + mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).AnyTimes().DoAndReturn( + func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) { + + go func(c C) { + snowflake := ipcCtx.AddSnowflake("fake", "", NATUnrestricted, 0) + <-snowflake.offerChannel + snowflake.answerChannel <- "fake answer" + }(c) + return validMessage, nil + }) mockSQSClient.EXPECT().CreateQueue(sqsHandlerContext, &sqsCreateQueueInput).Return(&sqs.CreateQueueOutput{ QueueUrl: responseQueueURL, }, nil).AnyTimes() - mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).AnyTimes() - numTimes := 0 + mockSQSClient.EXPECT().DeleteMessage(gomock.Any(), gomock.Any()).AnyTimes() + var numTimes atomic.Uint32 mockSQSClient.EXPECT().SendMessage(sqsHandlerContext, gomock.Any()).MinTimes(1).DoAndReturn( func(ctx context.Context, input *sqs.SendMessageInput, optFns ...func(*sqs.Options)) (*sqs.SendMessageOutput, error) { - numTimes += 1 - if numTimes == 1 { + n := numTimes.Add(1) + if n == 1 { c.So(input.MessageBody, ShouldEqual, aws.String("{\"answer\":\"fake answer\"}")) // Ensure that match is correctly recorded in metrics ipcCtx.metrics.printMetrics() @@ -201,19 +170,14 @@ client-ampcache-ips client-sqs-count 8 client-sqs-ips ??=8 `) - wg.Done() + sqsCancelFunc() } return &sqs.SendMessageOutput{}, nil }, ) runSQSHandler(sqsHandlerContext) - snowflake := ipcCtx.AddSnowflake("fake", "", NATUnrestricted, 0) - - offer := <-snowflake.offerChannel - So(offer.sdp, ShouldResemble, []byte("fake")) - - snowflake.answerChannel <- "fake answer" + <-sqsHandlerContext.Done() }) }) }) @@ -299,7 +263,6 @@ client-sqs-ips ??=8 // Executed on second iteration of cleanupClientQueues loop. This means that one full iteration has completed and we can verify the results of that iteration wg.Done() sqsCancelFunc() - c.So(logBuffer.String(), ShouldContainSubstring, "SQSHandler: finished running iteration of client queue cleanup. found and deleted 2 client queues.") return &sqs.ListQueuesOutput{ QueueUrls: []string{}, }, nil