snowflake/broker/sqs.go
Cecylia Bocovich 2250bc86f6
Process and read broker SQS messages more quickly
We're losing a lot of messages from the broker SQS queue because they
are exceeding their maximum lifetime before being read and processed by
the broker. This change speeds up that process by increasing the size of
messagesChn and processing the messages within a go routine.
2025-02-20 09:37:18 -05:00

232 lines
7.1 KiB
Go

package main
import (
"context"
"log"
"strconv"
"strings"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/sqs"
"github.com/aws/aws-sdk-go-v2/service/sqs/types"
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages"
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/sqsclient"
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/util"
)
const (
cleanupThreshold = -2 * time.Minute
)
type sqsHandler struct {
SQSClient sqsclient.SQSClient
SQSQueueURL *string
IPC *IPC
cleanupInterval time.Duration
}
func (r *sqsHandler) pollMessages(ctx context.Context, chn chan<- *types.Message) {
for {
select {
case <-ctx.Done():
// if context is cancelled
return
default:
res, err := r.SQSClient.ReceiveMessage(ctx, &sqs.ReceiveMessageInput{
QueueUrl: r.SQSQueueURL,
MaxNumberOfMessages: 10,
WaitTimeSeconds: 15,
MessageAttributeNames: []string{
string(types.QueueAttributeNameAll),
},
})
if err != nil {
log.Printf("SQSHandler: encountered error while polling for messages: %v\n", err)
continue
}
for _, message := range res.Messages {
chn <- &message
}
}
}
}
func (r *sqsHandler) cleanupClientQueues(ctx context.Context) {
for range time.NewTicker(r.cleanupInterval).C {
// Runs at fixed intervals to clean up any client queues that were last changed more than 2 minutes ago
select {
case <-ctx.Done():
// if context is cancelled
return
default:
queueURLsList := []string{}
var nextToken *string
for {
res, err := r.SQSClient.ListQueues(ctx, &sqs.ListQueuesInput{
QueueNamePrefix: aws.String("snowflake-client-"),
MaxResults: aws.Int32(1000),
NextToken: nextToken,
})
if err != nil {
log.Printf("SQSHandler: encountered error while retrieving client queues to clean up: %v\n", err)
// client queues will be cleaned up the next time the cleanup operation is triggered automatically
break
}
queueURLsList = append(queueURLsList, res.QueueUrls...)
if res.NextToken == nil {
break
} else {
nextToken = res.NextToken
}
}
numDeleted := 0
cleanupCutoff := time.Now().Add(cleanupThreshold)
for _, queueURL := range queueURLsList {
if !strings.Contains(queueURL, "snowflake-client-") {
continue
}
res, err := r.SQSClient.GetQueueAttributes(ctx, &sqs.GetQueueAttributesInput{
QueueUrl: aws.String(queueURL),
AttributeNames: []types.QueueAttributeName{types.QueueAttributeNameLastModifiedTimestamp},
})
if err != nil {
// According to the AWS SQS docs, the deletion process for a queue can take up to 60 seconds. So the queue
// can be in the process of being deleted, but will still be returned by the ListQueues operation, but
// fail when we try to GetQueueAttributes for the queue
log.Printf("SQSHandler: encountered error while getting attribute of client queue %s. queue may already be deleted.\n", queueURL)
continue
}
lastModifiedInt64, err := strconv.ParseInt(res.Attributes[string(types.QueueAttributeNameLastModifiedTimestamp)], 10, 64)
if err != nil {
log.Printf("SQSHandler: encountered invalid lastModifiedTimetamp value from client queue %s: %v\n", queueURL, err)
continue
}
lastModified := time.Unix(lastModifiedInt64, 0)
if lastModified.Before(cleanupCutoff) {
_, err := r.SQSClient.DeleteQueue(ctx, &sqs.DeleteQueueInput{
QueueUrl: aws.String(queueURL),
})
if err != nil {
log.Printf("SQSHandler: encountered error when deleting client queue %s: %v\n", queueURL, err)
continue
} else {
numDeleted += 1
}
}
}
log.Printf("SQSHandler: finished running iteration of client queue cleanup. found and deleted %d client queues.\n", numDeleted)
}
}
}
func (r *sqsHandler) handleMessage(context context.Context, message *types.Message) {
var encPollReq []byte
var response []byte
var err error
clientID := message.MessageAttributes["ClientID"].StringValue
if clientID == nil {
log.Println("SQSHandler: got SDP offer in SQS message with no client ID. ignoring this message.")
return
}
res, err := r.SQSClient.CreateQueue(context, &sqs.CreateQueueInput{
QueueName: aws.String("snowflake-client-" + *clientID),
})
if err != nil {
log.Printf("SQSHandler: error encountered when creating answer queue for client %s: %v\n", *clientID, err)
return
}
answerSQSURL := res.QueueUrl
encPollReq = []byte(*message.Body)
// Get best guess Client IP for geolocating
remoteAddr := ""
req, err := messages.DecodeClientPollRequest(encPollReq)
if err != nil {
log.Printf("SQSHandler: error encounted when decoding client poll request %s: %v\n", *clientID, err)
} else {
sdp, err := util.DeserializeSessionDescription(req.Offer)
if err != nil {
log.Printf("SQSHandler: error encounted when deserializing session desc %s: %v\n", *clientID, err)
} else {
candidateAddrs := util.GetCandidateAddrs(sdp.SDP)
if len(candidateAddrs) > 0 {
remoteAddr = candidateAddrs[0].String()
}
}
}
arg := messages.Arg{
Body: encPollReq,
RemoteAddr: remoteAddr,
RendezvousMethod: messages.RendezvousSqs,
}
err = r.IPC.ClientOffers(arg, &response)
if err != nil {
log.Printf("SQSHandler: error encountered when handling message: %v\n", err)
return
}
r.SQSClient.SendMessage(context, &sqs.SendMessageInput{
QueueUrl: answerSQSURL,
MessageBody: aws.String(string(response)),
})
}
func (r *sqsHandler) deleteMessage(context context.Context, message *types.Message) {
r.SQSClient.DeleteMessage(context, &sqs.DeleteMessageInput{
QueueUrl: r.SQSQueueURL,
ReceiptHandle: message.ReceiptHandle,
})
}
func newSQSHandler(context context.Context, client sqsclient.SQSClient, sqsQueueName string, region string, i *IPC) (*sqsHandler, error) {
// Creates the queue if a queue with the same name doesn't exist. If a queue with the same name and attributes
// already exists, then nothing will happen. If a queue with the same name, but different attributes exists, then
// an error will be returned
res, err := client.CreateQueue(context, &sqs.CreateQueueInput{
QueueName: aws.String(sqsQueueName),
Attributes: map[string]string{
"MessageRetentionPeriod": strconv.FormatInt(int64((5 * time.Minute).Seconds()), 10),
},
})
if err != nil {
return nil, err
}
return &sqsHandler{
SQSClient: client,
SQSQueueURL: res.QueueUrl,
IPC: i,
cleanupInterval: time.Second * 30,
}, nil
}
func (r *sqsHandler) PollAndHandleMessages(ctx context.Context) {
log.Println("SQSHandler: Starting to poll for messages at: " + *r.SQSQueueURL)
messagesChn := make(chan *types.Message, 20)
go r.pollMessages(ctx, messagesChn)
go r.cleanupClientQueues(ctx)
for message := range messagesChn {
select {
case <-ctx.Done():
// if context is cancelled
return
default:
go func(msg *types.Message) {
r.handleMessage(ctx, msg)
r.deleteMessage(ctx, msg)
}(message)
}
}
}