github.com/grafana/pyroscope@v1.18.0/pkg/tenant/interceptor_test.go (about)

     1  package tenant
     2  
     3  import (
     4  	"context"
     5  	"net/http"
     6  	"testing"
     7  
     8  	"connectrpc.com/connect"
     9  	"github.com/stretchr/testify/require"
    10  )
    11  
    12  func Test_AuthInterceptor(t *testing.T) {
    13  	for testName, testCase := range map[string]func(t *testing.T){
    14  		"client: forward from context": func(t *testing.T) {
    15  			i := NewAuthInterceptor(false)
    16  			resp, err := i.WrapUnary(func(ctx context.Context, ar connect.AnyRequest) (connect.AnyResponse, error) {
    17  				tenantID, _, err := ExtractTenantIDFromHeaders(context.Background(), ar.Header())
    18  				require.NoError(t, err)
    19  				require.Equal(t, tenantID, "foo")
    20  				return nil, nil
    21  			})(InjectTenantID(context.Background(), "foo"), newFakeReq(true))
    22  			require.NoError(t, err)
    23  			require.Nil(t, resp)
    24  		},
    25  		"client: no org forwarded": func(t *testing.T) {
    26  			i := NewAuthInterceptor(false)
    27  			resp, err := i.WrapUnary(func(ctx context.Context, ar connect.AnyRequest) (connect.AnyResponse, error) {
    28  				tenantID, _, err := ExtractTenantIDFromHeaders(context.Background(), ar.Header())
    29  				require.Equal(t, ErrNoTenantID, err)
    30  				require.Equal(t, tenantID, "")
    31  				return nil, nil
    32  			})(context.Background(), newFakeReq(true))
    33  			require.NoError(t, err)
    34  			require.Nil(t, resp)
    35  		},
    36  		"server: disable, static org": func(t *testing.T) {
    37  			i := NewAuthInterceptor(false)
    38  			req := newFakeReq(false)
    39  			req.Header().Set("X-Scope-OrgID", "foo")
    40  			resp, err := i.WrapUnary(func(ctx context.Context, ar connect.AnyRequest) (connect.AnyResponse, error) {
    41  				tenantID, err := ExtractTenantIDFromContext(ctx)
    42  				require.NoError(t, err)
    43  				require.Equal(t, tenantID, DefaultTenantID)
    44  				return nil, nil
    45  			})(context.Background(), req)
    46  			require.NoError(t, err)
    47  			require.Nil(t, resp)
    48  		},
    49  		"server: enable, forward header": func(t *testing.T) {
    50  			i := NewAuthInterceptor(true)
    51  			req := newFakeReq(false)
    52  			req.Header().Set("X-Scope-OrgID", "foo")
    53  			resp, err := i.WrapUnary(func(ctx context.Context, ar connect.AnyRequest) (connect.AnyResponse, error) {
    54  				tenantID, err := ExtractTenantIDFromContext(ctx)
    55  				require.NoError(t, err)
    56  				require.Equal(t, tenantID, "foo")
    57  				return nil, nil
    58  			})(context.Background(), req)
    59  			require.NoError(t, err)
    60  			require.Nil(t, resp)
    61  		},
    62  		"streaming client should forward from context": func(t *testing.T) {
    63  			i := NewAuthInterceptor(false)
    64  			inConn := newFakeClientStreamingConn()
    65  			outConn := i.WrapStreamingClient(func(ctx context.Context, s connect.Spec) connect.StreamingClientConn {
    66  				return inConn
    67  			})(InjectTenantID(context.Background(), "foo"), connect.Spec{})
    68  			require.Equal(t, "foo", outConn.RequestHeader().Get("X-Scope-OrgID"))
    69  		},
    70  		"streaming server should forward from header to context if enabled": func(t *testing.T) {
    71  			i := NewAuthInterceptor(true)
    72  			shc := newFakeClientStreamingConn()
    73  			shc.requestHeaders.Set("X-Scope-OrgID", "foo")
    74  			_ = i.WrapStreamingHandler(func(ctx context.Context, shc connect.StreamingHandlerConn) error {
    75  				tenantID, err := ExtractTenantIDFromContext(ctx)
    76  				require.NoError(t, err)
    77  				require.Equal(t, tenantID, "foo")
    78  				return nil
    79  			})(context.Background(), shc)
    80  		},
    81  		"streaming server should forward default tenant to context if disable": func(t *testing.T) {
    82  			i := NewAuthInterceptor(false)
    83  			shc := newFakeClientStreamingConn()
    84  			_ = i.WrapStreamingHandler(func(ctx context.Context, shc connect.StreamingHandlerConn) error {
    85  				tenantID, err := ExtractTenantIDFromContext(ctx)
    86  				require.NoError(t, err)
    87  				require.Equal(t, tenantID, DefaultTenantID)
    88  				return nil
    89  			})(context.Background(), shc)
    90  		},
    91  	} {
    92  		t.Run(testName, testCase)
    93  	}
    94  }
    95  
    96  type fakeReq struct {
    97  	connect.AnyRequest
    98  	isClient bool
    99  	headers  http.Header
   100  }
   101  
   102  func newFakeReq(isClient bool) fakeReq {
   103  	return fakeReq{
   104  		isClient:   isClient,
   105  		headers:    http.Header{},
   106  		AnyRequest: connect.NewRequest(&http.Request{}),
   107  	}
   108  }
   109  
   110  func (f fakeReq) Spec() connect.Spec {
   111  	return connect.Spec{
   112  		IsClient: f.isClient,
   113  	}
   114  }
   115  
   116  func (f fakeReq) Header() http.Header {
   117  	return f.headers
   118  }
   119  
   120  type fakeClientStreamingConn struct {
   121  	requestHeaders http.Header
   122  }
   123  
   124  func newFakeClientStreamingConn() fakeClientStreamingConn {
   125  	return fakeClientStreamingConn{
   126  		requestHeaders: http.Header{},
   127  	}
   128  }
   129  
   130  func (fakeClientStreamingConn) Peer() connect.Peer           { return connect.Peer{} }
   131  func (fakeClientStreamingConn) Spec() connect.Spec           { return connect.Spec{} }
   132  func (fakeClientStreamingConn) Send(any) error               { return nil }
   133  func (f fakeClientStreamingConn) RequestHeader() http.Header { return f.requestHeaders }
   134  func (fakeClientStreamingConn) CloseRequest() error          { return nil }
   135  func (fakeClientStreamingConn) Receive(any) error            { return nil }
   136  func (fakeClientStreamingConn) ResponseHeader() http.Header  { return nil }
   137  func (fakeClientStreamingConn) ResponseTrailer() http.Header { return nil }
   138  func (fakeClientStreamingConn) CloseResponse() error         { return nil }