Implement limitedRead function for client side

MaxBytesReader is only documented for server side reads, so we're using
a local limitedRead function instead that uses an io.LimitedReader.

Declared limits in a commented constant
This commit is contained in:
Cecylia Bocovich 2019-05-10 17:16:35 -04:00
parent ce3101d016
commit 1d76d3ca2e
3 changed files with 44 additions and 5 deletions

View file

@ -28,6 +28,7 @@ import (
const ( const (
ClientTimeout = 10 ClientTimeout = 10
ProxyTimeout = 10 ProxyTimeout = 10
readLimit = 100000 //Maximum number of bytes to be read from an HTTP request
) )
type BrokerContext struct { type BrokerContext struct {
@ -136,7 +137,7 @@ For snowflake proxies to request a client from the Broker.
*/ */
func proxyPolls(ctx *BrokerContext, w http.ResponseWriter, r *http.Request) { func proxyPolls(ctx *BrokerContext, w http.ResponseWriter, r *http.Request) {
id := r.Header.Get("X-Session-ID") id := r.Header.Get("X-Session-ID")
body, err := ioutil.ReadAll(http.MaxBytesReader(w, r.Body, 100000)) body, err := ioutil.ReadAll(http.MaxBytesReader(w, r.Body, readLimit))
if nil != err { if nil != err {
log.Println("Invalid data.") log.Println("Invalid data.")
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
@ -166,7 +167,7 @@ the HTTP response back to the client.
*/ */
func clientOffers(ctx *BrokerContext, w http.ResponseWriter, r *http.Request) { func clientOffers(ctx *BrokerContext, w http.ResponseWriter, r *http.Request) {
startTime := time.Now() startTime := time.Now()
offer, err := ioutil.ReadAll(http.MaxBytesReader(w, r.Body, 100000)) offer, err := ioutil.ReadAll(http.MaxBytesReader(w, r.Body, readLimit))
if nil != err { if nil != err {
log.Println("Invalid data.") log.Println("Invalid data.")
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
@ -213,7 +214,7 @@ func proxyAnswers(ctx *BrokerContext, w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusGone) w.WriteHeader(http.StatusGone)
return return
} }
body, err := ioutil.ReadAll(http.MaxBytesReader(w, r.Body, 100000)) body, err := ioutil.ReadAll(http.MaxBytesReader(w, r.Body, readLimit))
if nil != err || nil == body || len(body) <= 0 { if nil != err || nil == body || len(body) <= 0 {
log.Println("Invalid data.") log.Println("Invalid data.")
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)

View file

@ -11,6 +11,7 @@ package lib
import ( import (
"bytes" "bytes"
"errors" "errors"
"io"
"io/ioutil" "io/ioutil"
"log" "log"
"net/http" "net/http"
@ -23,6 +24,7 @@ const (
BrokerError503 string = "No snowflake proxies currently available." BrokerError503 string = "No snowflake proxies currently available."
BrokerError400 string = "You sent an invalid offer in the request." BrokerError400 string = "You sent an invalid offer in the request."
BrokerErrorUnexpected string = "Unexpected error, no answer." BrokerErrorUnexpected string = "Unexpected error, no answer."
readLimit = 100000 //Maximum number of bytes to be read from an HTTP response
) )
// Signalling Channel to the Broker. // Signalling Channel to the Broker.
@ -64,6 +66,23 @@ func NewBrokerChannel(broker string, front string, transport http.RoundTripper)
return bc return bc
} }
func limitedRead(r io.Reader, limit int64) ([]byte, error) {
p, err := ioutil.ReadAll(&io.LimitedReader{r, limit})
if err != nil {
return p, err
}
//Check to see if limit was exceeded
var tmp [1]byte
_, err = io.ReadFull(r, tmp[:])
if err == io.EOF {
err = nil
} else if err == nil {
err = io.ErrUnexpectedEOF
}
return p, err
}
// Roundtrip HTTP POST using WebRTC SessionDescriptions. // Roundtrip HTTP POST using WebRTC SessionDescriptions.
// //
// Send an SDP offer to the broker, which assigns a proxy and responds // Send an SDP offer to the broker, which assigns a proxy and responds
@ -91,7 +110,7 @@ func (bc *BrokerChannel) Negotiate(offer *webrtc.SessionDescription) (
switch resp.StatusCode { switch resp.StatusCode {
case http.StatusOK: case http.StatusOK:
body, err := ioutil.ReadAll(http.MaxBytesReader(nil, resp.Body, 100000)) body, err := limitedRead(resp.Body, readLimit)
if nil != err { if nil != err {
return nil, err return nil, err
} }

View file

@ -32,6 +32,8 @@ const pollInterval = 5 * time.Second
//client is not going to connect //client is not going to connect
const dataChannelTimeout = 20 * time.Second const dataChannelTimeout = 20 * time.Second
const readLimit = 100000 //Maximum number of bytes to be read from an HTTP request
var brokerURL *url.URL var brokerURL *url.URL
var relayURL string var relayURL string
@ -137,6 +139,23 @@ func genSessionID() string {
return strings.TrimRight(base64.StdEncoding.EncodeToString(buf), "=") return strings.TrimRight(base64.StdEncoding.EncodeToString(buf), "=")
} }
func limitedRead(r io.Reader, limit int64) ([]byte, error) {
p, err := ioutil.ReadAll(&io.LimitedReader{r, limit})
if err != nil {
return p, err
}
//Check to see if limit was exceeded
var tmp [1]byte
_, err = io.ReadFull(r, tmp[:])
if err == io.EOF {
err = nil
} else if err == nil {
err = io.ErrUnexpectedEOF
}
return p, err
}
func pollOffer(sid string) *webrtc.SessionDescription { func pollOffer(sid string) *webrtc.SessionDescription {
broker := brokerURL.ResolveReference(&url.URL{Path: "proxy"}) broker := brokerURL.ResolveReference(&url.URL{Path: "proxy"})
timeOfNextPoll := time.Now() timeOfNextPoll := time.Now()
@ -162,7 +181,7 @@ func pollOffer(sid string) *webrtc.SessionDescription {
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
log.Printf("broker returns: %d", resp.StatusCode) log.Printf("broker returns: %d", resp.StatusCode)
} else { } else {
body, err := ioutil.ReadAll(http.MaxBytesReader(nil, resp.Body, 100000)) body, err := limitedRead(resp.Body, readLimit)
if err != nil { if err != nil {
log.Printf("error reading broker response: %s", err) log.Printf("error reading broker response: %s", err)
} else { } else {