github.com/grafana/pyroscope@v1.18.0/pkg/util/body/limit_test.go (about)

     1  package body
     2  
     3  import (
     4  	"errors"
     5  	"io"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"strings"
     9  	"testing"
    10  
    11  	"github.com/stretchr/testify/assert"
    12  
    13  	"github.com/grafana/pyroscope/pkg/tenant"
    14  	httputil "github.com/grafana/pyroscope/pkg/util/http"
    15  	"github.com/grafana/pyroscope/pkg/validation"
    16  )
    17  
    18  // Test handler that records what happened
    19  type testHandler struct {
    20  	called bool
    21  }
    22  
    23  func (h *testHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    24  	h.called = true
    25  	n, err := io.Copy(io.Discard, r.Body)
    26  	println(n)
    27  	var maxBytesError *http.MaxBytesError
    28  	if errors.As(err, &maxBytesError) {
    29  		httputil.ErrorWithStatus(w, err, http.StatusRequestEntityTooLarge)
    30  		return
    31  	}
    32  
    33  	w.WriteHeader(http.StatusOK)
    34  }
    35  
    36  func TestRequestBodyLimitMiddleware(t *testing.T) {
    37  	tenantID := "my-tenant"
    38  	anyByte := string('0')
    39  	tests := []struct {
    40  		name          string
    41  		bodyLimit     int64
    42  		bodySize      int
    43  		expectedError bool
    44  	}{
    45  		{
    46  			name:          "body size below limit",
    47  			bodyLimit:     10,
    48  			bodySize:      9,
    49  			expectedError: false,
    50  		},
    51  		{
    52  			name:          "body size matches limit",
    53  			bodyLimit:     10,
    54  			bodySize:      10,
    55  			expectedError: false,
    56  		},
    57  		{
    58  			name:          "body exceeds limit",
    59  			bodyLimit:     10,
    60  			bodySize:      11,
    61  			expectedError: true,
    62  		},
    63  		{
    64  			name:          "no limit set",
    65  			bodyLimit:     0,
    66  			bodySize:      11,
    67  			expectedError: false,
    68  		},
    69  	}
    70  	for _, tt := range tests {
    71  		t.Run(tt.name, func(t *testing.T) {
    72  			limits := validation.MockLimits{
    73  				IngestionBodyLimitBytesValue: tt.bodyLimit,
    74  			}
    75  			middleware := NewSizeLimitHandler(limits)
    76  
    77  			var handler testHandler
    78  			req := httptest.NewRequest(http.MethodPost, "/test", strings.NewReader(strings.Repeat(anyByte, tt.bodySize)))
    79  			req = req.WithContext(tenant.InjectTenantID(req.Context(), tenantID))
    80  			w := httptest.NewRecorder()
    81  
    82  			middleware(&handler).ServeHTTP(w, req)
    83  
    84  			// Verify handler was called
    85  			assert.True(t, handler.called)
    86  
    87  			if tt.expectedError {
    88  				assert.Equal(t, http.StatusRequestEntityTooLarge, w.Code)
    89  			} else {
    90  				assert.Equal(t, http.StatusOK, w.Code)
    91  			}
    92  		})
    93  	}
    94  }