github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/file/s3file/s3transport/transport.go (about)

     1  package s3transport
     2  
     3  import (
     4  	"crypto/tls"
     5  	"fmt"
     6  	"io"
     7  	"math/rand"
     8  	"net"
     9  	"net/http"
    10  	"sort"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/Schaudge/grailbase/file/s3file/internal/autolog"
    15  	"github.com/Schaudge/grailbase/log"
    16  )
    17  
    18  // T is an http.RoundTripper specialized for S3. See https://github.com/aws/aws-sdk-go/issues/3739.
    19  type T struct {
    20  	factory func() *http.Transport
    21  
    22  	hostRTsMu sync.Mutex
    23  	hostRTs   map[string]http.RoundTripper
    24  
    25  	nOpenConnsPerIPMu sync.Mutex
    26  	nOpenConnsPerIP   map[string]int
    27  
    28  	hostIPs *expiringMap
    29  }
    30  
    31  var (
    32  	stdDefaultTransport = http.DefaultTransport.(*http.Transport)
    33  	httpTransport       = &http.Transport{
    34  		DialContext: (&net.Dialer{
    35  			Timeout:   30 * time.Second, // Copied from http.DefaultTransport.
    36  			KeepAlive: 30 * time.Second, // Copied from same.
    37  		}).DialContext,
    38  		ForceAttemptHTTP2:     false,                           // S3 doesn't support HTTP2.
    39  		MaxIdleConns:          200,                             // Keep many peers for future bursts.
    40  		MaxIdleConnsPerHost:   4,                               // But limit connections to each.
    41  		IdleConnTimeout:       expireAfter + 2*expireLoopEvery, // Keep until we forget the peer.
    42  		TLSClientConfig:       &tls.Config{},
    43  		TLSHandshakeTimeout:   stdDefaultTransport.TLSHandshakeTimeout,
    44  		ExpectContinueTimeout: stdDefaultTransport.ExpectContinueTimeout,
    45  	}
    46  
    47  	defaultOnce   sync.Once
    48  	defaultT      *T
    49  	defaultClient *http.Client
    50  )
    51  
    52  func defaults() (*T, *http.Client) {
    53  	defaultOnce.Do(func() {
    54  		defaultT = New(httpTransport.Clone)
    55  		defaultClient = &http.Client{Transport: defaultT}
    56  	})
    57  	return defaultT, defaultClient
    58  }
    59  
    60  // Default returns an http.RoundTripper with recommended settings.
    61  func Default() *T { t, _ := defaults(); return t }
    62  
    63  // DefaultClient returns an *http.Client that uses the http.RoundTripper
    64  // returned by Default (suitable for general use, analogous to
    65  // "net/http".DefaultClient).
    66  func DefaultClient() *http.Client { _, c := defaults(); return c }
    67  
    68  // New constructs *T using factory to create internal transports. Each call to factory()
    69  // must return a separate http.Transport and they must not share TLSClientConfig.
    70  func New(factory func() *http.Transport) *T {
    71  	t := T{
    72  		factory:         factory,
    73  		hostRTs:         map[string]http.RoundTripper{},
    74  		hostIPs:         newExpiringMap(runPeriodicForever(), time.Now),
    75  		nOpenConnsPerIP: map[string]int{},
    76  	}
    77  	autolog.Register(func() {
    78  		var nOpen []int
    79  		t.nOpenConnsPerIPMu.Lock()
    80  		for _, n := range t.nOpenConnsPerIP {
    81  			nOpen = append(nOpen, n)
    82  		}
    83  		t.nOpenConnsPerIPMu.Unlock()
    84  		sort.Sort(sort.Reverse(sort.IntSlice(nOpen)))
    85  		log.Printf("s3file transport: open RTs per IP: %v", nOpen)
    86  	})
    87  	return &t
    88  }
    89  
    90  func (t *T) RoundTrip(req *http.Request) (*http.Response, error) {
    91  	host := req.URL.Hostname()
    92  
    93  	ips, err := defaultResolver.LookupIP(host)
    94  	if err != nil {
    95  		if req.Body != nil {
    96  			_ = req.Body.Close()
    97  		}
    98  		return nil, fmt.Errorf("s3transport: lookup ip: %w", err)
    99  	}
   100  	ips = t.hostIPs.AddAndGet(host, ips)
   101  
   102  	hostReq := req.Clone(req.Context())
   103  	hostReq.Host = host
   104  	// TODO: Consider other load balancing strategies.
   105  	ip := ips[rand.Intn(len(ips))].String()
   106  	hostReq.URL.Host = ip
   107  
   108  	hostRT := t.hostRoundTripper(host)
   109  	resp, err := hostRT.RoundTrip(hostReq)
   110  	if resp != nil {
   111  		t.addOpenConnsPerIP(ip, 1)
   112  		resp.Body = &rcOnClose{resp.Body, func() { t.addOpenConnsPerIP(ip, -1) }}
   113  	}
   114  	return resp, err
   115  }
   116  
   117  func (t *T) hostRoundTripper(host string) http.RoundTripper {
   118  	t.hostRTsMu.Lock()
   119  	defer t.hostRTsMu.Unlock()
   120  	if rt, ok := t.hostRTs[host]; ok {
   121  		return rt
   122  	}
   123  	transport := t.factory()
   124  	// We modify request URL to contain an IP, but server certificates list hostnames, so we
   125  	// configure our client to check against original hostname.
   126  	if transport.TLSClientConfig == nil {
   127  		transport.TLSClientConfig = &tls.Config{}
   128  	}
   129  	transport.TLSClientConfig.ServerName = host
   130  	t.hostRTs[host] = transport
   131  	return transport
   132  }
   133  
   134  func (t *T) addOpenConnsPerIP(ip string, add int) {
   135  	t.nOpenConnsPerIPMu.Lock()
   136  	t.nOpenConnsPerIP[ip] += add
   137  	t.nOpenConnsPerIPMu.Unlock()
   138  }
   139  
   140  type rcOnClose struct {
   141  	io.ReadCloser
   142  	onClose func()
   143  }
   144  
   145  func (r *rcOnClose) Close() error {
   146  	// In rare cases, this Close() is called a second time, with a call stack from the AWS SDK's
   147  	// cleanup code.
   148  	if r.onClose != nil {
   149  		defer r.onClose()
   150  	}
   151  	r.onClose = nil
   152  	return r.ReadCloser.Close()
   153  }