Get rid of legacy version

Move the logic for the legacy version into the http handlers and use a
shim when doing ipc.
This commit is contained in:
Arlo Breault 2021-06-03 17:04:58 -04:00
parent 0ced1cc324
commit 87ad06a5e2
3 changed files with 53 additions and 49 deletions

View file

@ -102,7 +102,6 @@ func proxyPolls(i *IPC, w http.ResponseWriter, r *http.Request) {
arg := messages.Arg{ arg := messages.Arg{
Body: body, Body: body,
RemoteAddr: r.RemoteAddr, RemoteAddr: r.RemoteAddr,
NatType: "",
} }
var response []byte var response []byte
@ -138,28 +137,57 @@ func clientOffers(i *IPC, w http.ResponseWriter, r *http.Request) {
return return
} }
// Handle the legacy version
isLegacy := false
if len(body) > 0 && body[0] == '{' {
isLegacy = true
req := messages.ClientPollRequest{
Offer: string(body),
NAT: r.Header.Get("Snowflake-NAT-Type"),
}
body, err = req.EncodePollRequest()
if err != nil {
log.Printf("Error shimming the legacy request: %s", err.Error())
w.WriteHeader(http.StatusInternalServerError)
return
}
}
arg := messages.Arg{ arg := messages.Arg{
Body: body, Body: body,
RemoteAddr: "", RemoteAddr: "",
NatType: r.Header.Get("Snowflake-NAT-Type"),
} }
var response []byte var response []byte
err = i.ClientOffers(arg, &response) err = i.ClientOffers(arg, &response)
switch { if err != nil {
case err == nil: // Assert err == messages.ErrInternal
case errors.Is(err, messages.ErrUnavailable):
w.WriteHeader(http.StatusServiceUnavailable)
return
case errors.Is(err, messages.ErrTimeout):
w.WriteHeader(http.StatusGatewayTimeout)
return
default:
log.Println(err) log.Println(err)
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
return return
} }
if isLegacy {
resp, err := messages.DecodeClientPollResponse(response)
if err != nil {
log.Println(err)
w.WriteHeader(http.StatusInternalServerError)
return
}
switch resp.Error {
case "":
response = []byte(resp.Answer)
case "no snowflake proxies currently available":
w.WriteHeader(http.StatusServiceUnavailable)
return
case "timed out waiting for answer!":
w.WriteHeader(http.StatusGatewayTimeout)
return
default:
panic("unknown error")
}
}
if _, err := w.Write(response); err != nil { if _, err := w.Write(response); err != nil {
log.Printf("clientOffers unable to write answer with error: %v", err) log.Printf("clientOffers unable to write answer with error: %v", err)
} }
@ -181,7 +209,6 @@ func proxyAnswers(i *IPC, w http.ResponseWriter, r *http.Request) {
arg := messages.Arg{ arg := messages.Arg{
Body: body, Body: body,
RemoteAddr: "", RemoteAddr: "",
NatType: "",
} }
var response []byte var response []byte

View file

@ -21,14 +21,10 @@ const (
NATUnrestricted = "unrestricted" NATUnrestricted = "unrestricted"
) )
// We support two client message formats. The legacy format is for backwards
// combatability and relies heavily on HTTP headers and status codes to convey
// information.
type clientVersion int type clientVersion int
const ( const (
v0 clientVersion = iota //legacy version v1 clientVersion = iota
v1
) )
type IPC struct { type IPC struct {
@ -141,9 +137,6 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
startTime := time.Now() startTime := time.Now()
body := arg.Body body := arg.Body
if len(body) > 0 && body[0] == '{' {
version = v0
} else {
parts := bytes.SplitN(body, []byte("\n"), 2) parts := bytes.SplitN(body, []byte("\n"), 2)
if len(parts) < 2 { if len(parts) < 2 {
// no version number found // no version number found
@ -153,20 +146,13 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
body = parts[1] body = parts[1]
if string(parts[0]) == "1.0" { if string(parts[0]) == "1.0" {
version = v1 version = v1
} else { } else {
err := fmt.Errorf("unsupported message version") err := fmt.Errorf("unsupported message version")
return sendClientResponse(&messages.ClientPollResponse{Error: err.Error()}, response) return sendClientResponse(&messages.ClientPollResponse{Error: err.Error()}, response)
} }
}
var offer *ClientOffer var offer *ClientOffer
switch version { switch version {
case v0:
offer = &ClientOffer{
natType: arg.NatType,
sdp: body,
}
case v1: case v1:
req, err := messages.DecodeClientPollRequest(body) req, err := messages.DecodeClientPollRequest(body)
if err != nil { if err != nil {
@ -203,8 +189,6 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
} }
i.ctx.metrics.lock.Unlock() i.ctx.metrics.lock.Unlock()
switch version { switch version {
case v0:
return messages.ErrUnavailable
case v1: case v1:
resp := &messages.ClientPollResponse{Error: "no snowflake proxies currently available"} resp := &messages.ClientPollResponse{Error: "no snowflake proxies currently available"}
return sendClientResponse(resp, response) return sendClientResponse(resp, response)
@ -230,8 +214,6 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
i.ctx.metrics.promMetrics.ClientPollTotal.With(prometheus.Labels{"nat": offer.natType, "status": "matched"}).Inc() i.ctx.metrics.promMetrics.ClientPollTotal.With(prometheus.Labels{"nat": offer.natType, "status": "matched"}).Inc()
i.ctx.metrics.lock.Unlock() i.ctx.metrics.lock.Unlock()
switch version { switch version {
case v0:
*response = []byte(answer)
case v1: case v1:
resp := &messages.ClientPollResponse{Answer: answer} resp := &messages.ClientPollResponse{Answer: answer}
err = sendClientResponse(resp, response) err = sendClientResponse(resp, response)
@ -243,8 +225,6 @@ func (i *IPC) ClientOffers(arg messages.Arg, response *[]byte) error {
case <-time.After(time.Second * ClientTimeout): case <-time.After(time.Second * ClientTimeout):
log.Println("Client: Timed out.") log.Println("Client: Timed out.")
switch version { switch version {
case v0:
err = messages.ErrTimeout
case v1: case v1:
resp := &messages.ClientPollResponse{ resp := &messages.ClientPollResponse{
Error: "timed out waiting for answer!"} Error: "timed out waiting for answer!"}

View file

@ -7,12 +7,9 @@ import (
type Arg struct { type Arg struct {
Body []byte Body []byte
RemoteAddr string RemoteAddr string
NatType string
} }
var ( var (
ErrBadRequest = errors.New("bad request") ErrBadRequest = errors.New("bad request")
ErrInternal = errors.New("internal error") ErrInternal = errors.New("internal error")
ErrUnavailable = errors.New("service unavailable")
ErrTimeout = errors.New("timeout")
) )