github.com/weaveworks/common@v0.0.0-20230728070032-dd9e68f319d5/middleware/grpc_instrumentation.go (about)

     1  package middleware
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"strconv"
     7  	"time"
     8  
     9  	"github.com/prometheus/client_golang/prometheus"
    10  	grpcUtils "github.com/weaveworks/common/grpc"
    11  	"github.com/weaveworks/common/httpgrpc"
    12  	"github.com/weaveworks/common/instrument"
    13  	"google.golang.org/grpc"
    14  	"google.golang.org/grpc/metadata"
    15  )
    16  
    17  func observe(ctx context.Context, hist *prometheus.HistogramVec, method string, err error, duration time.Duration) {
    18  	respStatus := "success"
    19  	if err != nil {
    20  		if errResp, ok := httpgrpc.HTTPResponseFromError(err); ok {
    21  			respStatus = strconv.Itoa(int(errResp.Code))
    22  		} else if grpcUtils.IsCanceled(err) {
    23  			respStatus = "cancel"
    24  		} else {
    25  			respStatus = "error"
    26  		}
    27  	}
    28  	instrument.ObserveWithExemplar(ctx, hist.WithLabelValues(gRPC, method, respStatus, "false"), duration.Seconds())
    29  }
    30  
    31  // UnaryServerInstrumentInterceptor instruments gRPC requests for errors and latency.
    32  func UnaryServerInstrumentInterceptor(hist *prometheus.HistogramVec) grpc.UnaryServerInterceptor {
    33  	return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
    34  		begin := time.Now()
    35  		resp, err := handler(ctx, req)
    36  		observe(ctx, hist, info.FullMethod, err, time.Since(begin))
    37  		return resp, err
    38  	}
    39  }
    40  
    41  // StreamServerInstrumentInterceptor instruments gRPC requests for errors and latency.
    42  func StreamServerInstrumentInterceptor(hist *prometheus.HistogramVec) grpc.StreamServerInterceptor {
    43  	return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
    44  		begin := time.Now()
    45  		err := handler(srv, ss)
    46  		observe(ss.Context(), hist, info.FullMethod, err, time.Since(begin))
    47  		return err
    48  	}
    49  }
    50  
    51  // UnaryClientInstrumentInterceptor records duration of gRPC requests client side.
    52  func UnaryClientInstrumentInterceptor(metric *prometheus.HistogramVec) grpc.UnaryClientInterceptor {
    53  	return func(ctx context.Context, method string, req, resp interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
    54  		start := time.Now()
    55  		err := invoker(ctx, method, req, resp, cc, opts...)
    56  		metric.WithLabelValues(method, errorCode(err)).Observe(time.Since(start).Seconds())
    57  		return err
    58  	}
    59  }
    60  
    61  // StreamClientInstrumentInterceptor records duration of streaming gRPC requests client side.
    62  func StreamClientInstrumentInterceptor(metric *prometheus.HistogramVec) grpc.StreamClientInterceptor {
    63  	return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string,
    64  		streamer grpc.Streamer, opts ...grpc.CallOption,
    65  	) (grpc.ClientStream, error) {
    66  		start := time.Now()
    67  		stream, err := streamer(ctx, desc, cc, method, opts...)
    68  		return &instrumentedClientStream{
    69  			metric:       metric,
    70  			start:        start,
    71  			method:       method,
    72  			ClientStream: stream,
    73  		}, err
    74  	}
    75  }
    76  
    77  type instrumentedClientStream struct {
    78  	metric *prometheus.HistogramVec
    79  	start  time.Time
    80  	method string
    81  	grpc.ClientStream
    82  }
    83  
    84  func (s *instrumentedClientStream) SendMsg(m interface{}) error {
    85  	err := s.ClientStream.SendMsg(m)
    86  	if err == nil {
    87  		return nil
    88  	}
    89  
    90  	if err == io.EOF {
    91  		s.metric.WithLabelValues(s.method, errorCode(nil)).Observe(time.Since(s.start).Seconds())
    92  	} else {
    93  		s.metric.WithLabelValues(s.method, errorCode(err)).Observe(time.Since(s.start).Seconds())
    94  	}
    95  
    96  	return err
    97  }
    98  
    99  func (s *instrumentedClientStream) RecvMsg(m interface{}) error {
   100  	err := s.ClientStream.RecvMsg(m)
   101  	if err == nil {
   102  		return nil
   103  	}
   104  
   105  	if err == io.EOF {
   106  		s.metric.WithLabelValues(s.method, errorCode(nil)).Observe(time.Since(s.start).Seconds())
   107  	} else {
   108  		s.metric.WithLabelValues(s.method, errorCode(err)).Observe(time.Since(s.start).Seconds())
   109  	}
   110  
   111  	return err
   112  }
   113  
   114  func (s *instrumentedClientStream) Header() (metadata.MD, error) {
   115  	md, err := s.ClientStream.Header()
   116  	if err != nil {
   117  		s.metric.WithLabelValues(s.method, errorCode(err)).Observe(time.Since(s.start).Seconds())
   118  	}
   119  	return md, err
   120  }
   121  
   122  // errorCode converts an error into an error code string.
   123  func errorCode(err error) string {
   124  	if err == nil {
   125  		return "2xx"
   126  	}
   127  
   128  	if errResp, ok := httpgrpc.HTTPResponseFromError(err); ok {
   129  		statusFamily := int(errResp.Code / 100)
   130  		return strconv.Itoa(statusFamily) + "xx"
   131  	} else if grpcUtils.IsCanceled(err) {
   132  		return "cancel"
   133  	} else {
   134  		return "error"
   135  	}
   136  }