Factor out a function to extract the client IP address.

This commit is contained in:
David Fifield 2017-10-17 21:39:04 -07:00
parent 9e5eb7f5ee
commit 83f8712078

View file

@ -128,6 +128,16 @@ func proxy(local *net.TCPConn, conn *webSocketConn) {
wg.Wait() wg.Wait()
} }
// Return an address string suitable to pass into pt.DialOr.
func clientAddr(clientIPParam string) string {
// Check if client addr is a valid IP
clientIP := net.ParseIP(clientIPParam)
if clientIP == nil {
return ""
}
return clientIPParam
}
func webSocketHandler(ws *websocket.WebSocket) { func webSocketHandler(ws *websocket.WebSocket) {
// Undo timeouts on HTTP request handling. // Undo timeouts on HTTP request handling.
ws.Conn.SetDeadline(time.Time{}) ws.Conn.SetDeadline(time.Time{})
@ -139,17 +149,9 @@ func webSocketHandler(ws *websocket.WebSocket) {
handlerChan <- -1 handlerChan <- -1
}() }()
// Check if client addr is a valid IP
addr := ws.Request().URL.Query().Get("client_ip")
clientIP := net.ParseIP(addr)
if clientIP == nil {
// Set client addr to empty
addr = ""
}
// Pass the address of client as the remote address of incoming connection // Pass the address of client as the remote address of incoming connection
or, err := pt.DialOr(&ptInfo, addr, ptMethodName) clientIPParam := ws.Request().URL.Query().Get("client_ip")
or, err := pt.DialOr(&ptInfo, clientIPParam, ptMethodName)
if err != nil { if err != nil {
log.Printf("failed to connect to ORPort: %s", err) log.Printf("failed to connect to ORPort: %s", err)