go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/grpc/grpcutil/interceptors.go (about)

     1  // Copyright 2020 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package grpcutil
    16  
    17  import (
    18  	"context"
    19  
    20  	"google.golang.org/grpc"
    21  )
    22  
    23  // ChainUnaryServerInterceptors chains multiple unary interceptors together.
    24  //
    25  // The first one becomes the outermost, and the last one becomes the
    26  // innermost, i.e. `ChainUnaryServerInterceptors(a, b, c)(h) === a(b(c(h)))`.
    27  //
    28  // nil-valued interceptors are silently skipped.
    29  func ChainUnaryServerInterceptors(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
    30  	switch {
    31  	case len(interceptors) == 0:
    32  		// Noop interceptor.
    33  		return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
    34  			return handler(ctx, req)
    35  		}
    36  	case interceptors[0] == nil:
    37  		// Skip nils.
    38  		return ChainUnaryServerInterceptors(interceptors[1:]...)
    39  	case len(interceptors) == 1:
    40  		// No need to actually chain anything.
    41  		return interceptors[0]
    42  	default:
    43  		return unaryCombinator(interceptors[0], ChainUnaryServerInterceptors(interceptors[1:]...))
    44  	}
    45  }
    46  
    47  // unaryCombinator is an interceptor that chains just two interceptors together.
    48  func unaryCombinator(first, second grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
    49  	return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
    50  		return first(ctx, req, info, func(ctx context.Context, req any) (any, error) {
    51  			return second(ctx, req, info, handler)
    52  		})
    53  	}
    54  }
    55  
    56  // ChainStreamServerInterceptors chains multiple stream interceptors together.
    57  //
    58  // The first one becomes the outermost, and the last one becomes the
    59  // innermost, i.e. `ChainStreamServerInterceptors(a, b, c)(h) === a(b(c(h)))`.
    60  //
    61  // nil-valued interceptors are silently skipped.
    62  func ChainStreamServerInterceptors(interceptors ...grpc.StreamServerInterceptor) grpc.StreamServerInterceptor {
    63  	switch {
    64  	case len(interceptors) == 0:
    65  		// Noop interceptor.
    66  		return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
    67  			return handler(srv, ss)
    68  		}
    69  	case interceptors[0] == nil:
    70  		// Skip nils.
    71  		return ChainStreamServerInterceptors(interceptors[1:]...)
    72  	case len(interceptors) == 1:
    73  		// No need to actually chain anything.
    74  		return interceptors[0]
    75  	default:
    76  		return streamCombinator(interceptors[0], ChainStreamServerInterceptors(interceptors[1:]...))
    77  	}
    78  }
    79  
    80  // unaryCombinator is an interceptor that chains just two interceptors together.
    81  func streamCombinator(first, second grpc.StreamServerInterceptor) grpc.StreamServerInterceptor {
    82  	return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
    83  		return first(srv, ss, info, func(srv any, ss grpc.ServerStream) error {
    84  			return second(srv, ss, info, handler)
    85  		})
    86  	}
    87  }
    88  
    89  // ModifyServerStreamContext returns a ServerStream that fully wraps the given
    90  // one except its context is modified based on the result of the given callback.
    91  //
    92  // This is handy when implementing stream server interceptors that need to
    93  // put stuff into the stream's context.
    94  //
    95  // The callback will be called immediately and only once. It must return a
    96  // context derived from the context it receives or nil if the context
    97  // modification is not actually necessary.
    98  func ModifyServerStreamContext(ss grpc.ServerStream, cb func(context.Context) context.Context) grpc.ServerStream {
    99  	original := ss.Context()
   100  	modified := cb(original)
   101  	if modified == nil || modified == original {
   102  		return ss
   103  	}
   104  	return &wrappedSS{ss, modified}
   105  }
   106  
   107  // wrappedSS is a grpc.ServerStream that replaces the context.
   108  type wrappedSS struct {
   109  	grpc.ServerStream
   110  	ctx context.Context
   111  }
   112  
   113  // Context returns the context for this stream.
   114  //
   115  // This is part of grpc.ServerStream interface.
   116  func (ss *wrappedSS) Context() context.Context {
   117  	return ss.ctx
   118  }
   119  
   120  // UnifiedServerInterceptor can be converted into an unary or stream server
   121  // interceptor.
   122  //
   123  // Such interceptor can do something at the start of the request (in particular
   124  // modify the request context) and do something with the request error at the
   125  // end. It can also skip the request entirely by returning an error without
   126  // calling the handler.
   127  //
   128  // It is handy when implementing simple interceptors that can be used as both
   129  // unary and stream ones.
   130  type UnifiedServerInterceptor func(ctx context.Context, fullMethod string, handler func(ctx context.Context) error) error
   131  
   132  // Unary returns an unary form of the interceptor.
   133  func (u UnifiedServerInterceptor) Unary() grpc.UnaryServerInterceptor {
   134  	return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
   135  		var resp any
   136  		err := u(ctx, info.FullMethod, func(ctx context.Context) (err error) {
   137  			resp, err = handler(ctx, req)
   138  			return err
   139  		})
   140  		if err != nil {
   141  			resp = nil
   142  		}
   143  		return resp, err
   144  	}
   145  }
   146  
   147  // Stream returns a stream form of the interceptor.
   148  func (u UnifiedServerInterceptor) Stream() grpc.StreamServerInterceptor {
   149  	return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
   150  		original := ss.Context()
   151  		return u(original, info.FullMethod, func(ctx context.Context) error {
   152  			var wrapped grpc.ServerStream
   153  			if ctx != original {
   154  				wrapped = &wrappedSS{ss, ctx}
   155  			} else {
   156  				wrapped = ss
   157  			}
   158  			return handler(srv, wrapped)
   159  		})
   160  	}
   161  }