github.com/grafana/pyroscope@v1.18.0/pkg/util/body/limit_test.go (about) 1 package body 2 3 import ( 4 "errors" 5 "io" 6 "net/http" 7 "net/http/httptest" 8 "strings" 9 "testing" 10 11 "github.com/stretchr/testify/assert" 12 13 "github.com/grafana/pyroscope/pkg/tenant" 14 httputil "github.com/grafana/pyroscope/pkg/util/http" 15 "github.com/grafana/pyroscope/pkg/validation" 16 ) 17 18 // Test handler that records what happened 19 type testHandler struct { 20 called bool 21 } 22 23 func (h *testHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 24 h.called = true 25 n, err := io.Copy(io.Discard, r.Body) 26 println(n) 27 var maxBytesError *http.MaxBytesError 28 if errors.As(err, &maxBytesError) { 29 httputil.ErrorWithStatus(w, err, http.StatusRequestEntityTooLarge) 30 return 31 } 32 33 w.WriteHeader(http.StatusOK) 34 } 35 36 func TestRequestBodyLimitMiddleware(t *testing.T) { 37 tenantID := "my-tenant" 38 anyByte := string('0') 39 tests := []struct { 40 name string 41 bodyLimit int64 42 bodySize int 43 expectedError bool 44 }{ 45 { 46 name: "body size below limit", 47 bodyLimit: 10, 48 bodySize: 9, 49 expectedError: false, 50 }, 51 { 52 name: "body size matches limit", 53 bodyLimit: 10, 54 bodySize: 10, 55 expectedError: false, 56 }, 57 { 58 name: "body exceeds limit", 59 bodyLimit: 10, 60 bodySize: 11, 61 expectedError: true, 62 }, 63 { 64 name: "no limit set", 65 bodyLimit: 0, 66 bodySize: 11, 67 expectedError: false, 68 }, 69 } 70 for _, tt := range tests { 71 t.Run(tt.name, func(t *testing.T) { 72 limits := validation.MockLimits{ 73 IngestionBodyLimitBytesValue: tt.bodyLimit, 74 } 75 middleware := NewSizeLimitHandler(limits) 76 77 var handler testHandler 78 req := httptest.NewRequest(http.MethodPost, "/test", strings.NewReader(strings.Repeat(anyByte, tt.bodySize))) 79 req = req.WithContext(tenant.InjectTenantID(req.Context(), tenantID)) 80 w := httptest.NewRecorder() 81 82 middleware(&handler).ServeHTTP(w, req) 83 84 // Verify handler was called 85 assert.True(t, handler.called) 86 87 if tt.expectedError { 88 assert.Equal(t, http.StatusRequestEntityTooLarge, w.Code) 89 } else { 90 assert.Equal(t, http.StatusOK, w.Code) 91 } 92 }) 93 } 94 }