github.com/authzed/spicedb@v1.32.1-0.20240520085336-ebda56537386/pkg/middleware/requestid/requestid.go (about) 1 package requestid 2 3 import ( 4 "context" 5 6 log "github.com/authzed/spicedb/internal/logging" 7 8 "github.com/authzed/authzed-go/pkg/requestmeta" 9 "github.com/authzed/authzed-go/pkg/responsemeta" 10 "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors" 11 "github.com/rs/xid" 12 "google.golang.org/grpc" 13 "google.golang.org/grpc/metadata" 14 ) 15 16 const metadataKey = string(requestmeta.RequestIDKey) 17 18 // Option instances control how the middleware is initialized. 19 type Option func(*handleRequestID) 20 21 // GenerateIfMissing will instruct the middleware to create a request ID if one 22 // isn't already on the incoming request. 23 // 24 // default: false 25 func GenerateIfMissing(enable bool) Option { 26 return func(reporter *handleRequestID) { 27 reporter.generateIfMissing = enable 28 } 29 } 30 31 // IDGenerator functions are used to generate request IDs if a new one is needed. 32 type IDGenerator func() string 33 34 // GenerateRequestID generates a new request ID. 35 func GenerateRequestID() string { 36 return xid.New().String() 37 } 38 39 type handleRequestID struct { 40 generateIfMissing bool 41 requestIDGenerator IDGenerator 42 } 43 44 func (r *handleRequestID) ClientReporter(ctx context.Context, meta interceptors.CallMeta) (interceptors.Reporter, context.Context) { 45 haveRequestID, requestID, ctx := r.fromContextOrGenerate(ctx) 46 47 if haveRequestID { 48 ctx = requestmeta.SetRequestHeaders(ctx, map[requestmeta.RequestMetadataHeaderKey]string{ 49 requestmeta.RequestIDKey: requestID, 50 }) 51 } 52 53 return interceptors.NoopReporter{}, ctx 54 } 55 56 func (r *handleRequestID) ServerReporter(ctx context.Context, _ interceptors.CallMeta) (interceptors.Reporter, context.Context) { 57 haveRequestID, requestID, ctx := r.fromContextOrGenerate(ctx) 58 59 if haveRequestID { 60 err := responsemeta.SetResponseHeaderMetadata(ctx, map[responsemeta.ResponseMetadataHeaderKey]string{ 61 responsemeta.RequestID: requestID, 62 }) 63 // if context is cancelled, the stream will be closed, and gRPC will return ErrIllegalHeaderWrite 64 // this prevents logging unnecessary error messages 65 if ctx.Err() != nil { 66 return interceptors.NoopReporter{}, ctx 67 } 68 if err != nil { 69 log.Ctx(ctx).Warn().Err(err).Msg("requestid: could not report metadata") 70 } 71 } 72 73 return interceptors.NoopReporter{}, ctx 74 } 75 76 func (r *handleRequestID) fromContextOrGenerate(ctx context.Context) (bool, string, context.Context) { 77 haveRequestID, requestID, md := fromContext(ctx) 78 79 if !haveRequestID && r.generateIfMissing { 80 requestID = r.requestIDGenerator() 81 haveRequestID = true 82 83 // Inject the newly generated request ID into the metadata 84 if md == nil { 85 md = metadata.New(nil) 86 } 87 88 md.Set(metadataKey, requestID) 89 ctx = metadata.NewIncomingContext(ctx, md) 90 } 91 92 return haveRequestID, requestID, ctx 93 } 94 95 func fromContext(ctx context.Context) (bool, string, metadata.MD) { 96 var requestID string 97 var haveRequestID bool 98 md, ok := metadata.FromIncomingContext(ctx) 99 if ok { 100 var requestIDs []string 101 requestIDs, haveRequestID = md[metadataKey] 102 if haveRequestID { 103 requestID = requestIDs[0] 104 } 105 } 106 107 return haveRequestID, requestID, md 108 } 109 110 // PropagateIfExists copies the request ID from the source context to the target context if it exists. 111 // The updated target context is returned. 112 func PropagateIfExists(source, target context.Context) context.Context { 113 exists, requestID, _ := fromContext(source) 114 115 if exists { 116 targetMD, _ := metadata.FromIncomingContext(target) 117 if targetMD == nil { 118 targetMD = metadata.New(nil) 119 } 120 121 targetMD.Set(metadataKey, requestID) 122 return metadata.NewIncomingContext(target, targetMD) 123 } 124 125 return target 126 } 127 128 // UnaryServerInterceptor returns a new interceptor which handles server request IDs according 129 // to the provided options. 130 func UnaryServerInterceptor(opts ...Option) grpc.UnaryServerInterceptor { 131 return interceptors.UnaryServerInterceptor(createReporter(opts)) 132 } 133 134 // StreamServerInterceptor returns a new interceptor which handles server request IDs according 135 // to the provided options. 136 func StreamServerInterceptor(opts ...Option) grpc.StreamServerInterceptor { 137 return interceptors.StreamServerInterceptor(createReporter(opts)) 138 } 139 140 // UnaryClientInterceptor returns a new interceptor which handles client request IDs according 141 // to the provided options. 142 func UnaryClientInterceptor(opts ...Option) grpc.UnaryClientInterceptor { 143 return interceptors.UnaryClientInterceptor(createReporter(opts)) 144 } 145 146 // StreamClientInterceptor returns a new interceptor which handles client requestIDs according 147 // to the provided options. 148 func StreamClientInterceptor(opts ...Option) grpc.StreamClientInterceptor { 149 return interceptors.StreamClientInterceptor(createReporter(opts)) 150 } 151 152 func createReporter(opts []Option) *handleRequestID { 153 reporter := &handleRequestID{ 154 requestIDGenerator: GenerateRequestID, 155 } 156 157 for _, opt := range opts { 158 opt(reporter) 159 } 160 161 return reporter 162 }