mirror of
https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake.git
synced 2025-10-13 20:11:19 -04:00
Parse ClientPollRequest version in DecodeClientPollRequest
Instead of IPC.ClientOffers. This makes things consistent with EncodeClientPollRequest which adds the version while serializing.
This commit is contained in:
parent
6fd0f1ae5d
commit
829cacac5f
6 changed files with 52 additions and 57 deletions
|
@ -148,6 +148,7 @@ func clientOffers(i *IPC, w http.ResponseWriter, r *http.Request) {
|
||||||
req := messages.ClientPollRequest{
|
req := messages.ClientPollRequest{
|
||||||
Offer: string(body),
|
Offer: string(body),
|
||||||
NAT: r.Header.Get("Snowflake-NAT-Type"),
|
NAT: r.Header.Get("Snowflake-NAT-Type"),
|
||||||
|
Version: messages.ClientVersion1_0,
|
||||||
}
|
}
|
||||||
body, err = req.EncodeClientPollRequest()
|
body, err = req.EncodeClientPollRequest()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"container/heap"
|
"container/heap"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
@ -21,12 +20,6 @@ const (
|
||||||
NATUnrestricted = "unrestricted"
|
NATUnrestricted = "unrestricted"
|
||||||
)
|
)
|
||||||
|
|
||||||
type clientVersion int
|
|
||||||
|
|
||||||
const (
|
|
||||||
v1 clientVersion = iota
|
|
||||||
)
|
|
||||||
|
|
||||||
type IPC struct {
|
type IPC struct {
|
||||||
ctx *BrokerContext
|
ctx *BrokerContext
|
||||||
}
|
}
|
||||||
|
@ -132,32 +125,16 @@ func sendClientResponse(resp *messages.ClientPollResponse, response *[]byte) err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
|
func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
|
||||||
var version clientVersion
|
|
||||||
|
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
body := arg.Body
|
|
||||||
|
|
||||||
parts := bytes.SplitN(body, []byte("\n"), 2)
|
req, err := messages.DecodeClientPollRequest(arg.Body)
|
||||||
if len(parts) < 2 {
|
if err != nil {
|
||||||
// no version number found
|
|
||||||
err := fmt.Errorf("unsupported message version")
|
|
||||||
return sendClientResponse(&messages.ClientPollResponse{Error: err.Error()}, response)
|
|
||||||
}
|
|
||||||
body = parts[1]
|
|
||||||
if string(parts[0]) == "1.0" {
|
|
||||||
version = v1
|
|
||||||
} else {
|
|
||||||
err := fmt.Errorf("unsupported message version")
|
|
||||||
return sendClientResponse(&messages.ClientPollResponse{Error: err.Error()}, response)
|
return sendClientResponse(&messages.ClientPollResponse{Error: err.Error()}, response)
|
||||||
}
|
}
|
||||||
|
|
||||||
var offer *ClientOffer
|
var offer *ClientOffer
|
||||||
switch version {
|
switch req.Version {
|
||||||
case v1:
|
case messages.ClientVersion1_0:
|
||||||
req, err := messages.DecodeClientPollRequest(body)
|
|
||||||
if err != nil {
|
|
||||||
return sendClientResponse(&messages.ClientPollResponse{Error: err.Error()}, response)
|
|
||||||
}
|
|
||||||
offer = &ClientOffer{
|
offer = &ClientOffer{
|
||||||
natType: req.NAT,
|
natType: req.NAT,
|
||||||
sdp: []byte(req.Offer),
|
sdp: []byte(req.Offer),
|
||||||
|
@ -188,8 +165,8 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
|
||||||
i.ctx.metrics.clientRestrictedDeniedCount++
|
i.ctx.metrics.clientRestrictedDeniedCount++
|
||||||
}
|
}
|
||||||
i.ctx.metrics.lock.Unlock()
|
i.ctx.metrics.lock.Unlock()
|
||||||
switch version {
|
switch req.Version {
|
||||||
case v1:
|
case messages.ClientVersion1_0:
|
||||||
resp := &messages.ClientPollResponse{Error: messages.StrNoProxies}
|
resp := &messages.ClientPollResponse{Error: messages.StrNoProxies}
|
||||||
return sendClientResponse(resp, response)
|
return sendClientResponse(resp, response)
|
||||||
default:
|
default:
|
||||||
|
@ -204,8 +181,6 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
|
||||||
i.ctx.snowflakeLock.Unlock()
|
i.ctx.snowflakeLock.Unlock()
|
||||||
snowflake.offerChannel <- offer
|
snowflake.offerChannel <- offer
|
||||||
|
|
||||||
var err error
|
|
||||||
|
|
||||||
// Wait for the answer to be returned on the channel or timeout.
|
// Wait for the answer to be returned on the channel or timeout.
|
||||||
select {
|
select {
|
||||||
case answer := <-snowflake.answerChannel:
|
case answer := <-snowflake.answerChannel:
|
||||||
|
@ -213,8 +188,8 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
|
||||||
i.ctx.metrics.clientProxyMatchCount++
|
i.ctx.metrics.clientProxyMatchCount++
|
||||||
i.ctx.metrics.promMetrics.ClientPollTotal.With(prometheus.Labels{"nat": offer.natType, "status": "matched"}).Inc()
|
i.ctx.metrics.promMetrics.ClientPollTotal.With(prometheus.Labels{"nat": offer.natType, "status": "matched"}).Inc()
|
||||||
i.ctx.metrics.lock.Unlock()
|
i.ctx.metrics.lock.Unlock()
|
||||||
switch version {
|
switch req.Version {
|
||||||
case v1:
|
case messages.ClientVersion1_0:
|
||||||
resp := &messages.ClientPollResponse{Answer: answer}
|
resp := &messages.ClientPollResponse{Answer: answer}
|
||||||
err = sendClientResponse(resp, response)
|
err = sendClientResponse(resp, response)
|
||||||
default:
|
default:
|
||||||
|
@ -224,8 +199,8 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
|
||||||
i.ctx.metrics.clientRoundtripEstimate = time.Since(startTime) / time.Millisecond
|
i.ctx.metrics.clientRoundtripEstimate = time.Since(startTime) / time.Millisecond
|
||||||
case <-time.After(time.Second * ClientTimeout):
|
case <-time.After(time.Second * ClientTimeout):
|
||||||
log.Println("Client: Timed out.")
|
log.Println("Client: Timed out.")
|
||||||
switch version {
|
switch req.Version {
|
||||||
case v1:
|
case messages.ClientVersion1_0:
|
||||||
resp := &messages.ClientPollResponse{Error: messages.StrTimedOut}
|
resp := &messages.ClientPollResponse{Error: messages.StrTimedOut}
|
||||||
err = sendClientResponse(resp, response)
|
err = sendClientResponse(resp, response)
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -124,6 +124,7 @@ func (bc *BrokerChannel) Negotiate(offer *webrtc.SessionDescription) (
|
||||||
req := &messages.ClientPollRequest{
|
req := &messages.ClientPollRequest{
|
||||||
Offer: offerSDP,
|
Offer: offerSDP,
|
||||||
NAT: bc.natType,
|
NAT: bc.natType,
|
||||||
|
Version: messages.ClientVersion1_0,
|
||||||
}
|
}
|
||||||
encReq, err := req.EncodeClientPollRequest()
|
encReq, err := req.EncodeClientPollRequest()
|
||||||
bc.lock.Unlock()
|
bc.lock.Unlock()
|
||||||
|
|
|
@ -45,6 +45,7 @@ func makeEncPollReq(offer string) []byte {
|
||||||
encPollReq, err := (&messages.ClientPollRequest{
|
encPollReq, err := (&messages.ClientPollRequest{
|
||||||
Offer: offer,
|
Offer: offer,
|
||||||
NAT: nat.NATUnknown,
|
NAT: nat.NATUnknown,
|
||||||
|
Version: messages.ClientVersion1_0,
|
||||||
}).EncodeClientPollRequest()
|
}).EncodeClientPollRequest()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
|
|
|
@ -4,13 +4,14 @@
|
||||||
package messages
|
package messages
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"git.torproject.org/pluggable-transports/snowflake.git/v2/common/nat"
|
"git.torproject.org/pluggable-transports/snowflake.git/v2/common/nat"
|
||||||
)
|
)
|
||||||
|
|
||||||
const ClientVersion = "1.0"
|
const ClientVersion1_0 = "1.0"
|
||||||
|
|
||||||
/* Client--Broker protocol v1.x specification:
|
/* Client--Broker protocol v1.x specification:
|
||||||
|
|
||||||
|
@ -51,22 +52,39 @@ for the error.
|
||||||
type ClientPollRequest struct {
|
type ClientPollRequest struct {
|
||||||
Offer string `json:"offer"`
|
Offer string `json:"offer"`
|
||||||
NAT string `json:"nat"`
|
NAT string `json:"nat"`
|
||||||
|
Version string `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Encodes a poll message from a snowflake client
|
// Encodes a poll message from a snowflake client
|
||||||
func (req *ClientPollRequest) EncodeClientPollRequest() ([]byte, error) {
|
func (req *ClientPollRequest) EncodeClientPollRequest() ([]byte, error) {
|
||||||
|
if req.Version != ClientVersion1_0 {
|
||||||
|
return nil, fmt.Errorf("unsupported message version")
|
||||||
|
}
|
||||||
body, err := json.Marshal(req)
|
body, err := json.Marshal(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return append([]byte(ClientVersion+"\n"), body...), nil
|
return append([]byte(req.Version+"\n"), body...), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decodes a poll message from a snowflake client
|
// Decodes a poll message from a snowflake client
|
||||||
func DecodeClientPollRequest(data []byte) (*ClientPollRequest, error) {
|
func DecodeClientPollRequest(data []byte) (*ClientPollRequest, error) {
|
||||||
|
parts := bytes.SplitN(data, []byte("\n"), 2)
|
||||||
|
|
||||||
|
if len(parts) < 2 {
|
||||||
|
// no version number found
|
||||||
|
return nil, fmt.Errorf("unsupported message version")
|
||||||
|
}
|
||||||
|
|
||||||
var message ClientPollRequest
|
var message ClientPollRequest
|
||||||
|
|
||||||
err := json.Unmarshal(data, &message)
|
if string(parts[0]) == ClientVersion1_0 {
|
||||||
|
message.Version = ClientVersion1_0
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("unsupported message version")
|
||||||
|
}
|
||||||
|
|
||||||
|
err := json.Unmarshal(parts[1], &message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package messages
|
package messages
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -286,14 +285,16 @@ func TestDecodeClientPollRequest(t *testing.T) {
|
||||||
//version 1.0 client message
|
//version 1.0 client message
|
||||||
"unknown",
|
"unknown",
|
||||||
"fake",
|
"fake",
|
||||||
`{"nat":"unknown","offer":"fake"}`,
|
`1.0
|
||||||
|
{"nat":"unknown","offer":"fake"}`,
|
||||||
nil,
|
nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
//version 1.0 client message
|
//version 1.0 client message
|
||||||
"unknown",
|
"unknown",
|
||||||
"fake",
|
"fake",
|
||||||
`{"offer":"fake"}`,
|
`1.0
|
||||||
|
{"offer":"fake"}`,
|
||||||
nil,
|
nil,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -307,16 +308,17 @@ func TestDecodeClientPollRequest(t *testing.T) {
|
||||||
//no offer
|
//no offer
|
||||||
"",
|
"",
|
||||||
"",
|
"",
|
||||||
`{"nat":"unknown"}`,
|
`1.0
|
||||||
|
{"nat":"unknown"}`,
|
||||||
fmt.Errorf(""),
|
fmt.Errorf(""),
|
||||||
},
|
},
|
||||||
} {
|
} {
|
||||||
req, err := DecodeClientPollRequest([]byte(test.data))
|
req, err := DecodeClientPollRequest([]byte(test.data))
|
||||||
|
So(err, ShouldHaveSameTypeAs, test.err)
|
||||||
if test.err == nil {
|
if test.err == nil {
|
||||||
So(req.NAT, ShouldResemble, test.natType)
|
So(req.NAT, ShouldResemble, test.natType)
|
||||||
So(req.Offer, ShouldResemble, test.offer)
|
So(req.Offer, ShouldResemble, test.offer)
|
||||||
}
|
}
|
||||||
So(err, ShouldHaveSameTypeAs, test.err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
})
|
})
|
||||||
|
@ -327,13 +329,10 @@ func TestEncodeClientPollRequests(t *testing.T) {
|
||||||
req1 := &ClientPollRequest{
|
req1 := &ClientPollRequest{
|
||||||
NAT: "unknown",
|
NAT: "unknown",
|
||||||
Offer: "fake",
|
Offer: "fake",
|
||||||
|
Version: ClientVersion1_0,
|
||||||
}
|
}
|
||||||
b, err := req1.EncodeClientPollRequest()
|
b, err := req1.EncodeClientPollRequest()
|
||||||
So(err, ShouldEqual, nil)
|
So(err, ShouldEqual, nil)
|
||||||
fmt.Println(string(b))
|
|
||||||
parts := bytes.SplitN(b, []byte("\n"), 2)
|
|
||||||
So(string(parts[0]), ShouldEqual, "1.0")
|
|
||||||
b = parts[1]
|
|
||||||
req2, err := DecodeClientPollRequest(b)
|
req2, err := DecodeClientPollRequest(b)
|
||||||
So(err, ShouldEqual, nil)
|
So(err, ShouldEqual, nil)
|
||||||
So(req2, ShouldResemble, req1)
|
So(req2, ShouldResemble, req1)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue