Send shutdown signal to shutdown open connections

Normally all dangling goroutines are terminated when the main function
exits. However, for projects that use a patched version of snowflake as
a library, these goroutines continued running as long as the main function
had not yet terminated. This commit has all open SOCKS connections close
after receiving a shutdown signal.
This commit is contained in:
Cecylia Bocovich 2020-10-29 16:21:37 -04:00
parent 114df695ce
commit b9cc54b3b7

View file

@ -27,7 +27,7 @@ const (
) )
// Accept local SOCKS connections and pass them to the handler. // Accept local SOCKS connections and pass them to the handler.
func socksAcceptLoop(ln *pt.SocksListener, tongue sf.Tongue) { func socksAcceptLoop(ln *pt.SocksListener, tongue sf.Tongue, shutdown chan struct{}) {
defer ln.Close() defer ln.Close()
for { for {
conn, err := ln.AcceptSocks() conn, err := ln.AcceptSocks()
@ -48,11 +48,23 @@ func socksAcceptLoop(ln *pt.SocksListener, tongue sf.Tongue) {
return return
} }
err = sf.Handler(conn, tongue) handler := make(chan struct{})
if err != nil { go func() {
log.Printf("handler error: %s", err) err = sf.Handler(conn, tongue)
if err != nil {
log.Printf("handler error: %s", err)
}
close(handler)
return return
}()
select {
case <-shutdown:
log.Println("Received shutdown signal")
case <-handler:
log.Println("Handler ended")
} }
return
}() }()
} }
} }
@ -160,6 +172,7 @@ func main() {
os.Exit(1) os.Exit(1)
} }
listeners := make([]net.Listener, 0) listeners := make([]net.Listener, 0)
shutdown := make(chan struct{})
for _, methodName := range ptInfo.MethodNames { for _, methodName := range ptInfo.MethodNames {
switch methodName { switch methodName {
case "snowflake": case "snowflake":
@ -170,7 +183,7 @@ func main() {
break break
} }
log.Printf("Started SOCKS listener at %v.", ln.Addr()) log.Printf("Started SOCKS listener at %v.", ln.Addr())
go socksAcceptLoop(ln, dialer) go socksAcceptLoop(ln, dialer, shutdown)
pt.Cmethod(methodName, ln.Version(), ln.Addr()) pt.Cmethod(methodName, ln.Version(), ln.Addr())
listeners = append(listeners, ln) listeners = append(listeners, ln)
default: default:
@ -196,11 +209,13 @@ func main() {
// Wait for a signal. // Wait for a signal.
<-sigChan <-sigChan
log.Println("stopping snowflake")
// Signal received, shut down. // Signal received, shut down.
for _, ln := range listeners { for _, ln := range listeners {
ln.Close() ln.Close()
} }
close(shutdown)
log.Println("snowflake is done.") log.Println("snowflake is done.")
} }