go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/grpc/grpcutil/interceptors.go (about) 1 // Copyright 2020 The LUCI Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package grpcutil 16 17 import ( 18 "context" 19 20 "google.golang.org/grpc" 21 ) 22 23 // ChainUnaryServerInterceptors chains multiple unary interceptors together. 24 // 25 // The first one becomes the outermost, and the last one becomes the 26 // innermost, i.e. `ChainUnaryServerInterceptors(a, b, c)(h) === a(b(c(h)))`. 27 // 28 // nil-valued interceptors are silently skipped. 29 func ChainUnaryServerInterceptors(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor { 30 switch { 31 case len(interceptors) == 0: 32 // Noop interceptor. 33 return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { 34 return handler(ctx, req) 35 } 36 case interceptors[0] == nil: 37 // Skip nils. 38 return ChainUnaryServerInterceptors(interceptors[1:]...) 39 case len(interceptors) == 1: 40 // No need to actually chain anything. 41 return interceptors[0] 42 default: 43 return unaryCombinator(interceptors[0], ChainUnaryServerInterceptors(interceptors[1:]...)) 44 } 45 } 46 47 // unaryCombinator is an interceptor that chains just two interceptors together. 48 func unaryCombinator(first, second grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor { 49 return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { 50 return first(ctx, req, info, func(ctx context.Context, req any) (any, error) { 51 return second(ctx, req, info, handler) 52 }) 53 } 54 } 55 56 // ChainStreamServerInterceptors chains multiple stream interceptors together. 57 // 58 // The first one becomes the outermost, and the last one becomes the 59 // innermost, i.e. `ChainStreamServerInterceptors(a, b, c)(h) === a(b(c(h)))`. 60 // 61 // nil-valued interceptors are silently skipped. 62 func ChainStreamServerInterceptors(interceptors ...grpc.StreamServerInterceptor) grpc.StreamServerInterceptor { 63 switch { 64 case len(interceptors) == 0: 65 // Noop interceptor. 66 return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 67 return handler(srv, ss) 68 } 69 case interceptors[0] == nil: 70 // Skip nils. 71 return ChainStreamServerInterceptors(interceptors[1:]...) 72 case len(interceptors) == 1: 73 // No need to actually chain anything. 74 return interceptors[0] 75 default: 76 return streamCombinator(interceptors[0], ChainStreamServerInterceptors(interceptors[1:]...)) 77 } 78 } 79 80 // unaryCombinator is an interceptor that chains just two interceptors together. 81 func streamCombinator(first, second grpc.StreamServerInterceptor) grpc.StreamServerInterceptor { 82 return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 83 return first(srv, ss, info, func(srv any, ss grpc.ServerStream) error { 84 return second(srv, ss, info, handler) 85 }) 86 } 87 } 88 89 // ModifyServerStreamContext returns a ServerStream that fully wraps the given 90 // one except its context is modified based on the result of the given callback. 91 // 92 // This is handy when implementing stream server interceptors that need to 93 // put stuff into the stream's context. 94 // 95 // The callback will be called immediately and only once. It must return a 96 // context derived from the context it receives or nil if the context 97 // modification is not actually necessary. 98 func ModifyServerStreamContext(ss grpc.ServerStream, cb func(context.Context) context.Context) grpc.ServerStream { 99 original := ss.Context() 100 modified := cb(original) 101 if modified == nil || modified == original { 102 return ss 103 } 104 return &wrappedSS{ss, modified} 105 } 106 107 // wrappedSS is a grpc.ServerStream that replaces the context. 108 type wrappedSS struct { 109 grpc.ServerStream 110 ctx context.Context 111 } 112 113 // Context returns the context for this stream. 114 // 115 // This is part of grpc.ServerStream interface. 116 func (ss *wrappedSS) Context() context.Context { 117 return ss.ctx 118 } 119 120 // UnifiedServerInterceptor can be converted into an unary or stream server 121 // interceptor. 122 // 123 // Such interceptor can do something at the start of the request (in particular 124 // modify the request context) and do something with the request error at the 125 // end. It can also skip the request entirely by returning an error without 126 // calling the handler. 127 // 128 // It is handy when implementing simple interceptors that can be used as both 129 // unary and stream ones. 130 type UnifiedServerInterceptor func(ctx context.Context, fullMethod string, handler func(ctx context.Context) error) error 131 132 // Unary returns an unary form of the interceptor. 133 func (u UnifiedServerInterceptor) Unary() grpc.UnaryServerInterceptor { 134 return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { 135 var resp any 136 err := u(ctx, info.FullMethod, func(ctx context.Context) (err error) { 137 resp, err = handler(ctx, req) 138 return err 139 }) 140 if err != nil { 141 resp = nil 142 } 143 return resp, err 144 } 145 } 146 147 // Stream returns a stream form of the interceptor. 148 func (u UnifiedServerInterceptor) Stream() grpc.StreamServerInterceptor { 149 return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 150 original := ss.Context() 151 return u(original, info.FullMethod, func(ctx context.Context) error { 152 var wrapped grpc.ServerStream 153 if ctx != original { 154 wrapped = &wrappedSS{ss, ctx} 155 } else { 156 wrapped = ss 157 } 158 return handler(srv, wrapped) 159 }) 160 } 161 }