github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/remotestorage/internal/reliable/http.go (about)

     1  // Copyright 2024 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package reliable
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"fmt"
    21  	"io"
    22  	"net/http"
    23  	"sync/atomic"
    24  	"time"
    25  
    26  	"github.com/cenkalti/backoff/v4"
    27  )
    28  
    29  type HTTPFetcher interface {
    30  	Do(req *http.Request) (*http.Response, error)
    31  }
    32  
    33  type UrlFactoryFunc func(error) (string, error)
    34  
    35  type StreamingResponse struct {
    36  	Body   io.Reader
    37  	cancel func()
    38  }
    39  
    40  func (r StreamingResponse) Close() error {
    41  	r.cancel()
    42  	return nil
    43  }
    44  
    45  type StatsRecorder interface {
    46  	RecordTimeToFirstByte(retry int, size uint64, d time.Duration)
    47  	RecordDownloadAttemptStart(retry int, offset, size uint64)
    48  	RecordDownloadComplete(retry int, size uint64, d time.Duration)
    49  }
    50  
    51  type HealthRecorder interface {
    52  	RecordSuccess()
    53  	RecordFailure()
    54  }
    55  
    56  var ErrThroughputTooLow = errors.New("throughput below minimum threshold")
    57  var ErrHttpStatus = errors.New("http status")
    58  
    59  type MinimumThroughputCheck struct {
    60  	CheckInterval time.Duration
    61  	BytesPerCheck int
    62  	NumIntervals  int
    63  }
    64  
    65  type BackOffFactory func(context.Context) backoff.BackOff
    66  
    67  type StreamingRangeRequest struct {
    68  	Fetcher            HTTPFetcher
    69  	Offset             uint64
    70  	Length             uint64
    71  	UrlFact            UrlFactoryFunc
    72  	Stats              StatsRecorder
    73  	Health             HealthRecorder
    74  	BackOffFact        BackOffFactory
    75  	Throughput         MinimumThroughputCheck
    76  	RespHeadersTimeout time.Duration
    77  }
    78  
    79  // |StreamingRangeDownload| makes an immediate GET request to the URL returned
    80  // from |req.UrlFact|, returning a |StreamingResponse| object which can be used to
    81  // consume the body of the response. A |StreamingResponse| should be |Close|d
    82  // by the consumer, and it is safe to do so before the entire response has been
    83  // read, if a condition arises where the response is no longer needed.
    84  //
    85  // This method will kick off a goroutine which is responsible for making the
    86  // HTTP request(s) associated with fulfilling the request. Only one HTTP
    87  // request will be inflight at a time, but if errors are encountered while
    88  // making the requests or reading the response body, further requests may be
    89  // made for as-yet-undelivered bytes from the requested byte range.
    90  //
    91  // As a result, the bytes read from |StreamingResponse.Body| may be the
    92  // concatenation of multiple requests made the URLs returned from |urlStrF|.
    93  // Thus, those URLs should represent the same immutable remote resource which
    94  // is guaranteed to return the same bytes for overlapping byte ranges.
    95  //
    96  // If there is a fatal error when making the requests, it will be delivered
    97  // through the |err| responses of the |Read| method on
    98  // |StreamingResponse.Read|.
    99  //
   100  // |StreamingResponse.Read| can (and often will) return short reads.
   101  func StreamingRangeDownload(ctx context.Context, req StreamingRangeRequest) StreamingResponse {
   102  	// This never changes, but the offset at which retrys are made is based
   103  	// on how much of the response has already been delivered.
   104  	rangeEnd := req.Offset + req.Length - 1
   105  
   106  	// |StreamingResponse| is getting the read side of this pipe to read
   107  	// the body and/or any terminal error encountered. The goroutine making
   108  	// the retried HTTP requests will be writing to |w|.
   109  	r, w := io.Pipe()
   110  
   111  	// This is the overall context for the operation, encompassing all of its retries. When StreamingResponse is closed, this is canceled.
   112  	ctx, cancel := context.WithCancel(ctx)
   113  
   114  	// This naked go routine makes retried HTTP requests for the byte range, writing the HTTP response bodies to |w|.
   115  	go func() {
   116  		origOffset := req.Offset
   117  		// |offset| starts at |req.Offset| but may be updated if we
   118  		// make retries and have already delivered some bytes.
   119  		offset := req.Offset
   120  		var retry int
   121  		// |lastError| is used by UrlFact.
   122  		var lastError error
   123  		op := func() (rerr error) {
   124  			defer func() { retry += 1 }()
   125  			defer func() { lastError = rerr }()
   126  			// This is the per-call context. It can be canceled by
   127  			// EnforceThroughput, for example, without canceling
   128  			// the entire operation.
   129  			ctx, cCause := context.WithCancelCause(ctx)
   130  
   131  			url, err := req.UrlFact(lastError)
   132  			if err != nil {
   133  				return err
   134  			}
   135  
   136  			httpReq, err := http.NewRequest(http.MethodGet, url, nil)
   137  			if err != nil {
   138  				return err
   139  			}
   140  
   141  			rangeHeaderVal := fmt.Sprintf("bytes=%d-%d", offset, rangeEnd)
   142  			httpReq.Header.Set("Range", rangeHeaderVal)
   143  
   144  			// We use a TimeoutController to enforce a timeout for
   145  			// receiving the response headers. If the request is
   146  			// successful, the "timeout" on the overall request
   147  			// will be managed by |EnforceThroughput| on reading
   148  			// the response body. But we still need to impose a
   149  			// timeout on receiving the response headers, which we
   150  			// don't want to block for an indefinite or unspecified
   151  			// amount of time. Here we set things up so we will
   152  			// manually cancel the request context if the response
   153  			// headers are not received in time, but we can cancel
   154  			// this timeout immediately after the response headers
   155  			// are received.
   156  			tc := NewTimeoutController()
   157  			defer tc.Close()
   158  			go func() {
   159  				err := tc.Run()
   160  				if err != nil {
   161  					cCause(err)
   162  				}
   163  			}()
   164  
   165  			httpReq = httpReq.WithContext(ctx)
   166  
   167  			req.Stats.RecordDownloadAttemptStart(retry, offset-origOffset, req.Length)
   168  			start := time.Now()
   169  
   170  			tc.SetTimeout(ctx, req.RespHeadersTimeout)
   171  			resp, err := req.Fetcher.Do(httpReq)
   172  			tc.SetTimeout(ctx, 0)
   173  			if err != nil {
   174  				req.Health.RecordFailure()
   175  				return err
   176  			}
   177  			defer resp.Body.Close()
   178  
   179  			if resp.StatusCode/100 != 2 {
   180  				req.Health.RecordFailure()
   181  				return fmt.Errorf("%w: %d", ErrHttpStatus, resp.StatusCode)
   182  			}
   183  			req.Stats.RecordTimeToFirstByte(retry, req.Length, time.Since(start))
   184  
   185  			reader := &AtomicCountingReader{r: resp.Body}
   186  			cleanup := EnforceThroughput(reader.Count, req.Throughput, func(err error) {
   187  				cCause(err)
   188  			})
   189  			n, err := io.Copy(w, reader)
   190  			cleanup()
   191  			// We successfully wrote this many bytes to |w|. Update |offset|.
   192  			offset += uint64(n)
   193  			if err == nil {
   194  				// Success! We read until Body returned EOF.
   195  				req.Health.RecordSuccess()
   196  				return nil
   197  			} else if errors.Is(err, io.ErrClosedPipe) || errors.Is(err, io.ErrShortWrite) {
   198  				// Reader closed; bail.
   199  				return backoff.Permanent(err)
   200  			} else {
   201  				if cerr := context.Cause(ctx); errors.Is(err, context.Canceled) && cerr != nil {
   202  					// HTTP Body reader will return
   203  					// context.Canceled even if we cancel
   204  					// with a cause. Convert to the cause
   205  					// here, if we have one.
   206  					err = cerr
   207  				}
   208  				// Let backoff decide when and if we retry.
   209  				req.Health.RecordFailure()
   210  				return err
   211  			}
   212  		}
   213  		start := time.Now()
   214  		err := backoff.Retry(op, req.BackOffFact(ctx))
   215  		if err != nil {
   216  			w.CloseWithError(err)
   217  		} else {
   218  			req.Stats.RecordDownloadComplete(retry, req.Length, time.Since(start))
   219  			w.Close()
   220  		}
   221  	}()
   222  
   223  	return StreamingResponse{
   224  		Body:   r,
   225  		cancel: cancel,
   226  	}
   227  }
   228  
   229  type AtomicCountingReader struct {
   230  	r io.Reader
   231  	c atomic.Uint64
   232  }
   233  
   234  func (r *AtomicCountingReader) Read(bs []byte) (int, error) {
   235  	n, err := r.r.Read(bs)
   236  	r.c.Add(uint64(n))
   237  	return n, err
   238  }
   239  
   240  func (r *AtomicCountingReader) Count() uint64 {
   241  	return r.c.Load()
   242  }
   243  
   244  // EnforceThroughput will spawn a naked goroutine that will watch a |cnt|
   245  // source. If the rate by which |cnt| is increasing drops below the configured
   246  // threshold for too long, it will call |cancel|.  EnforceThroughput should be
   247  // cleaned up by calling |cleanup| once whatever it is monitoring is finished.
   248  func EnforceThroughput(cnt func() uint64, params MinimumThroughputCheck, cancel func(error)) (cleanup func()) {
   249  	done := make(chan struct{})
   250  	go func() {
   251  		n := params.NumIntervals
   252  		var counts []uint64
   253  		// Note: We don't look at the clock when we take these
   254  		// observations. If we make late observations, then we may see
   255  		// higher numbers than we should have and think our throughput
   256  		// is higher than it is.
   257  		tooSlow := func() bool {
   258  			if len(counts) < n {
   259  				return false
   260  			}
   261  			copy(counts[:n], counts[len(counts)-n:])
   262  			counts = counts[:n]
   263  			cnt := counts[n-1] - counts[0]
   264  			if int(cnt) < params.BytesPerCheck*n {
   265  				return true
   266  			}
   267  			return false
   268  		}
   269  		for {
   270  			select {
   271  			case <-time.After(params.CheckInterval):
   272  				counts = append(counts, cnt())
   273  				if tooSlow() {
   274  					cancel(fmt.Errorf("%w: needed %d bytes per interval across %d intervals, went from %d to %d instead",
   275  						ErrThroughputTooLow, params.BytesPerCheck, n, counts[0], counts[n-1]))
   276  					return
   277  				}
   278  			case <-done:
   279  				return
   280  			}
   281  		}
   282  	}()
   283  	return func() {
   284  		close(done)
   285  	}
   286  }