github.com/xraypb/Xray-core@v1.8.1/transport/internet/http/dialer.go (about)

     1  package http
     2  
     3  import (
     4  	"context"
     5  	gotls "crypto/tls"
     6  	"io"
     7  	"net/http"
     8  	"net/url"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/xraypb/Xray-core/common"
    13  	"github.com/xraypb/Xray-core/common/buf"
    14  	"github.com/xraypb/Xray-core/common/net"
    15  	"github.com/xraypb/Xray-core/common/net/cnc"
    16  	"github.com/xraypb/Xray-core/common/session"
    17  	"github.com/xraypb/Xray-core/transport/internet"
    18  	"github.com/xraypb/Xray-core/transport/internet/reality"
    19  	"github.com/xraypb/Xray-core/transport/internet/stat"
    20  	"github.com/xraypb/Xray-core/transport/internet/tls"
    21  	"github.com/xraypb/Xray-core/transport/pipe"
    22  	"golang.org/x/net/http2"
    23  )
    24  
    25  type dialerConf struct {
    26  	net.Destination
    27  	*internet.MemoryStreamConfig
    28  }
    29  
    30  var (
    31  	globalDialerMap    map[dialerConf]*http.Client
    32  	globalDialerAccess sync.Mutex
    33  )
    34  
    35  func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (*http.Client, error) {
    36  	globalDialerAccess.Lock()
    37  	defer globalDialerAccess.Unlock()
    38  
    39  	if globalDialerMap == nil {
    40  		globalDialerMap = make(map[dialerConf]*http.Client)
    41  	}
    42  
    43  	httpSettings := streamSettings.ProtocolSettings.(*Config)
    44  	tlsConfigs := tls.ConfigFromStreamSettings(streamSettings)
    45  	realityConfigs := reality.ConfigFromStreamSettings(streamSettings)
    46  	if tlsConfigs == nil && realityConfigs == nil {
    47  		return nil, newError("TLS or REALITY must be enabled for http transport.").AtWarning()
    48  	}
    49  	sockopt := streamSettings.SocketSettings
    50  
    51  	if client, found := globalDialerMap[dialerConf{dest, streamSettings}]; found {
    52  		return client, nil
    53  	}
    54  
    55  	transport := &http2.Transport{
    56  		DialTLS: func(network string, addr string, tlsConfig *gotls.Config) (net.Conn, error) {
    57  			rawHost, rawPort, err := net.SplitHostPort(addr)
    58  			if err != nil {
    59  				return nil, err
    60  			}
    61  			if len(rawPort) == 0 {
    62  				rawPort = "443"
    63  			}
    64  			port, err := net.PortFromString(rawPort)
    65  			if err != nil {
    66  				return nil, err
    67  			}
    68  			address := net.ParseAddress(rawHost)
    69  
    70  			dctx := context.Background()
    71  			dctx = session.ContextWithID(dctx, session.IDFromContext(ctx))
    72  			dctx = session.ContextWithOutbound(dctx, session.OutboundFromContext(ctx))
    73  
    74  			pconn, err := internet.DialSystem(dctx, net.TCPDestination(address, port), sockopt)
    75  			if err != nil {
    76  				newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
    77  				return nil, err
    78  			}
    79  
    80  			if realityConfigs != nil {
    81  				return reality.UClient(pconn, realityConfigs, ctx, dest)
    82  			}
    83  
    84  			var cn tls.Interface
    85  			if fingerprint := tls.GetFingerprint(tlsConfigs.Fingerprint); fingerprint != nil {
    86  				cn = tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn)
    87  			} else {
    88  				cn = tls.Client(pconn, tlsConfig).(*tls.Conn)
    89  			}
    90  			if err := cn.Handshake(); err != nil {
    91  				newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
    92  				return nil, err
    93  			}
    94  			if !tlsConfig.InsecureSkipVerify {
    95  				if err := cn.VerifyHostname(tlsConfig.ServerName); err != nil {
    96  					newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
    97  					return nil, err
    98  				}
    99  			}
   100  			negotiatedProtocol, negotiatedProtocolIsMutual := cn.NegotiatedProtocol()
   101  			if negotiatedProtocol != http2.NextProtoTLS {
   102  				return nil, newError("http2: unexpected ALPN protocol " + negotiatedProtocol + "; want q" + http2.NextProtoTLS).AtError()
   103  			}
   104  			if !negotiatedProtocolIsMutual {
   105  				return nil, newError("http2: could not negotiate protocol mutually").AtError()
   106  			}
   107  			return cn, nil
   108  		},
   109  	}
   110  
   111  	if tlsConfigs != nil {
   112  		transport.TLSClientConfig = tlsConfigs.GetTLSConfig(tls.WithDestination(dest))
   113  	}
   114  
   115  	if httpSettings.IdleTimeout > 0 || httpSettings.HealthCheckTimeout > 0 {
   116  		transport.ReadIdleTimeout = time.Second * time.Duration(httpSettings.IdleTimeout)
   117  		transport.PingTimeout = time.Second * time.Duration(httpSettings.HealthCheckTimeout)
   118  	}
   119  
   120  	client := &http.Client{
   121  		Transport: transport,
   122  	}
   123  
   124  	globalDialerMap[dialerConf{dest, streamSettings}] = client
   125  	return client, nil
   126  }
   127  
   128  // Dial dials a new TCP connection to the given destination.
   129  func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) {
   130  	httpSettings := streamSettings.ProtocolSettings.(*Config)
   131  	client, err := getHTTPClient(ctx, dest, streamSettings)
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  
   136  	opts := pipe.OptionsFromContext(ctx)
   137  	preader, pwriter := pipe.New(opts...)
   138  	breader := &buf.BufferedReader{Reader: preader}
   139  
   140  	httpMethod := "PUT"
   141  	if httpSettings.Method != "" {
   142  		httpMethod = httpSettings.Method
   143  	}
   144  
   145  	httpHeaders := make(http.Header)
   146  
   147  	for _, httpHeader := range httpSettings.Header {
   148  		for _, httpHeaderValue := range httpHeader.Value {
   149  			httpHeaders.Set(httpHeader.Name, httpHeaderValue)
   150  		}
   151  	}
   152  
   153  	request := &http.Request{
   154  		Method: httpMethod,
   155  		Host:   httpSettings.getRandomHost(),
   156  		Body:   breader,
   157  		URL: &url.URL{
   158  			Scheme: "https",
   159  			Host:   dest.NetAddr(),
   160  			Path:   httpSettings.getNormalizedPath(),
   161  		},
   162  		Proto:      "HTTP/2",
   163  		ProtoMajor: 2,
   164  		ProtoMinor: 0,
   165  		Header:     httpHeaders,
   166  	}
   167  	// Disable any compression method from server.
   168  	request.Header.Set("Accept-Encoding", "identity")
   169  
   170  	wrc := &WaitReadCloser{Wait: make(chan struct{})}
   171  	go func() {
   172  		response, err := client.Do(request)
   173  		if err != nil {
   174  			newError("failed to dial to ", dest).Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx))
   175  			wrc.Close()
   176  			return
   177  		}
   178  		if response.StatusCode != 200 {
   179  			newError("unexpected status", response.StatusCode).AtWarning().WriteToLog(session.ExportIDToError(ctx))
   180  			wrc.Close()
   181  			return
   182  		}
   183  		wrc.Set(response.Body)
   184  	}()
   185  
   186  	bwriter := buf.NewBufferedWriter(pwriter)
   187  	common.Must(bwriter.SetBuffered(false))
   188  	return cnc.NewConnection(
   189  		cnc.ConnectionOutput(wrc),
   190  		cnc.ConnectionInput(bwriter),
   191  		cnc.ConnectionOnClose(common.ChainedClosable{breader, bwriter, wrc}),
   192  	), nil
   193  }
   194  
   195  func init() {
   196  	common.Must(internet.RegisterTransportDialer(protocolName, Dial))
   197  }
   198  
   199  type WaitReadCloser struct {
   200  	Wait chan struct{}
   201  	io.ReadCloser
   202  }
   203  
   204  func (w *WaitReadCloser) Set(rc io.ReadCloser) {
   205  	w.ReadCloser = rc
   206  	defer func() {
   207  		if recover() != nil {
   208  			rc.Close()
   209  		}
   210  	}()
   211  	close(w.Wait)
   212  }
   213  
   214  func (w *WaitReadCloser) Read(b []byte) (int, error) {
   215  	if w.ReadCloser == nil {
   216  		if <-w.Wait; w.ReadCloser == nil {
   217  			return 0, io.ErrClosedPipe
   218  		}
   219  	}
   220  	return w.ReadCloser.Read(b)
   221  }
   222  
   223  func (w *WaitReadCloser) Close() error {
   224  	if w.ReadCloser != nil {
   225  		return w.ReadCloser.Close()
   226  	}
   227  	defer func() {
   228  		if recover() != nil && w.ReadCloser != nil {
   229  			w.ReadCloser.Close()
   230  		}
   231  	}()
   232  	close(w.Wait)
   233  	return nil
   234  }