github.com/eagleql/xray-core@v1.4.4/transport/internet/http/dialer.go (about)

     1  package http
     2  
     3  import (
     4  	"context"
     5  	gotls "crypto/tls"
     6  	"net/http"
     7  	"net/url"
     8  	"sync"
     9  
    10  	"github.com/eagleql/xray-core/common"
    11  	"github.com/eagleql/xray-core/common/buf"
    12  	"github.com/eagleql/xray-core/common/net"
    13  	"github.com/eagleql/xray-core/common/net/cnc"
    14  	"github.com/eagleql/xray-core/common/session"
    15  	"github.com/eagleql/xray-core/transport/internet"
    16  	"github.com/eagleql/xray-core/transport/internet/tls"
    17  	"github.com/eagleql/xray-core/transport/pipe"
    18  	"golang.org/x/net/http2"
    19  )
    20  
    21  type dialerConf struct {
    22  	net.Destination
    23  	*internet.SocketConfig
    24  	*tls.Config
    25  }
    26  
    27  var (
    28  	globalDialerMap    map[dialerConf]*http.Client
    29  	globalDialerAccess sync.Mutex
    30  )
    31  
    32  func getHTTPClient(ctx context.Context, dest net.Destination, tlsSettings *tls.Config, sockopt *internet.SocketConfig) (*http.Client, error) {
    33  	globalDialerAccess.Lock()
    34  	defer globalDialerAccess.Unlock()
    35  
    36  	if globalDialerMap == nil {
    37  		globalDialerMap = make(map[dialerConf]*http.Client)
    38  	}
    39  
    40  	if client, found := globalDialerMap[dialerConf{dest, sockopt, tlsSettings}]; found {
    41  		return client, nil
    42  	}
    43  
    44  	transport := &http2.Transport{
    45  		DialTLS: func(network string, addr string, tlsConfig *gotls.Config) (net.Conn, error) {
    46  			rawHost, rawPort, err := net.SplitHostPort(addr)
    47  			if err != nil {
    48  				return nil, err
    49  			}
    50  			if len(rawPort) == 0 {
    51  				rawPort = "443"
    52  			}
    53  			port, err := net.PortFromString(rawPort)
    54  			if err != nil {
    55  				return nil, err
    56  			}
    57  			address := net.ParseAddress(rawHost)
    58  
    59  			dctx := context.Background()
    60  			dctx = session.ContextWithID(dctx, session.IDFromContext(ctx))
    61  			dctx = session.ContextWithOutbound(dctx, session.OutboundFromContext(ctx))
    62  
    63  			pconn, err := internet.DialSystem(dctx, net.TCPDestination(address, port), sockopt)
    64  			if err != nil {
    65  				newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
    66  				return nil, err
    67  			}
    68  
    69  			cn := gotls.Client(pconn, tlsConfig)
    70  			if err := cn.Handshake(); err != nil {
    71  				newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
    72  				return nil, err
    73  			}
    74  			if !tlsConfig.InsecureSkipVerify {
    75  				if err := cn.VerifyHostname(tlsConfig.ServerName); err != nil {
    76  					newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
    77  					return nil, err
    78  				}
    79  			}
    80  			state := cn.ConnectionState()
    81  			if p := state.NegotiatedProtocol; p != http2.NextProtoTLS {
    82  				return nil, newError("http2: unexpected ALPN protocol " + p + "; want q" + http2.NextProtoTLS).AtError()
    83  			}
    84  			if !state.NegotiatedProtocolIsMutual {
    85  				return nil, newError("http2: could not negotiate protocol mutually").AtError()
    86  			}
    87  			return cn, nil
    88  		},
    89  		TLSClientConfig: tlsSettings.GetTLSConfig(tls.WithDestination(dest)),
    90  	}
    91  
    92  	client := &http.Client{
    93  		Transport: transport,
    94  	}
    95  
    96  	globalDialerMap[dialerConf{dest, sockopt, tlsSettings}] = client
    97  	return client, nil
    98  }
    99  
   100  // Dial dials a new TCP connection to the given destination.
   101  func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
   102  	httpSettings := streamSettings.ProtocolSettings.(*Config)
   103  	tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
   104  	if tlsConfig == nil {
   105  		return nil, newError("TLS must be enabled for http transport.").AtWarning()
   106  	}
   107  	client, err := getHTTPClient(ctx, dest, tlsConfig, streamSettings.SocketSettings)
   108  	if err != nil {
   109  		return nil, err
   110  	}
   111  
   112  	opts := pipe.OptionsFromContext(ctx)
   113  	preader, pwriter := pipe.New(opts...)
   114  	breader := &buf.BufferedReader{Reader: preader}
   115  	request := &http.Request{
   116  		Method: "PUT",
   117  		Host:   httpSettings.getRandomHost(),
   118  		Body:   breader,
   119  		URL: &url.URL{
   120  			Scheme: "https",
   121  			Host:   dest.NetAddr(),
   122  			Path:   httpSettings.getNormalizedPath(),
   123  		},
   124  		Proto:      "HTTP/2",
   125  		ProtoMajor: 2,
   126  		ProtoMinor: 0,
   127  		Header:     make(http.Header),
   128  	}
   129  	// Disable any compression method from server.
   130  	request.Header.Set("Accept-Encoding", "identity")
   131  
   132  	response, err := client.Do(request)
   133  	if err != nil {
   134  		return nil, newError("failed to dial to ", dest).Base(err).AtWarning()
   135  	}
   136  	if response.StatusCode != 200 {
   137  		return nil, newError("unexpected status", response.StatusCode).AtWarning()
   138  	}
   139  
   140  	bwriter := buf.NewBufferedWriter(pwriter)
   141  	common.Must(bwriter.SetBuffered(false))
   142  	return cnc.NewConnection(
   143  		cnc.ConnectionOutput(response.Body),
   144  		cnc.ConnectionInput(bwriter),
   145  		cnc.ConnectionOnClose(common.ChainedClosable{breader, bwriter, response.Body}),
   146  	), nil
   147  }
   148  
   149  func init() {
   150  	common.Must(internet.RegisterTransportDialer(protocolName, Dial))
   151  }