github.com/xmplusdev/xray-core@v1.8.10/transport/internet/http/dialer.go (about)

     1  package http
     2  
     3  import (
     4  	"context"
     5  	gotls "crypto/tls"
     6  	"io"
     7  	"net/http"
     8  	"net/url"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/xmplusdev/xray-core/common"
    13  	"github.com/xmplusdev/xray-core/common/buf"
    14  	"github.com/xmplusdev/xray-core/common/net"
    15  	"github.com/xmplusdev/xray-core/common/net/cnc"
    16  	"github.com/xmplusdev/xray-core/common/session"
    17  	"github.com/xmplusdev/xray-core/transport/internet"
    18  	"github.com/xmplusdev/xray-core/transport/internet/reality"
    19  	"github.com/xmplusdev/xray-core/transport/internet/stat"
    20  	"github.com/xmplusdev/xray-core/transport/internet/tls"
    21  	"github.com/xmplusdev/xray-core/transport/pipe"
    22  	"golang.org/x/net/http2"
    23  )
    24  
    25  type dialerConf struct {
    26  	net.Destination
    27  	*internet.MemoryStreamConfig
    28  }
    29  
    30  var (
    31  	globalDialerMap    map[dialerConf]*http.Client
    32  	globalDialerAccess sync.Mutex
    33  )
    34  
    35  func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (*http.Client, error) {
    36  	globalDialerAccess.Lock()
    37  	defer globalDialerAccess.Unlock()
    38  
    39  	if globalDialerMap == nil {
    40  		globalDialerMap = make(map[dialerConf]*http.Client)
    41  	}
    42  
    43  	httpSettings := streamSettings.ProtocolSettings.(*Config)
    44  	tlsConfigs := tls.ConfigFromStreamSettings(streamSettings)
    45  	realityConfigs := reality.ConfigFromStreamSettings(streamSettings)
    46  	if tlsConfigs == nil && realityConfigs == nil {
    47  		return nil, newError("TLS or REALITY must be enabled for http transport.").AtWarning()
    48  	}
    49  	sockopt := streamSettings.SocketSettings
    50  
    51  	if client, found := globalDialerMap[dialerConf{dest, streamSettings}]; found {
    52  		return client, nil
    53  	}
    54  
    55  	transport := &http2.Transport{
    56  		DialTLSContext: func(hctx context.Context, string, addr string, tlsConfig *gotls.Config) (net.Conn, error) {
    57  			rawHost, rawPort, err := net.SplitHostPort(addr)
    58  			if err != nil {
    59  				return nil, err
    60  			}
    61  			if len(rawPort) == 0 {
    62  				rawPort = "443"
    63  			}
    64  			port, err := net.PortFromString(rawPort)
    65  			if err != nil {
    66  				return nil, err
    67  			}
    68  			address := net.ParseAddress(rawHost)
    69  
    70  			hctx = session.ContextWithID(hctx, session.IDFromContext(ctx))
    71  			hctx = session.ContextWithOutbound(hctx, session.OutboundFromContext(ctx))
    72  			hctx = session.ContextWithTimeoutOnly(hctx, true)
    73  
    74  			pconn, err := internet.DialSystem(hctx, net.TCPDestination(address, port), sockopt)
    75  			if err != nil {
    76  				newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
    77  				return nil, err
    78  			}
    79  
    80  			if realityConfigs != nil {
    81  				return reality.UClient(pconn, realityConfigs, hctx, dest)
    82  			}
    83  
    84  			var cn tls.Interface
    85  			if fingerprint := tls.GetFingerprint(tlsConfigs.Fingerprint); fingerprint != nil {
    86  				cn = tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn)
    87  			} else {
    88  				cn = tls.Client(pconn, tlsConfig).(*tls.Conn)
    89  			}
    90  			if err := cn.HandshakeContext(ctx); err != nil {
    91  				newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
    92  				return nil, err
    93  			}
    94  			if !tlsConfig.InsecureSkipVerify {
    95  				if err := cn.VerifyHostname(tlsConfig.ServerName); err != nil {
    96  					newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
    97  					return nil, err
    98  				}
    99  			}
   100  			negotiatedProtocol, negotiatedProtocolIsMutual := cn.NegotiatedProtocol()
   101  			if negotiatedProtocol != http2.NextProtoTLS {
   102  				return nil, newError("http2: unexpected ALPN protocol " + negotiatedProtocol + "; want q" + http2.NextProtoTLS).AtError()
   103  			}
   104  			if !negotiatedProtocolIsMutual {
   105  				return nil, newError("http2: could not negotiate protocol mutually").AtError()
   106  			}
   107  			return cn, nil
   108  		},
   109  	}
   110  
   111  	if tlsConfigs != nil {
   112  		transport.TLSClientConfig = tlsConfigs.GetTLSConfig(tls.WithDestination(dest))
   113  	}
   114  
   115  	if httpSettings.IdleTimeout > 0 || httpSettings.HealthCheckTimeout > 0 {
   116  		transport.ReadIdleTimeout = time.Second * time.Duration(httpSettings.IdleTimeout)
   117  		transport.PingTimeout = time.Second * time.Duration(httpSettings.HealthCheckTimeout)
   118  	}
   119  
   120  	client := &http.Client{
   121  		Transport: transport,
   122  	}
   123  
   124  	globalDialerMap[dialerConf{dest, streamSettings}] = client
   125  	return client, nil
   126  }
   127  
   128  // Dial dials a new TCP connection to the given destination.
   129  func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) {
   130  	httpSettings := streamSettings.ProtocolSettings.(*Config)
   131  	client, err := getHTTPClient(ctx, dest, streamSettings)
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  
   136  	opts := pipe.OptionsFromContext(ctx)
   137  	preader, pwriter := pipe.New(opts...)
   138  	breader := &buf.BufferedReader{Reader: preader}
   139  
   140  	httpMethod := "PUT"
   141  	if httpSettings.Method != "" {
   142  		httpMethod = httpSettings.Method
   143  	}
   144  
   145  	httpHeaders := make(http.Header)
   146  
   147  	for _, httpHeader := range httpSettings.Header {
   148  		for _, httpHeaderValue := range httpHeader.Value {
   149  			httpHeaders.Set(httpHeader.Name, httpHeaderValue)
   150  		}
   151  	}
   152  
   153  	request := &http.Request{
   154  		Method: httpMethod,
   155  		Host:   httpSettings.getRandomHost(),
   156  		Body:   breader,
   157  		URL: &url.URL{
   158  			Scheme: "https",
   159  			Host:   dest.NetAddr(),
   160  			Path:   httpSettings.getNormalizedPath(),
   161  		},
   162  		Proto:      "HTTP/2",
   163  		ProtoMajor: 2,
   164  		ProtoMinor: 0,
   165  		Header:     httpHeaders,
   166  	}
   167  	// Disable any compression method from server.
   168  	request.Header.Set("Accept-Encoding", "identity")
   169  
   170  	wrc := &WaitReadCloser{Wait: make(chan struct{})}
   171  	go func() {
   172  		response, err := client.Do(request)
   173  		if err != nil {
   174  			newError("failed to dial to ", dest).Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx))
   175  			wrc.Close()
   176  			{
   177  				// Abandon `client` if `client.Do(request)` failed
   178  				// See https://github.com/golang/go/issues/30702
   179  				globalDialerAccess.Lock()
   180  				if globalDialerMap[dialerConf{dest, streamSettings}] == client {
   181  					delete(globalDialerMap, dialerConf{dest, streamSettings})
   182  				}
   183  				globalDialerAccess.Unlock()
   184  			}
   185  			return
   186  		}
   187  		if response.StatusCode != 200 {
   188  			newError("unexpected status", response.StatusCode).AtWarning().WriteToLog(session.ExportIDToError(ctx))
   189  			wrc.Close()
   190  			return
   191  		}
   192  		wrc.Set(response.Body)
   193  	}()
   194  
   195  	bwriter := buf.NewBufferedWriter(pwriter)
   196  	common.Must(bwriter.SetBuffered(false))
   197  	return cnc.NewConnection(
   198  		cnc.ConnectionOutput(wrc),
   199  		cnc.ConnectionInput(bwriter),
   200  		cnc.ConnectionOnClose(common.ChainedClosable{breader, bwriter, wrc}),
   201  	), nil
   202  }
   203  
   204  func init() {
   205  	common.Must(internet.RegisterTransportDialer(protocolName, Dial))
   206  }
   207  
   208  type WaitReadCloser struct {
   209  	Wait chan struct{}
   210  	io.ReadCloser
   211  }
   212  
   213  func (w *WaitReadCloser) Set(rc io.ReadCloser) {
   214  	w.ReadCloser = rc
   215  	defer func() {
   216  		if recover() != nil {
   217  			rc.Close()
   218  		}
   219  	}()
   220  	close(w.Wait)
   221  }
   222  
   223  func (w *WaitReadCloser) Read(b []byte) (int, error) {
   224  	if w.ReadCloser == nil {
   225  		if <-w.Wait; w.ReadCloser == nil {
   226  			return 0, io.ErrClosedPipe
   227  		}
   228  	}
   229  	return w.ReadCloser.Read(b)
   230  }
   231  
   232  func (w *WaitReadCloser) Close() error {
   233  	if w.ReadCloser != nil {
   234  		return w.ReadCloser.Close()
   235  	}
   236  	defer func() {
   237  		if recover() != nil && w.ReadCloser != nil {
   238  			w.ReadCloser.Close()
   239  		}
   240  	}()
   241  	close(w.Wait)
   242  	return nil
   243  }