github.com/blend/go-sdk@v1.20220411.3/grpcutil/tracer.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package grpcutil
     9  
    10  import (
    11  	"context"
    12  
    13  	"google.golang.org/grpc"
    14  )
    15  
    16  // Tracer is the full tracer.
    17  type Tracer interface {
    18  	ServerTracer
    19  	ClientTracer
    20  }
    21  
    22  // ServerTracer is a type that starts traces.
    23  type ServerTracer interface {
    24  	StartServerUnary(ctx context.Context, method string) (context.Context, TraceFinisher, error)
    25  	StartServerStream(ctx context.Context, method string) (context.Context, TraceFinisher, error)
    26  }
    27  
    28  // ClientTracer is a type that starts traces.
    29  type ClientTracer interface {
    30  	StartClientUnary(ctx context.Context, remoteAddr, method string) (context.Context, TraceFinisher, error)
    31  	StartClientStream(ctx context.Context, remoteAddr, method string) (context.Context, TraceFinisher, error)
    32  }
    33  
    34  // TraceFinisher is a finisher for traces
    35  type TraceFinisher interface {
    36  	Finish(err error)
    37  }
    38  
    39  // TracedServerUnary returns a unary server interceptor.
    40  func TracedServerUnary(tracer ServerTracer) grpc.UnaryServerInterceptor {
    41  	return func(ctx context.Context, args interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (result interface{}, err error) {
    42  		if tracer == nil {
    43  			return handler(ctx, args)
    44  		}
    45  		var finisher TraceFinisher
    46  		ctx, finisher, err = tracer.StartServerUnary(ctx, info.FullMethod)
    47  		if err != nil {
    48  			return nil, err
    49  		}
    50  		defer func() {
    51  			finisher.Finish(err)
    52  		}()
    53  		result, err = handler(ctx, args)
    54  		return
    55  	}
    56  }
    57  
    58  // TracedServerStream returns a grpc streaming interceptor.
    59  func TracedServerStream(tracer ServerTracer) grpc.StreamServerInterceptor {
    60  	return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
    61  		if tracer == nil {
    62  			return handler(srv, ss)
    63  		}
    64  		var finisher TraceFinisher
    65  		var err error
    66  		var ctx context.Context
    67  		ctx, finisher, err = tracer.StartServerStream(ss.Context(), info.FullMethod)
    68  		if err != nil {
    69  			return err
    70  		}
    71  		defer func() {
    72  			finisher.Finish(err)
    73  		}()
    74  		err = handler(srv, &contextServerStream{ServerStream: ss, ctx: ctx})
    75  		return err
    76  	}
    77  }
    78  
    79  // spanServerStream wraps around the embedded grpc.ServerStream, and
    80  // intercepts calls to `Context()` returning a context with the span information injected.
    81  //
    82  // NOTE: you can extend this type to intercept calls to `SendMsg` and `RecvMsg` if you want to
    83  // add tracing handling for individual stream calls.
    84  type contextServerStream struct {
    85  	grpc.ServerStream
    86  	ctx context.Context
    87  }
    88  
    89  func (cs *contextServerStream) Context() context.Context {
    90  	if cs.ctx != nil {
    91  		return cs.ctx
    92  	}
    93  	return cs.ServerStream.Context()
    94  }
    95  
    96  // TracedClientUnary implements the unary client interceptor based on a tracer.
    97  func TracedClientUnary(tracer ClientTracer) grpc.UnaryClientInterceptor {
    98  	return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) (err error) {
    99  		if tracer == nil {
   100  			err = invoker(ctx, method, req, reply, cc, opts...)
   101  			return
   102  		}
   103  		var finisher TraceFinisher
   104  		ctx, finisher, err = tracer.StartClientUnary(ctx, cc.Target(), method)
   105  		if err != nil {
   106  			return
   107  		}
   108  		defer func() {
   109  			finisher.Finish(err)
   110  		}()
   111  		err = invoker(ctx, method, req, reply, cc, opts...)
   112  		return
   113  	}
   114  }
   115  
   116  // TracedClientStream implements the stream client interceptor based on a tracer.
   117  func TracedClientStream(tracer ClientTracer) grpc.StreamClientInterceptor {
   118  	return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (cs grpc.ClientStream, err error) {
   119  		if tracer == nil {
   120  			cs, err = streamer(ctx, desc, cc, method, opts...)
   121  			return
   122  		}
   123  		var finisher TraceFinisher
   124  		ctx, finisher, err = tracer.StartClientStream(ctx, cc.Target(), method)
   125  		if err != nil {
   126  			return
   127  		}
   128  		defer func() {
   129  			finisher.Finish(err)
   130  		}()
   131  		cs, err = streamer(ctx, desc, cc, method, opts...)
   132  		return
   133  	}
   134  }