github.com/xraypb/xray-core@v1.6.6/transport/internet/http/dialer.go (about)

     1  package http
     2  
     3  import (
     4  	"context"
     5  	gotls "crypto/tls"
     6  	"net/http"
     7  	"net/url"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/xraypb/xray-core/common"
    12  	"github.com/xraypb/xray-core/common/buf"
    13  	"github.com/xraypb/xray-core/common/net"
    14  	"github.com/xraypb/xray-core/common/net/cnc"
    15  	"github.com/xraypb/xray-core/common/session"
    16  	"github.com/xraypb/xray-core/transport/internet"
    17  	"github.com/xraypb/xray-core/transport/internet/stat"
    18  	"github.com/xraypb/xray-core/transport/internet/tls"
    19  	"github.com/xraypb/xray-core/transport/pipe"
    20  	"golang.org/x/net/http2"
    21  )
    22  
    23  type dialerConf struct {
    24  	net.Destination
    25  	*internet.MemoryStreamConfig
    26  }
    27  
    28  var (
    29  	globalDialerMap    map[dialerConf]*http.Client
    30  	globalDialerAccess sync.Mutex
    31  )
    32  
    33  func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (*http.Client, error) {
    34  	globalDialerAccess.Lock()
    35  	defer globalDialerAccess.Unlock()
    36  
    37  	if globalDialerMap == nil {
    38  		globalDialerMap = make(map[dialerConf]*http.Client)
    39  	}
    40  
    41  	httpSettings := streamSettings.ProtocolSettings.(*Config)
    42  	tlsConfigs := tls.ConfigFromStreamSettings(streamSettings)
    43  	if tlsConfigs == nil {
    44  		return nil, newError("TLS must be enabled for http transport.").AtWarning()
    45  	}
    46  	sockopt := streamSettings.SocketSettings
    47  
    48  	if client, found := globalDialerMap[dialerConf{dest, streamSettings}]; found {
    49  		return client, nil
    50  	}
    51  
    52  	transport := &http2.Transport{
    53  		DialTLS: func(network string, addr string, tlsConfig *gotls.Config) (net.Conn, error) {
    54  			rawHost, rawPort, err := net.SplitHostPort(addr)
    55  			if err != nil {
    56  				return nil, err
    57  			}
    58  			if len(rawPort) == 0 {
    59  				rawPort = "443"
    60  			}
    61  			port, err := net.PortFromString(rawPort)
    62  			if err != nil {
    63  				return nil, err
    64  			}
    65  			address := net.ParseAddress(rawHost)
    66  
    67  			dctx := context.Background()
    68  			dctx = session.ContextWithID(dctx, session.IDFromContext(ctx))
    69  			dctx = session.ContextWithOutbound(dctx, session.OutboundFromContext(ctx))
    70  
    71  			pconn, err := internet.DialSystem(dctx, net.TCPDestination(address, port), sockopt)
    72  			if err != nil {
    73  				newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
    74  				return nil, err
    75  			}
    76  
    77  			var cn tls.Interface
    78  			if fingerprint, ok := tls.Fingerprints[tlsConfigs.Fingerprint]; ok {
    79  				cn = tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn)
    80  			} else {
    81  				cn = tls.Client(pconn, tlsConfig).(*tls.Conn)
    82  			}
    83  			if err := cn.Handshake(); err != nil {
    84  				newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
    85  				return nil, err
    86  			}
    87  			if !tlsConfig.InsecureSkipVerify {
    88  				if err := cn.VerifyHostname(tlsConfig.ServerName); err != nil {
    89  					newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
    90  					return nil, err
    91  				}
    92  			}
    93  			negotiatedProtocol, negotiatedProtocolIsMutual := cn.NegotiatedProtocol()
    94  			if negotiatedProtocol != http2.NextProtoTLS {
    95  				return nil, newError("http2: unexpected ALPN protocol " + negotiatedProtocol + "; want q" + http2.NextProtoTLS).AtError()
    96  			}
    97  			if !negotiatedProtocolIsMutual {
    98  				return nil, newError("http2: could not negotiate protocol mutually").AtError()
    99  			}
   100  			return cn, nil
   101  		},
   102  		TLSClientConfig: tlsConfigs.GetTLSConfig(tls.WithDestination(dest)),
   103  	}
   104  
   105  	if httpSettings.IdleTimeout > 0 || httpSettings.HealthCheckTimeout > 0 {
   106  		transport.ReadIdleTimeout = time.Second * time.Duration(httpSettings.IdleTimeout)
   107  		transport.PingTimeout = time.Second * time.Duration(httpSettings.HealthCheckTimeout)
   108  	}
   109  
   110  	client := &http.Client{
   111  		Transport: transport,
   112  	}
   113  
   114  	globalDialerMap[dialerConf{dest, streamSettings}] = client
   115  	return client, nil
   116  }
   117  
   118  // Dial dials a new TCP connection to the given destination.
   119  func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) {
   120  	httpSettings := streamSettings.ProtocolSettings.(*Config)
   121  	client, err := getHTTPClient(ctx, dest, streamSettings)
   122  	if err != nil {
   123  		return nil, err
   124  	}
   125  
   126  	opts := pipe.OptionsFromContext(ctx)
   127  	preader, pwriter := pipe.New(opts...)
   128  	breader := &buf.BufferedReader{Reader: preader}
   129  
   130  	httpMethod := "PUT"
   131  	if httpSettings.Method != "" {
   132  		httpMethod = httpSettings.Method
   133  	}
   134  
   135  	httpHeaders := make(http.Header)
   136  
   137  	for _, httpHeader := range httpSettings.Header {
   138  		for _, httpHeaderValue := range httpHeader.Value {
   139  			httpHeaders.Set(httpHeader.Name, httpHeaderValue)
   140  		}
   141  	}
   142  
   143  	request := &http.Request{
   144  		Method: httpMethod,
   145  		Host:   httpSettings.getRandomHost(),
   146  		Body:   breader,
   147  		URL: &url.URL{
   148  			Scheme: "https",
   149  			Host:   dest.NetAddr(),
   150  			Path:   httpSettings.getNormalizedPath(),
   151  		},
   152  		Proto:      "HTTP/2",
   153  		ProtoMajor: 2,
   154  		ProtoMinor: 0,
   155  		Header:     httpHeaders,
   156  	}
   157  	// Disable any compression method from server.
   158  	request.Header.Set("Accept-Encoding", "identity")
   159  
   160  	response, err := client.Do(request)
   161  	if err != nil {
   162  		return nil, newError("failed to dial to ", dest).Base(err).AtWarning()
   163  	}
   164  	if response.StatusCode != 200 {
   165  		return nil, newError("unexpected status", response.StatusCode).AtWarning()
   166  	}
   167  
   168  	bwriter := buf.NewBufferedWriter(pwriter)
   169  	common.Must(bwriter.SetBuffered(false))
   170  	return cnc.NewConnection(
   171  		cnc.ConnectionOutput(response.Body),
   172  		cnc.ConnectionInput(bwriter),
   173  		cnc.ConnectionOnClose(common.ChainedClosable{breader, bwriter, response.Body}),
   174  	), nil
   175  }
   176  
   177  func init() {
   178  	common.Must(internet.RegisterTransportDialer(protocolName, Dial))
   179  }