github.com/v2fly/v2ray-core/v5@v5.16.2-0.20240507031116-8191faa6e095/transport/internet/transportcommon/httpDialer.go (about) 1 package transportcommon 2 3 import ( 4 "context" 5 "crypto/tls" 6 "errors" 7 "fmt" 8 "net" 9 "net/http" 10 "sync" 11 "time" 12 13 "golang.org/x/net/http2" 14 15 "github.com/v2fly/v2ray-core/v5/transport/internet/security" 16 ) 17 18 type DialerFunc func(ctx context.Context, addr string) (net.Conn, error) 19 20 // NewALPNAwareHTTPRoundTripper creates an instance of RoundTripper that dial to remote HTTPS endpoint with 21 // an alternative version of TLS implementation. 22 func NewALPNAwareHTTPRoundTripper(ctx context.Context, dialer DialerFunc, 23 backdropTransport http.RoundTripper, 24 ) http.RoundTripper { 25 rtImpl := &alpnAwareHTTPRoundTripperImpl{ 26 connectWithH1: map[string]bool{}, 27 backdropTransport: backdropTransport, 28 pendingConn: map[pendingConnKey]*unclaimedConnection{}, 29 dialer: dialer, 30 ctx: ctx, 31 } 32 rtImpl.init() 33 return rtImpl 34 } 35 36 type alpnAwareHTTPRoundTripperImpl struct { 37 accessConnectWithH1 sync.Mutex 38 connectWithH1 map[string]bool 39 40 httpsH1Transport http.RoundTripper 41 httpsH2Transport http.RoundTripper 42 backdropTransport http.RoundTripper 43 44 accessDialingConnection sync.Mutex 45 pendingConn map[pendingConnKey]*unclaimedConnection 46 47 ctx context.Context 48 dialer DialerFunc 49 } 50 51 type pendingConnKey struct { 52 isH2 bool 53 dest string 54 } 55 56 var ( 57 errEAGAIN = errors.New("incorrect ALPN negotiated, try again with another ALPN") 58 errEAGAINTooMany = errors.New("incorrect ALPN negotiated") 59 errExpired = errors.New("connection have expired") 60 ) 61 62 func (r *alpnAwareHTTPRoundTripperImpl) RoundTrip(req *http.Request) (*http.Response, error) { 63 if req.URL.Scheme != "https" { 64 return r.backdropTransport.RoundTrip(req) 65 } 66 for retryCount := 0; retryCount < 5; retryCount++ { 67 effectivePort := req.URL.Port() 68 if effectivePort == "" { 69 effectivePort = "443" 70 } 71 if r.getShouldConnectWithH1(fmt.Sprintf("%v:%v", req.URL.Hostname(), effectivePort)) { 72 resp, err := r.httpsH1Transport.RoundTrip(req) 73 if errors.Is(err, errEAGAIN) { 74 continue 75 } 76 return resp, err 77 } 78 resp, err := r.httpsH2Transport.RoundTrip(req) 79 if errors.Is(err, errEAGAIN) { 80 continue 81 } 82 return resp, err 83 } 84 return nil, errEAGAINTooMany 85 } 86 87 func (r *alpnAwareHTTPRoundTripperImpl) getShouldConnectWithH1(domainName string) bool { 88 r.accessConnectWithH1.Lock() 89 defer r.accessConnectWithH1.Unlock() 90 if value, set := r.connectWithH1[domainName]; set { 91 return value 92 } 93 return false 94 } 95 96 func (r *alpnAwareHTTPRoundTripperImpl) setShouldConnectWithH1(domainName string) { 97 r.accessConnectWithH1.Lock() 98 defer r.accessConnectWithH1.Unlock() 99 r.connectWithH1[domainName] = true 100 } 101 102 func (r *alpnAwareHTTPRoundTripperImpl) clearShouldConnectWithH1(domainName string) { 103 r.accessConnectWithH1.Lock() 104 defer r.accessConnectWithH1.Unlock() 105 r.connectWithH1[domainName] = false 106 } 107 108 func getPendingConnectionID(dest string, alpnIsH2 bool) pendingConnKey { 109 return pendingConnKey{isH2: alpnIsH2, dest: dest} 110 } 111 112 func (r *alpnAwareHTTPRoundTripperImpl) putConn(addr string, alpnIsH2 bool, conn net.Conn) { 113 connID := getPendingConnectionID(addr, alpnIsH2) 114 r.pendingConn[connID] = NewUnclaimedConnection(conn, time.Minute) 115 } 116 117 func (r *alpnAwareHTTPRoundTripperImpl) getConn(addr string, alpnIsH2 bool) net.Conn { 118 connID := getPendingConnectionID(addr, alpnIsH2) 119 if conn, ok := r.pendingConn[connID]; ok { 120 delete(r.pendingConn, connID) 121 if claimedConnection, err := conn.claimConnection(); err == nil { 122 return claimedConnection 123 } 124 } 125 return nil 126 } 127 128 func (r *alpnAwareHTTPRoundTripperImpl) dialOrGetTLSWithExpectedALPN(ctx context.Context, addr string, expectedH2 bool) (net.Conn, error) { 129 r.accessDialingConnection.Lock() 130 defer r.accessDialingConnection.Unlock() 131 132 if r.getShouldConnectWithH1(addr) == expectedH2 { 133 return nil, errEAGAIN 134 } 135 136 // Get a cached connection if possible to reduce preflight connection closed without sending data 137 if gconn := r.getConn(addr, expectedH2); gconn != nil { 138 return gconn, nil 139 } 140 141 conn, err := r.dialTLS(ctx, addr) 142 if err != nil { 143 return nil, err 144 } 145 146 protocol := "" 147 if connAPLNGetter, ok := conn.(security.ConnectionApplicationProtocol); ok { 148 connectionALPN, err := connAPLNGetter.GetConnectionApplicationProtocol() 149 if err != nil { 150 return nil, newError("failed to get connection ALPN").Base(err).AtWarning() 151 } 152 protocol = connectionALPN 153 } 154 155 protocolIsH2 := protocol == http2.NextProtoTLS 156 157 if protocolIsH2 == expectedH2 { 158 return conn, err 159 } 160 161 r.putConn(addr, protocolIsH2, conn) 162 163 if protocolIsH2 { 164 r.clearShouldConnectWithH1(addr) 165 } else { 166 r.setShouldConnectWithH1(addr) 167 } 168 169 return nil, errEAGAIN 170 } 171 172 func (r *alpnAwareHTTPRoundTripperImpl) dialTLS(ctx context.Context, addr string) (net.Conn, error) { 173 _ = ctx 174 return r.dialer(r.ctx, addr) 175 } 176 177 func (r *alpnAwareHTTPRoundTripperImpl) init() { 178 r.httpsH2Transport = &http2.Transport{ 179 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { 180 return r.dialOrGetTLSWithExpectedALPN(context.Background(), addr, true) 181 }, 182 } 183 r.httpsH1Transport = &http.Transport{ 184 DialTLSContext: func(ctx context.Context, network string, addr string) (net.Conn, error) { 185 return r.dialOrGetTLSWithExpectedALPN(ctx, addr, false) 186 }, 187 } 188 } 189 190 func NewUnclaimedConnection(conn net.Conn, expireTime time.Duration) *unclaimedConnection { 191 c := &unclaimedConnection{ 192 Conn: conn, 193 } 194 time.AfterFunc(expireTime, c.tick) 195 return c 196 } 197 198 type unclaimedConnection struct { 199 net.Conn 200 claimed bool 201 access sync.Mutex 202 } 203 204 func (c *unclaimedConnection) claimConnection() (net.Conn, error) { 205 c.access.Lock() 206 defer c.access.Unlock() 207 if !c.claimed { 208 c.claimed = true 209 return c.Conn, nil 210 } 211 return nil, errExpired 212 } 213 214 func (c *unclaimedConnection) tick() { 215 c.access.Lock() 216 defer c.access.Unlock() 217 if !c.claimed { 218 c.claimed = true 219 c.Conn.Close() 220 c.Conn = nil 221 } 222 }