github.com/grafana/pyroscope@v1.18.0/pkg/validation/user_limits_handler_test.go (about)

     1  // SPDX-License-Identifier: AGPL-3.0-only
     2  
     3  package validation
     4  
     5  import (
     6  	"context"
     7  	"encoding/json"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"testing"
    11  
    12  	"github.com/grafana/dskit/user"
    13  	"github.com/stretchr/testify/require"
    14  )
    15  
    16  func TestTenantLimitsHandler(t *testing.T) {
    17  	defaults := Limits{
    18  		IngestionRateMB:      100,
    19  		IngestionBurstSizeMB: 10,
    20  	}
    21  
    22  	tenantLimits := make(map[string]*Limits)
    23  	testLimits := defaults
    24  	testLimits.IngestionRateMB = 200
    25  	tenantLimits["test-with-override"] = &testLimits
    26  
    27  	for _, tc := range []struct {
    28  		name               string
    29  		orgID              string
    30  		expectedStatusCode int
    31  		expectedLimits     TenantLimitsResponse
    32  	}{
    33  		{
    34  			name:               "Authenticated user with override",
    35  			orgID:              "test-with-override",
    36  			expectedStatusCode: http.StatusOK,
    37  			expectedLimits: TenantLimitsResponse{
    38  				IngestionRate:      200,
    39  				IngestionBurstSize: 10,
    40  			},
    41  		},
    42  		{
    43  			name:               "Authenticated user without override",
    44  			orgID:              "test-no-override",
    45  			expectedStatusCode: http.StatusOK,
    46  			expectedLimits: TenantLimitsResponse{
    47  				IngestionRate:      100,
    48  				IngestionBurstSize: 10,
    49  			},
    50  		},
    51  		{
    52  			name:               "Unauthenticated user",
    53  			orgID:              "",
    54  			expectedStatusCode: http.StatusUnauthorized,
    55  			expectedLimits:     TenantLimitsResponse{},
    56  		},
    57  	} {
    58  		t.Run(tc.name, func(t *testing.T) {
    59  			handler := TenantLimitsHandler(defaults, NewMockTenantLimits(tenantLimits))
    60  			request := httptest.NewRequest("GET", "/api/v1/user_limits", nil)
    61  			if tc.orgID != "" {
    62  				ctx := user.InjectOrgID(context.Background(), tc.orgID)
    63  				request = request.WithContext(ctx)
    64  			}
    65  
    66  			recorder := httptest.NewRecorder()
    67  			handler.ServeHTTP(recorder, request)
    68  			require.Equal(t, tc.expectedStatusCode, recorder.Result().StatusCode)
    69  
    70  			if recorder.Result().StatusCode == http.StatusOK {
    71  				var response TenantLimitsResponse
    72  				decoder := json.NewDecoder(recorder.Result().Body)
    73  				require.NoError(t, decoder.Decode(&response))
    74  				require.Equal(t, tc.expectedLimits, response)
    75  			}
    76  		})
    77  	}
    78  }