mirror of
https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake.git
synced 2025-10-13 20:11:19 -04:00
Parse ClientPollRequest version in DecodeClientPollRequest
Instead of IPC.ClientOffers. This makes things consistent with EncodeClientPollRequest which adds the version while serializing.
This commit is contained in:
parent
6fd0f1ae5d
commit
829cacac5f
6 changed files with 52 additions and 57 deletions
|
@ -1,7 +1,6 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"container/heap"
|
||||
"fmt"
|
||||
"log"
|
||||
|
@ -21,12 +20,6 @@ const (
|
|||
NATUnrestricted = "unrestricted"
|
||||
)
|
||||
|
||||
type clientVersion int
|
||||
|
||||
const (
|
||||
v1 clientVersion = iota
|
||||
)
|
||||
|
||||
type IPC struct {
|
||||
ctx *BrokerContext
|
||||
}
|
||||
|
@ -132,32 +125,16 @@ func sendClientResponse(resp *messages.ClientPollResponse, response *[]byte) err
|
|||
}
|
||||
|
||||
func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
|
||||
var version clientVersion
|
||||
|
||||
startTime := time.Now()
|
||||
body := arg.Body
|
||||
|
||||
parts := bytes.SplitN(body, []byte("\n"), 2)
|
||||
if len(parts) < 2 {
|
||||
// no version number found
|
||||
err := fmt.Errorf("unsupported message version")
|
||||
return sendClientResponse(&messages.ClientPollResponse{Error: err.Error()}, response)
|
||||
}
|
||||
body = parts[1]
|
||||
if string(parts[0]) == "1.0" {
|
||||
version = v1
|
||||
} else {
|
||||
err := fmt.Errorf("unsupported message version")
|
||||
req, err := messages.DecodeClientPollRequest(arg.Body)
|
||||
if err != nil {
|
||||
return sendClientResponse(&messages.ClientPollResponse{Error: err.Error()}, response)
|
||||
}
|
||||
|
||||
var offer *ClientOffer
|
||||
switch version {
|
||||
case v1:
|
||||
req, err := messages.DecodeClientPollRequest(body)
|
||||
if err != nil {
|
||||
return sendClientResponse(&messages.ClientPollResponse{Error: err.Error()}, response)
|
||||
}
|
||||
switch req.Version {
|
||||
case messages.ClientVersion1_0:
|
||||
offer = &ClientOffer{
|
||||
natType: req.NAT,
|
||||
sdp: []byte(req.Offer),
|
||||
|
@ -188,8 +165,8 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
|
|||
i.ctx.metrics.clientRestrictedDeniedCount++
|
||||
}
|
||||
i.ctx.metrics.lock.Unlock()
|
||||
switch version {
|
||||
case v1:
|
||||
switch req.Version {
|
||||
case messages.ClientVersion1_0:
|
||||
resp := &messages.ClientPollResponse{Error: messages.StrNoProxies}
|
||||
return sendClientResponse(resp, response)
|
||||
default:
|
||||
|
@ -204,8 +181,6 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
|
|||
i.ctx.snowflakeLock.Unlock()
|
||||
snowflake.offerChannel <- offer
|
||||
|
||||
var err error
|
||||
|
||||
// Wait for the answer to be returned on the channel or timeout.
|
||||
select {
|
||||
case answer := <-snowflake.answerChannel:
|
||||
|
@ -213,8 +188,8 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
|
|||
i.ctx.metrics.clientProxyMatchCount++
|
||||
i.ctx.metrics.promMetrics.ClientPollTotal.With(prometheus.Labels{"nat": offer.natType, "status": "matched"}).Inc()
|
||||
i.ctx.metrics.lock.Unlock()
|
||||
switch version {
|
||||
case v1:
|
||||
switch req.Version {
|
||||
case messages.ClientVersion1_0:
|
||||
resp := &messages.ClientPollResponse{Answer: answer}
|
||||
err = sendClientResponse(resp, response)
|
||||
default:
|
||||
|
@ -224,8 +199,8 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
|
|||
i.ctx.metrics.clientRoundtripEstimate = time.Since(startTime) / time.Millisecond
|
||||
case <-time.After(time.Second * ClientTimeout):
|
||||
log.Println("Client: Timed out.")
|
||||
switch version {
|
||||
case v1:
|
||||
switch req.Version {
|
||||
case messages.ClientVersion1_0:
|
||||
resp := &messages.ClientPollResponse{Error: messages.StrTimedOut}
|
||||
err = sendClientResponse(resp, response)
|
||||
default:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue