Parse ClientPollRequest version in DecodeClientPollRequest

Instead of IPC.ClientOffers.  This makes things consistent with
EncodeClientPollRequest which adds the version while serializing.
This commit is contained in:
Arlo Breault 2022-03-09 19:48:16 -05:00
parent 6fd0f1ae5d
commit 829cacac5f
6 changed files with 52 additions and 57 deletions

View file

@ -146,8 +146,9 @@ func clientOffers(i *IPC, w http.ResponseWriter, r *http.Request) {
if len(body) > 0 && body[0] == '{' { if len(body) > 0 && body[0] == '{' {
isLegacy = true isLegacy = true
req := messages.ClientPollRequest{ req := messages.ClientPollRequest{
Offer: string(body), Offer: string(body),
NAT: r.Header.Get("Snowflake-NAT-Type"), NAT: r.Header.Get("Snowflake-NAT-Type"),
Version: messages.ClientVersion1_0,
} }
body, err = req.EncodeClientPollRequest() body, err = req.EncodeClientPollRequest()
if err != nil { if err != nil {

View file

@ -1,7 +1,6 @@
package main package main
import ( import (
"bytes"
"container/heap" "container/heap"
"fmt" "fmt"
"log" "log"
@ -21,12 +20,6 @@ const (
NATUnrestricted = "unrestricted" NATUnrestricted = "unrestricted"
) )
type clientVersion int
const (
v1 clientVersion = iota
)
type IPC struct { type IPC struct {
ctx *BrokerContext ctx *BrokerContext
} }
@ -132,32 +125,16 @@ func sendClientResponse(resp *messages.ClientPollResponse, response *[]byte) err
} }
func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error { func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
var version clientVersion
startTime := time.Now() startTime := time.Now()
body := arg.Body
parts := bytes.SplitN(body, []byte("\n"), 2) req, err := messages.DecodeClientPollRequest(arg.Body)
if len(parts) < 2 { if err != nil {
// no version number found
err := fmt.Errorf("unsupported message version")
return sendClientResponse(&messages.ClientPollResponse{Error: err.Error()}, response)
}
body = parts[1]
if string(parts[0]) == "1.0" {
version = v1
} else {
err := fmt.Errorf("unsupported message version")
return sendClientResponse(&messages.ClientPollResponse{Error: err.Error()}, response) return sendClientResponse(&messages.ClientPollResponse{Error: err.Error()}, response)
} }
var offer *ClientOffer var offer *ClientOffer
switch version { switch req.Version {
case v1: case messages.ClientVersion1_0:
req, err := messages.DecodeClientPollRequest(body)
if err != nil {
return sendClientResponse(&messages.ClientPollResponse{Error: err.Error()}, response)
}
offer = &ClientOffer{ offer = &ClientOffer{
natType: req.NAT, natType: req.NAT,
sdp: []byte(req.Offer), sdp: []byte(req.Offer),
@ -188,8 +165,8 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
i.ctx.metrics.clientRestrictedDeniedCount++ i.ctx.metrics.clientRestrictedDeniedCount++
} }
i.ctx.metrics.lock.Unlock() i.ctx.metrics.lock.Unlock()
switch version { switch req.Version {
case v1: case messages.ClientVersion1_0:
resp := &messages.ClientPollResponse{Error: messages.StrNoProxies} resp := &messages.ClientPollResponse{Error: messages.StrNoProxies}
return sendClientResponse(resp, response) return sendClientResponse(resp, response)
default: default:
@ -204,8 +181,6 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
i.ctx.snowflakeLock.Unlock() i.ctx.snowflakeLock.Unlock()
snowflake.offerChannel <- offer snowflake.offerChannel <- offer
var err error
// Wait for the answer to be returned on the channel or timeout. // Wait for the answer to be returned on the channel or timeout.
select { select {
case answer := <-snowflake.answerChannel: case answer := <-snowflake.answerChannel:
@ -213,8 +188,8 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
i.ctx.metrics.clientProxyMatchCount++ i.ctx.metrics.clientProxyMatchCount++
i.ctx.metrics.promMetrics.ClientPollTotal.With(prometheus.Labels{"nat": offer.natType, "status": "matched"}).Inc() i.ctx.metrics.promMetrics.ClientPollTotal.With(prometheus.Labels{"nat": offer.natType, "status": "matched"}).Inc()
i.ctx.metrics.lock.Unlock() i.ctx.metrics.lock.Unlock()
switch version { switch req.Version {
case v1: case messages.ClientVersion1_0:
resp := &messages.ClientPollResponse{Answer: answer} resp := &messages.ClientPollResponse{Answer: answer}
err = sendClientResponse(resp, response) err = sendClientResponse(resp, response)
default: default:
@ -224,8 +199,8 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
i.ctx.metrics.clientRoundtripEstimate = time.Since(startTime) / time.Millisecond i.ctx.metrics.clientRoundtripEstimate = time.Since(startTime) / time.Millisecond
case <-time.After(time.Second * ClientTimeout): case <-time.After(time.Second * ClientTimeout):
log.Println("Client: Timed out.") log.Println("Client: Timed out.")
switch version { switch req.Version {
case v1: case messages.ClientVersion1_0:
resp := &messages.ClientPollResponse{Error: messages.StrTimedOut} resp := &messages.ClientPollResponse{Error: messages.StrTimedOut}
err = sendClientResponse(resp, response) err = sendClientResponse(resp, response)
default: default:

View file

@ -122,8 +122,9 @@ func (bc *BrokerChannel) Negotiate(offer *webrtc.SessionDescription) (
// Encode the client poll request. // Encode the client poll request.
bc.lock.Lock() bc.lock.Lock()
req := &messages.ClientPollRequest{ req := &messages.ClientPollRequest{
Offer: offerSDP, Offer: offerSDP,
NAT: bc.natType, NAT: bc.natType,
Version: messages.ClientVersion1_0,
} }
encReq, err := req.EncodeClientPollRequest() encReq, err := req.EncodeClientPollRequest()
bc.lock.Unlock() bc.lock.Unlock()

View file

@ -43,8 +43,9 @@ func (t errorTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// offer. // offer.
func makeEncPollReq(offer string) []byte { func makeEncPollReq(offer string) []byte {
encPollReq, err := (&messages.ClientPollRequest{ encPollReq, err := (&messages.ClientPollRequest{
Offer: offer, Offer: offer,
NAT: nat.NATUnknown, NAT: nat.NATUnknown,
Version: messages.ClientVersion1_0,
}).EncodeClientPollRequest() }).EncodeClientPollRequest()
if err != nil { if err != nil {
panic(err) panic(err)

View file

@ -4,13 +4,14 @@
package messages package messages
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"git.torproject.org/pluggable-transports/snowflake.git/v2/common/nat" "git.torproject.org/pluggable-transports/snowflake.git/v2/common/nat"
) )
const ClientVersion = "1.0" const ClientVersion1_0 = "1.0"
/* Client--Broker protocol v1.x specification: /* Client--Broker protocol v1.x specification:
@ -49,24 +50,41 @@ for the error.
*/ */
type ClientPollRequest struct { type ClientPollRequest struct {
Offer string `json:"offer"` Offer string `json:"offer"`
NAT string `json:"nat"` NAT string `json:"nat"`
Version string `json:"-"`
} }
// Encodes a poll message from a snowflake client // Encodes a poll message from a snowflake client
func (req *ClientPollRequest) EncodeClientPollRequest() ([]byte, error) { func (req *ClientPollRequest) EncodeClientPollRequest() ([]byte, error) {
if req.Version != ClientVersion1_0 {
return nil, fmt.Errorf("unsupported message version")
}
body, err := json.Marshal(req) body, err := json.Marshal(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return append([]byte(ClientVersion+"\n"), body...), nil return append([]byte(req.Version+"\n"), body...), nil
} }
// Decodes a poll message from a snowflake client // Decodes a poll message from a snowflake client
func DecodeClientPollRequest(data []byte) (*ClientPollRequest, error) { func DecodeClientPollRequest(data []byte) (*ClientPollRequest, error) {
parts := bytes.SplitN(data, []byte("\n"), 2)
if len(parts) < 2 {
// no version number found
return nil, fmt.Errorf("unsupported message version")
}
var message ClientPollRequest var message ClientPollRequest
err := json.Unmarshal(data, &message) if string(parts[0]) == ClientVersion1_0 {
message.Version = ClientVersion1_0
} else {
return nil, fmt.Errorf("unsupported message version")
}
err := json.Unmarshal(parts[1], &message)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -1,7 +1,6 @@
package messages package messages
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt" "fmt"
"testing" "testing"
@ -286,14 +285,16 @@ func TestDecodeClientPollRequest(t *testing.T) {
//version 1.0 client message //version 1.0 client message
"unknown", "unknown",
"fake", "fake",
`{"nat":"unknown","offer":"fake"}`, `1.0
{"nat":"unknown","offer":"fake"}`,
nil, nil,
}, },
{ {
//version 1.0 client message //version 1.0 client message
"unknown", "unknown",
"fake", "fake",
`{"offer":"fake"}`, `1.0
{"offer":"fake"}`,
nil, nil,
}, },
{ {
@ -307,16 +308,17 @@ func TestDecodeClientPollRequest(t *testing.T) {
//no offer //no offer
"", "",
"", "",
`{"nat":"unknown"}`, `1.0
{"nat":"unknown"}`,
fmt.Errorf(""), fmt.Errorf(""),
}, },
} { } {
req, err := DecodeClientPollRequest([]byte(test.data)) req, err := DecodeClientPollRequest([]byte(test.data))
So(err, ShouldHaveSameTypeAs, test.err)
if test.err == nil { if test.err == nil {
So(req.NAT, ShouldResemble, test.natType) So(req.NAT, ShouldResemble, test.natType)
So(req.Offer, ShouldResemble, test.offer) So(req.Offer, ShouldResemble, test.offer)
} }
So(err, ShouldHaveSameTypeAs, test.err)
} }
}) })
@ -325,15 +327,12 @@ func TestDecodeClientPollRequest(t *testing.T) {
func TestEncodeClientPollRequests(t *testing.T) { func TestEncodeClientPollRequests(t *testing.T) {
Convey("Context", t, func() { Convey("Context", t, func() {
req1 := &ClientPollRequest{ req1 := &ClientPollRequest{
NAT: "unknown", NAT: "unknown",
Offer: "fake", Offer: "fake",
Version: ClientVersion1_0,
} }
b, err := req1.EncodeClientPollRequest() b, err := req1.EncodeClientPollRequest()
So(err, ShouldEqual, nil) So(err, ShouldEqual, nil)
fmt.Println(string(b))
parts := bytes.SplitN(b, []byte("\n"), 2)
So(string(parts[0]), ShouldEqual, "1.0")
b = parts[1]
req2, err := DecodeClientPollRequest(b) req2, err := DecodeClientPollRequest(b)
So(err, ShouldEqual, nil) So(err, ShouldEqual, nil)
So(req2, ShouldResemble, req1) So(req2, ShouldResemble, req1)