github.com/yimialmonte/fabric@v2.1.1+incompatible/common/grpclogging/server.go (about)

     1  /*
     2  Copyright IBM Corp. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package grpclogging
     8  
     9  import (
    10  	"context"
    11  	"strings"
    12  	"time"
    13  
    14  	"go.uber.org/zap"
    15  	"go.uber.org/zap/zapcore"
    16  	"google.golang.org/grpc"
    17  	"google.golang.org/grpc/credentials"
    18  	"google.golang.org/grpc/peer"
    19  	"google.golang.org/grpc/status"
    20  )
    21  
    22  // Leveler returns a zap level to use when logging from a grpc interceptor.
    23  type Leveler interface {
    24  	Level(ctx context.Context, fullMethod string) zapcore.Level
    25  }
    26  
    27  // PayloadLeveler gets the level to use when logging grpc message payloads.
    28  type PayloadLeveler interface {
    29  	PayloadLevel(ctx context.Context, fullMethod string) zapcore.Level
    30  }
    31  
    32  //go:generate counterfeiter -o fakes/leveler.go --fake-name Leveler . LevelerFunc
    33  
    34  type LevelerFunc func(ctx context.Context, fullMethod string) zapcore.Level
    35  
    36  func (l LevelerFunc) Level(ctx context.Context, fullMethod string) zapcore.Level {
    37  	return l(ctx, fullMethod)
    38  }
    39  
    40  func (l LevelerFunc) PayloadLevel(ctx context.Context, fullMethod string) zapcore.Level {
    41  	return l(ctx, fullMethod)
    42  }
    43  
    44  // DefaultPayloadLevel is default level to use when logging payloads
    45  const DefaultPayloadLevel = zapcore.Level(zapcore.DebugLevel - 1)
    46  
    47  type options struct {
    48  	Leveler
    49  	PayloadLeveler
    50  }
    51  
    52  type Option func(o *options)
    53  
    54  func WithLeveler(l Leveler) Option {
    55  	return func(o *options) { o.Leveler = l }
    56  }
    57  
    58  func WithPayloadLeveler(l PayloadLeveler) Option {
    59  	return func(o *options) { o.PayloadLeveler = l }
    60  }
    61  
    62  func applyOptions(opts ...Option) *options {
    63  	o := &options{
    64  		Leveler:        LevelerFunc(func(context.Context, string) zapcore.Level { return zapcore.InfoLevel }),
    65  		PayloadLeveler: LevelerFunc(func(context.Context, string) zapcore.Level { return DefaultPayloadLevel }),
    66  	}
    67  	for _, opt := range opts {
    68  		opt(o)
    69  	}
    70  	return o
    71  }
    72  
    73  // Levelers will be required and should be provided with the full method info
    74  
    75  func UnaryServerInterceptor(logger *zap.Logger, opts ...Option) grpc.UnaryServerInterceptor {
    76  	o := applyOptions(opts...)
    77  
    78  	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
    79  		logger := logger
    80  		startTime := time.Now()
    81  
    82  		fields := getFields(ctx, info.FullMethod)
    83  		logger = logger.With(fields...)
    84  		ctx = WithFields(ctx, fields)
    85  
    86  		payloadLogger := logger.Named("payload")
    87  		payloadLevel := o.PayloadLevel(ctx, info.FullMethod)
    88  		if ce := payloadLogger.Check(payloadLevel, "received unary request"); ce != nil {
    89  			ce.Write(ProtoMessage("message", req))
    90  		}
    91  
    92  		resp, err := handler(ctx, req)
    93  
    94  		if ce := payloadLogger.Check(payloadLevel, "sending unary response"); ce != nil && err == nil {
    95  			ce.Write(ProtoMessage("message", resp))
    96  		}
    97  
    98  		if ce := logger.Check(o.Level(ctx, info.FullMethod), "unary call completed"); ce != nil {
    99  			st, _ := status.FromError(err)
   100  			ce.Write(
   101  				Error(err),
   102  				zap.Stringer("grpc.code", st.Code()),
   103  				zap.Duration("grpc.call_duration", time.Since(startTime)),
   104  			)
   105  		}
   106  
   107  		return resp, err
   108  	}
   109  }
   110  
   111  func StreamServerInterceptor(logger *zap.Logger, opts ...Option) grpc.StreamServerInterceptor {
   112  	o := applyOptions(opts...)
   113  
   114  	return func(service interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
   115  		logger := logger
   116  		ctx := stream.Context()
   117  		startTime := time.Now()
   118  
   119  		fields := getFields(ctx, info.FullMethod)
   120  		logger = logger.With(fields...)
   121  		ctx = WithFields(ctx, fields)
   122  
   123  		wrappedStream := &serverStream{
   124  			ServerStream:  stream,
   125  			context:       ctx,
   126  			payloadLogger: logger.Named("payload"),
   127  			payloadLevel:  o.PayloadLevel(ctx, info.FullMethod),
   128  		}
   129  
   130  		err := handler(service, wrappedStream)
   131  		if ce := logger.Check(o.Level(ctx, info.FullMethod), "streaming call completed"); ce != nil {
   132  			st, _ := status.FromError(err)
   133  			ce.Write(
   134  				Error(err),
   135  				zap.Stringer("grpc.code", st.Code()),
   136  				zap.Duration("grpc.call_duration", time.Since(startTime)),
   137  			)
   138  		}
   139  		return err
   140  	}
   141  }
   142  
   143  func getFields(ctx context.Context, method string) []zapcore.Field {
   144  	var fields []zap.Field
   145  	if parts := strings.Split(method, "/"); len(parts) == 3 {
   146  		fields = append(fields, zap.String("grpc.service", parts[1]), zap.String("grpc.method", parts[2]))
   147  	}
   148  	if deadline, ok := ctx.Deadline(); ok {
   149  		fields = append(fields, zap.Time("grpc.request_deadline", deadline))
   150  	}
   151  	if p, ok := peer.FromContext(ctx); ok {
   152  		fields = append(fields, zap.String("grpc.peer_address", p.Addr.String()))
   153  		if ti, ok := p.AuthInfo.(credentials.TLSInfo); ok {
   154  			if len(ti.State.PeerCertificates) > 0 {
   155  				cert := ti.State.PeerCertificates[0]
   156  				fields = append(fields, zap.String("grpc.peer_subject", cert.Subject.String()))
   157  			}
   158  		}
   159  	}
   160  	return fields
   161  }
   162  
   163  type serverStream struct {
   164  	grpc.ServerStream
   165  	context       context.Context
   166  	payloadLogger *zap.Logger
   167  	payloadLevel  zapcore.Level
   168  }
   169  
   170  func (ss *serverStream) Context() context.Context {
   171  	return ss.context
   172  }
   173  
   174  func (ss *serverStream) SendMsg(msg interface{}) error {
   175  	if ce := ss.payloadLogger.Check(ss.payloadLevel, "sending stream message"); ce != nil {
   176  		ce.Write(ProtoMessage("message", msg))
   177  	}
   178  	return ss.ServerStream.SendMsg(msg)
   179  }
   180  
   181  func (ss *serverStream) RecvMsg(msg interface{}) error {
   182  	err := ss.ServerStream.RecvMsg(msg)
   183  	if ce := ss.payloadLogger.Check(ss.payloadLevel, "received stream message"); ce != nil {
   184  		ce.Write(ProtoMessage("message", msg))
   185  	}
   186  	return err
   187  }