Add connection expire time for uTLS pendingConn

This commit is contained in:
Shelikhoo 2022-02-16 11:11:37 +00:00
parent 8d5998b744
commit 3132f68012
No known key found for this signature in database
GPG key ID: C4D5E79D22B25316

View file

@ -7,6 +7,7 @@ import (
"net"
"net/http"
"sync"
"time"
utls "github.com/refraction-networking/utls"
"golang.org/x/net/http2"
@ -19,7 +20,7 @@ func NewUTLSHTTPRoundTripper(clientHelloID utls.ClientHelloID, uTlsConfig *utls.
config: uTlsConfig,
connectWithH1: map[string]bool{},
backdropTransport: backdropTransport,
pendingConn: map[pendingConnKey]net.Conn{},
pendingConn: map[pendingConnKey]*unclaimedConnection{},
removeSNI: removeSNI,
}
rtImpl.init()
@ -38,7 +39,7 @@ type uTLSHTTPRoundTripperImpl struct {
backdropTransport http.RoundTripper
accessDialingConnection sync.Mutex
pendingConn map[pendingConnKey]net.Conn
pendingConn map[pendingConnKey]*unclaimedConnection
removeSNI bool
}
@ -50,6 +51,7 @@ type pendingConnKey struct {
var errEAGAIN = errors.New("incorrect ALPN negotiated, try again with another ALPN")
var errEAGAINTooMany = errors.New("incorrect ALPN negotiated")
var errExpired = errors.New("connection have expired")
func (r *uTLSHTTPRoundTripperImpl) RoundTrip(req *http.Request) (*http.Response, error) {
if req.URL.Scheme != "https" {
@ -99,12 +101,15 @@ func getPendingConnectionID(dest string, alpnIsH2 bool) pendingConnKey {
func (r *uTLSHTTPRoundTripperImpl) putConn(addr string, alpnIsH2 bool, conn net.Conn) {
connId := getPendingConnectionID(addr, alpnIsH2)
r.pendingConn[connId] = conn
r.pendingConn[connId] = NewUnclaimedConnection(conn, time.Minute)
}
func (r *uTLSHTTPRoundTripperImpl) getConn(addr string, alpnIsH2 bool) net.Conn {
connId := getPendingConnectionID(addr, alpnIsH2)
if conn, ok := r.pendingConn[connId]; ok {
return conn
delete(r.pendingConn, connId)
if claimedConnection, err := conn.claimConnection(); err == nil {
return claimedConnection
}
}
return nil
}
@ -189,3 +194,37 @@ func (r *uTLSHTTPRoundTripperImpl) init() {
},
}
}
func NewUnclaimedConnection(conn net.Conn, expireTime time.Duration) *unclaimedConnection {
c := &unclaimedConnection{
Conn: conn,
}
time.AfterFunc(expireTime, c.tick)
return c
}
type unclaimedConnection struct {
net.Conn
claimed bool
access sync.Mutex
}
func (c *unclaimedConnection) claimConnection() (net.Conn, error) {
c.access.Lock()
defer c.access.Unlock()
if !c.claimed {
c.claimed = true
return c.Conn, nil
}
return nil, errExpired
}
func (c *unclaimedConnection) tick() {
c.access.Lock()
defer c.access.Unlock()
if !c.claimed {
c.claimed = true
c.Conn.Close()
c.Conn = nil
}
}