github.com/astaguna/popon-core@v0.0.0-20231019235610-96e42d76a5ff/psiphon/net.go (about) 1 /* 2 * Copyright (c) 2015, Psiphon Inc. 3 * All rights reserved. 4 * 5 * This program is free software: you can redistribute it and/or modify 6 * it under the terms of the GNU General Public License as published by 7 * the Free Software Foundation, either version 3 of the License, or 8 * (at your option) any later version. 9 * 10 * This program is distributed in the hope that it will be useful, 11 * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 * GNU General Public License for more details. 14 * 15 * You should have received a copy of the GNU General Public License 16 * along with this program. If not, see <http://www.gnu.org/licenses/>. 17 * 18 */ 19 20 package psiphon 21 22 import ( 23 "context" 24 "crypto/tls" 25 "crypto/x509" 26 std_errors "errors" 27 "fmt" 28 "io" 29 "io/ioutil" 30 "net" 31 "net/http" 32 "os" 33 "strings" 34 "sync" 35 "sync/atomic" 36 "time" 37 38 "github.com/astaguna/popon-core/psiphon/common" 39 "github.com/astaguna/popon-core/psiphon/common/errors" 40 "github.com/astaguna/popon-core/psiphon/common/fragmentor" 41 "github.com/astaguna/popon-core/psiphon/common/parameters" 42 "github.com/astaguna/popon-core/psiphon/common/prng" 43 "github.com/astaguna/popon-core/psiphon/common/protocol" 44 "github.com/astaguna/popon-core/psiphon/common/resolver" 45 "golang.org/x/net/bpf" 46 ) 47 48 // DialConfig contains parameters to determine the behavior 49 // of a Psiphon dialer (TCPDial, UDPDial, MeekDial, etc.) 50 type DialConfig struct { 51 52 // DiagnosticID is the server ID to record in any diagnostics notices. 53 DiagnosticID string 54 55 // UpstreamProxyURL specifies a proxy to connect through. 56 // E.g., "http://proxyhost:8080" 57 // "socks5://user:password@proxyhost:1080" 58 // "socks4a://proxyhost:1080" 59 // "http://NTDOMAIN\NTUser:password@proxyhost:3375" 60 // 61 // Certain tunnel protocols require HTTP CONNECT support 62 // when a HTTP proxy is specified. If CONNECT is not 63 // supported, those protocols will not connect. 64 // 65 // UpstreamProxyURL is not used by UDPDial. 66 UpstreamProxyURL string 67 68 // CustomHeaders is a set of additional arbitrary HTTP headers that are 69 // added to all plaintext HTTP requests and requests made through an HTTP 70 // upstream proxy when specified by UpstreamProxyURL. 71 CustomHeaders http.Header 72 73 // BPFProgramInstructions specifies a BPF program to attach to the dial 74 // socket before connecting. 75 BPFProgramInstructions []bpf.RawInstruction 76 77 // DeviceBinder, when not nil, is applied when dialing UDP/TCP. See: 78 // DeviceBinder doc. 79 DeviceBinder DeviceBinder 80 81 // IPv6Synthesizer, when not nil, is applied when dialing UDP/TCP. See: 82 // IPv6Synthesizer doc. 83 IPv6Synthesizer IPv6Synthesizer 84 85 // ResolveIP is used to resolve destination domains. ResolveIP should 86 // return either at least one IP address or an error. 87 ResolveIP func(context.Context, string) ([]net.IP, error) 88 89 // ResolvedIPCallback, when set, is called with the IP address that was 90 // dialed. This is either the specified IP address in the dial address, 91 // or the resolved IP address in the case where the dial address is a 92 // domain name. 93 // The callback may be invoked by a concurrent goroutine. 94 ResolvedIPCallback func(string) 95 96 // TrustedCACertificatesFilename specifies a file containing trusted 97 // CA certs. See Config.TrustedCACertificatesFilename. 98 TrustedCACertificatesFilename string 99 100 // FragmentorConfig specifies whether to layer a fragmentor.Conn on top 101 // of dialed TCP conns, and the fragmentation configuration to use. 102 FragmentorConfig *fragmentor.Config 103 104 // UpstreamProxyErrorCallback is called when a dial fails due to an upstream 105 // proxy error. As the upstream proxy is user configured, the error message 106 // may need to be relayed to the user. 107 UpstreamProxyErrorCallback func(error) 108 109 // CustomDialer overrides the dialer created by NewNetDialer/NewTCPDialer. 110 // When CustomDialer is set, all other DialConfig parameters are ignored by 111 // NewNetDialer/NewTCPDialer. Other DialConfig consumers may still reference 112 // other DialConfig parameters; for example MeekConfig still uses 113 // TrustedCACertificatesFilename. 114 CustomDialer common.Dialer 115 } 116 117 // WithoutFragmentor returns a copy of the DialConfig with any fragmentor 118 // configuration disabled. The return value is not a deep copy and may be the 119 // input DialConfig; it should not be modified. 120 func (config *DialConfig) WithoutFragmentor() *DialConfig { 121 if config.FragmentorConfig == nil { 122 return config 123 } 124 newConfig := new(DialConfig) 125 *newConfig = *config 126 newConfig.FragmentorConfig = nil 127 return newConfig 128 } 129 130 // NetworkConnectivityChecker defines the interface to the external 131 // HasNetworkConnectivity provider, which call into the host application to 132 // check for network connectivity. 133 type NetworkConnectivityChecker interface { 134 // TODO: change to bool return value once gobind supports that type 135 HasNetworkConnectivity() int 136 } 137 138 // DeviceBinder defines the interface to the external BindToDevice provider 139 // which calls into the host application to bind sockets to specific devices. 140 // This is used for VPN routing exclusion. 141 // The string return value should report device information for diagnostics. 142 type DeviceBinder interface { 143 BindToDevice(fileDescriptor int) (string, error) 144 } 145 146 // DNSServerGetter defines the interface to the external GetDNSServers provider 147 // which calls into the host application to discover the native network DNS 148 // server settings. 149 type DNSServerGetter interface { 150 GetDNSServers() []string 151 } 152 153 // IPv6Synthesizer defines the interface to the external IPv6Synthesize 154 // provider which calls into the host application to synthesize IPv6 addresses 155 // from IPv4 ones. This is used to correctly lookup IPs on DNS64/NAT64 156 // networks. 157 type IPv6Synthesizer interface { 158 IPv6Synthesize(IPv4Addr string) string 159 } 160 161 // HasIPv6RouteGetter defines the interface to the external HasIPv6Route 162 // provider which calls into the host application to determine if the host 163 // has an IPv6 route. 164 type HasIPv6RouteGetter interface { 165 // TODO: change to bool return value once gobind supports that type 166 HasIPv6Route() int 167 } 168 169 // NetworkIDGetter defines the interface to the external GetNetworkID 170 // provider, which returns an identifier for the host's current active 171 // network. 172 // 173 // The identifier is a string that should indicate the network type and 174 // identity; for example "WIFI-<BSSID>" or "MOBILE-<MCC/MNC>". As this network 175 // ID is personally identifying, it is only used locally in the client to 176 // determine network context and is not sent to the Psiphon server. The 177 // identifer will be logged in diagnostics messages; in this case only the 178 // substring before the first "-" is logged, so all PII must appear after the 179 // first "-". 180 // 181 // NetworkIDGetter.GetNetworkID should always return an identifier value, as 182 // logic that uses GetNetworkID, including tactics, is intended to proceed 183 // regardless of whether an accurate network identifier can be obtained. By 184 // convention, the provider should return "UNKNOWN" when an accurate network 185 // identifier cannot be obtained. Best-effort is acceptable: e.g., return just 186 // "WIFI" when only the type of the network but no details can be determined. 187 type NetworkIDGetter interface { 188 GetNetworkID() string 189 } 190 191 // NetDialer implements an interface that matches net.Dialer. 192 // Limitation: only "tcp" Dials are supported. 193 type NetDialer struct { 194 dialTCP common.Dialer 195 } 196 197 // NewNetDialer creates a new NetDialer. 198 func NewNetDialer(config *DialConfig) *NetDialer { 199 return &NetDialer{ 200 dialTCP: NewTCPDialer(config), 201 } 202 } 203 204 func (d *NetDialer) Dial(network, address string) (net.Conn, error) { 205 conn, err := d.DialContext(context.Background(), network, address) 206 if err != nil { 207 return nil, errors.Trace(err) 208 } 209 return conn, nil 210 } 211 212 func (d *NetDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { 213 switch network { 214 case "tcp": 215 conn, err := d.dialTCP(ctx, "tcp", address) 216 if err != nil { 217 return nil, errors.Trace(err) 218 } 219 return conn, nil 220 default: 221 return nil, errors.Tracef("unsupported network: %s", network) 222 } 223 } 224 225 // LocalProxyRelay sends to remoteConn bytes received from localConn, 226 // and sends to localConn bytes received from remoteConn. 227 // 228 // LocalProxyRelay must close localConn in order to interrupt blocking 229 // I/O calls when the upstream port forward is closed. remoteConn is 230 // also closed before returning. 231 func LocalProxyRelay(config *Config, proxyType string, localConn, remoteConn net.Conn) { 232 233 closing := int32(0) 234 235 copyWaitGroup := new(sync.WaitGroup) 236 copyWaitGroup.Add(1) 237 238 go func() { 239 defer copyWaitGroup.Done() 240 241 _, err := RelayCopyBuffer(config, localConn, remoteConn) 242 if err != nil && atomic.LoadInt32(&closing) != 1 { 243 NoticeLocalProxyError(proxyType, errors.TraceMsg(err, "Relay failed")) 244 } 245 246 // When the server closes a port forward, ex. due to idle timeout, 247 // remoteConn.Read will return EOF, which causes the downstream io.Copy to 248 // return (with a nil error). To ensure the downstream local proxy 249 // connection also closes at this point, we interrupt the blocking upstream 250 // io.Copy by closing localConn. 251 252 atomic.StoreInt32(&closing, 1) 253 localConn.Close() 254 }() 255 256 _, err := RelayCopyBuffer(config, remoteConn, localConn) 257 if err != nil && atomic.LoadInt32(&closing) != 1 { 258 NoticeLocalProxyError(proxyType, errors.TraceMsg(err, "Relay failed")) 259 } 260 261 // When a local proxy peer connection closes, localConn.Read will return EOF. 262 // As above, close the other end of the relay to ensure immediate shutdown, 263 // as no more data can be relayed. 264 265 atomic.StoreInt32(&closing, 1) 266 remoteConn.Close() 267 268 copyWaitGroup.Wait() 269 } 270 271 // RelayCopyBuffer performs an io.Copy, optionally using a smaller buffer when 272 // config.LimitRelayBufferSizes is set. 273 func RelayCopyBuffer(config *Config, dst io.Writer, src io.Reader) (int64, error) { 274 275 // By default, io.CopyBuffer will allocate a 32K buffer when a nil buffer 276 // is passed in. When configured, make and specify a smaller buffer. But 277 // only if src doesn't implement WriterTo and dst doesn't implement 278 // ReaderFrom, as in those cases io.CopyBuffer entirely avoids a buffer 279 // allocation. 280 281 var buffer []byte 282 if config.LimitRelayBufferSizes { 283 _, isWT := src.(io.WriterTo) 284 _, isRF := dst.(io.ReaderFrom) 285 if !isWT && !isRF { 286 buffer = make([]byte, 4096) 287 } 288 } 289 290 // Do not wrap any I/O errors 291 return io.CopyBuffer(dst, src, buffer) 292 } 293 294 // WaitForNetworkConnectivity uses a NetworkConnectivityChecker to 295 // periodically check for network connectivity. It returns true if 296 // no NetworkConnectivityChecker is provided (waiting is disabled) 297 // or when NetworkConnectivityChecker.HasNetworkConnectivity() 298 // indicates connectivity. It waits and polls the checker once a second. 299 // When the context is done, false is returned immediately. 300 func WaitForNetworkConnectivity( 301 ctx context.Context, connectivityChecker NetworkConnectivityChecker) bool { 302 303 if connectivityChecker == nil || connectivityChecker.HasNetworkConnectivity() == 1 { 304 return true 305 } 306 307 NoticeInfo("waiting for network connectivity") 308 309 ticker := time.NewTicker(1 * time.Second) 310 defer ticker.Stop() 311 312 for { 313 if connectivityChecker.HasNetworkConnectivity() == 1 { 314 return true 315 } 316 317 select { 318 case <-ticker.C: 319 // Check HasNetworkConnectivity again 320 case <-ctx.Done(): 321 return false 322 } 323 } 324 } 325 326 // New Resolver creates a new resolver using the specified config. 327 // useBindToDevice indicates whether to apply config.BindToDevice, when it 328 // exists; set useBindToDevice to false when the resolve doesn't need to be 329 // excluded from any VPN routing. 330 func NewResolver(config *Config, useBindToDevice bool) *resolver.Resolver { 331 332 p := config.GetParameters().Get() 333 334 networkConfig := &resolver.NetworkConfig{ 335 LogWarning: func(err error) { NoticeWarning("ResolveIP: %v", err) }, 336 LogHostnames: config.EmitDiagnosticNetworkParameters, 337 CacheExtensionInitialTTL: p.Duration(parameters.DNSResolverCacheExtensionInitialTTL), 338 CacheExtensionVerifiedTTL: p.Duration(parameters.DNSResolverCacheExtensionVerifiedTTL), 339 } 340 341 if config.DNSServerGetter != nil { 342 networkConfig.GetDNSServers = config.DNSServerGetter.GetDNSServers 343 } 344 345 if useBindToDevice && config.DeviceBinder != nil { 346 networkConfig.BindToDevice = config.DeviceBinder.BindToDevice 347 networkConfig.AllowDefaultResolverWithBindToDevice = 348 config.AllowDefaultDNSResolverWithBindToDevice 349 } 350 351 if config.IPv6Synthesizer != nil { 352 networkConfig.IPv6Synthesize = config.IPv6Synthesizer.IPv6Synthesize 353 } 354 355 if config.HasIPv6RouteGetter != nil { 356 networkConfig.HasIPv6Route = func() bool { 357 return config.HasIPv6RouteGetter.HasIPv6Route() == 1 358 } 359 } 360 361 return resolver.NewResolver(networkConfig, config.GetNetworkID()) 362 } 363 364 // UntunneledResolveIP is used to resolve domains for untunneled dials, 365 // including remote server list and upgrade downloads. 366 func UntunneledResolveIP( 367 ctx context.Context, 368 config *Config, 369 resolver *resolver.Resolver, 370 hostname, 371 frontingProviderID string) ([]net.IP, error) { 372 373 // Limitations: for untunneled resolves, there is currently no resolve 374 // parameter replay, and no support for pre-resolved IPs. 375 376 params, err := resolver.MakeResolveParameters( 377 config.GetParameters().Get(), frontingProviderID) 378 if err != nil { 379 return nil, errors.Trace(err) 380 } 381 382 IPs, err := resolver.ResolveIP( 383 ctx, 384 config.GetNetworkID(), 385 params, 386 hostname) 387 if err != nil { 388 return nil, errors.Trace(err) 389 } 390 391 return IPs, nil 392 } 393 394 // makeUntunneledFrontedHTTPClient returns a net/http.Client which is 395 // configured to use domain fronting and custom dialing features -- including 396 // BindToDevice, etc. One or more fronting specs must be provided, i.e. 397 // len(frontingSpecs) must be greater than 0. A function is returned which, 398 // if non-nil, can be called after each request made with the net/http.Client 399 // completes to retrieve the set of API parameter values applied to the request. 400 // 401 // The context is applied to underlying TCP dials. The caller is responsible 402 // for applying the context to requests made with the returned http.Client. 403 func makeUntunneledFrontedHTTPClient(ctx context.Context, config *Config, untunneledDialConfig *DialConfig, frontingSpecs parameters.FrontingSpecs, skipVerify, disableSystemRootCAs bool) (*http.Client, func() common.APIParameters, error) { 404 405 frontingProviderID, meekFrontingDialAddress, meekSNIServerName, meekVerifyServerName, meekVerifyPins, meekFrontingHost, err := parameters.FrontingSpecs(frontingSpecs).SelectParameters() 406 if err != nil { 407 return nil, nil, errors.Trace(err) 408 } 409 410 meekDialAddress := net.JoinHostPort(meekFrontingDialAddress, "443") 411 meekHostHeader := meekFrontingHost 412 413 p := config.GetParameters().Get() 414 effectiveTunnelProtocol := protocol.TUNNEL_PROTOCOL_FRONTED_MEEK 415 416 requireTLS12SessionTickets := protocol.TunnelProtocolRequiresTLS12SessionTickets( 417 effectiveTunnelProtocol) 418 requireTLS13Support := protocol.TunnelProtocolRequiresTLS13Support(effectiveTunnelProtocol) 419 isFronted := true 420 421 tlsProfile, tlsVersion, randomizedTLSProfileSeed, err := SelectTLSProfile( 422 requireTLS12SessionTickets, requireTLS13Support, isFronted, frontingProviderID, p) 423 if err != nil { 424 return nil, nil, errors.Trace(err) 425 } 426 427 if tlsProfile == "" && (requireTLS12SessionTickets || requireTLS13Support) { 428 return nil, nil, errors.TraceNew("required TLS profile not found") 429 } 430 431 noDefaultTLSSessionID := p.WeightedCoinFlip( 432 parameters.NoDefaultTLSSessionIDProbability) 433 434 // For a FrontingSpec, an SNI value of "" indicates to disable/omit SNI, so 435 // never transform in that case. 436 var meekTransformedHostName bool 437 if meekSNIServerName != "" { 438 if p.WeightedCoinFlip(parameters.TransformHostNameProbability) { 439 meekSNIServerName = selectHostName(effectiveTunnelProtocol, p) 440 meekTransformedHostName = true 441 } 442 } 443 444 addPsiphonFrontingHeader := false 445 if frontingProviderID != "" { 446 addPsiphonFrontingHeader = common.Contains( 447 p.LabeledTunnelProtocols( 448 parameters.AddFrontingProviderPsiphonFrontingHeader, frontingProviderID), 449 effectiveTunnelProtocol) 450 } 451 452 networkLatencyMultiplierMin := p.Float(parameters.NetworkLatencyMultiplierMin) 453 networkLatencyMultiplierMax := p.Float(parameters.NetworkLatencyMultiplierMax) 454 455 networkLatencyMultiplier := prng.ExpFloat64Range( 456 networkLatencyMultiplierMin, 457 networkLatencyMultiplierMax, 458 p.Float(parameters.NetworkLatencyMultiplierLambda)) 459 460 meekConfig := &MeekConfig{ 461 DiagnosticID: frontingProviderID, 462 Parameters: config.GetParameters(), 463 Mode: MeekModePlaintextRoundTrip, 464 DialAddress: meekDialAddress, 465 UseHTTPS: true, 466 TLSProfile: tlsProfile, 467 NoDefaultTLSSessionID: noDefaultTLSSessionID, 468 RandomizedTLSProfileSeed: randomizedTLSProfileSeed, 469 SNIServerName: meekSNIServerName, 470 AddPsiphonFrontingHeader: addPsiphonFrontingHeader, 471 HostHeader: meekHostHeader, 472 TransformedHostName: meekTransformedHostName, 473 ClientTunnelProtocol: effectiveTunnelProtocol, 474 NetworkLatencyMultiplier: networkLatencyMultiplier, 475 } 476 477 if !skipVerify { 478 meekConfig.VerifyServerName = meekVerifyServerName 479 meekConfig.VerifyPins = meekVerifyPins 480 meekConfig.DisableSystemRootCAs = disableSystemRootCAs 481 } 482 483 var resolvedIPAddress atomic.Value 484 resolvedIPAddress.Store("") 485 486 // The default untunneled dial config does not support pre-resolved IPs so 487 // redefine the dial config to override ResolveIP with an implementation 488 // that enables their use by passing the fronting provider ID into 489 // UntunneledResolveIP. 490 meekDialConfig := &DialConfig{ 491 UpstreamProxyURL: untunneledDialConfig.UpstreamProxyURL, 492 CustomHeaders: untunneledDialConfig.CustomHeaders, 493 DeviceBinder: untunneledDialConfig.DeviceBinder, 494 IPv6Synthesizer: untunneledDialConfig.IPv6Synthesizer, 495 ResolveIP: func(ctx context.Context, hostname string) ([]net.IP, error) { 496 IPs, err := UntunneledResolveIP( 497 ctx, config, config.GetResolver(), hostname, frontingProviderID) 498 if err != nil { 499 return nil, errors.Trace(err) 500 } 501 return IPs, nil 502 }, 503 ResolvedIPCallback: func(IPAddress string) { 504 resolvedIPAddress.Store(IPAddress) 505 }, 506 } 507 508 selectedUserAgent, userAgent := selectUserAgentIfUnset(p, meekDialConfig.CustomHeaders) 509 if selectedUserAgent { 510 if meekDialConfig.CustomHeaders == nil { 511 meekDialConfig.CustomHeaders = make(http.Header) 512 } 513 meekDialConfig.CustomHeaders.Set("User-Agent", userAgent) 514 } 515 516 // Use MeekConn to domain front requests. 517 // 518 // DialMeek will create a TLS connection immediately. We will delay 519 // initializing the MeekConn-based RoundTripper until we know it's needed. 520 // This is implemented by passing in a RoundTripper that establishes a 521 // MeekConn when RoundTrip is called. 522 // 523 // Resources are cleaned up when the response body is closed. 524 roundTrip := func(request *http.Request) (*http.Response, error) { 525 526 conn, err := DialMeek( 527 ctx, meekConfig, meekDialConfig) 528 if err != nil { 529 return nil, errors.Trace(err) 530 } 531 532 response, err := conn.RoundTrip(request) 533 if err != nil { 534 return nil, errors.Trace(err) 535 } 536 537 // Do not read the response body into memory all at once because it may 538 // be large. Instead allow the caller to stream the response. 539 response.Body = newMeekHTTPResponseReadCloser(conn, response.Body) 540 541 return response, nil 542 } 543 544 params := func() common.APIParameters { 545 params := make(common.APIParameters) 546 547 params["fronting_provider_id"] = frontingProviderID 548 549 if meekConfig.DialAddress != "" { 550 params["meek_dial_address"] = meekConfig.DialAddress 551 } 552 553 meekResolvedIPAddress := resolvedIPAddress.Load() 554 if meekResolvedIPAddress != "" { 555 params["meek_resolved_ip_address"] = meekResolvedIPAddress 556 } 557 558 if meekConfig.SNIServerName != "" { 559 params["meek_sni_server_name"] = meekConfig.SNIServerName 560 } 561 562 if meekConfig.HostHeader != "" { 563 params["meek_host_header"] = meekConfig.HostHeader 564 } 565 566 transformedHostName := "0" 567 if meekTransformedHostName { 568 transformedHostName = "1" 569 } 570 params["meek_transformed_host_name"] = transformedHostName 571 572 if meekConfig.TLSProfile != "" { 573 params["tls_profile"] = meekConfig.TLSProfile 574 } 575 576 if selectedUserAgent { 577 params["user_agent"] = userAgent 578 } 579 580 if tlsVersion != "" { 581 params["tls_version"] = getTLSVersionForMetrics(tlsVersion, meekConfig.NoDefaultTLSSessionID) 582 } 583 584 return params 585 } 586 587 return &http.Client{ 588 Transport: common.NewHTTPRoundTripper(roundTrip), 589 }, params, nil 590 } 591 592 // meekHTTPResponseReadCloser wraps an http.Response.Body received over a 593 // MeekConn in MeekModePlaintextRoundTrip and exposes an io.ReadCloser. Close 594 // closes the meek conn and response body. 595 type meekHTTPResponseReadCloser struct { 596 conn *MeekConn 597 responseBody io.ReadCloser 598 } 599 600 // newMeekHTTPResponseReadCloser creates a meekHTTPResponseReadCloser. 601 func newMeekHTTPResponseReadCloser(meekConn *MeekConn, responseBody io.ReadCloser) *meekHTTPResponseReadCloser { 602 return &meekHTTPResponseReadCloser{ 603 conn: meekConn, 604 responseBody: responseBody, 605 } 606 } 607 608 // Read implements the io.Reader interface. 609 func (meek *meekHTTPResponseReadCloser) Read(p []byte) (n int, err error) { 610 return meek.responseBody.Read(p) 611 } 612 613 // Read implements the io.Closer interface. 614 func (meek *meekHTTPResponseReadCloser) Close() error { 615 err := meek.responseBody.Close() 616 _ = meek.conn.Close() 617 return err 618 } 619 620 // MakeUntunneledHTTPClient returns a net/http.Client which is configured to 621 // use custom dialing features -- including BindToDevice, etc. A function is 622 // returned which, if non-nil, can be called after each request made with the 623 // net/http.Client completes to retrieve the set of API parameter values 624 // applied to the request. 625 // 626 // The context is applied to underlying TCP dials. The caller is responsible 627 // for applying the context to requests made with the returned http.Client. 628 func MakeUntunneledHTTPClient( 629 ctx context.Context, 630 config *Config, 631 untunneledDialConfig *DialConfig, 632 skipVerify bool, 633 disableSystemRootCAs bool, 634 frontingSpecs parameters.FrontingSpecs) (*http.Client, func() common.APIParameters, error) { 635 636 if len(frontingSpecs) > 0 { 637 638 // Ignore skipVerify because it only applies when there are no 639 // fronting specs. 640 httpClient, getParams, err := makeUntunneledFrontedHTTPClient(ctx, config, untunneledDialConfig, frontingSpecs, false, disableSystemRootCAs) 641 if err != nil { 642 return nil, nil, errors.Trace(err) 643 } 644 return httpClient, getParams, nil 645 } 646 647 dialer := NewTCPDialer(untunneledDialConfig) 648 649 tlsConfig := &CustomTLSConfig{ 650 Parameters: config.GetParameters(), 651 Dial: dialer, 652 UseDialAddrSNI: true, 653 SNIServerName: "", 654 SkipVerify: skipVerify, 655 DisableSystemRootCAs: disableSystemRootCAs, 656 TrustedCACertificatesFilename: untunneledDialConfig.TrustedCACertificatesFilename, 657 } 658 tlsConfig.EnableClientSessionCache() 659 660 tlsDialer := NewCustomTLSDialer(tlsConfig) 661 662 transport := &http.Transport{ 663 Dial: func(network, addr string) (net.Conn, error) { 664 return dialer(ctx, network, addr) 665 }, 666 DialTLS: func(network, addr string) (net.Conn, error) { 667 return tlsDialer(ctx, network, addr) 668 }, 669 } 670 671 httpClient := &http.Client{ 672 Transport: transport, 673 } 674 675 return httpClient, nil, nil 676 } 677 678 // MakeTunneledHTTPClient returns a net/http.Client which is 679 // configured to use custom dialing features including tunneled 680 // dialing and, optionally, UseTrustedCACertificatesForStockTLS. 681 // This http.Client uses stock TLS for HTTPS. 682 func MakeTunneledHTTPClient( 683 config *Config, 684 tunnel *Tunnel, 685 skipVerify bool) (*http.Client, error) { 686 687 // Note: there is no dial context since SSH port forward dials cannot 688 // be interrupted directly. Closing the tunnel will interrupt the dials. 689 690 tunneledDialer := func(_, addr string) (net.Conn, error) { 691 // Set alwaysTunneled to ensure the http.Client traffic is always tunneled, 692 // even when split tunnel mode is enabled. 693 conn, _, err := tunnel.DialTCPChannel(addr, true, nil) 694 return conn, errors.Trace(err) 695 } 696 697 transport := &http.Transport{ 698 Dial: tunneledDialer, 699 } 700 701 if skipVerify { 702 703 transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} 704 705 } else if config.TrustedCACertificatesFilename != "" { 706 707 rootCAs := x509.NewCertPool() 708 certData, err := ioutil.ReadFile(config.TrustedCACertificatesFilename) 709 if err != nil { 710 return nil, errors.Trace(err) 711 } 712 rootCAs.AppendCertsFromPEM(certData) 713 transport.TLSClientConfig = &tls.Config{RootCAs: rootCAs} 714 } 715 716 return &http.Client{ 717 Transport: transport, 718 }, nil 719 } 720 721 // MakeDownloadHTTPClient is a helper that sets up a http.Client for use either 722 // untunneled or through a tunnel. True is returned if the http.Client is setup 723 // for use through a tunnel; otherwise it is setup for untunneled use. A 724 // function is returned which, if non-nil, can be called after each request 725 // made with the http.Client completes to retrieve the set of API 726 // parameter values applied to the request. 727 func MakeDownloadHTTPClient( 728 ctx context.Context, 729 config *Config, 730 tunnel *Tunnel, 731 untunneledDialConfig *DialConfig, 732 skipVerify, 733 disableSystemRootCAs bool, 734 frontingSpecs parameters.FrontingSpecs) (*http.Client, bool, func() common.APIParameters, error) { 735 736 var httpClient *http.Client 737 var getParams func() common.APIParameters 738 var err error 739 740 tunneled := tunnel != nil 741 742 if tunneled { 743 744 httpClient, err = MakeTunneledHTTPClient( 745 config, tunnel, skipVerify || disableSystemRootCAs) 746 if err != nil { 747 return nil, false, nil, errors.Trace(err) 748 } 749 750 } else { 751 httpClient, getParams, err = MakeUntunneledHTTPClient( 752 ctx, config, untunneledDialConfig, skipVerify, disableSystemRootCAs, frontingSpecs) 753 if err != nil { 754 return nil, false, nil, errors.Trace(err) 755 } 756 } 757 758 return httpClient, tunneled, getParams, nil 759 } 760 761 // ResumeDownload is a reusable helper that downloads requestUrl via the 762 // httpClient, storing the result in downloadFilename when the download is 763 // complete. Intermediate, partial downloads state is stored in 764 // downloadFilename.part and downloadFilename.part.etag. 765 // Any existing downloadFilename file will be overwritten. 766 // 767 // In the case where the remote object has changed while a partial download 768 // is to be resumed, the partial state is reset and resumeDownload fails. 769 // The caller must restart the download. 770 // 771 // When ifNoneMatchETag is specified, no download is made if the remote 772 // object has the same ETag. ifNoneMatchETag has an effect only when no 773 // partial download is in progress. 774 func ResumeDownload( 775 ctx context.Context, 776 httpClient *http.Client, 777 downloadURL string, 778 userAgent string, 779 downloadFilename string, 780 ifNoneMatchETag string) (int64, string, error) { 781 782 partialFilename := fmt.Sprintf("%s.part", downloadFilename) 783 784 partialETagFilename := fmt.Sprintf("%s.part.etag", downloadFilename) 785 786 file, err := os.OpenFile(partialFilename, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600) 787 if err != nil { 788 return 0, "", errors.Trace(err) 789 } 790 defer file.Close() 791 792 fileInfo, err := file.Stat() 793 if err != nil { 794 return 0, "", errors.Trace(err) 795 } 796 797 // A partial download should have an ETag which is to be sent with the 798 // Range request to ensure that the source object is the same as the 799 // one that is partially downloaded. 800 var partialETag []byte 801 if fileInfo.Size() > 0 { 802 803 partialETag, err = ioutil.ReadFile(partialETagFilename) 804 805 // When the ETag can't be loaded, delete the partial download. To keep the 806 // code simple, there is no immediate, inline retry here, on the assumption 807 // that the controller's upgradeDownloader will shortly call DownloadUpgrade 808 // again. 809 if err != nil { 810 811 // On Windows, file must be closed before it can be deleted 812 file.Close() 813 814 tempErr := os.Remove(partialFilename) 815 if tempErr != nil && !os.IsNotExist(tempErr) { 816 NoticeWarning("reset partial download failed: %s", tempErr) 817 } 818 819 tempErr = os.Remove(partialETagFilename) 820 if tempErr != nil && !os.IsNotExist(tempErr) { 821 NoticeWarning("reset partial download ETag failed: %s", tempErr) 822 } 823 824 return 0, "", errors.Tracef( 825 "failed to load partial download ETag: %s", err) 826 } 827 } 828 829 request, err := http.NewRequest("GET", downloadURL, nil) 830 if err != nil { 831 return 0, "", errors.Trace(err) 832 } 833 834 request = request.WithContext(ctx) 835 836 request.Header.Set("User-Agent", userAgent) 837 838 request.Header.Add("Range", fmt.Sprintf("bytes=%d-", fileInfo.Size())) 839 840 if partialETag != nil { 841 842 // Note: not using If-Range, since not all host servers support it. 843 // Using If-Match means we need to check for status code 412 and reset 844 // when the ETag has changed since the last partial download. 845 request.Header.Add("If-Match", string(partialETag)) 846 847 } else if ifNoneMatchETag != "" { 848 849 // Can't specify both If-Match and If-None-Match. Behavior is undefined. 850 // https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.26 851 // So for downloaders that store an ETag and wish to use that to prevent 852 // redundant downloads, that ETag is sent as If-None-Match in the case 853 // where a partial download is not in progress. When a partial download 854 // is in progress, the partial ETag is sent as If-Match: either that's 855 // a version that was never fully received, or it's no longer current in 856 // which case the response will be StatusPreconditionFailed, the partial 857 // download will be discarded, and then the next retry will use 858 // If-None-Match. 859 860 // Note: in this case, fileInfo.Size() == 0 861 862 request.Header.Add("If-None-Match", ifNoneMatchETag) 863 } 864 865 response, err := httpClient.Do(request) 866 867 // The resumeable download may ask for bytes past the resource range 868 // since it doesn't store the "completed download" state. In this case, 869 // the HTTP server returns 416. Otherwise, we expect 206. We may also 870 // receive 412 on ETag mismatch. 871 if err == nil && 872 (response.StatusCode != http.StatusPartialContent && 873 874 // Certain http servers return 200 OK where we expect 206, so accept that. 875 response.StatusCode != http.StatusOK && 876 877 response.StatusCode != http.StatusRequestedRangeNotSatisfiable && 878 response.StatusCode != http.StatusPreconditionFailed && 879 response.StatusCode != http.StatusNotModified) { 880 response.Body.Close() 881 err = fmt.Errorf("unexpected response status code: %d", response.StatusCode) 882 } 883 if err != nil { 884 885 // Redact URL from "net/http" error message. 886 if !GetEmitNetworkParameters() { 887 errStr := err.Error() 888 err = std_errors.New(strings.Replace(errStr, downloadURL, "[redacted]", -1)) 889 } 890 891 return 0, "", errors.Trace(err) 892 } 893 defer response.Body.Close() 894 895 responseETag := response.Header.Get("ETag") 896 897 if response.StatusCode == http.StatusPreconditionFailed { 898 // When the ETag no longer matches, delete the partial download. As above, 899 // simply failing and relying on the caller's retry schedule. 900 os.Remove(partialFilename) 901 os.Remove(partialETagFilename) 902 return 0, "", errors.TraceNew("partial download ETag mismatch") 903 904 } else if response.StatusCode == http.StatusNotModified { 905 // This status code is possible in the "If-None-Match" case. Don't leave 906 // any partial download in progress. Caller should check that responseETag 907 // matches ifNoneMatchETag. 908 os.Remove(partialFilename) 909 os.Remove(partialETagFilename) 910 return 0, responseETag, nil 911 } 912 913 // Not making failure to write ETag file fatal, in case the entire download 914 // succeeds in this one request. 915 ioutil.WriteFile(partialETagFilename, []byte(responseETag), 0600) 916 917 // A partial download occurs when this copy is interrupted. The io.Copy 918 // will fail, leaving a partial download in place (.part and .part.etag). 919 n, err := io.Copy(NewSyncFileWriter(file), response.Body) 920 921 // From this point, n bytes are indicated as downloaded, even if there is 922 // an error; the caller may use this to report partial download progress. 923 924 if err != nil { 925 return n, "", errors.Trace(err) 926 } 927 928 // Ensure the file is flushed to disk. The deferred close 929 // will be a noop when this succeeds. 930 err = file.Close() 931 if err != nil { 932 return n, "", errors.Trace(err) 933 } 934 935 // Remove if exists, to enable rename 936 os.Remove(downloadFilename) 937 938 err = os.Rename(partialFilename, downloadFilename) 939 if err != nil { 940 return n, "", errors.Trace(err) 941 } 942 943 os.Remove(partialETagFilename) 944 945 return n, responseETag, nil 946 }