mirror of
https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake.git
synced 2025-10-13 11:11:30 -04:00
Now that both SQS and AMP cache are pulling remote addresses from the SDP, avoid duplicate decodings of the ClientPollRequest by extracting the remote addr in ClientOffers.
217 lines
6.4 KiB
Go
217 lines
6.4 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"log"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/aws/aws-sdk-go-v2/aws"
|
|
"github.com/aws/aws-sdk-go-v2/service/sqs"
|
|
"github.com/aws/aws-sdk-go-v2/service/sqs/types"
|
|
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages"
|
|
"gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/sqsclient"
|
|
)
|
|
|
|
const (
|
|
cleanupThreshold = -2 * time.Minute
|
|
)
|
|
|
|
type sqsHandler struct {
|
|
SQSClient sqsclient.SQSClient
|
|
SQSQueueURL *string
|
|
IPC *IPC
|
|
cleanupInterval time.Duration
|
|
}
|
|
|
|
func (r *sqsHandler) pollMessages(ctx context.Context, chn chan<- *types.Message) {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
// if context is cancelled
|
|
return
|
|
default:
|
|
res, err := r.SQSClient.ReceiveMessage(ctx, &sqs.ReceiveMessageInput{
|
|
QueueUrl: r.SQSQueueURL,
|
|
MaxNumberOfMessages: 10,
|
|
WaitTimeSeconds: 15,
|
|
MessageAttributeNames: []string{
|
|
string(types.QueueAttributeNameAll),
|
|
},
|
|
})
|
|
|
|
if err != nil {
|
|
log.Printf("SQSHandler: encountered error while polling for messages: %v\n", err)
|
|
continue
|
|
}
|
|
|
|
for _, message := range res.Messages {
|
|
chn <- &message
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (r *sqsHandler) cleanupClientQueues(ctx context.Context) {
|
|
for range time.NewTicker(r.cleanupInterval).C {
|
|
// Runs at fixed intervals to clean up any client queues that were last changed more than 2 minutes ago
|
|
select {
|
|
case <-ctx.Done():
|
|
// if context is cancelled
|
|
return
|
|
default:
|
|
queueURLsList := []string{}
|
|
var nextToken *string
|
|
for {
|
|
res, err := r.SQSClient.ListQueues(ctx, &sqs.ListQueuesInput{
|
|
QueueNamePrefix: aws.String("snowflake-client-"),
|
|
MaxResults: aws.Int32(1000),
|
|
NextToken: nextToken,
|
|
})
|
|
if err != nil {
|
|
log.Printf("SQSHandler: encountered error while retrieving client queues to clean up: %v\n", err)
|
|
// client queues will be cleaned up the next time the cleanup operation is triggered automatically
|
|
break
|
|
}
|
|
queueURLsList = append(queueURLsList, res.QueueUrls...)
|
|
if res.NextToken == nil {
|
|
break
|
|
} else {
|
|
nextToken = res.NextToken
|
|
}
|
|
}
|
|
|
|
numDeleted := 0
|
|
cleanupCutoff := time.Now().Add(cleanupThreshold)
|
|
for _, queueURL := range queueURLsList {
|
|
if !strings.Contains(queueURL, "snowflake-client-") {
|
|
continue
|
|
}
|
|
res, err := r.SQSClient.GetQueueAttributes(ctx, &sqs.GetQueueAttributesInput{
|
|
QueueUrl: aws.String(queueURL),
|
|
AttributeNames: []types.QueueAttributeName{types.QueueAttributeNameLastModifiedTimestamp},
|
|
})
|
|
if err != nil {
|
|
// According to the AWS SQS docs, the deletion process for a queue can take up to 60 seconds. So the queue
|
|
// can be in the process of being deleted, but will still be returned by the ListQueues operation, but
|
|
// fail when we try to GetQueueAttributes for the queue
|
|
log.Printf("SQSHandler: encountered error while getting attribute of client queue %s. queue may already be deleted.\n", queueURL)
|
|
continue
|
|
}
|
|
lastModifiedInt64, err := strconv.ParseInt(res.Attributes[string(types.QueueAttributeNameLastModifiedTimestamp)], 10, 64)
|
|
if err != nil {
|
|
log.Printf("SQSHandler: encountered invalid lastModifiedTimetamp value from client queue %s: %v\n", queueURL, err)
|
|
continue
|
|
}
|
|
lastModified := time.Unix(lastModifiedInt64, 0)
|
|
if lastModified.Before(cleanupCutoff) {
|
|
_, err := r.SQSClient.DeleteQueue(ctx, &sqs.DeleteQueueInput{
|
|
QueueUrl: aws.String(queueURL),
|
|
})
|
|
if err != nil {
|
|
log.Printf("SQSHandler: encountered error when deleting client queue %s: %v\n", queueURL, err)
|
|
continue
|
|
} else {
|
|
numDeleted += 1
|
|
}
|
|
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (r *sqsHandler) handleMessage(mainCtx context.Context, message *types.Message) {
|
|
var encPollReq []byte
|
|
var response []byte
|
|
var err error
|
|
|
|
ctx, cancel := context.WithTimeout(mainCtx, ClientTimeout*time.Second)
|
|
defer cancel()
|
|
|
|
clientID := message.MessageAttributes["ClientID"].StringValue
|
|
if clientID == nil {
|
|
log.Println("SQSHandler: got SDP offer in SQS message with no client ID. ignoring this message.")
|
|
return
|
|
}
|
|
|
|
res, err := r.SQSClient.CreateQueue(ctx, &sqs.CreateQueueInput{
|
|
QueueName: aws.String("snowflake-client-" + *clientID),
|
|
})
|
|
if err != nil {
|
|
log.Printf("SQSHandler: error encountered when creating answer queue for client %s: %v\n", *clientID, err)
|
|
return
|
|
}
|
|
answerSQSURL := res.QueueUrl
|
|
|
|
encPollReq = []byte(*message.Body)
|
|
|
|
arg := messages.Arg{
|
|
Body: encPollReq,
|
|
RemoteAddr: "",
|
|
RendezvousMethod: messages.RendezvousSqs,
|
|
Context: ctx,
|
|
}
|
|
err = r.IPC.ClientOffers(arg, &response)
|
|
|
|
if err != nil {
|
|
log.Printf("SQSHandler: error encountered when handling message: %v\n", err)
|
|
return
|
|
}
|
|
|
|
r.SQSClient.SendMessage(ctx, &sqs.SendMessageInput{
|
|
QueueUrl: answerSQSURL,
|
|
MessageBody: aws.String(string(response)),
|
|
})
|
|
}
|
|
|
|
func (r *sqsHandler) deleteMessage(context context.Context, message *types.Message) {
|
|
r.SQSClient.DeleteMessage(context, &sqs.DeleteMessageInput{
|
|
QueueUrl: r.SQSQueueURL,
|
|
ReceiptHandle: message.ReceiptHandle,
|
|
})
|
|
}
|
|
|
|
func newSQSHandler(context context.Context, client sqsclient.SQSClient, sqsQueueName string, region string, i *IPC) (*sqsHandler, error) {
|
|
// Creates the queue if a queue with the same name doesn't exist. If a queue with the same name and attributes
|
|
// already exists, then nothing will happen. If a queue with the same name, but different attributes exists, then
|
|
// an error will be returned
|
|
res, err := client.CreateQueue(context, &sqs.CreateQueueInput{
|
|
QueueName: aws.String(sqsQueueName),
|
|
Attributes: map[string]string{
|
|
"MessageRetentionPeriod": strconv.FormatInt(int64((5 * time.Minute).Seconds()), 10),
|
|
},
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &sqsHandler{
|
|
SQSClient: client,
|
|
SQSQueueURL: res.QueueUrl,
|
|
IPC: i,
|
|
cleanupInterval: time.Second * 30,
|
|
}, nil
|
|
}
|
|
|
|
func (r *sqsHandler) PollAndHandleMessages(ctx context.Context) {
|
|
log.Println("SQSHandler: Starting to poll for messages at: " + *r.SQSQueueURL)
|
|
messagesChn := make(chan *types.Message, 20)
|
|
go r.pollMessages(ctx, messagesChn)
|
|
go r.cleanupClientQueues(ctx)
|
|
|
|
for message := range messagesChn {
|
|
select {
|
|
case <-ctx.Done():
|
|
// if context is cancelled
|
|
return
|
|
default:
|
|
go func(msg *types.Message) {
|
|
r.handleMessage(ctx, msg)
|
|
r.deleteMessage(ctx, msg)
|
|
}(message)
|
|
}
|
|
}
|
|
}
|