github.com/grafana/pyroscope@v1.18.0/pkg/tenant/interceptor_test.go (about) 1 package tenant 2 3 import ( 4 "context" 5 "net/http" 6 "testing" 7 8 "connectrpc.com/connect" 9 "github.com/stretchr/testify/require" 10 ) 11 12 func Test_AuthInterceptor(t *testing.T) { 13 for testName, testCase := range map[string]func(t *testing.T){ 14 "client: forward from context": func(t *testing.T) { 15 i := NewAuthInterceptor(false) 16 resp, err := i.WrapUnary(func(ctx context.Context, ar connect.AnyRequest) (connect.AnyResponse, error) { 17 tenantID, _, err := ExtractTenantIDFromHeaders(context.Background(), ar.Header()) 18 require.NoError(t, err) 19 require.Equal(t, tenantID, "foo") 20 return nil, nil 21 })(InjectTenantID(context.Background(), "foo"), newFakeReq(true)) 22 require.NoError(t, err) 23 require.Nil(t, resp) 24 }, 25 "client: no org forwarded": func(t *testing.T) { 26 i := NewAuthInterceptor(false) 27 resp, err := i.WrapUnary(func(ctx context.Context, ar connect.AnyRequest) (connect.AnyResponse, error) { 28 tenantID, _, err := ExtractTenantIDFromHeaders(context.Background(), ar.Header()) 29 require.Equal(t, ErrNoTenantID, err) 30 require.Equal(t, tenantID, "") 31 return nil, nil 32 })(context.Background(), newFakeReq(true)) 33 require.NoError(t, err) 34 require.Nil(t, resp) 35 }, 36 "server: disable, static org": func(t *testing.T) { 37 i := NewAuthInterceptor(false) 38 req := newFakeReq(false) 39 req.Header().Set("X-Scope-OrgID", "foo") 40 resp, err := i.WrapUnary(func(ctx context.Context, ar connect.AnyRequest) (connect.AnyResponse, error) { 41 tenantID, err := ExtractTenantIDFromContext(ctx) 42 require.NoError(t, err) 43 require.Equal(t, tenantID, DefaultTenantID) 44 return nil, nil 45 })(context.Background(), req) 46 require.NoError(t, err) 47 require.Nil(t, resp) 48 }, 49 "server: enable, forward header": func(t *testing.T) { 50 i := NewAuthInterceptor(true) 51 req := newFakeReq(false) 52 req.Header().Set("X-Scope-OrgID", "foo") 53 resp, err := i.WrapUnary(func(ctx context.Context, ar connect.AnyRequest) (connect.AnyResponse, error) { 54 tenantID, err := ExtractTenantIDFromContext(ctx) 55 require.NoError(t, err) 56 require.Equal(t, tenantID, "foo") 57 return nil, nil 58 })(context.Background(), req) 59 require.NoError(t, err) 60 require.Nil(t, resp) 61 }, 62 "streaming client should forward from context": func(t *testing.T) { 63 i := NewAuthInterceptor(false) 64 inConn := newFakeClientStreamingConn() 65 outConn := i.WrapStreamingClient(func(ctx context.Context, s connect.Spec) connect.StreamingClientConn { 66 return inConn 67 })(InjectTenantID(context.Background(), "foo"), connect.Spec{}) 68 require.Equal(t, "foo", outConn.RequestHeader().Get("X-Scope-OrgID")) 69 }, 70 "streaming server should forward from header to context if enabled": func(t *testing.T) { 71 i := NewAuthInterceptor(true) 72 shc := newFakeClientStreamingConn() 73 shc.requestHeaders.Set("X-Scope-OrgID", "foo") 74 _ = i.WrapStreamingHandler(func(ctx context.Context, shc connect.StreamingHandlerConn) error { 75 tenantID, err := ExtractTenantIDFromContext(ctx) 76 require.NoError(t, err) 77 require.Equal(t, tenantID, "foo") 78 return nil 79 })(context.Background(), shc) 80 }, 81 "streaming server should forward default tenant to context if disable": func(t *testing.T) { 82 i := NewAuthInterceptor(false) 83 shc := newFakeClientStreamingConn() 84 _ = i.WrapStreamingHandler(func(ctx context.Context, shc connect.StreamingHandlerConn) error { 85 tenantID, err := ExtractTenantIDFromContext(ctx) 86 require.NoError(t, err) 87 require.Equal(t, tenantID, DefaultTenantID) 88 return nil 89 })(context.Background(), shc) 90 }, 91 } { 92 t.Run(testName, testCase) 93 } 94 } 95 96 type fakeReq struct { 97 connect.AnyRequest 98 isClient bool 99 headers http.Header 100 } 101 102 func newFakeReq(isClient bool) fakeReq { 103 return fakeReq{ 104 isClient: isClient, 105 headers: http.Header{}, 106 AnyRequest: connect.NewRequest(&http.Request{}), 107 } 108 } 109 110 func (f fakeReq) Spec() connect.Spec { 111 return connect.Spec{ 112 IsClient: f.isClient, 113 } 114 } 115 116 func (f fakeReq) Header() http.Header { 117 return f.headers 118 } 119 120 type fakeClientStreamingConn struct { 121 requestHeaders http.Header 122 } 123 124 func newFakeClientStreamingConn() fakeClientStreamingConn { 125 return fakeClientStreamingConn{ 126 requestHeaders: http.Header{}, 127 } 128 } 129 130 func (fakeClientStreamingConn) Peer() connect.Peer { return connect.Peer{} } 131 func (fakeClientStreamingConn) Spec() connect.Spec { return connect.Spec{} } 132 func (fakeClientStreamingConn) Send(any) error { return nil } 133 func (f fakeClientStreamingConn) RequestHeader() http.Header { return f.requestHeaders } 134 func (fakeClientStreamingConn) CloseRequest() error { return nil } 135 func (fakeClientStreamingConn) Receive(any) error { return nil } 136 func (fakeClientStreamingConn) ResponseHeader() http.Header { return nil } 137 func (fakeClientStreamingConn) ResponseTrailer() http.Header { return nil } 138 func (fakeClientStreamingConn) CloseResponse() error { return nil }