Merge branch 'refactor-proxy-simplify-runSession' into 'main'

refactor(proxy): simplify `runSession()` and `Start()`

See merge request tpo/anti-censorship/pluggable-transports/snowflake!525
This commit is contained in:
WofWca 2025-03-06 14:46:41 +00:00
commit 7457c90bd2

View file

@ -338,29 +338,6 @@ func copyLoop(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser, shutdown chan struct
log.Println("copy loop ended") log.Println("copy loop ended")
} }
// We pass conn.RemoteAddr() as an additional parameter, rather than calling
// conn.RemoteAddr() inside this function, as a workaround for a hang that
// otherwise occurs inside conn.pc.RemoteDescription() (called by RemoteAddr).
// https://bugs.torproject.org/18628#comment:8
func (sf *SnowflakeProxy) datachannelHandler(conn *webRTCConn, remoteIP net.IP, relayURL string) {
defer conn.Close()
defer tokens.ret()
if relayURL == "" {
relayURL = sf.RelayURL
}
wsConn, err := connectToRelay(relayURL, remoteIP)
if err != nil {
log.Print(err)
return
}
defer wsConn.Close()
copyLoop(conn, wsConn, sf.shutdown)
log.Printf("datachannelHandler ends")
}
func connectToRelay(relayURL string, remoteIP net.IP) (*websocketconn.Conn, error) { func connectToRelay(relayURL string, remoteIP net.IP) (*websocketconn.Conn, error) {
u, err := url.Parse(relayURL) u, err := url.Parse(relayURL)
if err != nil { if err != nil {
@ -386,15 +363,6 @@ func connectToRelay(relayURL string, remoteIP net.IP) (*websocketconn.Conn, erro
return wsConn, nil return wsConn, nil
} }
type dataChannelHandlerWithRelayURL struct {
RelayURL string
sf *SnowflakeProxy
}
func (d dataChannelHandlerWithRelayURL) datachannelHandler(conn *webRTCConn, remoteIP net.IP) {
d.sf.datachannelHandler(conn, remoteIP, d.RelayURL)
}
func (sf *SnowflakeProxy) makeWebRTCAPI() *webrtc.API { func (sf *SnowflakeProxy) makeWebRTCAPI() *webrtc.API {
settingsEngine := webrtc.SettingEngine{} settingsEngine := webrtc.SettingEngine{}
@ -438,22 +406,25 @@ func (sf *SnowflakeProxy) makeWebRTCAPI() *webrtc.API {
// Create a PeerConnection from an SDP offer. Blocks until the gathering of ICE // Create a PeerConnection from an SDP offer. Blocks until the gathering of ICE
// candidates is complete and the answer is available in LocalDescription. // candidates is complete and the answer is available in LocalDescription.
// Installs an OnDataChannel callback that creates a webRTCConn and passes it to // When the client creates the WebRTC data channel,
// datachannelHandler. // this function wraps the data channel
// into a `webRTCConn` and sends to the returned channel.
// The data channel may never get created, in which case the channel
// will never be sent to.
func (sf *SnowflakeProxy) makePeerConnectionFromOffer( func (sf *SnowflakeProxy) makePeerConnectionFromOffer(
sdp *webrtc.SessionDescription, sdp *webrtc.SessionDescription,
config webrtc.Configuration, dataChan chan struct{}, config webrtc.Configuration,
handler func(conn *webRTCConn, remoteIP net.IP), ) (*webrtc.PeerConnection, chan *webRTCConn, error) {
) (*webrtc.PeerConnection, error) {
api := sf.makeWebRTCAPI() api := sf.makeWebRTCAPI()
pc, err := api.NewPeerConnection(config) pc, err := api.NewPeerConnection(config)
if err != nil { if err != nil {
return nil, fmt.Errorf("accept: NewPeerConnection: %s", err) return nil, nil, fmt.Errorf("accept: NewPeerConnection: %s", err)
} }
// Buffered to avoil blocking `OnDataChannel`.
webRTCConnChan := make(chan *webRTCConn, 1)
pc.OnDataChannel(func(dc *webrtc.DataChannel) { pc.OnDataChannel(func(dc *webrtc.DataChannel) {
log.Printf("New Data Channel %s-%d\n", dc.Label(), dc.ID()) log.Printf("New Data Channel %s-%d\n", dc.Label(), dc.ID())
close(dataChan)
pr, pw := io.Pipe() pr, pw := io.Pipe()
conn := newWebRTCConn(pc, dc, pr, sf.bytesLogger) conn := newWebRTCConn(pc, dc, pr, sf.bytesLogger)
@ -518,7 +489,8 @@ func (sf *SnowflakeProxy) makePeerConnectionFromOffer(
} }
}) })
go handler(conn, remoteIP) webRTCConnChan <- conn
close(webRTCConnChan)
}) })
// As of v3.0.0, pion-webrtc uses trickle ICE by default. // As of v3.0.0, pion-webrtc uses trickle ICE by default.
// We have to wait for candidate gathering to complete // We have to wait for candidate gathering to complete
@ -529,7 +501,7 @@ func (sf *SnowflakeProxy) makePeerConnectionFromOffer(
if inerr := pc.Close(); inerr != nil { if inerr := pc.Close(); inerr != nil {
log.Printf("unable to call pc.Close after pc.SetRemoteDescription with error: %v", inerr) log.Printf("unable to call pc.Close after pc.SetRemoteDescription with error: %v", inerr)
} }
return nil, fmt.Errorf("accept: SetRemoteDescription: %s", err) return nil, webRTCConnChan, fmt.Errorf("accept: SetRemoteDescription: %s", err)
} }
log.Println("Generating answer...") log.Println("Generating answer...")
@ -538,7 +510,7 @@ func (sf *SnowflakeProxy) makePeerConnectionFromOffer(
if inerr := pc.Close(); inerr != nil { if inerr := pc.Close(); inerr != nil {
log.Printf("ICE gathering has generated an error when calling pc.Close: %v", inerr) log.Printf("ICE gathering has generated an error when calling pc.Close: %v", inerr)
} }
return nil, err return nil, webRTCConnChan, err
} }
err = pc.SetLocalDescription(answer) err = pc.SetLocalDescription(answer)
@ -546,7 +518,7 @@ func (sf *SnowflakeProxy) makePeerConnectionFromOffer(
if err = pc.Close(); err != nil { if err = pc.Close(); err != nil {
log.Printf("pc.Close after setting local description returned : %v", err) log.Printf("pc.Close after setting local description returned : %v", err)
} }
return nil, err return nil, webRTCConnChan, err
} }
// Wait for ICE candidate gathering to complete, // Wait for ICE candidate gathering to complete,
@ -562,7 +534,7 @@ func (sf *SnowflakeProxy) makePeerConnectionFromOffer(
log.Printf("Answer: \n\t%s", strings.ReplaceAll(pc.LocalDescription().SDP, "\n", "\n\t")) log.Printf("Answer: \n\t%s", strings.ReplaceAll(pc.LocalDescription().SDP, "\n", "\n\t"))
return pc, nil return pc, webRTCConnChan, nil
} }
// Create a new PeerConnection. Blocks until the gathering of ICE // Create a new PeerConnection. Blocks until the gathering of ICE
@ -635,19 +607,14 @@ func (sf *SnowflakeProxy) makeNewPeerConnection(
return pc, nil return pc, nil
} }
func (sf *SnowflakeProxy) runSession(sid string) { // runSession connects to the client and to the server,
connectedToClient := false // and pipes the data between them.
defer func() { // It blocks until the session ends.
if !connectedToClient { func (sf *SnowflakeProxy) runSession(
tokens.ret() offer *webrtc.SessionDescription,
} relayURL string,
// Otherwise we'll `tokens.ret()` when the connection finishes. sid string,
}() ) {
offer, relayURL := broker.pollOffer(sid, sf.ProxyType, sf.RelayDomainNamePattern)
if offer == nil {
return
}
log.Printf("Received Offer From Broker: \n\t%s", strings.ReplaceAll(offer.SDP, "\n", "\n\t")) log.Printf("Received Offer From Broker: \n\t%s", strings.ReplaceAll(offer.SDP, "\n", "\n\t"))
if relayURL != "" { if relayURL != "" {
@ -657,35 +624,45 @@ func (sf *SnowflakeProxy) runSession(sid string) {
} }
} }
dataChan := make(chan struct{}) pc, webRTCConnChan, err := sf.makePeerConnectionFromOffer(offer, config)
dataChannelAdaptor := dataChannelHandlerWithRelayURL{RelayURL: relayURL, sf: sf}
pc, err := sf.makePeerConnectionFromOffer(offer, config, dataChan, dataChannelAdaptor.datachannelHandler)
if err != nil { if err != nil {
log.Printf("error making WebRTC connection: %s", err) log.Printf("error making WebRTC connection: %s", err)
return return
} }
defer func() {
err = broker.sendAnswer(sid, pc)
if err != nil {
log.Printf("error sending answer to client through broker: %s", err)
if inerr := pc.Close(); inerr != nil { if inerr := pc.Close(); inerr != nil {
log.Printf("error calling pc.Close: %v", inerr) log.Printf("error calling pc.Close: %v", inerr)
} }
}()
err = broker.sendAnswer(sid, pc)
if err != nil {
return return
} }
// Set a timeout on peerconnection. If the connection state has not var webRTCConn *webRTCConn
// advanced to PeerConnectionStateConnected in this time,
// destroy the peer connection and return the token.
select { select {
case <-dataChan: case webRTCConn = <-webRTCConnChan:
log.Println("Connection successful") log.Println("Connection successful")
connectedToClient = true
case <-time.After(dataChannelTimeout): case <-time.After(dataChannelTimeout):
log.Println("Timed out waiting for client to open data channel.") log.Println("Timed out waiting for client to open data channel.")
if err := pc.Close(); err != nil { return
log.Printf("error calling pc.Close: %v", err)
} }
defer webRTCConn.Close()
if relayURL == "" {
relayURL = sf.RelayURL
} }
remoteIP := webRTCConn.RemoteIP()
websocketConn, err := connectToRelay(relayURL, remoteIP)
if err != nil {
log.Print(err)
return
}
defer websocketConn.Close()
copyLoop(webRTCConn, websocketConn, sf.shutdown)
log.Printf("session with id %v ended", sid)
} }
// Returns nil if the relayURL is acceptable. // Returns nil if the relayURL is acceptable.
@ -850,7 +827,15 @@ func (sf *SnowflakeProxy) Start() error {
default: default:
tokens.get() tokens.get()
sessionID := genSessionID() sessionID := genSessionID()
sf.runSession(sessionID) offer, relayURL := broker.pollOffer(sessionID, sf.ProxyType, sf.RelayDomainNamePattern)
if offer == nil {
tokens.ret()
continue
}
go func() {
sf.runSession(offer, relayURL, sessionID)
tokens.ret()
}()
} }
} }
return nil return nil