github.com/v2fly/v2ray-core/v5@v5.16.2-0.20240507031116-8191faa6e095/transport/internet/transportcommon/httpDialer.go (about)

     1  package transportcommon
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"fmt"
     8  	"net"
     9  	"net/http"
    10  	"sync"
    11  	"time"
    12  
    13  	"golang.org/x/net/http2"
    14  
    15  	"github.com/v2fly/v2ray-core/v5/transport/internet/security"
    16  )
    17  
    18  type DialerFunc func(ctx context.Context, addr string) (net.Conn, error)
    19  
    20  // NewALPNAwareHTTPRoundTripper creates an instance of RoundTripper that dial to remote HTTPS endpoint with
    21  // an alternative version of TLS implementation.
    22  func NewALPNAwareHTTPRoundTripper(ctx context.Context, dialer DialerFunc,
    23  	backdropTransport http.RoundTripper,
    24  ) http.RoundTripper {
    25  	rtImpl := &alpnAwareHTTPRoundTripperImpl{
    26  		connectWithH1:     map[string]bool{},
    27  		backdropTransport: backdropTransport,
    28  		pendingConn:       map[pendingConnKey]*unclaimedConnection{},
    29  		dialer:            dialer,
    30  		ctx:               ctx,
    31  	}
    32  	rtImpl.init()
    33  	return rtImpl
    34  }
    35  
    36  type alpnAwareHTTPRoundTripperImpl struct {
    37  	accessConnectWithH1 sync.Mutex
    38  	connectWithH1       map[string]bool
    39  
    40  	httpsH1Transport  http.RoundTripper
    41  	httpsH2Transport  http.RoundTripper
    42  	backdropTransport http.RoundTripper
    43  
    44  	accessDialingConnection sync.Mutex
    45  	pendingConn             map[pendingConnKey]*unclaimedConnection
    46  
    47  	ctx    context.Context
    48  	dialer DialerFunc
    49  }
    50  
    51  type pendingConnKey struct {
    52  	isH2 bool
    53  	dest string
    54  }
    55  
    56  var (
    57  	errEAGAIN        = errors.New("incorrect ALPN negotiated, try again with another ALPN")
    58  	errEAGAINTooMany = errors.New("incorrect ALPN negotiated")
    59  	errExpired       = errors.New("connection have expired")
    60  )
    61  
    62  func (r *alpnAwareHTTPRoundTripperImpl) RoundTrip(req *http.Request) (*http.Response, error) {
    63  	if req.URL.Scheme != "https" {
    64  		return r.backdropTransport.RoundTrip(req)
    65  	}
    66  	for retryCount := 0; retryCount < 5; retryCount++ {
    67  		effectivePort := req.URL.Port()
    68  		if effectivePort == "" {
    69  			effectivePort = "443"
    70  		}
    71  		if r.getShouldConnectWithH1(fmt.Sprintf("%v:%v", req.URL.Hostname(), effectivePort)) {
    72  			resp, err := r.httpsH1Transport.RoundTrip(req)
    73  			if errors.Is(err, errEAGAIN) {
    74  				continue
    75  			}
    76  			return resp, err
    77  		}
    78  		resp, err := r.httpsH2Transport.RoundTrip(req)
    79  		if errors.Is(err, errEAGAIN) {
    80  			continue
    81  		}
    82  		return resp, err
    83  	}
    84  	return nil, errEAGAINTooMany
    85  }
    86  
    87  func (r *alpnAwareHTTPRoundTripperImpl) getShouldConnectWithH1(domainName string) bool {
    88  	r.accessConnectWithH1.Lock()
    89  	defer r.accessConnectWithH1.Unlock()
    90  	if value, set := r.connectWithH1[domainName]; set {
    91  		return value
    92  	}
    93  	return false
    94  }
    95  
    96  func (r *alpnAwareHTTPRoundTripperImpl) setShouldConnectWithH1(domainName string) {
    97  	r.accessConnectWithH1.Lock()
    98  	defer r.accessConnectWithH1.Unlock()
    99  	r.connectWithH1[domainName] = true
   100  }
   101  
   102  func (r *alpnAwareHTTPRoundTripperImpl) clearShouldConnectWithH1(domainName string) {
   103  	r.accessConnectWithH1.Lock()
   104  	defer r.accessConnectWithH1.Unlock()
   105  	r.connectWithH1[domainName] = false
   106  }
   107  
   108  func getPendingConnectionID(dest string, alpnIsH2 bool) pendingConnKey {
   109  	return pendingConnKey{isH2: alpnIsH2, dest: dest}
   110  }
   111  
   112  func (r *alpnAwareHTTPRoundTripperImpl) putConn(addr string, alpnIsH2 bool, conn net.Conn) {
   113  	connID := getPendingConnectionID(addr, alpnIsH2)
   114  	r.pendingConn[connID] = NewUnclaimedConnection(conn, time.Minute)
   115  }
   116  
   117  func (r *alpnAwareHTTPRoundTripperImpl) getConn(addr string, alpnIsH2 bool) net.Conn {
   118  	connID := getPendingConnectionID(addr, alpnIsH2)
   119  	if conn, ok := r.pendingConn[connID]; ok {
   120  		delete(r.pendingConn, connID)
   121  		if claimedConnection, err := conn.claimConnection(); err == nil {
   122  			return claimedConnection
   123  		}
   124  	}
   125  	return nil
   126  }
   127  
   128  func (r *alpnAwareHTTPRoundTripperImpl) dialOrGetTLSWithExpectedALPN(ctx context.Context, addr string, expectedH2 bool) (net.Conn, error) {
   129  	r.accessDialingConnection.Lock()
   130  	defer r.accessDialingConnection.Unlock()
   131  
   132  	if r.getShouldConnectWithH1(addr) == expectedH2 {
   133  		return nil, errEAGAIN
   134  	}
   135  
   136  	// Get a cached connection if possible to reduce preflight connection closed without sending data
   137  	if gconn := r.getConn(addr, expectedH2); gconn != nil {
   138  		return gconn, nil
   139  	}
   140  
   141  	conn, err := r.dialTLS(ctx, addr)
   142  	if err != nil {
   143  		return nil, err
   144  	}
   145  
   146  	protocol := ""
   147  	if connAPLNGetter, ok := conn.(security.ConnectionApplicationProtocol); ok {
   148  		connectionALPN, err := connAPLNGetter.GetConnectionApplicationProtocol()
   149  		if err != nil {
   150  			return nil, newError("failed to get connection ALPN").Base(err).AtWarning()
   151  		}
   152  		protocol = connectionALPN
   153  	}
   154  
   155  	protocolIsH2 := protocol == http2.NextProtoTLS
   156  
   157  	if protocolIsH2 == expectedH2 {
   158  		return conn, err
   159  	}
   160  
   161  	r.putConn(addr, protocolIsH2, conn)
   162  
   163  	if protocolIsH2 {
   164  		r.clearShouldConnectWithH1(addr)
   165  	} else {
   166  		r.setShouldConnectWithH1(addr)
   167  	}
   168  
   169  	return nil, errEAGAIN
   170  }
   171  
   172  func (r *alpnAwareHTTPRoundTripperImpl) dialTLS(ctx context.Context, addr string) (net.Conn, error) {
   173  	_ = ctx
   174  	return r.dialer(r.ctx, addr)
   175  }
   176  
   177  func (r *alpnAwareHTTPRoundTripperImpl) init() {
   178  	r.httpsH2Transport = &http2.Transport{
   179  		DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
   180  			return r.dialOrGetTLSWithExpectedALPN(context.Background(), addr, true)
   181  		},
   182  	}
   183  	r.httpsH1Transport = &http.Transport{
   184  		DialTLSContext: func(ctx context.Context, network string, addr string) (net.Conn, error) {
   185  			return r.dialOrGetTLSWithExpectedALPN(ctx, addr, false)
   186  		},
   187  	}
   188  }
   189  
   190  func NewUnclaimedConnection(conn net.Conn, expireTime time.Duration) *unclaimedConnection {
   191  	c := &unclaimedConnection{
   192  		Conn: conn,
   193  	}
   194  	time.AfterFunc(expireTime, c.tick)
   195  	return c
   196  }
   197  
   198  type unclaimedConnection struct {
   199  	net.Conn
   200  	claimed bool
   201  	access  sync.Mutex
   202  }
   203  
   204  func (c *unclaimedConnection) claimConnection() (net.Conn, error) {
   205  	c.access.Lock()
   206  	defer c.access.Unlock()
   207  	if !c.claimed {
   208  		c.claimed = true
   209  		return c.Conn, nil
   210  	}
   211  	return nil, errExpired
   212  }
   213  
   214  func (c *unclaimedConnection) tick() {
   215  	c.access.Lock()
   216  	defer c.access.Unlock()
   217  	if !c.claimed {
   218  		c.claimed = true
   219  		c.Conn.Close()
   220  		c.Conn = nil
   221  	}
   222  }