github.com/jpmicrosoft/grab/v3@v3.0.2/pkg/grabtest/handler.go (about) 1 package grabtest 2 3 import ( 4 "bufio" 5 "fmt" 6 "net/http" 7 "net/http/httptest" 8 "testing" 9 "time" 10 ) 11 12 var ( 13 DefaultHandlerContentLength = 1 << 20 14 DefaultHandlerMD5Checksum = "c35cc7d8d91728a0cb052831bc4ef372" 15 DefaultHandlerMD5ChecksumBytes = MustHexDecodeString(DefaultHandlerMD5Checksum) 16 DefaultHandlerSHA256Checksum = "fbbab289f7f94b25736c58be46a994c441fd02552cc6022352e3d86d2fab7c83" 17 DefaultHandlerSHA256ChecksumBytes = MustHexDecodeString(DefaultHandlerSHA256Checksum) 18 ) 19 20 type StatusCodeFunc func(req *http.Request) int 21 22 type handler struct { 23 statusCodeFunc StatusCodeFunc 24 methodWhitelist []string 25 headerBlacklist []string 26 contentLength int 27 acceptRanges bool 28 attachmentFilename string 29 lastModified time.Time 30 ttfb time.Duration 31 rateLimiter *time.Ticker 32 } 33 34 func NewHandler(options ...HandlerOption) (http.Handler, error) { 35 h := &handler{ 36 statusCodeFunc: func(req *http.Request) int { return http.StatusOK }, 37 methodWhitelist: []string{"GET", "HEAD"}, 38 contentLength: DefaultHandlerContentLength, 39 acceptRanges: true, 40 } 41 for _, option := range options { 42 if err := option(h); err != nil { 43 return nil, err 44 } 45 } 46 return h, nil 47 } 48 49 func WithTestServer(t *testing.T, f func(url string), options ...HandlerOption) { 50 h, err := NewHandler(options...) 51 if err != nil { 52 t.Fatalf("unable to create test server handler: %v", err) 53 return 54 } 55 s := httptest.NewServer(h) 56 defer func() { 57 h.(*handler).close() 58 s.Close() 59 }() 60 f(s.URL) 61 } 62 63 func (h *handler) close() { 64 if h.rateLimiter != nil { 65 h.rateLimiter.Stop() 66 } 67 } 68 69 func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 70 // delay response 71 if h.ttfb > 0 { 72 time.Sleep(h.ttfb) 73 } 74 75 // validate request method 76 allowed := false 77 for _, m := range h.methodWhitelist { 78 if r.Method == m { 79 allowed = true 80 break 81 } 82 } 83 if !allowed { 84 httpError(w, http.StatusMethodNotAllowed) 85 return 86 } 87 88 // set server options 89 if h.acceptRanges { 90 w.Header().Set("Accept-Ranges", "bytes") 91 } 92 93 // set attachment filename 94 if h.attachmentFilename != "" { 95 w.Header().Set( 96 "Content-Disposition", 97 fmt.Sprintf("attachment;filename=\"%s\"", h.attachmentFilename), 98 ) 99 } 100 101 // set last modified timestamp 102 lastMod := time.Now() 103 if !h.lastModified.IsZero() { 104 lastMod = h.lastModified 105 } 106 w.Header().Set("Last-Modified", lastMod.Format(http.TimeFormat)) 107 108 // set content-length 109 offset := 0 110 if h.acceptRanges { 111 if reqRange := r.Header.Get("Range"); reqRange != "" { 112 if _, err := fmt.Sscanf(reqRange, "bytes=%d-", &offset); err != nil { 113 httpError(w, http.StatusBadRequest) 114 return 115 } 116 if offset >= h.contentLength { 117 httpError(w, http.StatusRequestedRangeNotSatisfiable) 118 return 119 } 120 } 121 } 122 w.Header().Set("Content-Length", fmt.Sprintf("%d", h.contentLength-offset)) 123 124 // apply header blacklist 125 for _, key := range h.headerBlacklist { 126 w.Header().Del(key) 127 } 128 129 // send header and status code 130 w.WriteHeader(h.statusCodeFunc(r)) 131 132 // send body 133 if r.Method == "GET" { 134 // use buffered io to reduce overhead on the reader 135 bw := bufio.NewWriterSize(w, 4096) 136 for i := offset; !isRequestClosed(r) && i < h.contentLength; i++ { 137 bw.Write([]byte{byte(i)}) 138 if h.rateLimiter != nil { 139 bw.Flush() 140 w.(http.Flusher).Flush() // force the server to send the data to the client 141 select { 142 case <-h.rateLimiter.C: 143 case <-r.Context().Done(): 144 } 145 } 146 } 147 if !isRequestClosed(r) { 148 bw.Flush() 149 } 150 } 151 } 152 153 // isRequestClosed returns true if the client request has been canceled. 154 func isRequestClosed(r *http.Request) bool { 155 return r.Context().Err() != nil 156 } 157 158 func httpError(w http.ResponseWriter, code int) { 159 http.Error(w, http.StatusText(code), code) 160 }