Use http.RoundTripper for connections to broker

This change makes it easier for us to write tests with mock transports
This commit is contained in:
Cecylia Bocovich 2019-06-12 14:15:21 -04:00
parent 574c57cc98
commit 446f39a9e5

View file

@ -36,7 +36,7 @@ const dataChannelTimeout = 20 * time.Second
const readLimit = 100000 //Maximum number of bytes to be read from an HTTP request const readLimit = 100000 //Maximum number of bytes to be read from an HTTP request
var brokerURL *url.URL var broker *Broker
var relayURL string var relayURL string
const ( const (
@ -68,6 +68,11 @@ func remoteIPFromSDP(sdp string) net.IP {
return nil return nil
} }
type Broker struct {
url *url.URL
transport http.RoundTripper
}
type webRTCConn struct { type webRTCConn struct {
dc *webrtc.DataChannel dc *webrtc.DataChannel
pc *webrtc.PeerConnection pc *webrtc.PeerConnection
@ -154,8 +159,8 @@ func limitedRead(r io.Reader, limit int64) ([]byte, error) {
return p, err return p, err
} }
func pollOffer(sid string) *webrtc.SessionDescription { func (b *Broker) pollOffer(sid string) *webrtc.SessionDescription {
broker := brokerURL.ResolveReference(&url.URL{Path: "proxy"}) brokerPath := b.url.ResolveReference(&url.URL{Path: "proxy"})
timeOfNextPoll := time.Now() timeOfNextPoll := time.Now()
for { for {
// Sleep until we're scheduled to poll again. // Sleep until we're scheduled to poll again.
@ -169,14 +174,13 @@ func pollOffer(sid string) *webrtc.SessionDescription {
timeOfNextPoll = now timeOfNextPoll = now
} }
b, err := messages.EncodePollRequest(sid) body, err := messages.EncodePollRequest(sid)
if err != nil { if err != nil {
log.Printf("Error encoding poll message: %s", err.Error()) log.Printf("Error encoding poll message: %s", err.Error())
return nil return nil
} }
req, _ := http.NewRequest("POST", broker.String(), bytes.NewBuffer(b)) req, _ := http.NewRequest("POST", brokerPath.String(), bytes.NewBuffer(body))
req.Header.Set("X-Session-ID", sid) resp, err := b.transport.RoundTrip(req)
resp, err := client.Do(req)
if err != nil { if err != nil {
log.Printf("error polling broker: %s", err) log.Printf("error polling broker: %s", err)
} else { } else {
@ -204,15 +208,15 @@ func pollOffer(sid string) *webrtc.SessionDescription {
} }
} }
func sendAnswer(sid string, pc *webrtc.PeerConnection) error { func (b *Broker) sendAnswer(sid string, pc *webrtc.PeerConnection) error {
broker := brokerURL.ResolveReference(&url.URL{Path: "answer"}) brokerPath := b.url.ResolveReference(&url.URL{Path: "answer"})
answer := string([]byte(serializeSessionDescription(pc.LocalDescription()))) answer := string([]byte(serializeSessionDescription(pc.LocalDescription())))
b, err := messages.EncodeAnswerRequest(answer, sid) body, err := messages.EncodeAnswerRequest(answer, sid)
if err != nil { if err != nil {
return err return err
} }
req, _ := http.NewRequest("POST", broker.String(), bytes.NewBuffer(b)) req, _ := http.NewRequest("POST", brokerPath.String(), bytes.NewBuffer(body))
resp, err := client.Do(req) resp, err := b.transport.RoundTrip(req)
if err != nil { if err != nil {
return err return err
} }
@ -220,7 +224,7 @@ func sendAnswer(sid string, pc *webrtc.PeerConnection) error {
return fmt.Errorf("broker returned %d", resp.StatusCode) return fmt.Errorf("broker returned %d", resp.StatusCode)
} }
body, err := limitedRead(resp.Body, readLimit) body, err = limitedRead(resp.Body, readLimit)
if err != nil { if err != nil {
return fmt.Errorf("error reading broker response: %s", err) return fmt.Errorf("error reading broker response: %s", err)
} }
@ -364,7 +368,7 @@ func makePeerConnectionFromOffer(sdp *webrtc.SessionDescription, config webrtc.C
} }
func runSession(sid string) { func runSession(sid string) {
offer := pollOffer(sid) offer := broker.pollOffer(sid)
if offer == nil { if offer == nil {
log.Printf("bad offer from broker") log.Printf("bad offer from broker")
retToken() retToken()
@ -377,7 +381,7 @@ func runSession(sid string) {
retToken() retToken()
return return
} }
err = sendAnswer(sid, pc) err = broker.sendAnswer(sid, pc)
if err != nil { if err != nil {
log.Printf("error sending answer to client through broker: %s", err) log.Printf("error sending answer to client through broker: %s", err)
if inerr := pc.Close(); inerr != nil { if inerr := pc.Close(); inerr != nil {
@ -430,7 +434,8 @@ func main() {
log.Println("starting") log.Println("starting")
var err error var err error
brokerURL, err = url.Parse(rawBrokerURL) broker = new(Broker)
broker.url, err = url.Parse(rawBrokerURL)
if err != nil { if err != nil {
log.Fatalf("invalid broker url: %s", err) log.Fatalf("invalid broker url: %s", err)
} }
@ -443,6 +448,7 @@ func main() {
log.Fatalf("invalid relay url: %s", err) log.Fatalf("invalid relay url: %s", err)
} }
broker.transport = http.DefaultTransport.(*http.Transport)
config = webrtc.Configuration{ config = webrtc.Configuration{
ICEServers: []webrtc.ICEServer{ ICEServers: []webrtc.ICEServer{
{ {