github.com/jpmicrosoft/grab/v3@v3.0.2/pkg/grabtest/handler.go (about)

     1  package grabtest
     2  
     3  import (
     4  	"bufio"
     5  	"fmt"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"testing"
     9  	"time"
    10  )
    11  
    12  var (
    13  	DefaultHandlerContentLength       = 1 << 20
    14  	DefaultHandlerMD5Checksum         = "c35cc7d8d91728a0cb052831bc4ef372"
    15  	DefaultHandlerMD5ChecksumBytes    = MustHexDecodeString(DefaultHandlerMD5Checksum)
    16  	DefaultHandlerSHA256Checksum      = "fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83"
    17  	DefaultHandlerSHA256ChecksumBytes = MustHexDecodeString(DefaultHandlerSHA256Checksum)
    18  )
    19  
    20  type StatusCodeFunc func(req *http.Request) int
    21  
    22  type handler struct {
    23  	statusCodeFunc     StatusCodeFunc
    24  	methodWhitelist    []string
    25  	headerBlacklist    []string
    26  	contentLength      int
    27  	acceptRanges       bool
    28  	attachmentFilename string
    29  	lastModified       time.Time
    30  	ttfb               time.Duration
    31  	rateLimiter        *time.Ticker
    32  }
    33  
    34  func NewHandler(options ...HandlerOption) (http.Handler, error) {
    35  	h := &handler{
    36  		statusCodeFunc:  func(req *http.Request) int { return http.StatusOK },
    37  		methodWhitelist: []string{"GET", "HEAD"},
    38  		contentLength:   DefaultHandlerContentLength,
    39  		acceptRanges:    true,
    40  	}
    41  	for _, option := range options {
    42  		if err := option(h); err != nil {
    43  			return nil, err
    44  		}
    45  	}
    46  	return h, nil
    47  }
    48  
    49  func WithTestServer(t *testing.T, f func(url string), options ...HandlerOption) {
    50  	h, err := NewHandler(options...)
    51  	if err != nil {
    52  		t.Fatalf("unable to create test server handler: %v", err)
    53  		return
    54  	}
    55  	s := httptest.NewServer(h)
    56  	defer func() {
    57  		h.(*handler).close()
    58  		s.Close()
    59  	}()
    60  	f(s.URL)
    61  }
    62  
    63  func (h *handler) close() {
    64  	if h.rateLimiter != nil {
    65  		h.rateLimiter.Stop()
    66  	}
    67  }
    68  
    69  func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    70  	// delay response
    71  	if h.ttfb > 0 {
    72  		time.Sleep(h.ttfb)
    73  	}
    74  
    75  	// validate request method
    76  	allowed := false
    77  	for _, m := range h.methodWhitelist {
    78  		if r.Method == m {
    79  			allowed = true
    80  			break
    81  		}
    82  	}
    83  	if !allowed {
    84  		httpError(w, http.StatusMethodNotAllowed)
    85  		return
    86  	}
    87  
    88  	// set server options
    89  	if h.acceptRanges {
    90  		w.Header().Set("Accept-Ranges", "bytes")
    91  	}
    92  
    93  	// set attachment filename
    94  	if h.attachmentFilename != "" {
    95  		w.Header().Set(
    96  			"Content-Disposition",
    97  			fmt.Sprintf("attachment;filename=\"%s\"", h.attachmentFilename),
    98  		)
    99  	}
   100  
   101  	// set last modified timestamp
   102  	lastMod := time.Now()
   103  	if !h.lastModified.IsZero() {
   104  		lastMod = h.lastModified
   105  	}
   106  	w.Header().Set("Last-Modified", lastMod.Format(http.TimeFormat))
   107  
   108  	// set content-length
   109  	offset := 0
   110  	if h.acceptRanges {
   111  		if reqRange := r.Header.Get("Range"); reqRange != "" {
   112  			if _, err := fmt.Sscanf(reqRange, "bytes=%d-", &offset); err != nil {
   113  				httpError(w, http.StatusBadRequest)
   114  				return
   115  			}
   116  			if offset >= h.contentLength {
   117  				httpError(w, http.StatusRequestedRangeNotSatisfiable)
   118  				return
   119  			}
   120  		}
   121  	}
   122  	w.Header().Set("Content-Length", fmt.Sprintf("%d", h.contentLength-offset))
   123  
   124  	// apply header blacklist
   125  	for _, key := range h.headerBlacklist {
   126  		w.Header().Del(key)
   127  	}
   128  
   129  	// send header and status code
   130  	w.WriteHeader(h.statusCodeFunc(r))
   131  
   132  	// send body
   133  	if r.Method == "GET" {
   134  		// use buffered io to reduce overhead on the reader
   135  		bw := bufio.NewWriterSize(w, 4096)
   136  		for i := offset; !isRequestClosed(r) && i < h.contentLength; i++ {
   137  			bw.Write([]byte{byte(i)})
   138  			if h.rateLimiter != nil {
   139  				bw.Flush()
   140  				w.(http.Flusher).Flush() // force the server to send the data to the client
   141  				select {
   142  				case <-h.rateLimiter.C:
   143  				case <-r.Context().Done():
   144  				}
   145  			}
   146  		}
   147  		if !isRequestClosed(r) {
   148  			bw.Flush()
   149  		}
   150  	}
   151  }
   152  
   153  // isRequestClosed returns true if the client request has been canceled.
   154  func isRequestClosed(r *http.Request) bool {
   155  	return r.Context().Err() != nil
   156  }
   157  
   158  func httpError(w http.ResponseWriter, code int) {
   159  	http.Error(w, http.StatusText(code), code)
   160  }