github.com/grafana/pyroscope@v1.18.0/pkg/tenant/interceptor.go (about) 1 package tenant 2 3 import ( 4 "context" 5 "errors" 6 "net/http" 7 8 "connectrpc.com/connect" 9 "github.com/grafana/dskit/tenant" 10 "github.com/grafana/dskit/user" 11 ) 12 13 // DefaultTenantID is the default tenant ID used when the interceptor is disabled. 14 const DefaultTenantID = "anonymous" 15 16 // NewAuthInterceptor create a new tenant authentication interceptor for the server and client. 17 // 18 // For the server: 19 // 20 // If enabled, the interceptor will check the tenant ID in the request header is present and inject it into the context. 21 // When the interceptor is disabled, it will inject the default tenant ID into the context. 22 // 23 // For the client : 24 // 25 // The interceptor will inject the tenant ID from the context into the request header no matter if the interceptor is enabled or not. 26 func NewAuthInterceptor(enabled bool) connect.Interceptor { 27 return &authInterceptor{enabled: enabled} 28 } 29 30 type authInterceptor struct { 31 enabled bool 32 } 33 34 func (i *authInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { 35 return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { 36 // client side we extract the tenantID from the context and inject it into the request header 37 if req.Spec().IsClient { 38 tenantID, _ := ExtractTenantIDFromContext(ctx) 39 if tenantID != "" { 40 req.Header().Set("X-Scope-OrgID", tenantID) 41 } 42 return next(ctx, req) 43 } 44 // Server side if the interceptor is enabled, we extract the tenantID from the request header and inject it into the context 45 // If the interceptor is disabled, we inject the default tenant ID into the context. 46 if !i.enabled { 47 return next(InjectTenantID(ctx, DefaultTenantID), req) 48 } 49 _, ctx, _ = ExtractTenantIDFromHeaders(ctx, req.Header()) 50 51 resp, err := next(ctx, req) 52 if err != nil && errors.Is(err, ErrNoTenantID) { 53 return resp, connect.NewError(connect.CodeUnauthenticated, err) 54 } 55 return resp, err 56 } 57 } 58 59 func (i *authInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc { 60 return func(ctx context.Context, s connect.Spec) connect.StreamingClientConn { 61 conn := next(ctx, s) 62 tenantID, _ := ExtractTenantIDFromContext(ctx) 63 if tenantID != "" { 64 conn.RequestHeader().Set("X-Scope-OrgID", tenantID) 65 } 66 return conn 67 } 68 } 69 70 func (i *authInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { 71 return func(ctx context.Context, conn connect.StreamingHandlerConn) error { 72 if !i.enabled { 73 return next(InjectTenantID(ctx, DefaultTenantID), conn) 74 } 75 _, ctx, _ = ExtractTenantIDFromHeaders(ctx, conn.RequestHeader()) 76 if err := next(ctx, conn); err != nil { 77 if errors.Is(err, ErrNoTenantID) { 78 return connect.NewError(connect.CodeUnauthenticated, err) 79 } 80 return err 81 } 82 return nil 83 } 84 } 85 86 var defaultResolver tenant.Resolver = tenant.NewMultiResolver() 87 88 // ExtractTenantIDFromHeaders extracts a single TenantID from http headers. 89 func ExtractTenantIDFromHeaders(ctx context.Context, headers http.Header) (string, context.Context, error) { 90 orgID := headers.Get(user.OrgIDHeaderName) 91 if orgID == "" { 92 return "", ctx, ErrNoTenantID 93 } 94 ctx = InjectTenantID(ctx, orgID) 95 96 tenantID, err := defaultResolver.TenantID(ctx) 97 if err != nil { 98 return "", ctx, err 99 } 100 101 return tenantID, ctx, nil 102 } 103 104 // ExtractTenantIDFromContext extracts a single TenantID from the context. 105 func ExtractTenantIDFromContext(ctx context.Context) (string, error) { 106 tenantID, err := defaultResolver.TenantID(ctx) 107 if err != nil { 108 return "", err 109 } 110 111 return tenantID, nil 112 }