snowflake/broker/sqs.go
Michael Pu 8fb17de152
Implement SQS rendezvous in client and broker
This features adds an additional rendezvous method to send client offers
and receive proxy answers through the use of Amazon SQS queues.

https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/-/issues/26151
2024-01-22 13:06:42 -05:00

195 lines
6 KiB
Go

package main
import (
"context"
"log"
"strconv"
"strings"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/sqs"
"github.com/aws/aws-sdk-go-v2/service/sqs/types"
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages"
)
const (
cleanupInterval = time.Second * 30
cleanupThreshold = -2 * time.Minute
)
type sqsHandler struct {
SQSClient *sqs.Client
SQSQueueURL *string
IPC *IPC
}
func (r *sqsHandler) pollMessages(context context.Context, chn chan<- *types.Message) {
for {
res, err := r.SQSClient.ReceiveMessage(context, &sqs.ReceiveMessageInput{
QueueUrl: r.SQSQueueURL,
MaxNumberOfMessages: 10,
WaitTimeSeconds: 15,
MessageAttributeNames: []string{
string(types.QueueAttributeNameAll),
},
})
if err != nil {
log.Printf("SQSHandler: encountered error while polling for messages: %v\n", err)
}
for _, message := range res.Messages {
chn <- &message
}
}
}
func (r *sqsHandler) cleanupClientQueues(context context.Context) {
for range time.Tick(cleanupInterval) {
// Runs at fixed intervals to clean up any client queues that were last changed more than 2 minutes ago
queueURLsList := []string{}
var nextToken *string
for {
res, err := r.SQSClient.ListQueues(context, &sqs.ListQueuesInput{
QueueNamePrefix: aws.String("snowflake-client-"),
MaxResults: aws.Int32(1000),
NextToken: nextToken,
})
if err != nil {
log.Printf("SQSHandler: encountered error while retrieving client queues to clean up: %v\n", err)
}
queueURLsList = append(queueURLsList, res.QueueUrls...)
if res.NextToken == nil {
break
} else {
nextToken = res.NextToken
}
}
numDeleted := 0
cleanupCutoff := time.Now().Add(cleanupThreshold)
for _, queueURL := range queueURLsList {
if !strings.Contains(queueURL, "snowflake-client-") {
continue
}
res, err := r.SQSClient.GetQueueAttributes(context, &sqs.GetQueueAttributesInput{
QueueUrl: aws.String(queueURL),
AttributeNames: []types.QueueAttributeName{types.QueueAttributeNameLastModifiedTimestamp},
})
if err != nil {
// According to the AWS SQS docs, the deletion process for a queue can take up to 60 seconds. So the queue
// can be in the process of being deleted, but will still be returned by the ListQueues operation, but
// fail when we try to GetQueueAttributes for the queue
log.Printf("SQSHandler: encountered error while getting attribute of client queue %s. queue may already be deleted.\n", queueURL)
continue
}
lastModifiedInt64, err := strconv.ParseInt(res.Attributes[string(types.QueueAttributeNameLastModifiedTimestamp)], 10, 64)
if err != nil {
log.Printf("SQSHandler: encountered invalid lastModifiedTimetamp value from client queue %s: %v\n", queueURL, err)
continue
}
lastModified := time.Unix(lastModifiedInt64, 0)
if lastModified.Before(cleanupCutoff) {
_, err := r.SQSClient.DeleteQueue(context, &sqs.DeleteQueueInput{
QueueUrl: aws.String(queueURL),
})
if err != nil {
log.Printf("SQSHandler: encountered error when deleting client queue %s: %v\n", queueURL, err)
continue
} else {
numDeleted += 1
}
}
}
log.Printf("SQSHandler: finished running iteration of client queue cleanup. found and deleted %d client queues.\n", numDeleted)
}
}
func (r *sqsHandler) handleMessage(context context.Context, message *types.Message) {
var encPollReq []byte
var response []byte
var err error
clientID := message.MessageAttributes["ClientID"].StringValue
if clientID == nil {
log.Println("SQSHandler: got SDP offer in SQS message with no client ID. ignoring this message.")
return
}
res, err := r.SQSClient.CreateQueue(context, &sqs.CreateQueueInput{
QueueName: aws.String("snowflake-client-" + *clientID),
})
answerSQSURL := res.QueueUrl
if err != nil {
log.Printf("SQSHandler: error encountered when creating answer queue for client %s: %v\n", *clientID, err)
}
encPollReq = []byte(*message.Body)
arg := messages.Arg{
Body: encPollReq,
RemoteAddr: "",
}
err = r.IPC.ClientOffers(arg, &response)
if err != nil {
log.Printf("SQSHandler: error encountered when handling message: %v\n", err)
return
}
r.SQSClient.SendMessage(context, &sqs.SendMessageInput{
QueueUrl: answerSQSURL,
MessageBody: aws.String(string(response)),
})
}
func (r *sqsHandler) deleteMessage(context context.Context, message *types.Message) {
r.SQSClient.DeleteMessage(context, &sqs.DeleteMessageInput{
QueueUrl: r.SQSQueueURL,
ReceiptHandle: message.ReceiptHandle,
})
}
func newSQSHandler(context context.Context, sqsQueueName string, region string, i *IPC) (*sqsHandler, error) {
log.Printf("Loading SQSHandler using SQS Queue %s in region %s\n", sqsQueueName, region)
cfg, err := config.LoadDefaultConfig(context, config.WithRegion(region))
if err != nil {
return nil, err
}
client := sqs.NewFromConfig(cfg)
// Creates the queue if a queue with the same name doesn't exist. If a queue with the same name and attributes
// already exists, then nothing will happen. If a queue with the same name, but different attributes exists, then
// an error will be returned
res, err := client.CreateQueue(context, &sqs.CreateQueueInput{
QueueName: aws.String(sqsQueueName),
Attributes: map[string]string{
"MessageRetentionPeriod": strconv.FormatInt(int64((5 * time.Minute).Seconds()), 10),
},
})
if err != nil {
return nil, err
}
return &sqsHandler{
SQSClient: client,
SQSQueueURL: res.QueueUrl,
IPC: i,
}, nil
}
func (r *sqsHandler) PollAndHandleMessages(context context.Context) {
log.Println("SQSHandler: Starting to poll for messages at: " + *r.SQSQueueURL)
messagesChn := make(chan *types.Message, 2)
go r.pollMessages(context, messagesChn)
go r.cleanupClientQueues(context)
for message := range messagesChn {
r.handleMessage(context, message)
r.deleteMessage(context, message)
}
}