github.com/openfga/openfga@v1.5.4-rc1/pkg/middleware/storeid/storeid.go (about)

     1  package storeid
     2  
     3  import (
     4  	"context"
     5  	"time"
     6  
     7  	grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags"
     8  	"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors"
     9  	"go.opentelemetry.io/otel/attribute"
    10  	"go.opentelemetry.io/otel/trace"
    11  	"google.golang.org/grpc"
    12  	"google.golang.org/grpc/metadata"
    13  )
    14  
    15  type ctxKey string
    16  
    17  const (
    18  	storeIDCtxKey ctxKey = "store-id-context-key"
    19  	storeIDKey    string = "store_id"
    20  
    21  	// StoreIDHeader represents the HTTP header name used to
    22  	// specify the OpenFGA store identifier in API requests.
    23  	StoreIDHeader string = "Openfga-Store-Id"
    24  )
    25  
    26  type storeidHandle struct {
    27  	storeid string
    28  }
    29  
    30  // StoreIDFromContext retrieves the store ID stored in the provided context.
    31  func StoreIDFromContext(ctx context.Context) (string, bool) {
    32  	if c := ctx.Value(storeIDCtxKey); c != nil {
    33  		handle := c.(*storeidHandle)
    34  		return handle.storeid, true
    35  	}
    36  
    37  	return "", false
    38  }
    39  
    40  func contextWithHandle(ctx context.Context) context.Context {
    41  	return context.WithValue(ctx, storeIDCtxKey, &storeidHandle{})
    42  }
    43  
    44  // SetStoreIDInContext sets the store ID in the provided context based on information from the request.
    45  func SetStoreIDInContext(ctx context.Context, req interface{}) {
    46  	handle := ctx.Value(storeIDCtxKey)
    47  	if handle == nil {
    48  		return
    49  	}
    50  
    51  	if r, ok := req.(hasGetStoreID); ok {
    52  		handle.(*storeidHandle).storeid = r.GetStoreId()
    53  	}
    54  }
    55  
    56  type hasGetStoreID interface {
    57  	GetStoreId() string
    58  }
    59  
    60  // NewUnaryInterceptor creates a grpc.UnaryServerInterceptor which injects
    61  // store_id metadata into the RPC context if an RPC message is received with
    62  // a GetStoreId method.
    63  func NewUnaryInterceptor() grpc.UnaryServerInterceptor {
    64  	return interceptors.UnaryServerInterceptor(reportable())
    65  }
    66  
    67  // NewStreamingInterceptor creates a grpc.StreamServerInterceptor which injects
    68  // store_id metadata into the RPC context if an RPC message is received with a
    69  // GetStoreId method.
    70  func NewStreamingInterceptor() grpc.StreamServerInterceptor {
    71  	return interceptors.StreamServerInterceptor(reportable())
    72  }
    73  
    74  type reporter struct {
    75  	ctx context.Context
    76  }
    77  
    78  // PostCall is a placeholder for handling actions after a gRPC call.
    79  func (r *reporter) PostCall(error, time.Duration) {}
    80  
    81  // PostMsgSend is a placeholder for handling actions after sending a message in streaming requests.
    82  func (r *reporter) PostMsgSend(interface{}, error, time.Duration) {}
    83  
    84  // PostMsgReceive is invoked after receiving a message in streaming requests.
    85  func (r *reporter) PostMsgReceive(msg interface{}, _ error, _ time.Duration) {
    86  	if m, ok := msg.(hasGetStoreID); ok {
    87  		storeID := m.GetStoreId()
    88  
    89  		SetStoreIDInContext(r.ctx, msg)
    90  		trace.SpanFromContext(r.ctx).SetAttributes(attribute.String(storeIDKey, storeID))
    91  
    92  		grpc_ctxtags.Extract(r.ctx).Set(storeIDKey, storeID)
    93  
    94  		_ = grpc.SetHeader(r.ctx, metadata.Pairs(StoreIDHeader, storeID))
    95  	}
    96  }
    97  
    98  func reportable() interceptors.CommonReportableFunc {
    99  	return func(ctx context.Context, c interceptors.CallMeta) (interceptors.Reporter, context.Context) {
   100  		ctx = contextWithHandle(ctx)
   101  
   102  		r := reporter{ctx}
   103  		return &r, r.ctx
   104  	}
   105  }