github.com/grafana/pyroscope@v1.18.0/pkg/tenant/interceptor.go (about)

     1  package tenant
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"net/http"
     7  
     8  	"connectrpc.com/connect"
     9  	"github.com/grafana/dskit/tenant"
    10  	"github.com/grafana/dskit/user"
    11  )
    12  
    13  // DefaultTenantID is the default tenant ID used when the interceptor is disabled.
    14  const DefaultTenantID = "anonymous"
    15  
    16  // NewAuthInterceptor create a new tenant authentication interceptor for the server and client.
    17  //
    18  // For the server:
    19  //
    20  // If enabled, the interceptor will check the tenant ID in the request header is present and inject it into the context.
    21  // When the interceptor is disabled, it will inject the default tenant ID into the context.
    22  //
    23  // For the client :
    24  //
    25  // The interceptor will inject the tenant ID from the context into the request header no matter if the interceptor is enabled or not.
    26  func NewAuthInterceptor(enabled bool) connect.Interceptor {
    27  	return &authInterceptor{enabled: enabled}
    28  }
    29  
    30  type authInterceptor struct {
    31  	enabled bool
    32  }
    33  
    34  func (i *authInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
    35  	return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
    36  		// client side we extract the tenantID from the context and inject it into the request header
    37  		if req.Spec().IsClient {
    38  			tenantID, _ := ExtractTenantIDFromContext(ctx)
    39  			if tenantID != "" {
    40  				req.Header().Set("X-Scope-OrgID", tenantID)
    41  			}
    42  			return next(ctx, req)
    43  		}
    44  		// Server side if the interceptor is enabled, we extract the tenantID from the request header and inject it into the context
    45  		// If the interceptor is disabled, we inject the default tenant ID into the context.
    46  		if !i.enabled {
    47  			return next(InjectTenantID(ctx, DefaultTenantID), req)
    48  		}
    49  		_, ctx, _ = ExtractTenantIDFromHeaders(ctx, req.Header())
    50  
    51  		resp, err := next(ctx, req)
    52  		if err != nil && errors.Is(err, ErrNoTenantID) {
    53  			return resp, connect.NewError(connect.CodeUnauthenticated, err)
    54  		}
    55  		return resp, err
    56  	}
    57  }
    58  
    59  func (i *authInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
    60  	return func(ctx context.Context, s connect.Spec) connect.StreamingClientConn {
    61  		conn := next(ctx, s)
    62  		tenantID, _ := ExtractTenantIDFromContext(ctx)
    63  		if tenantID != "" {
    64  			conn.RequestHeader().Set("X-Scope-OrgID", tenantID)
    65  		}
    66  		return conn
    67  	}
    68  }
    69  
    70  func (i *authInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
    71  	return func(ctx context.Context, conn connect.StreamingHandlerConn) error {
    72  		if !i.enabled {
    73  			return next(InjectTenantID(ctx, DefaultTenantID), conn)
    74  		}
    75  		_, ctx, _ = ExtractTenantIDFromHeaders(ctx, conn.RequestHeader())
    76  		if err := next(ctx, conn); err != nil {
    77  			if errors.Is(err, ErrNoTenantID) {
    78  				return connect.NewError(connect.CodeUnauthenticated, err)
    79  			}
    80  			return err
    81  		}
    82  		return nil
    83  	}
    84  }
    85  
    86  var defaultResolver tenant.Resolver = tenant.NewMultiResolver()
    87  
    88  // ExtractTenantIDFromHeaders extracts a single TenantID from http headers.
    89  func ExtractTenantIDFromHeaders(ctx context.Context, headers http.Header) (string, context.Context, error) {
    90  	orgID := headers.Get(user.OrgIDHeaderName)
    91  	if orgID == "" {
    92  		return "", ctx, ErrNoTenantID
    93  	}
    94  	ctx = InjectTenantID(ctx, orgID)
    95  
    96  	tenantID, err := defaultResolver.TenantID(ctx)
    97  	if err != nil {
    98  		return "", ctx, err
    99  	}
   100  
   101  	return tenantID, ctx, nil
   102  }
   103  
   104  // ExtractTenantIDFromContext extracts a single TenantID from the context.
   105  func ExtractTenantIDFromContext(ctx context.Context) (string, error) {
   106  	tenantID, err := defaultResolver.TenantID(ctx)
   107  	if err != nil {
   108  		return "", err
   109  	}
   110  
   111  	return tenantID, nil
   112  }