Add context with timeout for client requests

Client timeouts are currently counted from when the client is matched
with a proxy. Instead, count client timeouts from the moment when the
request is received.

Closes #40449
This commit is contained in:
Cecylia Bocovich 2025-03-18 16:11:42 -04:00
parent db0364ef87
commit 8343bbc336
No known key found for this signature in database
GPG key ID: 009DE379FD9B7B90
6 changed files with 25 additions and 7 deletions

View file

@ -1,9 +1,11 @@
package main package main
import ( import (
"context"
"log" "log"
"net/http" "net/http"
"strings" "strings"
"time"
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/amp" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/amp"
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages"
@ -16,6 +18,9 @@ import (
// HTTP request body (because an AMP cache does not support POST), and the // HTTP request body (because an AMP cache does not support POST), and the
// encoded client poll response is sent back as AMP-armored HTML. // encoded client poll response is sent back as AMP-armored HTML.
func ampClientOffers(i *IPC, w http.ResponseWriter, r *http.Request) { func ampClientOffers(i *IPC, w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), ClientTimeout*time.Second)
defer cancel()
// The encoded client poll message immediately follows the /amp/client/ // The encoded client poll message immediately follows the /amp/client/
// path prefix, so this function unfortunately needs to be aware of and // path prefix, so this function unfortunately needs to be aware of and
// remote its own routing prefix. // remote its own routing prefix.
@ -38,6 +43,7 @@ func ampClientOffers(i *IPC, w http.ResponseWriter, r *http.Request) {
Body: encPollReq, Body: encPollReq,
RemoteAddr: util.GetClientIp(r), RemoteAddr: util.GetClientIp(r),
RendezvousMethod: messages.RendezvousAmpCache, RendezvousMethod: messages.RendezvousAmpCache,
Context: ctx,
} }
err = i.ClientOffers(arg, &response) err = i.ClientOffers(arg, &response)
} else { } else {

View file

@ -2,12 +2,14 @@ package main
import ( import (
"bytes" "bytes"
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log" "log"
"net/http" "net/http"
"os" "os"
"time"
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages"
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/util" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/util"
@ -132,6 +134,9 @@ snowflake proxy, which responds with the SDP answer to be sent in
the HTTP response back to the client. the HTTP response back to the client.
*/ */
func clientOffers(i *IPC, w http.ResponseWriter, r *http.Request) { func clientOffers(i *IPC, w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), ClientTimeout*time.Second)
defer cancel()
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, readLimit)) body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, readLimit))
if err != nil { if err != nil {
log.Printf("Error reading client request: %s", err.Error()) log.Printf("Error reading client request: %s", err.Error())
@ -163,6 +168,7 @@ func clientOffers(i *IPC, w http.ResponseWriter, r *http.Request) {
Body: body, Body: body,
RemoteAddr: util.GetClientIp(r), RemoteAddr: util.GetClientIp(r),
RendezvousMethod: messages.RendezvousHttp, RendezvousMethod: messages.RendezvousHttp,
Context: ctx,
} }
var response []byte var response []byte

View file

@ -217,7 +217,7 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
i.ctx.metrics.lock.Lock() i.ctx.metrics.lock.Lock()
i.ctx.metrics.clientRoundtripEstimate = time.Since(startTime) / time.Millisecond i.ctx.metrics.clientRoundtripEstimate = time.Since(startTime) / time.Millisecond
i.ctx.metrics.lock.Unlock() i.ctx.metrics.lock.Unlock()
case <-time.After(time.Second * ClientTimeout): case <-arg.Context.Done():
i.ctx.metrics.lock.Lock() i.ctx.metrics.lock.Lock()
i.ctx.metrics.UpdateRendezvousStats(arg.RemoteAddr, arg.RendezvousMethod, offer.natType, "timeout") i.ctx.metrics.UpdateRendezvousStats(arg.RemoteAddr, arg.RendezvousMethod, offer.natType, "timeout")
i.ctx.metrics.lock.Unlock() i.ctx.metrics.lock.Unlock()

View file

@ -124,18 +124,21 @@ func (r *sqsHandler) cleanupClientQueues(ctx context.Context) {
} }
} }
func (r *sqsHandler) handleMessage(context context.Context, message *types.Message) { func (r *sqsHandler) handleMessage(mainCtx context.Context, message *types.Message) {
var encPollReq []byte var encPollReq []byte
var response []byte var response []byte
var err error var err error
ctx, cancel := context.WithTimeout(mainCtx, ClientTimeout*time.Second)
defer cancel()
clientID := message.MessageAttributes["ClientID"].StringValue clientID := message.MessageAttributes["ClientID"].StringValue
if clientID == nil { if clientID == nil {
log.Println("SQSHandler: got SDP offer in SQS message with no client ID. ignoring this message.") log.Println("SQSHandler: got SDP offer in SQS message with no client ID. ignoring this message.")
return return
} }
res, err := r.SQSClient.CreateQueue(context, &sqs.CreateQueueInput{ res, err := r.SQSClient.CreateQueue(ctx, &sqs.CreateQueueInput{
QueueName: aws.String("snowflake-client-" + *clientID), QueueName: aws.String("snowflake-client-" + *clientID),
}) })
if err != nil { if err != nil {
@ -167,6 +170,7 @@ func (r *sqsHandler) handleMessage(context context.Context, message *types.Messa
Body: encPollReq, Body: encPollReq,
RemoteAddr: remoteAddr, RemoteAddr: remoteAddr,
RendezvousMethod: messages.RendezvousSqs, RendezvousMethod: messages.RendezvousSqs,
Context: ctx,
} }
err = r.IPC.ClientOffers(arg, &response) err = r.IPC.ClientOffers(arg, &response)
@ -175,7 +179,7 @@ func (r *sqsHandler) handleMessage(context context.Context, message *types.Messa
return return
} }
r.SQSClient.SendMessage(context, &sqs.SendMessageInput{ r.SQSClient.SendMessage(ctx, &sqs.SendMessageInput{
QueueUrl: answerSQSURL, QueueUrl: answerSQSURL,
MessageBody: aws.String(string(response)), MessageBody: aws.String(string(response)),
}) })

View file

@ -138,7 +138,7 @@ func TestSQS(t *testing.T) {
sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background()) sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background())
var numTimes atomic.Uint32 var numTimes atomic.Uint32
mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).AnyTimes().DoAndReturn( mockSQSClient.EXPECT().ReceiveMessage(gomock.Any(), &sqsReceiveMessageInput).AnyTimes().DoAndReturn(
func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) { func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) {
n := numTimes.Add(1) n := numTimes.Add(1)
@ -153,11 +153,11 @@ func TestSQS(t *testing.T) {
return nil, errors.New("error") return nil, errors.New("error")
}) })
mockSQSClient.EXPECT().CreateQueue(sqsHandlerContext, &sqsCreateQueueInput).Return(&sqs.CreateQueueOutput{ mockSQSClient.EXPECT().CreateQueue(gomock.Any(), &sqsCreateQueueInput).Return(&sqs.CreateQueueOutput{
QueueUrl: responseQueueURL, QueueUrl: responseQueueURL,
}, nil).AnyTimes() }, nil).AnyTimes()
mockSQSClient.EXPECT().DeleteMessage(gomock.Any(), gomock.Any()).AnyTimes() mockSQSClient.EXPECT().DeleteMessage(gomock.Any(), gomock.Any()).AnyTimes()
mockSQSClient.EXPECT().SendMessage(sqsHandlerContext, gomock.Any()).Times(1).DoAndReturn( mockSQSClient.EXPECT().SendMessage(gomock.Any(), gomock.Any()).Times(1).DoAndReturn(
func(ctx context.Context, input *sqs.SendMessageInput, optFns ...func(*sqs.Options)) (*sqs.SendMessageOutput, error) { func(ctx context.Context, input *sqs.SendMessageInput, optFns ...func(*sqs.Options)) (*sqs.SendMessageOutput, error) {
c.So(input.MessageBody, ShouldEqual, aws.String("{\"answer\":\"fake answer\"}")) c.So(input.MessageBody, ShouldEqual, aws.String("{\"answer\":\"fake answer\"}"))
// Ensure that match is correctly recorded in metrics // Ensure that match is correctly recorded in metrics

View file

@ -1,6 +1,7 @@
package messages package messages
import ( import (
"context"
"errors" "errors"
) )
@ -16,6 +17,7 @@ type Arg struct {
Body []byte Body []byte
RemoteAddr string RemoteAddr string
RendezvousMethod RendezvousMethod RendezvousMethod RendezvousMethod
Context context.Context
} }
var ( var (