package main import ( "context" "log" "strconv" "strings" "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/sqs" "github.com/aws/aws-sdk-go-v2/service/sqs/types" "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake/v2/common/messages" ) const ( cleanupInterval = time.Second * 30 cleanupThreshold = -2 * time.Minute ) type sqsHandler struct { SQSClient *sqs.Client SQSQueueURL *string IPC *IPC } func (r *sqsHandler) pollMessages(context context.Context, chn chan<- *types.Message) { for { res, err := r.SQSClient.ReceiveMessage(context, &sqs.ReceiveMessageInput{ QueueUrl: r.SQSQueueURL, MaxNumberOfMessages: 10, WaitTimeSeconds: 15, MessageAttributeNames: []string{ string(types.QueueAttributeNameAll), }, }) if err != nil { log.Printf("SQSHandler: encountered error while polling for messages: %v\n", err) } for _, message := range res.Messages { chn <- &message } } } func (r *sqsHandler) cleanupClientQueues(context context.Context) { for range time.Tick(cleanupInterval) { // Runs at fixed intervals to clean up any client queues that were last changed more than 2 minutes ago queueURLsList := []string{} var nextToken *string for { res, err := r.SQSClient.ListQueues(context, &sqs.ListQueuesInput{ QueueNamePrefix: aws.String("snowflake-client-"), MaxResults: aws.Int32(1000), NextToken: nextToken, }) if err != nil { log.Printf("SQSHandler: encountered error while retrieving client queues to clean up: %v\n", err) } queueURLsList = append(queueURLsList, res.QueueUrls...) if res.NextToken == nil { break } else { nextToken = res.NextToken } } numDeleted := 0 cleanupCutoff := time.Now().Add(cleanupThreshold) for _, queueURL := range queueURLsList { if !strings.Contains(queueURL, "snowflake-client-") { continue } res, err := r.SQSClient.GetQueueAttributes(context, &sqs.GetQueueAttributesInput{ QueueUrl: aws.String(queueURL), AttributeNames: []types.QueueAttributeName{types.QueueAttributeNameLastModifiedTimestamp}, }) if err != nil { // According to the AWS SQS docs, the deletion process for a queue can take up to 60 seconds. So the queue // can be in the process of being deleted, but will still be returned by the ListQueues operation, but // fail when we try to GetQueueAttributes for the queue log.Printf("SQSHandler: encountered error while getting attribute of client queue %s. queue may already be deleted.\n", queueURL) continue } lastModifiedInt64, err := strconv.ParseInt(res.Attributes[string(types.QueueAttributeNameLastModifiedTimestamp)], 10, 64) if err != nil { log.Printf("SQSHandler: encountered invalid lastModifiedTimetamp value from client queue %s: %v\n", queueURL, err) continue } lastModified := time.Unix(lastModifiedInt64, 0) if lastModified.Before(cleanupCutoff) { _, err := r.SQSClient.DeleteQueue(context, &sqs.DeleteQueueInput{ QueueUrl: aws.String(queueURL), }) if err != nil { log.Printf("SQSHandler: encountered error when deleting client queue %s: %v\n", queueURL, err) continue } else { numDeleted += 1 } } } log.Printf("SQSHandler: finished running iteration of client queue cleanup. found and deleted %d client queues.\n", numDeleted) } } func (r *sqsHandler) handleMessage(context context.Context, message *types.Message) { var encPollReq []byte var response []byte var err error clientID := message.MessageAttributes["ClientID"].StringValue if clientID == nil { log.Println("SQSHandler: got SDP offer in SQS message with no client ID. ignoring this message.") return } res, err := r.SQSClient.CreateQueue(context, &sqs.CreateQueueInput{ QueueName: aws.String("snowflake-client-" + *clientID), }) answerSQSURL := res.QueueUrl if err != nil { log.Printf("SQSHandler: error encountered when creating answer queue for client %s: %v\n", *clientID, err) } encPollReq = []byte(*message.Body) arg := messages.Arg{ Body: encPollReq, RemoteAddr: "", } err = r.IPC.ClientOffers(arg, &response) if err != nil { log.Printf("SQSHandler: error encountered when handling message: %v\n", err) return } r.SQSClient.SendMessage(context, &sqs.SendMessageInput{ QueueUrl: answerSQSURL, MessageBody: aws.String(string(response)), }) } func (r *sqsHandler) deleteMessage(context context.Context, message *types.Message) { r.SQSClient.DeleteMessage(context, &sqs.DeleteMessageInput{ QueueUrl: r.SQSQueueURL, ReceiptHandle: message.ReceiptHandle, }) } func newSQSHandler(context context.Context, sqsQueueName string, region string, i *IPC) (*sqsHandler, error) { log.Printf("Loading SQSHandler using SQS Queue %s in region %s\n", sqsQueueName, region) cfg, err := config.LoadDefaultConfig(context, config.WithRegion(region)) if err != nil { return nil, err } client := sqs.NewFromConfig(cfg) // Creates the queue if a queue with the same name doesn't exist. If a queue with the same name and attributes // already exists, then nothing will happen. If a queue with the same name, but different attributes exists, then // an error will be returned res, err := client.CreateQueue(context, &sqs.CreateQueueInput{ QueueName: aws.String(sqsQueueName), Attributes: map[string]string{ "MessageRetentionPeriod": strconv.FormatInt(int64((5 * time.Minute).Seconds()), 10), }, }) if err != nil { return nil, err } return &sqsHandler{ SQSClient: client, SQSQueueURL: res.QueueUrl, IPC: i, }, nil } func (r *sqsHandler) PollAndHandleMessages(context context.Context) { log.Println("SQSHandler: Starting to poll for messages at: " + *r.SQSQueueURL) messagesChn := make(chan *types.Message, 2) go r.pollMessages(context, messagesChn) go r.cleanupClientQueues(context) for message := range messagesChn { r.handleMessage(context, message) r.deleteMessage(context, message) } }