From 8343bbc336968c9b6adbb0c2f5d7958b59eab609 Mon Sep 17 00:00:00 2001 From: Cecylia Bocovich Date: Tue, 18 Mar 2025 16:11:42 -0400 Subject: [PATCH] Add context with timeout for client requests Client timeouts are currently counted from when the client is matched with a proxy. Instead, count client timeouts from the moment when the request is received. Closes #40449 --- broker/amp.go | 6 ++++++ broker/http.go | 6 ++++++ broker/ipc.go | 2 +- broker/sqs.go | 10 +++++++--- broker/sqs_test.go | 6 +++--- common/messages/ipc.go | 2 ++ 6 files changed, 25 insertions(+), 7 deletions(-) diff --git a/broker/amp.go b/broker/amp.go index 99289de..2bfcb71 100644 --- a/broker/amp.go +++ b/broker/amp.go @@ -1,9 +1,11 @@ package main import ( + "context" "log" "net/http" "strings" + "time" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/amp" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages" @@ -16,6 +18,9 @@ import ( // HTTP request body (because an AMP cache does not support POST), and the // encoded client poll response is sent back as AMP-armored HTML. func ampClientOffers(i *IPC, w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), ClientTimeout*time.Second) + defer cancel() + // The encoded client poll message immediately follows the /amp/client/ // path prefix, so this function unfortunately needs to be aware of and // remote its own routing prefix. @@ -38,6 +43,7 @@ func ampClientOffers(i *IPC, w http.ResponseWriter, r *http.Request) { Body: encPollReq, RemoteAddr: util.GetClientIp(r), RendezvousMethod: messages.RendezvousAmpCache, + Context: ctx, } err = i.ClientOffers(arg, &response) } else { diff --git a/broker/http.go b/broker/http.go index b6f449d..0cfbe64 100644 --- a/broker/http.go +++ b/broker/http.go @@ -2,12 +2,14 @@ package main import ( "bytes" + "context" "errors" "fmt" "io" "log" "net/http" "os" + "time" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/util" @@ -132,6 +134,9 @@ snowflake proxy, which responds with the SDP answer to be sent in the HTTP response back to the client. */ func clientOffers(i *IPC, w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), ClientTimeout*time.Second) + defer cancel() + body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, readLimit)) if err != nil { log.Printf("Error reading client request: %s", err.Error()) @@ -163,6 +168,7 @@ func clientOffers(i *IPC, w http.ResponseWriter, r *http.Request) { Body: body, RemoteAddr: util.GetClientIp(r), RendezvousMethod: messages.RendezvousHttp, + Context: ctx, } var response []byte diff --git a/broker/ipc.go b/broker/ipc.go index 64cefcd..3a194f1 100644 --- a/broker/ipc.go +++ b/broker/ipc.go @@ -217,7 +217,7 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error { i.ctx.metrics.lock.Lock() i.ctx.metrics.clientRoundtripEstimate = time.Since(startTime) / time.Millisecond i.ctx.metrics.lock.Unlock() - case <-time.After(time.Second * ClientTimeout): + case <-arg.Context.Done(): i.ctx.metrics.lock.Lock() i.ctx.metrics.UpdateRendezvousStats(arg.RemoteAddr, arg.RendezvousMethod, offer.natType, "timeout") i.ctx.metrics.lock.Unlock() diff --git a/broker/sqs.go b/broker/sqs.go index fb1164e..cb77ba6 100644 --- a/broker/sqs.go +++ b/broker/sqs.go @@ -124,18 +124,21 @@ func (r *sqsHandler) cleanupClientQueues(ctx context.Context) { } } -func (r *sqsHandler) handleMessage(context context.Context, message *types.Message) { +func (r *sqsHandler) handleMessage(mainCtx context.Context, message *types.Message) { var encPollReq []byte var response []byte var err error + ctx, cancel := context.WithTimeout(mainCtx, ClientTimeout*time.Second) + defer cancel() + clientID := message.MessageAttributes["ClientID"].StringValue if clientID == nil { log.Println("SQSHandler: got SDP offer in SQS message with no client ID. ignoring this message.") return } - res, err := r.SQSClient.CreateQueue(context, &sqs.CreateQueueInput{ + res, err := r.SQSClient.CreateQueue(ctx, &sqs.CreateQueueInput{ QueueName: aws.String("snowflake-client-" + *clientID), }) if err != nil { @@ -167,6 +170,7 @@ func (r *sqsHandler) handleMessage(context context.Context, message *types.Messa Body: encPollReq, RemoteAddr: remoteAddr, RendezvousMethod: messages.RendezvousSqs, + Context: ctx, } err = r.IPC.ClientOffers(arg, &response) @@ -175,7 +179,7 @@ func (r *sqsHandler) handleMessage(context context.Context, message *types.Messa return } - r.SQSClient.SendMessage(context, &sqs.SendMessageInput{ + r.SQSClient.SendMessage(ctx, &sqs.SendMessageInput{ QueueUrl: answerSQSURL, MessageBody: aws.String(string(response)), }) diff --git a/broker/sqs_test.go b/broker/sqs_test.go index 33e38f1..708e3ef 100644 --- a/broker/sqs_test.go +++ b/broker/sqs_test.go @@ -138,7 +138,7 @@ func TestSQS(t *testing.T) { sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background()) var numTimes atomic.Uint32 - mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).AnyTimes().DoAndReturn( + mockSQSClient.EXPECT().ReceiveMessage(gomock.Any(), &sqsReceiveMessageInput).AnyTimes().DoAndReturn( func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) { n := numTimes.Add(1) @@ -153,11 +153,11 @@ func TestSQS(t *testing.T) { return nil, errors.New("error") }) - mockSQSClient.EXPECT().CreateQueue(sqsHandlerContext, &sqsCreateQueueInput).Return(&sqs.CreateQueueOutput{ + mockSQSClient.EXPECT().CreateQueue(gomock.Any(), &sqsCreateQueueInput).Return(&sqs.CreateQueueOutput{ QueueUrl: responseQueueURL, }, nil).AnyTimes() mockSQSClient.EXPECT().DeleteMessage(gomock.Any(), gomock.Any()).AnyTimes() - mockSQSClient.EXPECT().SendMessage(sqsHandlerContext, gomock.Any()).Times(1).DoAndReturn( + mockSQSClient.EXPECT().SendMessage(gomock.Any(), gomock.Any()).Times(1).DoAndReturn( func(ctx context.Context, input *sqs.SendMessageInput, optFns ...func(*sqs.Options)) (*sqs.SendMessageOutput, error) { c.So(input.MessageBody, ShouldEqual, aws.String("{\"answer\":\"fake answer\"}")) // Ensure that match is correctly recorded in metrics diff --git a/common/messages/ipc.go b/common/messages/ipc.go index 2a61b9d..91eccdb 100644 --- a/common/messages/ipc.go +++ b/common/messages/ipc.go @@ -1,6 +1,7 @@ package messages import ( + "context" "errors" ) @@ -16,6 +17,7 @@ type Arg struct { Body []byte RemoteAddr string RendezvousMethod RendezvousMethod + Context context.Context } var (