Parse ClientPollRequest version in DecodeClientPollRequest

Instead of IPC.ClientOffers.  This makes things consistent with
EncodeClientPollRequest which adds the version while serializing.
This commit is contained in:
Arlo Breault 2022-03-09 19:48:16 -05:00
parent 6fd0f1ae5d
commit 829cacac5f
6 changed files with 52 additions and 57 deletions

View file

@ -4,13 +4,14 @@
package messages
import (
"bytes"
"encoding/json"
"fmt"
"git.torproject.org/pluggable-transports/snowflake.git/v2/common/nat"
)
const ClientVersion = "1.0"
const ClientVersion1_0 = "1.0"
/* Client--Broker protocol v1.x specification:
@ -49,24 +50,41 @@ for the error.
*/
type ClientPollRequest struct {
Offer string `json:"offer"`
NAT string `json:"nat"`
Offer string `json:"offer"`
NAT string `json:"nat"`
Version string `json:"-"`
}
// Encodes a poll message from a snowflake client
func (req *ClientPollRequest) EncodeClientPollRequest() ([]byte, error) {
if req.Version != ClientVersion1_0 {
return nil, fmt.Errorf("unsupported message version")
}
body, err := json.Marshal(req)
if err != nil {
return nil, err
}
return append([]byte(ClientVersion+"\n"), body...), nil
return append([]byte(req.Version+"\n"), body...), nil
}
// Decodes a poll message from a snowflake client
func DecodeClientPollRequest(data []byte) (*ClientPollRequest, error) {
parts := bytes.SplitN(data, []byte("\n"), 2)
if len(parts) < 2 {
// no version number found
return nil, fmt.Errorf("unsupported message version")
}
var message ClientPollRequest
err := json.Unmarshal(data, &message)
if string(parts[0]) == ClientVersion1_0 {
message.Version = ClientVersion1_0
} else {
return nil, fmt.Errorf("unsupported message version")
}
err := json.Unmarshal(parts[1], &message)
if err != nil {
return nil, err
}

View file

@ -1,7 +1,6 @@
package messages
import (
"bytes"
"encoding/json"
"fmt"
"testing"
@ -286,14 +285,16 @@ func TestDecodeClientPollRequest(t *testing.T) {
//version 1.0 client message
"unknown",
"fake",
`{"nat":"unknown","offer":"fake"}`,
`1.0
{"nat":"unknown","offer":"fake"}`,
nil,
},
{
//version 1.0 client message
"unknown",
"fake",
`{"offer":"fake"}`,
`1.0
{"offer":"fake"}`,
nil,
},
{
@ -307,16 +308,17 @@ func TestDecodeClientPollRequest(t *testing.T) {
//no offer
"",
"",
`{"nat":"unknown"}`,
`1.0
{"nat":"unknown"}`,
fmt.Errorf(""),
},
} {
req, err := DecodeClientPollRequest([]byte(test.data))
So(err, ShouldHaveSameTypeAs, test.err)
if test.err == nil {
So(req.NAT, ShouldResemble, test.natType)
So(req.Offer, ShouldResemble, test.offer)
}
So(err, ShouldHaveSameTypeAs, test.err)
}
})
@ -325,15 +327,12 @@ func TestDecodeClientPollRequest(t *testing.T) {
func TestEncodeClientPollRequests(t *testing.T) {
Convey("Context", t, func() {
req1 := &ClientPollRequest{
NAT: "unknown",
Offer: "fake",
NAT: "unknown",
Offer: "fake",
Version: ClientVersion1_0,
}
b, err := req1.EncodeClientPollRequest()
So(err, ShouldEqual, nil)
fmt.Println(string(b))
parts := bytes.SplitN(b, []byte("\n"), 2)
So(string(parts[0]), ShouldEqual, "1.0")
b = parts[1]
req2, err := DecodeClientPollRequest(b)
So(err, ShouldEqual, nil)
So(req2, ShouldResemble, req1)