Stop storing version in ClientPollRequest

This continues to asserts the known version while decoding.  The client
will only ever generate the latest version while encoding and if the
response needs to change, the impetus will be a new feature, set in the
deserialized request, which can be used as a distinguisher.
This commit is contained in:
Arlo Breault 2022-03-16 20:26:40 -04:00
parent b73add1550
commit 281d917beb
6 changed files with 22 additions and 53 deletions

View file

@ -148,7 +148,6 @@ func clientOffers(i *IPC, w http.ResponseWriter, r *http.Request) {
req := messages.ClientPollRequest{ req := messages.ClientPollRequest{
Offer: string(body), Offer: string(body),
NAT: r.Header.Get("Snowflake-NAT-Type"), NAT: r.Header.Get("Snowflake-NAT-Type"),
Version: messages.ClientVersion1_0,
} }
body, err = req.EncodeClientPollRequest() body, err = req.EncodeClientPollRequest()
if err != nil { if err != nil {

View file

@ -129,16 +129,10 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
return sendClientResponse(&messages.ClientPollResponse{Error: err.Error()}, response) return sendClientResponse(&messages.ClientPollResponse{Error: err.Error()}, response)
} }
var offer *ClientOffer offer := &ClientOffer{
switch req.Version {
case messages.ClientVersion1_0:
offer = &ClientOffer{
natType: req.NAT, natType: req.NAT,
sdp: []byte(req.Offer), sdp: []byte(req.Offer),
} }
default:
panic("unknown version")
}
// Only hand out known restricted snowflakes to unrestricted clients // Only hand out known restricted snowflakes to unrestricted clients
var snowflakeHeap *SnowflakeHeap var snowflakeHeap *SnowflakeHeap
@ -162,13 +156,8 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
i.ctx.metrics.clientRestrictedDeniedCount++ i.ctx.metrics.clientRestrictedDeniedCount++
} }
i.ctx.metrics.lock.Unlock() i.ctx.metrics.lock.Unlock()
switch req.Version {
case messages.ClientVersion1_0:
resp := &messages.ClientPollResponse{Error: messages.StrNoProxies} resp := &messages.ClientPollResponse{Error: messages.StrNoProxies}
return sendClientResponse(resp, response) return sendClientResponse(resp, response)
default:
panic("unknown version")
}
} }
// Otherwise, find the most available snowflake proxy, and pass the offer to it. // Otherwise, find the most available snowflake proxy, and pass the offer to it.
@ -185,24 +174,14 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
i.ctx.metrics.clientProxyMatchCount++ i.ctx.metrics.clientProxyMatchCount++
i.ctx.metrics.promMetrics.ClientPollTotal.With(prometheus.Labels{"nat": offer.natType, "status": "matched"}).Inc() i.ctx.metrics.promMetrics.ClientPollTotal.With(prometheus.Labels{"nat": offer.natType, "status": "matched"}).Inc()
i.ctx.metrics.lock.Unlock() i.ctx.metrics.lock.Unlock()
switch req.Version {
case messages.ClientVersion1_0:
resp := &messages.ClientPollResponse{Answer: answer} resp := &messages.ClientPollResponse{Answer: answer}
err = sendClientResponse(resp, response) err = sendClientResponse(resp, response)
default:
panic("unknown version")
}
// Initial tracking of elapsed time. // Initial tracking of elapsed time.
i.ctx.metrics.clientRoundtripEstimate = time.Since(startTime) / time.Millisecond i.ctx.metrics.clientRoundtripEstimate = time.Since(startTime) / time.Millisecond
case <-time.After(time.Second * ClientTimeout): case <-time.After(time.Second * ClientTimeout):
log.Println("Client: Timed out.") log.Println("Client: Timed out.")
switch req.Version {
case messages.ClientVersion1_0:
resp := &messages.ClientPollResponse{Error: messages.StrTimedOut} resp := &messages.ClientPollResponse{Error: messages.StrTimedOut}
err = sendClientResponse(resp, response) err = sendClientResponse(resp, response)
default:
panic("unknown version")
}
} }
i.ctx.snowflakeLock.Lock() i.ctx.snowflakeLock.Lock()

View file

@ -118,7 +118,6 @@ func (bc *BrokerChannel) Negotiate(offer *webrtc.SessionDescription) (
req := &messages.ClientPollRequest{ req := &messages.ClientPollRequest{
Offer: offerSDP, Offer: offerSDP,
NAT: bc.natType, NAT: bc.natType,
Version: messages.ClientVersion1_0,
} }
encReq, err := req.EncodeClientPollRequest() encReq, err := req.EncodeClientPollRequest()
bc.lock.Unlock() bc.lock.Unlock()

View file

@ -45,7 +45,6 @@ func makeEncPollReq(offer string) []byte {
encPollReq, err := (&messages.ClientPollRequest{ encPollReq, err := (&messages.ClientPollRequest{
Offer: offer, Offer: offer,
NAT: nat.NATUnknown, NAT: nat.NATUnknown,
Version: messages.ClientVersion1_0,
}).EncodeClientPollRequest() }).EncodeClientPollRequest()
if err != nil { if err != nil {
panic(err) panic(err)

View file

@ -11,7 +11,7 @@ import (
"git.torproject.org/pluggable-transports/snowflake.git/v2/common/nat" "git.torproject.org/pluggable-transports/snowflake.git/v2/common/nat"
) )
const ClientVersion1_0 = "1.0" const ClientVersion = "1.0"
/* Client--Broker protocol v1.x specification: /* Client--Broker protocol v1.x specification:
@ -52,19 +52,15 @@ for the error.
type ClientPollRequest struct { type ClientPollRequest struct {
Offer string `json:"offer"` Offer string `json:"offer"`
NAT string `json:"nat"` NAT string `json:"nat"`
Version string `json:"-"`
} }
// Encodes a poll message from a snowflake client // Encodes a poll message from a snowflake client
func (req *ClientPollRequest) EncodeClientPollRequest() ([]byte, error) { func (req *ClientPollRequest) EncodeClientPollRequest() ([]byte, error) {
if req.Version != ClientVersion1_0 {
return nil, fmt.Errorf("unsupported message version")
}
body, err := json.Marshal(req) body, err := json.Marshal(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return append([]byte(req.Version+"\n"), body...), nil return append([]byte(ClientVersion+"\n"), body...), nil
} }
// Decodes a poll message from a snowflake client // Decodes a poll message from a snowflake client
@ -78,9 +74,7 @@ func DecodeClientPollRequest(data []byte) (*ClientPollRequest, error) {
var message ClientPollRequest var message ClientPollRequest
if string(parts[0]) == ClientVersion1_0 { if string(parts[0]) != ClientVersion {
message.Version = ClientVersion1_0
} else {
return nil, fmt.Errorf("unsupported message version") return nil, fmt.Errorf("unsupported message version")
} }

View file

@ -329,7 +329,6 @@ func TestEncodeClientPollRequests(t *testing.T) {
req1 := &ClientPollRequest{ req1 := &ClientPollRequest{
NAT: "unknown", NAT: "unknown",
Offer: "fake", Offer: "fake",
Version: ClientVersion1_0,
} }
b, err := req1.EncodeClientPollRequest() b, err := req1.EncodeClientPollRequest()
So(err, ShouldEqual, nil) So(err, ShouldEqual, nil)