Merge branch 'refactor-proxy-simplify-runSession' into 'main'

refactor(proxy): simplify `runSession()` and `Start()`

See merge request tpo/anti-censorship/pluggable-transports/snowflake!525
This commit is contained in:
WofWca 2025-03-06 14:46:41 +00:00
commit 7457c90bd2

View file

@ -338,29 +338,6 @@ func copyLoop(c1 io.ReadWriteCloser, c2 io.ReadWriteCloser, shutdown chan struct
log.Println("copy loop ended")
}
// We pass conn.RemoteAddr() as an additional parameter, rather than calling
// conn.RemoteAddr() inside this function, as a workaround for a hang that
// otherwise occurs inside conn.pc.RemoteDescription() (called by RemoteAddr).
// https://bugs.torproject.org/18628#comment:8
func (sf *SnowflakeProxy) datachannelHandler(conn *webRTCConn, remoteIP net.IP, relayURL string) {
defer conn.Close()
defer tokens.ret()
if relayURL == "" {
relayURL = sf.RelayURL
}
wsConn, err := connectToRelay(relayURL, remoteIP)
if err != nil {
log.Print(err)
return
}
defer wsConn.Close()
copyLoop(conn, wsConn, sf.shutdown)
log.Printf("datachannelHandler ends")
}
func connectToRelay(relayURL string, remoteIP net.IP) (*websocketconn.Conn, error) {
u, err := url.Parse(relayURL)
if err != nil {
@ -386,15 +363,6 @@ func connectToRelay(relayURL string, remoteIP net.IP) (*websocketconn.Conn, erro
return wsConn, nil
}
type dataChannelHandlerWithRelayURL struct {
RelayURL string
sf *SnowflakeProxy
}
func (d dataChannelHandlerWithRelayURL) datachannelHandler(conn *webRTCConn, remoteIP net.IP) {
d.sf.datachannelHandler(conn, remoteIP, d.RelayURL)
}
func (sf *SnowflakeProxy) makeWebRTCAPI() *webrtc.API {
settingsEngine := webrtc.SettingEngine{}
@ -438,22 +406,25 @@ func (sf *SnowflakeProxy) makeWebRTCAPI() *webrtc.API {
// Create a PeerConnection from an SDP offer. Blocks until the gathering of ICE
// candidates is complete and the answer is available in LocalDescription.
// Installs an OnDataChannel callback that creates a webRTCConn and passes it to
// datachannelHandler.
// When the client creates the WebRTC data channel,
// this function wraps the data channel
// into a `webRTCConn` and sends to the returned channel.
// The data channel may never get created, in which case the channel
// will never be sent to.
func (sf *SnowflakeProxy) makePeerConnectionFromOffer(
sdp *webrtc.SessionDescription,
config webrtc.Configuration, dataChan chan struct{},
handler func(conn *webRTCConn, remoteIP net.IP),
) (*webrtc.PeerConnection, error) {
config webrtc.Configuration,
) (*webrtc.PeerConnection, chan *webRTCConn, error) {
api := sf.makeWebRTCAPI()
pc, err := api.NewPeerConnection(config)
if err != nil {
return nil, fmt.Errorf("accept: NewPeerConnection: %s", err)
return nil, nil, fmt.Errorf("accept: NewPeerConnection: %s", err)
}
// Buffered to avoil blocking `OnDataChannel`.
webRTCConnChan := make(chan *webRTCConn, 1)
pc.OnDataChannel(func(dc *webrtc.DataChannel) {
log.Printf("New Data Channel %s-%d\n", dc.Label(), dc.ID())
close(dataChan)
pr, pw := io.Pipe()
conn := newWebRTCConn(pc, dc, pr, sf.bytesLogger)
@ -518,7 +489,8 @@ func (sf *SnowflakeProxy) makePeerConnectionFromOffer(
}
})
go handler(conn, remoteIP)
webRTCConnChan <- conn
close(webRTCConnChan)
})
// As of v3.0.0, pion-webrtc uses trickle ICE by default.
// We have to wait for candidate gathering to complete
@ -529,7 +501,7 @@ func (sf *SnowflakeProxy) makePeerConnectionFromOffer(
if inerr := pc.Close(); inerr != nil {
log.Printf("unable to call pc.Close after pc.SetRemoteDescription with error: %v", inerr)
}
return nil, fmt.Errorf("accept: SetRemoteDescription: %s", err)
return nil, webRTCConnChan, fmt.Errorf("accept: SetRemoteDescription: %s", err)
}
log.Println("Generating answer...")
@ -538,7 +510,7 @@ func (sf *SnowflakeProxy) makePeerConnectionFromOffer(
if inerr := pc.Close(); inerr != nil {
log.Printf("ICE gathering has generated an error when calling pc.Close: %v", inerr)
}
return nil, err
return nil, webRTCConnChan, err
}
err = pc.SetLocalDescription(answer)
@ -546,7 +518,7 @@ func (sf *SnowflakeProxy) makePeerConnectionFromOffer(
if err = pc.Close(); err != nil {
log.Printf("pc.Close after setting local description returned : %v", err)
}
return nil, err
return nil, webRTCConnChan, err
}
// Wait for ICE candidate gathering to complete,
@ -562,7 +534,7 @@ func (sf *SnowflakeProxy) makePeerConnectionFromOffer(
log.Printf("Answer: \n\t%s", strings.ReplaceAll(pc.LocalDescription().SDP, "\n", "\n\t"))
return pc, nil
return pc, webRTCConnChan, nil
}
// Create a new PeerConnection. Blocks until the gathering of ICE
@ -635,19 +607,14 @@ func (sf *SnowflakeProxy) makeNewPeerConnection(
return pc, nil
}
func (sf *SnowflakeProxy) runSession(sid string) {
connectedToClient := false
defer func() {
if !connectedToClient {
tokens.ret()
}
// Otherwise we'll `tokens.ret()` when the connection finishes.
}()
offer, relayURL := broker.pollOffer(sid, sf.ProxyType, sf.RelayDomainNamePattern)
if offer == nil {
return
}
// runSession connects to the client and to the server,
// and pipes the data between them.
// It blocks until the session ends.
func (sf *SnowflakeProxy) runSession(
offer *webrtc.SessionDescription,
relayURL string,
sid string,
) {
log.Printf("Received Offer From Broker: \n\t%s", strings.ReplaceAll(offer.SDP, "\n", "\n\t"))
if relayURL != "" {
@ -657,35 +624,45 @@ func (sf *SnowflakeProxy) runSession(sid string) {
}
}
dataChan := make(chan struct{})
dataChannelAdaptor := dataChannelHandlerWithRelayURL{RelayURL: relayURL, sf: sf}
pc, err := sf.makePeerConnectionFromOffer(offer, config, dataChan, dataChannelAdaptor.datachannelHandler)
pc, webRTCConnChan, err := sf.makePeerConnectionFromOffer(offer, config)
if err != nil {
log.Printf("error making WebRTC connection: %s", err)
return
}
err = broker.sendAnswer(sid, pc)
if err != nil {
log.Printf("error sending answer to client through broker: %s", err)
defer func() {
if inerr := pc.Close(); inerr != nil {
log.Printf("error calling pc.Close: %v", inerr)
}
}()
err = broker.sendAnswer(sid, pc)
if err != nil {
return
}
// Set a timeout on peerconnection. If the connection state has not
// advanced to PeerConnectionStateConnected in this time,
// destroy the peer connection and return the token.
var webRTCConn *webRTCConn
select {
case <-dataChan:
case webRTCConn = <-webRTCConnChan:
log.Println("Connection successful")
connectedToClient = true
case <-time.After(dataChannelTimeout):
log.Println("Timed out waiting for client to open data channel.")
if err := pc.Close(); err != nil {
log.Printf("error calling pc.Close: %v", err)
return
}
defer webRTCConn.Close()
if relayURL == "" {
relayURL = sf.RelayURL
}
remoteIP := webRTCConn.RemoteIP()
websocketConn, err := connectToRelay(relayURL, remoteIP)
if err != nil {
log.Print(err)
return
}
defer websocketConn.Close()
copyLoop(webRTCConn, websocketConn, sf.shutdown)
log.Printf("session with id %v ended", sid)
}
// Returns nil if the relayURL is acceptable.
@ -850,7 +827,15 @@ func (sf *SnowflakeProxy) Start() error {
default:
tokens.get()
sessionID := genSessionID()
sf.runSession(sessionID)
offer, relayURL := broker.pollOffer(sessionID, sf.ProxyType, sf.RelayDomainNamePattern)
if offer == nil {
tokens.ret()
continue
}
go func() {
sf.runSession(offer, relayURL, sessionID)
tokens.ret()
}()
}
}
return nil