github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/transport/internet/http/dialer.go (about) 1 package http 2 3 import ( 4 "context" 5 gotls "crypto/tls" 6 "io" 7 "net/http" 8 "net/url" 9 "sync" 10 "time" 11 12 "github.com/xtls/xray-core/common" 13 "github.com/xtls/xray-core/common/buf" 14 "github.com/xtls/xray-core/common/net" 15 "github.com/xtls/xray-core/common/net/cnc" 16 "github.com/xtls/xray-core/common/session" 17 "github.com/xtls/xray-core/transport/internet" 18 "github.com/xtls/xray-core/transport/internet/reality" 19 "github.com/xtls/xray-core/transport/internet/stat" 20 "github.com/xtls/xray-core/transport/internet/tls" 21 "github.com/xtls/xray-core/transport/pipe" 22 "golang.org/x/net/http2" 23 ) 24 25 type dialerConf struct { 26 net.Destination 27 *internet.MemoryStreamConfig 28 } 29 30 var ( 31 globalDialerMap map[dialerConf]*http.Client 32 globalDialerAccess sync.Mutex 33 ) 34 35 func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (*http.Client, error) { 36 globalDialerAccess.Lock() 37 defer globalDialerAccess.Unlock() 38 39 if globalDialerMap == nil { 40 globalDialerMap = make(map[dialerConf]*http.Client) 41 } 42 43 httpSettings := streamSettings.ProtocolSettings.(*Config) 44 tlsConfigs := tls.ConfigFromStreamSettings(streamSettings) 45 realityConfigs := reality.ConfigFromStreamSettings(streamSettings) 46 if tlsConfigs == nil && realityConfigs == nil { 47 return nil, newError("TLS or REALITY must be enabled for http transport.").AtWarning() 48 } 49 sockopt := streamSettings.SocketSettings 50 51 if client, found := globalDialerMap[dialerConf{dest, streamSettings}]; found { 52 return client, nil 53 } 54 55 transport := &http2.Transport{ 56 DialTLSContext: func(hctx context.Context, string, addr string, tlsConfig *gotls.Config) (net.Conn, error) { 57 rawHost, rawPort, err := net.SplitHostPort(addr) 58 if err != nil { 59 return nil, err 60 } 61 if len(rawPort) == 0 { 62 rawPort = "443" 63 } 64 port, err := net.PortFromString(rawPort) 65 if err != nil { 66 return nil, err 67 } 68 address := net.ParseAddress(rawHost) 69 70 hctx = session.ContextWithID(hctx, session.IDFromContext(ctx)) 71 hctx = session.ContextWithOutbounds(hctx, session.OutboundsFromContext(ctx)) 72 hctx = session.ContextWithTimeoutOnly(hctx, true) 73 74 pconn, err := internet.DialSystem(hctx, net.TCPDestination(address, port), sockopt) 75 if err != nil { 76 newError("failed to dial to " + addr).Base(err).AtError().WriteToLog() 77 return nil, err 78 } 79 80 if realityConfigs != nil { 81 return reality.UClient(pconn, realityConfigs, hctx, dest) 82 } 83 84 var cn tls.Interface 85 if fingerprint := tls.GetFingerprint(tlsConfigs.Fingerprint); fingerprint != nil { 86 cn = tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn) 87 } else { 88 cn = tls.Client(pconn, tlsConfig).(*tls.Conn) 89 } 90 if err := cn.HandshakeContext(ctx); err != nil { 91 newError("failed to dial to " + addr).Base(err).AtError().WriteToLog() 92 return nil, err 93 } 94 if !tlsConfig.InsecureSkipVerify { 95 if err := cn.VerifyHostname(tlsConfig.ServerName); err != nil { 96 newError("failed to dial to " + addr).Base(err).AtError().WriteToLog() 97 return nil, err 98 } 99 } 100 negotiatedProtocol := cn.NegotiatedProtocol() 101 if negotiatedProtocol != http2.NextProtoTLS { 102 return nil, newError("http2: unexpected ALPN protocol " + negotiatedProtocol + "; want q" + http2.NextProtoTLS).AtError() 103 } 104 return cn, nil 105 }, 106 } 107 108 if tlsConfigs != nil { 109 transport.TLSClientConfig = tlsConfigs.GetTLSConfig(tls.WithDestination(dest)) 110 } 111 112 if httpSettings.IdleTimeout > 0 || httpSettings.HealthCheckTimeout > 0 { 113 transport.ReadIdleTimeout = time.Second * time.Duration(httpSettings.IdleTimeout) 114 transport.PingTimeout = time.Second * time.Duration(httpSettings.HealthCheckTimeout) 115 } 116 117 client := &http.Client{ 118 Transport: transport, 119 } 120 121 globalDialerMap[dialerConf{dest, streamSettings}] = client 122 return client, nil 123 } 124 125 // Dial dials a new TCP connection to the given destination. 126 func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) { 127 httpSettings := streamSettings.ProtocolSettings.(*Config) 128 client, err := getHTTPClient(ctx, dest, streamSettings) 129 if err != nil { 130 return nil, err 131 } 132 133 opts := pipe.OptionsFromContext(ctx) 134 preader, pwriter := pipe.New(opts...) 135 breader := &buf.BufferedReader{Reader: preader} 136 137 httpMethod := "PUT" 138 if httpSettings.Method != "" { 139 httpMethod = httpSettings.Method 140 } 141 142 httpHeaders := make(http.Header) 143 144 for _, httpHeader := range httpSettings.Header { 145 for _, httpHeaderValue := range httpHeader.Value { 146 httpHeaders.Set(httpHeader.Name, httpHeaderValue) 147 } 148 } 149 150 request := &http.Request{ 151 Method: httpMethod, 152 Host: httpSettings.getRandomHost(), 153 Body: breader, 154 URL: &url.URL{ 155 Scheme: "https", 156 Host: dest.NetAddr(), 157 Path: httpSettings.getNormalizedPath(), 158 }, 159 Proto: "HTTP/2", 160 ProtoMajor: 2, 161 ProtoMinor: 0, 162 Header: httpHeaders, 163 } 164 // Disable any compression method from server. 165 request.Header.Set("Accept-Encoding", "identity") 166 167 wrc := &WaitReadCloser{Wait: make(chan struct{})} 168 go func() { 169 response, err := client.Do(request) 170 if err != nil { 171 newError("failed to dial to ", dest).Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx)) 172 wrc.Close() 173 { 174 // Abandon `client` if `client.Do(request)` failed 175 // See https://github.com/golang/go/issues/30702 176 globalDialerAccess.Lock() 177 if globalDialerMap[dialerConf{dest, streamSettings}] == client { 178 delete(globalDialerMap, dialerConf{dest, streamSettings}) 179 } 180 globalDialerAccess.Unlock() 181 } 182 return 183 } 184 if response.StatusCode != 200 { 185 newError("unexpected status", response.StatusCode).AtWarning().WriteToLog(session.ExportIDToError(ctx)) 186 wrc.Close() 187 return 188 } 189 wrc.Set(response.Body) 190 }() 191 192 bwriter := buf.NewBufferedWriter(pwriter) 193 common.Must(bwriter.SetBuffered(false)) 194 return cnc.NewConnection( 195 cnc.ConnectionOutput(wrc), 196 cnc.ConnectionInput(bwriter), 197 cnc.ConnectionOnClose(common.ChainedClosable{breader, bwriter, wrc}), 198 ), nil 199 } 200 201 func init() { 202 common.Must(internet.RegisterTransportDialer(protocolName, Dial)) 203 } 204 205 type WaitReadCloser struct { 206 Wait chan struct{} 207 io.ReadCloser 208 } 209 210 func (w *WaitReadCloser) Set(rc io.ReadCloser) { 211 w.ReadCloser = rc 212 defer func() { 213 if recover() != nil { 214 rc.Close() 215 } 216 }() 217 close(w.Wait) 218 } 219 220 func (w *WaitReadCloser) Read(b []byte) (int, error) { 221 if w.ReadCloser == nil { 222 if <-w.Wait; w.ReadCloser == nil { 223 return 0, io.ErrClosedPipe 224 } 225 } 226 return w.ReadCloser.Read(b) 227 } 228 229 func (w *WaitReadCloser) Close() error { 230 if w.ReadCloser != nil { 231 return w.ReadCloser.Close() 232 } 233 defer func() { 234 if recover() != nil && w.ReadCloser != nil { 235 w.ReadCloser.Close() 236 } 237 }() 238 close(w.Wait) 239 return nil 240 }