mirror of
https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake.git
synced 2025-10-13 11:11:30 -04:00
Add context with timeout for client requests
Client timeouts are currently counted from when the client is matched with a proxy. Instead, count client timeouts from the moment when the request is received. Closes #40449
This commit is contained in:
parent
db0364ef87
commit
8343bbc336
6 changed files with 25 additions and 7 deletions
|
@ -1,9 +1,11 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/amp"
|
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/amp"
|
||||||
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages"
|
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages"
|
||||||
|
@ -16,6 +18,9 @@ import (
|
||||||
// HTTP request body (because an AMP cache does not support POST), and the
|
// HTTP request body (because an AMP cache does not support POST), and the
|
||||||
// encoded client poll response is sent back as AMP-armored HTML.
|
// encoded client poll response is sent back as AMP-armored HTML.
|
||||||
func ampClientOffers(i *IPC, w http.ResponseWriter, r *http.Request) {
|
func ampClientOffers(i *IPC, w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), ClientTimeout*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
// The encoded client poll message immediately follows the /amp/client/
|
// The encoded client poll message immediately follows the /amp/client/
|
||||||
// path prefix, so this function unfortunately needs to be aware of and
|
// path prefix, so this function unfortunately needs to be aware of and
|
||||||
// remote its own routing prefix.
|
// remote its own routing prefix.
|
||||||
|
@ -38,6 +43,7 @@ func ampClientOffers(i *IPC, w http.ResponseWriter, r *http.Request) {
|
||||||
Body: encPollReq,
|
Body: encPollReq,
|
||||||
RemoteAddr: util.GetClientIp(r),
|
RemoteAddr: util.GetClientIp(r),
|
||||||
RendezvousMethod: messages.RendezvousAmpCache,
|
RendezvousMethod: messages.RendezvousAmpCache,
|
||||||
|
Context: ctx,
|
||||||
}
|
}
|
||||||
err = i.ClientOffers(arg, &response)
|
err = i.ClientOffers(arg, &response)
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -2,12 +2,14 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages"
|
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages"
|
||||||
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/util"
|
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/util"
|
||||||
|
@ -132,6 +134,9 @@ snowflake proxy, which responds with the SDP answer to be sent in
|
||||||
the HTTP response back to the client.
|
the HTTP response back to the client.
|
||||||
*/
|
*/
|
||||||
func clientOffers(i *IPC, w http.ResponseWriter, r *http.Request) {
|
func clientOffers(i *IPC, w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx, cancel := context.WithTimeout(r.Context(), ClientTimeout*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, readLimit))
|
body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, readLimit))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Error reading client request: %s", err.Error())
|
log.Printf("Error reading client request: %s", err.Error())
|
||||||
|
@ -163,6 +168,7 @@ func clientOffers(i *IPC, w http.ResponseWriter, r *http.Request) {
|
||||||
Body: body,
|
Body: body,
|
||||||
RemoteAddr: util.GetClientIp(r),
|
RemoteAddr: util.GetClientIp(r),
|
||||||
RendezvousMethod: messages.RendezvousHttp,
|
RendezvousMethod: messages.RendezvousHttp,
|
||||||
|
Context: ctx,
|
||||||
}
|
}
|
||||||
|
|
||||||
var response []byte
|
var response []byte
|
||||||
|
|
|
@ -217,7 +217,7 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
|
||||||
i.ctx.metrics.lock.Lock()
|
i.ctx.metrics.lock.Lock()
|
||||||
i.ctx.metrics.clientRoundtripEstimate = time.Since(startTime) / time.Millisecond
|
i.ctx.metrics.clientRoundtripEstimate = time.Since(startTime) / time.Millisecond
|
||||||
i.ctx.metrics.lock.Unlock()
|
i.ctx.metrics.lock.Unlock()
|
||||||
case <-time.After(time.Second * ClientTimeout):
|
case <-arg.Context.Done():
|
||||||
i.ctx.metrics.lock.Lock()
|
i.ctx.metrics.lock.Lock()
|
||||||
i.ctx.metrics.UpdateRendezvousStats(arg.RemoteAddr, arg.RendezvousMethod, offer.natType, "timeout")
|
i.ctx.metrics.UpdateRendezvousStats(arg.RemoteAddr, arg.RendezvousMethod, offer.natType, "timeout")
|
||||||
i.ctx.metrics.lock.Unlock()
|
i.ctx.metrics.lock.Unlock()
|
||||||
|
|
|
@ -124,18 +124,21 @@ func (r *sqsHandler) cleanupClientQueues(ctx context.Context) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *sqsHandler) handleMessage(context context.Context, message *types.Message) {
|
func (r *sqsHandler) handleMessage(mainCtx context.Context, message *types.Message) {
|
||||||
var encPollReq []byte
|
var encPollReq []byte
|
||||||
var response []byte
|
var response []byte
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(mainCtx, ClientTimeout*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
clientID := message.MessageAttributes["ClientID"].StringValue
|
clientID := message.MessageAttributes["ClientID"].StringValue
|
||||||
if clientID == nil {
|
if clientID == nil {
|
||||||
log.Println("SQSHandler: got SDP offer in SQS message with no client ID. ignoring this message.")
|
log.Println("SQSHandler: got SDP offer in SQS message with no client ID. ignoring this message.")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := r.SQSClient.CreateQueue(context, &sqs.CreateQueueInput{
|
res, err := r.SQSClient.CreateQueue(ctx, &sqs.CreateQueueInput{
|
||||||
QueueName: aws.String("snowflake-client-" + *clientID),
|
QueueName: aws.String("snowflake-client-" + *clientID),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -167,6 +170,7 @@ func (r *sqsHandler) handleMessage(context context.Context, message *types.Messa
|
||||||
Body: encPollReq,
|
Body: encPollReq,
|
||||||
RemoteAddr: remoteAddr,
|
RemoteAddr: remoteAddr,
|
||||||
RendezvousMethod: messages.RendezvousSqs,
|
RendezvousMethod: messages.RendezvousSqs,
|
||||||
|
Context: ctx,
|
||||||
}
|
}
|
||||||
err = r.IPC.ClientOffers(arg, &response)
|
err = r.IPC.ClientOffers(arg, &response)
|
||||||
|
|
||||||
|
@ -175,7 +179,7 @@ func (r *sqsHandler) handleMessage(context context.Context, message *types.Messa
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
r.SQSClient.SendMessage(context, &sqs.SendMessageInput{
|
r.SQSClient.SendMessage(ctx, &sqs.SendMessageInput{
|
||||||
QueueUrl: answerSQSURL,
|
QueueUrl: answerSQSURL,
|
||||||
MessageBody: aws.String(string(response)),
|
MessageBody: aws.String(string(response)),
|
||||||
})
|
})
|
||||||
|
|
|
@ -138,7 +138,7 @@ func TestSQS(t *testing.T) {
|
||||||
sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background())
|
sqsHandlerContext, sqsCancelFunc := context.WithCancel(context.Background())
|
||||||
var numTimes atomic.Uint32
|
var numTimes atomic.Uint32
|
||||||
|
|
||||||
mockSQSClient.EXPECT().ReceiveMessage(sqsHandlerContext, &sqsReceiveMessageInput).AnyTimes().DoAndReturn(
|
mockSQSClient.EXPECT().ReceiveMessage(gomock.Any(), &sqsReceiveMessageInput).AnyTimes().DoAndReturn(
|
||||||
func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) {
|
func(ctx context.Context, input *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) {
|
||||||
|
|
||||||
n := numTimes.Add(1)
|
n := numTimes.Add(1)
|
||||||
|
@ -153,11 +153,11 @@ func TestSQS(t *testing.T) {
|
||||||
return nil, errors.New("error")
|
return nil, errors.New("error")
|
||||||
|
|
||||||
})
|
})
|
||||||
mockSQSClient.EXPECT().CreateQueue(sqsHandlerContext, &sqsCreateQueueInput).Return(&sqs.CreateQueueOutput{
|
mockSQSClient.EXPECT().CreateQueue(gomock.Any(), &sqsCreateQueueInput).Return(&sqs.CreateQueueOutput{
|
||||||
QueueUrl: responseQueueURL,
|
QueueUrl: responseQueueURL,
|
||||||
}, nil).AnyTimes()
|
}, nil).AnyTimes()
|
||||||
mockSQSClient.EXPECT().DeleteMessage(gomock.Any(), gomock.Any()).AnyTimes()
|
mockSQSClient.EXPECT().DeleteMessage(gomock.Any(), gomock.Any()).AnyTimes()
|
||||||
mockSQSClient.EXPECT().SendMessage(sqsHandlerContext, gomock.Any()).Times(1).DoAndReturn(
|
mockSQSClient.EXPECT().SendMessage(gomock.Any(), gomock.Any()).Times(1).DoAndReturn(
|
||||||
func(ctx context.Context, input *sqs.SendMessageInput, optFns ...func(*sqs.Options)) (*sqs.SendMessageOutput, error) {
|
func(ctx context.Context, input *sqs.SendMessageInput, optFns ...func(*sqs.Options)) (*sqs.SendMessageOutput, error) {
|
||||||
c.So(input.MessageBody, ShouldEqual, aws.String("{\"answer\":\"fake answer\"}"))
|
c.So(input.MessageBody, ShouldEqual, aws.String("{\"answer\":\"fake answer\"}"))
|
||||||
// Ensure that match is correctly recorded in metrics
|
// Ensure that match is correctly recorded in metrics
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package messages
|
package messages
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -16,6 +17,7 @@ type Arg struct {
|
||||||
Body []byte
|
Body []byte
|
||||||
RemoteAddr string
|
RemoteAddr string
|
||||||
RendezvousMethod RendezvousMethod
|
RendezvousMethod RendezvousMethod
|
||||||
|
Context context.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue