github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/middleware/requestid/requestid.go (about)

     1  package requestid
     2  
     3  import (
     4  	"context"
     5  
     6  	log "github.com/authzed/spicedb/internal/logging"
     7  
     8  	"github.com/authzed/authzed-go/pkg/requestmeta"
     9  	"github.com/authzed/authzed-go/pkg/responsemeta"
    10  	"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors"
    11  	"github.com/rs/xid"
    12  	"google.golang.org/grpc"
    13  	"google.golang.org/grpc/metadata"
    14  )
    15  
    16  const metadataKey = string(requestmeta.RequestIDKey)
    17  
    18  // Option instances control how the middleware is initialized.
    19  type Option func(*handleRequestID)
    20  
    21  // GenerateIfMissing will instruct the middleware to create a request ID if one
    22  // isn't already on the incoming request.
    23  //
    24  // default: false
    25  func GenerateIfMissing(enable bool) Option {
    26  	return func(reporter *handleRequestID) {
    27  		reporter.generateIfMissing = enable
    28  	}
    29  }
    30  
    31  // IDGenerator functions are used to generate request IDs if a new one is needed.
    32  type IDGenerator func() string
    33  
    34  // GenerateRequestID generates a new request ID.
    35  func GenerateRequestID() string {
    36  	return xid.New().String()
    37  }
    38  
    39  type handleRequestID struct {
    40  	generateIfMissing  bool
    41  	requestIDGenerator IDGenerator
    42  }
    43  
    44  func (r *handleRequestID) ClientReporter(ctx context.Context, meta interceptors.CallMeta) (interceptors.Reporter, context.Context) {
    45  	haveRequestID, requestID, ctx := r.fromContextOrGenerate(ctx)
    46  
    47  	if haveRequestID {
    48  		ctx = requestmeta.SetRequestHeaders(ctx, map[requestmeta.RequestMetadataHeaderKey]string{
    49  			requestmeta.RequestIDKey: requestID,
    50  		})
    51  	}
    52  
    53  	return interceptors.NoopReporter{}, ctx
    54  }
    55  
    56  func (r *handleRequestID) ServerReporter(ctx context.Context, _ interceptors.CallMeta) (interceptors.Reporter, context.Context) {
    57  	haveRequestID, requestID, ctx := r.fromContextOrGenerate(ctx)
    58  
    59  	if haveRequestID {
    60  		err := responsemeta.SetResponseHeaderMetadata(ctx, map[responsemeta.ResponseMetadataHeaderKey]string{
    61  			responsemeta.RequestID: requestID,
    62  		})
    63  		// if context is cancelled, the stream will be closed, and gRPC will return ErrIllegalHeaderWrite
    64  		// this prevents logging unnecessary error messages
    65  		if ctx.Err() != nil {
    66  			return interceptors.NoopReporter{}, ctx
    67  		}
    68  		if err != nil {
    69  			log.Ctx(ctx).Warn().Err(err).Msg("requestid: could not report metadata")
    70  		}
    71  	}
    72  
    73  	return interceptors.NoopReporter{}, ctx
    74  }
    75  
    76  func (r *handleRequestID) fromContextOrGenerate(ctx context.Context) (bool, string, context.Context) {
    77  	haveRequestID, requestID, md := fromContext(ctx)
    78  
    79  	if !haveRequestID && r.generateIfMissing {
    80  		requestID = r.requestIDGenerator()
    81  		haveRequestID = true
    82  
    83  		// Inject the newly generated request ID into the metadata
    84  		if md == nil {
    85  			md = metadata.New(nil)
    86  		}
    87  
    88  		md.Set(metadataKey, requestID)
    89  		ctx = metadata.NewIncomingContext(ctx, md)
    90  	}
    91  
    92  	return haveRequestID, requestID, ctx
    93  }
    94  
    95  func fromContext(ctx context.Context) (bool, string, metadata.MD) {
    96  	var requestID string
    97  	var haveRequestID bool
    98  	md, ok := metadata.FromIncomingContext(ctx)
    99  	if ok {
   100  		var requestIDs []string
   101  		requestIDs, haveRequestID = md[metadataKey]
   102  		if haveRequestID {
   103  			requestID = requestIDs[0]
   104  		}
   105  	}
   106  
   107  	return haveRequestID, requestID, md
   108  }
   109  
   110  // PropagateIfExists copies the request ID from the source context to the target context if it exists.
   111  // The updated target context is returned.
   112  func PropagateIfExists(source, target context.Context) context.Context {
   113  	exists, requestID, _ := fromContext(source)
   114  
   115  	if exists {
   116  		targetMD, _ := metadata.FromIncomingContext(target)
   117  		if targetMD == nil {
   118  			targetMD = metadata.New(nil)
   119  		}
   120  
   121  		targetMD.Set(metadataKey, requestID)
   122  		return metadata.NewIncomingContext(target, targetMD)
   123  	}
   124  
   125  	return target
   126  }
   127  
   128  // UnaryServerInterceptor returns a new interceptor which handles server request IDs according
   129  // to the provided options.
   130  func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor {
   131  	return interceptors.UnaryServerInterceptor(createReporter(opts))
   132  }
   133  
   134  // StreamServerInterceptor returns a new interceptor which handles server request IDs according
   135  // to the provided options.
   136  func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor {
   137  	return interceptors.StreamServerInterceptor(createReporter(opts))
   138  }
   139  
   140  // UnaryClientInterceptor returns a new interceptor which handles client request IDs according
   141  // to the provided options.
   142  func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor {
   143  	return interceptors.UnaryClientInterceptor(createReporter(opts))
   144  }
   145  
   146  // StreamClientInterceptor returns a new interceptor which handles client requestIDs according
   147  // to the provided options.
   148  func StreamClientInterceptor(opts ...Option) grpc.StreamClientInterceptor {
   149  	return interceptors.StreamClientInterceptor(createReporter(opts))
   150  }
   151  
   152  func createReporter(opts []Option) *handleRequestID {
   153  	reporter := &handleRequestID{
   154  		requestIDGenerator: GenerateRequestID,
   155  	}
   156  
   157  	for _, opt := range opts {
   158  		opt(reporter)
   159  	}
   160  
   161  	return reporter
   162  }