github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/transport/internet/http/dialer.go (about)

     1  package http
     2  
     3  import (
     4  	"context"
     5  	gotls "crypto/tls"
     6  	"io"
     7  	"net/http"
     8  	"net/url"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/xtls/xray-core/common"
    13  	"github.com/xtls/xray-core/common/buf"
    14  	"github.com/xtls/xray-core/common/net"
    15  	"github.com/xtls/xray-core/common/net/cnc"
    16  	"github.com/xtls/xray-core/common/session"
    17  	"github.com/xtls/xray-core/transport/internet"
    18  	"github.com/xtls/xray-core/transport/internet/reality"
    19  	"github.com/xtls/xray-core/transport/internet/stat"
    20  	"github.com/xtls/xray-core/transport/internet/tls"
    21  	"github.com/xtls/xray-core/transport/pipe"
    22  	"golang.org/x/net/http2"
    23  )
    24  
    25  type dialerConf struct {
    26  	net.Destination
    27  	*internet.MemoryStreamConfig
    28  }
    29  
    30  var (
    31  	globalDialerMap    map[dialerConf]*http.Client
    32  	globalDialerAccess sync.Mutex
    33  )
    34  
    35  func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (*http.Client, error) {
    36  	globalDialerAccess.Lock()
    37  	defer globalDialerAccess.Unlock()
    38  
    39  	if globalDialerMap == nil {
    40  		globalDialerMap = make(map[dialerConf]*http.Client)
    41  	}
    42  
    43  	httpSettings := streamSettings.ProtocolSettings.(*Config)
    44  	tlsConfigs := tls.ConfigFromStreamSettings(streamSettings)
    45  	realityConfigs := reality.ConfigFromStreamSettings(streamSettings)
    46  	if tlsConfigs == nil && realityConfigs == nil {
    47  		return nil, newError("TLS or REALITY must be enabled for http transport.").AtWarning()
    48  	}
    49  	sockopt := streamSettings.SocketSettings
    50  
    51  	if client, found := globalDialerMap[dialerConf{dest, streamSettings}]; found {
    52  		return client, nil
    53  	}
    54  
    55  	transport := &http2.Transport{
    56  		DialTLSContext: func(hctx context.Context, string, addr string, tlsConfig *gotls.Config) (net.Conn, error) {
    57  			rawHost, rawPort, err := net.SplitHostPort(addr)
    58  			if err != nil {
    59  				return nil, err
    60  			}
    61  			if len(rawPort) == 0 {
    62  				rawPort = "443"
    63  			}
    64  			port, err := net.PortFromString(rawPort)
    65  			if err != nil {
    66  				return nil, err
    67  			}
    68  			address := net.ParseAddress(rawHost)
    69  
    70  			hctx = session.ContextWithID(hctx, session.IDFromContext(ctx))
    71  			hctx = session.ContextWithOutbounds(hctx, session.OutboundsFromContext(ctx))
    72  			hctx = session.ContextWithTimeoutOnly(hctx, true)
    73  
    74  			pconn, err := internet.DialSystem(hctx, net.TCPDestination(address, port), sockopt)
    75  			if err != nil {
    76  				newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
    77  				return nil, err
    78  			}
    79  
    80  			if realityConfigs != nil {
    81  				return reality.UClient(pconn, realityConfigs, hctx, dest)
    82  			}
    83  
    84  			var cn tls.Interface
    85  			if fingerprint := tls.GetFingerprint(tlsConfigs.Fingerprint); fingerprint != nil {
    86  				cn = tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn)
    87  			} else {
    88  				cn = tls.Client(pconn, tlsConfig).(*tls.Conn)
    89  			}
    90  			if err := cn.HandshakeContext(ctx); err != nil {
    91  				newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
    92  				return nil, err
    93  			}
    94  			if !tlsConfig.InsecureSkipVerify {
    95  				if err := cn.VerifyHostname(tlsConfig.ServerName); err != nil {
    96  					newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
    97  					return nil, err
    98  				}
    99  			}
   100  			negotiatedProtocol := cn.NegotiatedProtocol()
   101  			if negotiatedProtocol != http2.NextProtoTLS {
   102  				return nil, newError("http2: unexpected ALPN protocol " + negotiatedProtocol + "; want q" + http2.NextProtoTLS).AtError()
   103  			}
   104  			return cn, nil
   105  		},
   106  	}
   107  
   108  	if tlsConfigs != nil {
   109  		transport.TLSClientConfig = tlsConfigs.GetTLSConfig(tls.WithDestination(dest))
   110  	}
   111  
   112  	if httpSettings.IdleTimeout > 0 || httpSettings.HealthCheckTimeout > 0 {
   113  		transport.ReadIdleTimeout = time.Second * time.Duration(httpSettings.IdleTimeout)
   114  		transport.PingTimeout = time.Second * time.Duration(httpSettings.HealthCheckTimeout)
   115  	}
   116  
   117  	client := &http.Client{
   118  		Transport: transport,
   119  	}
   120  
   121  	globalDialerMap[dialerConf{dest, streamSettings}] = client
   122  	return client, nil
   123  }
   124  
   125  // Dial dials a new TCP connection to the given destination.
   126  func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) {
   127  	httpSettings := streamSettings.ProtocolSettings.(*Config)
   128  	client, err := getHTTPClient(ctx, dest, streamSettings)
   129  	if err != nil {
   130  		return nil, err
   131  	}
   132  
   133  	opts := pipe.OptionsFromContext(ctx)
   134  	preader, pwriter := pipe.New(opts...)
   135  	breader := &buf.BufferedReader{Reader: preader}
   136  
   137  	httpMethod := "PUT"
   138  	if httpSettings.Method != "" {
   139  		httpMethod = httpSettings.Method
   140  	}
   141  
   142  	httpHeaders := make(http.Header)
   143  
   144  	for _, httpHeader := range httpSettings.Header {
   145  		for _, httpHeaderValue := range httpHeader.Value {
   146  			httpHeaders.Set(httpHeader.Name, httpHeaderValue)
   147  		}
   148  	}
   149  
   150  	request := &http.Request{
   151  		Method: httpMethod,
   152  		Host:   httpSettings.getRandomHost(),
   153  		Body:   breader,
   154  		URL: &url.URL{
   155  			Scheme: "https",
   156  			Host:   dest.NetAddr(),
   157  			Path:   httpSettings.getNormalizedPath(),
   158  		},
   159  		Proto:      "HTTP/2",
   160  		ProtoMajor: 2,
   161  		ProtoMinor: 0,
   162  		Header:     httpHeaders,
   163  	}
   164  	// Disable any compression method from server.
   165  	request.Header.Set("Accept-Encoding", "identity")
   166  
   167  	wrc := &WaitReadCloser{Wait: make(chan struct{})}
   168  	go func() {
   169  		response, err := client.Do(request)
   170  		if err != nil {
   171  			newError("failed to dial to ", dest).Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx))
   172  			wrc.Close()
   173  			{
   174  				// Abandon `client` if `client.Do(request)` failed
   175  				// See https://github.com/golang/go/issues/30702
   176  				globalDialerAccess.Lock()
   177  				if globalDialerMap[dialerConf{dest, streamSettings}] == client {
   178  					delete(globalDialerMap, dialerConf{dest, streamSettings})
   179  				}
   180  				globalDialerAccess.Unlock()
   181  			}
   182  			return
   183  		}
   184  		if response.StatusCode != 200 {
   185  			newError("unexpected status", response.StatusCode).AtWarning().WriteToLog(session.ExportIDToError(ctx))
   186  			wrc.Close()
   187  			return
   188  		}
   189  		wrc.Set(response.Body)
   190  	}()
   191  
   192  	bwriter := buf.NewBufferedWriter(pwriter)
   193  	common.Must(bwriter.SetBuffered(false))
   194  	return cnc.NewConnection(
   195  		cnc.ConnectionOutput(wrc),
   196  		cnc.ConnectionInput(bwriter),
   197  		cnc.ConnectionOnClose(common.ChainedClosable{breader, bwriter, wrc}),
   198  	), nil
   199  }
   200  
   201  func init() {
   202  	common.Must(internet.RegisterTransportDialer(protocolName, Dial))
   203  }
   204  
   205  type WaitReadCloser struct {
   206  	Wait chan struct{}
   207  	io.ReadCloser
   208  }
   209  
   210  func (w *WaitReadCloser) Set(rc io.ReadCloser) {
   211  	w.ReadCloser = rc
   212  	defer func() {
   213  		if recover() != nil {
   214  			rc.Close()
   215  		}
   216  	}()
   217  	close(w.Wait)
   218  }
   219  
   220  func (w *WaitReadCloser) Read(b []byte) (int, error) {
   221  	if w.ReadCloser == nil {
   222  		if <-w.Wait; w.ReadCloser == nil {
   223  			return 0, io.ErrClosedPipe
   224  		}
   225  	}
   226  	return w.ReadCloser.Read(b)
   227  }
   228  
   229  func (w *WaitReadCloser) Close() error {
   230  	if w.ReadCloser != nil {
   231  		return w.ReadCloser.Close()
   232  	}
   233  	defer func() {
   234  		if recover() != nil && w.ReadCloser != nil {
   235  			w.ReadCloser.Close()
   236  		}
   237  	}()
   238  	close(w.Wait)
   239  	return nil
   240  }