Parse ClientPollRequest version in DecodeClientPollRequest

Instead of IPC.ClientOffers.  This makes things consistent with
EncodeClientPollRequest which adds the version while serializing.
This commit is contained in:
Arlo Breault 2022-03-09 19:48:16 -05:00
parent 6fd0f1ae5d
commit 829cacac5f
6 changed files with 52 additions and 57 deletions

View file

@ -146,8 +146,9 @@ func clientOffers(i *IPC, w http.ResponseWriter, r *http.Request) {
if len(body) > 0 && body[0] == '{' {
isLegacy = true
req := messages.ClientPollRequest{
Offer: string(body),
NAT: r.Header.Get("Snowflake-NAT-Type"),
Offer: string(body),
NAT: r.Header.Get("Snowflake-NAT-Type"),
Version: messages.ClientVersion1_0,
}
body, err = req.EncodeClientPollRequest()
if err != nil {

View file

@ -1,7 +1,6 @@
package main
import (
"bytes"
"container/heap"
"fmt"
"log"
@ -21,12 +20,6 @@ const (
NATUnrestricted = "unrestricted"
)
type clientVersion int
const (
v1 clientVersion = iota
)
type IPC struct {
ctx *BrokerContext
}
@ -132,32 +125,16 @@ func sendClientResponse(resp *messages.ClientPollResponse, response *[]byte) err
}
func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
var version clientVersion
startTime := time.Now()
body := arg.Body
parts := bytes.SplitN(body, []byte("\n"), 2)
if len(parts) < 2 {
// no version number found
err := fmt.Errorf("unsupported message version")
return sendClientResponse(&messages.ClientPollResponse{Error: err.Error()}, response)
}
body = parts[1]
if string(parts[0]) == "1.0" {
version = v1
} else {
err := fmt.Errorf("unsupported message version")
req, err := messages.DecodeClientPollRequest(arg.Body)
if err != nil {
return sendClientResponse(&messages.ClientPollResponse{Error: err.Error()}, response)
}
var offer *ClientOffer
switch version {
case v1:
req, err := messages.DecodeClientPollRequest(body)
if err != nil {
return sendClientResponse(&messages.ClientPollResponse{Error: err.Error()}, response)
}
switch req.Version {
case messages.ClientVersion1_0:
offer = &ClientOffer{
natType: req.NAT,
sdp: []byte(req.Offer),
@ -188,8 +165,8 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
i.ctx.metrics.clientRestrictedDeniedCount++
}
i.ctx.metrics.lock.Unlock()
switch version {
case v1:
switch req.Version {
case messages.ClientVersion1_0:
resp := &messages.ClientPollResponse{Error: messages.StrNoProxies}
return sendClientResponse(resp, response)
default:
@ -204,8 +181,6 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
i.ctx.snowflakeLock.Unlock()
snowflake.offerChannel <- offer
var err error
// Wait for the answer to be returned on the channel or timeout.
select {
case answer := <-snowflake.answerChannel:
@ -213,8 +188,8 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
i.ctx.metrics.clientProxyMatchCount++
i.ctx.metrics.promMetrics.ClientPollTotal.With(prometheus.Labels{"nat": offer.natType, "status": "matched"}).Inc()
i.ctx.metrics.lock.Unlock()
switch version {
case v1:
switch req.Version {
case messages.ClientVersion1_0:
resp := &messages.ClientPollResponse{Answer: answer}
err = sendClientResponse(resp, response)
default:
@ -224,8 +199,8 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
i.ctx.metrics.clientRoundtripEstimate = time.Since(startTime) / time.Millisecond
case <-time.After(time.Second * ClientTimeout):
log.Println("Client: Timed out.")
switch version {
case v1:
switch req.Version {
case messages.ClientVersion1_0:
resp := &messages.ClientPollResponse{Error: messages.StrTimedOut}
err = sendClientResponse(resp, response)
default: