github.com/argoproj/argo-cd/v3@v3.2.1/util/grpc/logging.go (about)

     1  package grpc
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/json"
     7  	"fmt"
     8  	"time"
     9  
    10  	"github.com/gogo/protobuf/jsonpb"
    11  	"github.com/gogo/protobuf/proto"
    12  	"github.com/golang-jwt/jwt/v5"
    13  	"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors"
    14  	"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/logging"
    15  	"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/selector"
    16  	"github.com/sirupsen/logrus"
    17  	"google.golang.org/grpc"
    18  )
    19  
    20  func logRequest(ctx context.Context, entry *logrus.Entry, info string, pbMsg any, logClaims bool) {
    21  	if logClaims {
    22  		claims := ctx.Value("claims")
    23  		mapClaims, ok := claims.(jwt.MapClaims)
    24  		if ok {
    25  			claimsCopy := make(map[string]any)
    26  			for k, v := range mapClaims {
    27  				if k != "groups" || entry.Logger.IsLevelEnabled(logrus.DebugLevel) {
    28  					claimsCopy[k] = v
    29  				}
    30  			}
    31  			if data, err := json.Marshal(claimsCopy); err == nil {
    32  				entry = entry.WithField("grpc.request.claims", string(data))
    33  			}
    34  		}
    35  	}
    36  	if p, ok := pbMsg.(proto.Message); ok {
    37  		entry = entry.WithField("grpc.request.content", &jsonpbMarshalleble{p})
    38  	}
    39  	entry.Info(info)
    40  }
    41  
    42  type jsonpbMarshalleble struct {
    43  	proto.Message
    44  }
    45  
    46  func (j *jsonpbMarshalleble) MarshalJSON() ([]byte, error) {
    47  	var b bytes.Buffer
    48  	m := &jsonpb.Marshaler{}
    49  	err := m.Marshal(&b, j.Message)
    50  	if err != nil {
    51  		return nil, fmt.Errorf("jsonpb serializer failed: %w", err)
    52  	}
    53  	return b.Bytes(), nil
    54  }
    55  
    56  type reporter struct {
    57  	ctx       context.Context
    58  	entry     *logrus.Entry
    59  	logClaims bool
    60  	info      string
    61  }
    62  
    63  func (r *reporter) PostCall(_ error, _ time.Duration) {}
    64  
    65  func (r *reporter) PostMsgSend(_ any, _ error, _ time.Duration) {}
    66  
    67  func (r *reporter) PostMsgReceive(payload any, err error, _ time.Duration) {
    68  	if err == nil {
    69  		logRequest(r.ctx, r.entry, r.info, payload, r.logClaims)
    70  	}
    71  }
    72  
    73  func PayloadStreamServerInterceptor(entry *logrus.Entry, logClaims bool, decider func(context.Context, interceptors.CallMeta) bool) grpc.StreamServerInterceptor {
    74  	return selector.StreamServerInterceptor(interceptors.StreamServerInterceptor(reportable(entry, "streaming", logClaims)), selector.MatchFunc(decider))
    75  }
    76  
    77  func PayloadUnaryServerInterceptor(entry *logrus.Entry, logClaims bool, decider func(context.Context, interceptors.CallMeta) bool) grpc.UnaryServerInterceptor {
    78  	return selector.UnaryServerInterceptor(interceptors.UnaryServerInterceptor(reportable(entry, "unary", logClaims)), selector.MatchFunc(decider))
    79  }
    80  
    81  func reportable(entry *logrus.Entry, callType string, logClaims bool) interceptors.CommonReportableFunc {
    82  	return func(ctx context.Context, c interceptors.CallMeta) (interceptors.Reporter, context.Context) {
    83  		return &reporter{
    84  			ctx:       ctx,
    85  			entry:     entry,
    86  			info:      fmt.Sprintf("received %s call %s", callType, c.FullMethod()),
    87  			logClaims: logClaims,
    88  		}, ctx
    89  	}
    90  }
    91  
    92  // InterceptorLogger adapts logrus logger to interceptor logger.
    93  func InterceptorLogger(l logrus.FieldLogger) logging.Logger {
    94  	return logging.LoggerFunc(func(_ context.Context, lvl logging.Level, msg string, fields ...any) {
    95  		f := make(map[string]any, len(fields)/2)
    96  		i := logging.Fields(fields).Iterator()
    97  		for i.Next() {
    98  			k, v := i.At()
    99  			f[k] = v
   100  		}
   101  		l := l.WithFields(f)
   102  
   103  		switch lvl {
   104  		case logging.LevelDebug:
   105  			l.Debug(msg)
   106  		case logging.LevelInfo:
   107  			l.Info(msg)
   108  		case logging.LevelWarn:
   109  			l.Warn(msg)
   110  		case logging.LevelError:
   111  			l.Error(msg)
   112  		default:
   113  			panic(fmt.Sprintf("unknown level %v", lvl))
   114  		}
   115  	})
   116  }