mirror of
https://gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/snowflake.git
synced 2025-10-13 20:11:19 -04:00
Stop counting handlers before terminating.
The requirement to do so is obsolete and has already been removed from other pluggable transports. https://bugs.torproject.org/32046
This commit is contained in:
parent
d8d3170af8
commit
b4f4b29a03
4 changed files with 5 additions and 68 deletions
|
@ -13,17 +13,9 @@ const (
|
||||||
SnowflakeTimeout = 30
|
SnowflakeTimeout = 30
|
||||||
)
|
)
|
||||||
|
|
||||||
// HandlerChan - When a connection handler starts, +1 is written to this channel; when it
|
|
||||||
// ends, -1 is written.
|
|
||||||
var HandlerChan = make(chan int)
|
|
||||||
|
|
||||||
// Given an accepted SOCKS connection, establish a WebRTC connection to the
|
// Given an accepted SOCKS connection, establish a WebRTC connection to the
|
||||||
// remote peer and exchange traffic.
|
// remote peer and exchange traffic.
|
||||||
func Handler(socks SocksConnector, snowflakes SnowflakeCollector) error {
|
func Handler(socks SocksConnector, snowflakes SnowflakeCollector) error {
|
||||||
HandlerChan <- 1
|
|
||||||
defer func() {
|
|
||||||
HandlerChan <- -1
|
|
||||||
}()
|
|
||||||
// Obtain an available WebRTC remote. May block.
|
// Obtain an available WebRTC remote. May block.
|
||||||
snowflake := snowflakes.Pop()
|
snowflake := snowflakes.Pop()
|
||||||
if nil == snowflake {
|
if nil == snowflake {
|
||||||
|
|
|
@ -184,8 +184,6 @@ func main() {
|
||||||
}
|
}
|
||||||
pt.CmethodsDone()
|
pt.CmethodsDone()
|
||||||
|
|
||||||
var numHandlers int
|
|
||||||
var sig os.Signal
|
|
||||||
sigChan := make(chan os.Signal, 1)
|
sigChan := make(chan os.Signal, 1)
|
||||||
signal.Notify(sigChan, syscall.SIGTERM)
|
signal.Notify(sigChan, syscall.SIGTERM)
|
||||||
|
|
||||||
|
@ -202,22 +200,12 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// keep track of handlers and wait for a signal
|
// keep track of handlers and wait for a signal
|
||||||
sig = nil
|
<-sigChan
|
||||||
for sig == nil {
|
|
||||||
select {
|
|
||||||
case n := <-sf.HandlerChan:
|
|
||||||
numHandlers += n
|
|
||||||
case sig = <-sigChan:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// signal received, shut down
|
// signal received, shut down
|
||||||
for _, ln := range listeners {
|
for _, ln := range listeners {
|
||||||
ln.Close()
|
ln.Close()
|
||||||
}
|
}
|
||||||
snowflakes.End()
|
snowflakes.End()
|
||||||
for numHandlers > 0 {
|
|
||||||
numHandlers += <-sf.HandlerChan
|
|
||||||
}
|
|
||||||
log.Println("snowflake is done.")
|
log.Println("snowflake is done.")
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,10 +21,6 @@ var ptMethodName = "snowflake"
|
||||||
var ptInfo pt.ServerInfo
|
var ptInfo pt.ServerInfo
|
||||||
var logFile *os.File
|
var logFile *os.File
|
||||||
|
|
||||||
// When a datachannel handler starts, +1 is written to this channel;
|
|
||||||
// when it ends, -1 is written.
|
|
||||||
var handlerChan = make(chan int)
|
|
||||||
|
|
||||||
func copyLoop(WebRTC, ORPort net.Conn) {
|
func copyLoop(WebRTC, ORPort net.Conn) {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(2)
|
wg.Add(2)
|
||||||
|
@ -100,11 +96,6 @@ func (c *webRTCConn) SetWriteDeadline(t time.Time) error {
|
||||||
func datachannelHandler(conn *webRTCConn) {
|
func datachannelHandler(conn *webRTCConn) {
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
handlerChan <- 1
|
|
||||||
defer func() {
|
|
||||||
handlerChan <- -1
|
|
||||||
}()
|
|
||||||
|
|
||||||
or, err := pt.DialOr(&ptInfo, "", ptMethodName) // TODO: Extended OR
|
or, err := pt.DialOr(&ptInfo, "", ptMethodName) // TODO: Extended OR
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Failed to connect to ORPort: " + err.Error())
|
log.Printf("Failed to connect to ORPort: " + err.Error())
|
||||||
|
@ -246,8 +237,6 @@ func main() {
|
||||||
}
|
}
|
||||||
pt.SmethodsDone()
|
pt.SmethodsDone()
|
||||||
|
|
||||||
var numHandlers int
|
|
||||||
var sig os.Signal
|
|
||||||
sigChan := make(chan os.Signal, 1)
|
sigChan := make(chan os.Signal, 1)
|
||||||
signal.Notify(sigChan, syscall.SIGTERM)
|
signal.Notify(sigChan, syscall.SIGTERM)
|
||||||
|
|
||||||
|
@ -263,17 +252,6 @@ func main() {
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// keep track of handlers and wait for a signal
|
// wait for a signal
|
||||||
sig = nil
|
<-sigChan
|
||||||
for sig == nil {
|
|
||||||
select {
|
|
||||||
case n := <-handlerChan:
|
|
||||||
numHandlers += n
|
|
||||||
case sig = <-sigChan:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for numHandlers > 0 {
|
|
||||||
numHandlers += <-handlerChan
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,10 +37,6 @@ const listenAndServeErrorTimeout = 100 * time.Millisecond
|
||||||
|
|
||||||
var ptInfo pt.ServerInfo
|
var ptInfo pt.ServerInfo
|
||||||
|
|
||||||
// When a connection handler starts, +1 is written to this channel; when it
|
|
||||||
// ends, -1 is written.
|
|
||||||
var handlerChan = make(chan int)
|
|
||||||
|
|
||||||
func usage() {
|
func usage() {
|
||||||
fmt.Fprintf(os.Stderr, `Usage: %s [OPTIONS]
|
fmt.Fprintf(os.Stderr, `Usage: %s [OPTIONS]
|
||||||
|
|
||||||
|
@ -157,11 +153,6 @@ func webSocketHandler(ws *websocket.WebSocket) {
|
||||||
conn := newWebSocketConn(ws)
|
conn := newWebSocketConn(ws)
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
handlerChan <- 1
|
|
||||||
defer func() {
|
|
||||||
handlerChan <- -1
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Pass the address of client as the remote address of incoming connection
|
// Pass the address of client as the remote address of incoming connection
|
||||||
clientIPParam := ws.Request().URL.Query().Get("client_ip")
|
clientIPParam := ws.Request().URL.Query().Get("client_ip")
|
||||||
addr := clientAddr(clientIPParam)
|
addr := clientAddr(clientIPParam)
|
||||||
|
@ -390,8 +381,6 @@ func main() {
|
||||||
}
|
}
|
||||||
pt.SmethodsDone()
|
pt.SmethodsDone()
|
||||||
|
|
||||||
var numHandlers int
|
|
||||||
var sig os.Signal
|
|
||||||
sigChan := make(chan os.Signal, 1)
|
sigChan := make(chan os.Signal, 1)
|
||||||
signal.Notify(sigChan, syscall.SIGTERM)
|
signal.Notify(sigChan, syscall.SIGTERM)
|
||||||
|
|
||||||
|
@ -407,22 +396,12 @@ func main() {
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// keep track of handlers and wait for a signal
|
// wait for a signal
|
||||||
sig = nil
|
sig := <-sigChan
|
||||||
for sig == nil {
|
|
||||||
select {
|
|
||||||
case n := <-handlerChan:
|
|
||||||
numHandlers += n
|
|
||||||
case sig = <-sigChan:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// signal received, shut down
|
// signal received, shut down
|
||||||
log.Printf("caught signal %q, exiting", sig)
|
log.Printf("caught signal %q, exiting", sig)
|
||||||
for _, server := range servers {
|
for _, server := range servers {
|
||||||
server.Close()
|
server.Close()
|
||||||
}
|
}
|
||||||
for numHandlers > 0 {
|
|
||||||
numHandlers += <-handlerChan
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue