Remove data races from sqs tests

Our SQS tests were not concurrency safe and we hadn't noticed until now
because we were processing incoming SQS queue messages sequentially
rather than in parallel.

This fix removes the log output checks, which were prone to error
anyway, and relies instead on gomock's expected function calls and
strategic use of the context cancel function for each test.
This commit is contained in:
Cecylia Bocovich 2025-02-19 17:51:37 -05:00
parent 2250bc86f6
commit 1180d11a66
No known key found for this signature in database
GPG key ID: 009DE379FD9B7B90

View file

@ -7,6 +7,7 @@ import (
"log" "log"
"strconv" "strconv"
"sync" "sync"
"sync/atomic"
"testing" "testing"
"time" "time"
@ -25,9 +26,6 @@ func TestSQS(t *testing.T) {
ipcCtx := NewBrokerContext(log.New(buf, "", 0), "", "") ipcCtx := NewBrokerContext(log.New(buf, "", 0), "", "")
i := &IPC{ipcCtx} i := &IPC{ipcCtx}
var logBuffer bytes.Buffer
log.SetOutput(&logBuffer)
Convey("Responds to SQS client offers...", func() { Convey("Responds to SQS client offers...", func() {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
mockSQSClient := sqsclient.NewMockSQSClient(ctrl) mockSQSClient := sqsclient.NewMockSQSClient(ctrl)
@ -65,12 +63,7 @@ func TestSQS(t *testing.T) {
} }
Convey("by ignoring it if no client id specified", func(c C) { Convey("by ignoring it if no client id specified", func(c C) {
var wg sync.WaitGroup
wg.Add(1)
sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background()) sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background())
defer sqsCancelFunc()
defer wg.Wait()
mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).MinTimes(1).DoAndReturn( mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).MinTimes(1).DoAndReturn(
func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) { func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) {
return &sqs.ReceiveMessageOutput{ return &sqs.ReceiveMessageOutput{
@ -83,41 +76,32 @@ func TestSQS(t *testing.T) {
}, nil }, nil
}, },
) )
mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).Times(1).Do( mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).MinTimes(1).Do(
func(ctx context.Context, input *sqs.DeleteMessageInput, optFns ...func(*sqs.Options)) { func(ctx context.Context, input *sqs.DeleteMessageInput, optFns ...func(*sqs.Options)) {
defer wg.Done() sqsCancelFunc()
c.So(logBuffer.String(), ShouldContainSubstring, "SQSHandler: got SDP offer in SQS message with no client ID. ignoring this message.")
mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).AnyTimes()
}, },
) )
// We expect no queues to be created
mockSQSClient.EXPECT().CreateQueue(gomock.Any(), gomock.Any()).Times(0)
runSQSHandler(sqsHandlerContext) runSQSHandler(sqsHandlerContext)
<-sqsHandlerContext.Done()
}) })
Convey("by doing nothing if an error occurs upon receipt of the message", func(c C) { Convey("by doing nothing if an error occurs upon receipt of the message", func(c C) {
var wg sync.WaitGroup
wg.Add(2)
sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background()) sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background())
defer sqsCancelFunc()
defer wg.Wait()
numTimes := 0 mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).MinTimes(1).DoAndReturn(
// When ReceiveMessage is called for the first time, the error has not had a chance to be logged yet.
// Therefore, we opt to wait for the second call because we are guaranteed that the error was logged
// by then.
mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).MinTimes(2).DoAndReturn(
func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) { func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) {
numTimes += 1 sqsCancelFunc()
if numTimes <= 2 {
wg.Done()
if numTimes == 2 {
c.So(logBuffer.String(), ShouldContainSubstring, "SQSHandler: encountered error while polling for messages: error")
}
}
return nil, errors.New("error") return nil, errors.New("error")
}, },
) )
// We expect no queues to be created or deleted
mockSQSClient.EXPECT().CreateQueue(gomock.Any(), gomock.Any()).Times(0)
mockSQSClient.EXPECT().DeleteMessage(gomock.Any(), gomock.Any()).Times(0)
runSQSHandler(sqsHandlerContext) runSQSHandler(sqsHandlerContext)
<-sqsHandlerContext.Done()
}) })
Convey("by attempting to create a new sqs queue...", func() { Convey("by attempting to create a new sqs queue...", func() {
@ -125,11 +109,7 @@ func TestSQS(t *testing.T) {
sqsCreateQueueInput := sqs.CreateQueueInput{ sqsCreateQueueInput := sqs.CreateQueueInput{
QueueName: aws.String("snowflake-client-fake-id"), QueueName: aws.String("snowflake-client-fake-id"),
} }
validMessage := &sqs.ReceiveMessageOutput{
expectReceiveMessageReturnsValidMessage := func(sqsHandlerContext context.Context) {
mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).AnyTimes().DoAndReturn(
func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) {
return &sqs.ReceiveMessageOutput{
Messages: []types.Message{ Messages: []types.Message{
{ {
Body: messageBody, Body: messageBody,
@ -139,54 +119,43 @@ func TestSQS(t *testing.T) {
ReceiptHandle: &receiptHandle, ReceiptHandle: &receiptHandle,
}, },
}, },
}, nil
},
)
} }
Convey("and does not attempt to send a message via SQS if queue creation fails.", func(c C) { Convey("and does not attempt to send a message via SQS if queue creation fails.", func(c C) {
var wg sync.WaitGroup
wg.Add(2)
sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background()) sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background())
defer sqsCancelFunc()
defer wg.Wait()
expectReceiveMessageReturnsValidMessage(sqsHandlerContext) mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).AnyTimes().DoAndReturn(
func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) {
sqsCancelFunc()
return validMessage, nil
})
mockSQSClient.EXPECT().CreateQueue(sqsHandlerContext, &sqsCreateQueueInput).Return(nil, errors.New("error")).AnyTimes() mockSQSClient.EXPECT().CreateQueue(sqsHandlerContext, &sqsCreateQueueInput).Return(nil, errors.New("error")).AnyTimes()
numTimes := 0 mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).AnyTimes()
mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).MinTimes(2).Do(
func(ctx context.Context, input *sqs.DeleteMessageInput, optFns ...func(*sqs.Options)) {
numTimes += 1
if numTimes <= 2 {
wg.Done()
if numTimes == 2 {
c.So(logBuffer.String(), ShouldContainSubstring, "SQSHandler: error encountered when creating answer queue for client fake-id: error")
}
}
},
)
runSQSHandler(sqsHandlerContext) runSQSHandler(sqsHandlerContext)
<-sqsHandlerContext.Done()
}) })
Convey("and responds with a proxy answer if available.", func(c C) { Convey("and responds with a proxy answer if available.", func(c C) {
var wg sync.WaitGroup
wg.Add(1)
sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background()) sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background())
defer sqsCancelFunc()
defer wg.Wait()
expectReceiveMessageReturnsValidMessage(sqsHandlerContext) mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).AnyTimes().DoAndReturn(
func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) {
go func(c C) {
snowflake := ipcCtx.AddSnowflake("fake", "", NATUnrestricted, 0)
<-snowflake.offerChannel
snowflake.answerChannel <- "fake answer"
}(c)
return validMessage, nil
})
mockSQSClient.EXPECT().CreateQueue(sqsHandlerContext, &sqsCreateQueueInput).Return(&sqs.CreateQueueOutput{ mockSQSClient.EXPECT().CreateQueue(sqsHandlerContext, &sqsCreateQueueInput).Return(&sqs.CreateQueueOutput{
QueueUrl: responseQueueURL, QueueUrl: responseQueueURL,
}, nil).AnyTimes() }, nil).AnyTimes()
mockSQSClient.EXPECT().DeleteMessage(sqsHandlerContext, &sqsDeleteMessageInput).AnyTimes() mockSQSClient.EXPECT().DeleteMessage(gomock.Any(), gomock.Any()).AnyTimes()
numTimes := 0 var numTimes atomic.Uint32
mockSQSClient.EXPECT().SendMessage(sqsHandlerContext, gomock.Any()).MinTimes(1).DoAndReturn( mockSQSClient.EXPECT().SendMessage(sqsHandlerContext, gomock.Any()).MinTimes(1).DoAndReturn(
func(ctx context.Context, input *sqs.SendMessageInput, optFns ...func(*sqs.Options)) (*sqs.SendMessageOutput, error) { func(ctx context.Context, input *sqs.SendMessageInput, optFns ...func(*sqs.Options)) (*sqs.SendMessageOutput, error) {
numTimes += 1 n := numTimes.Add(1)
if numTimes == 1 { if n == 1 {
c.So(input.MessageBody, ShouldEqual, aws.String("{\"answer\":\"fake answer\"}")) c.So(input.MessageBody, ShouldEqual, aws.String("{\"answer\":\"fake answer\"}"))
// Ensure that match is correctly recorded in metrics // Ensure that match is correctly recorded in metrics
ipcCtx.metrics.printMetrics() ipcCtx.metrics.printMetrics()
@ -201,19 +170,14 @@ client-ampcache-ips
client-sqs-count 8 client-sqs-count 8
client-sqs-ips ??=8 client-sqs-ips ??=8
`) `)
wg.Done() sqsCancelFunc()
} }
return &sqs.SendMessageOutput{}, nil return &sqs.SendMessageOutput{}, nil
}, },
) )
runSQSHandler(sqsHandlerContext) runSQSHandler(sqsHandlerContext)
snowflake := ipcCtx.AddSnowflake("fake", "", NATUnrestricted, 0) <-sqsHandlerContext.Done()
offer := <-snowflake.offerChannel
So(offer.sdp, ShouldResemble, []byte("fake"))
snowflake.answerChannel <- "fake answer"
}) })
}) })
}) })
@ -299,7 +263,6 @@ client-sqs-ips ??=8
// Executed on second iteration of cleanupClientQueues loop. This means that one full iteration has completed and we can verify the results of that iteration // Executed on second iteration of cleanupClientQueues loop. This means that one full iteration has completed and we can verify the results of that iteration
wg.Done() wg.Done()
sqsCancelFunc() sqsCancelFunc()
c.So(logBuffer.String(), ShouldContainSubstring, "SQSHandler: finished running iteration of client queue cleanup. found and deleted 2 client queues.")
return &sqs.ListQueuesOutput{ return &sqs.ListQueuesOutput{
QueueUrls: []string{}, QueueUrls: []string{},
}, nil }, nil