github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/file/s3file/s3transport/transport.go (about) 1 package s3transport 2 3 import ( 4 "crypto/tls" 5 "fmt" 6 "io" 7 "math/rand" 8 "net" 9 "net/http" 10 "sort" 11 "sync" 12 "time" 13 14 "github.com/Schaudge/grailbase/file/s3file/internal/autolog" 15 "github.com/Schaudge/grailbase/log" 16 ) 17 18 // T is an http.RoundTripper specialized for S3. See https://github.com/aws/aws-sdk-go/issues/3739. 19 type T struct { 20 factory func() *http.Transport 21 22 hostRTsMu sync.Mutex 23 hostRTs map[string]http.RoundTripper 24 25 nOpenConnsPerIPMu sync.Mutex 26 nOpenConnsPerIP map[string]int 27 28 hostIPs *expiringMap 29 } 30 31 var ( 32 stdDefaultTransport = http.DefaultTransport.(*http.Transport) 33 httpTransport = &http.Transport{ 34 DialContext: (&net.Dialer{ 35 Timeout: 30 * time.Second, // Copied from http.DefaultTransport. 36 KeepAlive: 30 * time.Second, // Copied from same. 37 }).DialContext, 38 ForceAttemptHTTP2: false, // S3 doesn't support HTTP2. 39 MaxIdleConns: 200, // Keep many peers for future bursts. 40 MaxIdleConnsPerHost: 4, // But limit connections to each. 41 IdleConnTimeout: expireAfter + 2*expireLoopEvery, // Keep until we forget the peer. 42 TLSClientConfig: &tls.Config{}, 43 TLSHandshakeTimeout: stdDefaultTransport.TLSHandshakeTimeout, 44 ExpectContinueTimeout: stdDefaultTransport.ExpectContinueTimeout, 45 } 46 47 defaultOnce sync.Once 48 defaultT *T 49 defaultClient *http.Client 50 ) 51 52 func defaults() (*T, *http.Client) { 53 defaultOnce.Do(func() { 54 defaultT = New(httpTransport.Clone) 55 defaultClient = &http.Client{Transport: defaultT} 56 }) 57 return defaultT, defaultClient 58 } 59 60 // Default returns an http.RoundTripper with recommended settings. 61 func Default() *T { t, _ := defaults(); return t } 62 63 // DefaultClient returns an *http.Client that uses the http.RoundTripper 64 // returned by Default (suitable for general use, analogous to 65 // "net/http".DefaultClient). 66 func DefaultClient() *http.Client { _, c := defaults(); return c } 67 68 // New constructs *T using factory to create internal transports. Each call to factory() 69 // must return a separate http.Transport and they must not share TLSClientConfig. 70 func New(factory func() *http.Transport) *T { 71 t := T{ 72 factory: factory, 73 hostRTs: map[string]http.RoundTripper{}, 74 hostIPs: newExpiringMap(runPeriodicForever(), time.Now), 75 nOpenConnsPerIP: map[string]int{}, 76 } 77 autolog.Register(func() { 78 var nOpen []int 79 t.nOpenConnsPerIPMu.Lock() 80 for _, n := range t.nOpenConnsPerIP { 81 nOpen = append(nOpen, n) 82 } 83 t.nOpenConnsPerIPMu.Unlock() 84 sort.Sort(sort.Reverse(sort.IntSlice(nOpen))) 85 log.Printf("s3file transport: open RTs per IP: %v", nOpen) 86 }) 87 return &t 88 } 89 90 func (t *T) RoundTrip(req *http.Request) (*http.Response, error) { 91 host := req.URL.Hostname() 92 93 ips, err := defaultResolver.LookupIP(host) 94 if err != nil { 95 if req.Body != nil { 96 _ = req.Body.Close() 97 } 98 return nil, fmt.Errorf("s3transport: lookup ip: %w", err) 99 } 100 ips = t.hostIPs.AddAndGet(host, ips) 101 102 hostReq := req.Clone(req.Context()) 103 hostReq.Host = host 104 // TODO: Consider other load balancing strategies. 105 ip := ips[rand.Intn(len(ips))].String() 106 hostReq.URL.Host = ip 107 108 hostRT := t.hostRoundTripper(host) 109 resp, err := hostRT.RoundTrip(hostReq) 110 if resp != nil { 111 t.addOpenConnsPerIP(ip, 1) 112 resp.Body = &rcOnClose{resp.Body, func() { t.addOpenConnsPerIP(ip, -1) }} 113 } 114 return resp, err 115 } 116 117 func (t *T) hostRoundTripper(host string) http.RoundTripper { 118 t.hostRTsMu.Lock() 119 defer t.hostRTsMu.Unlock() 120 if rt, ok := t.hostRTs[host]; ok { 121 return rt 122 } 123 transport := t.factory() 124 // We modify request URL to contain an IP, but server certificates list hostnames, so we 125 // configure our client to check against original hostname. 126 if transport.TLSClientConfig == nil { 127 transport.TLSClientConfig = &tls.Config{} 128 } 129 transport.TLSClientConfig.ServerName = host 130 t.hostRTs[host] = transport 131 return transport 132 } 133 134 func (t *T) addOpenConnsPerIP(ip string, add int) { 135 t.nOpenConnsPerIPMu.Lock() 136 t.nOpenConnsPerIP[ip] += add 137 t.nOpenConnsPerIPMu.Unlock() 138 } 139 140 type rcOnClose struct { 141 io.ReadCloser 142 onClose func() 143 } 144 145 func (r *rcOnClose) Close() error { 146 // In rare cases, this Close() is called a second time, with a call stack from the AWS SDK's 147 // cleanup code. 148 if r.onClose != nil { 149 defer r.onClose() 150 } 151 r.onClose = nil 152 return r.ReadCloser.Close() 153 }