github.com/blend/go-sdk@v1.20220411.3/grpcutil/tracer.go (about) 1 /* 2 3 Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file. 5 6 */ 7 8 package grpcutil 9 10 import ( 11 "context" 12 13 "google.golang.org/grpc" 14 ) 15 16 // Tracer is the full tracer. 17 type Tracer interface { 18 ServerTracer 19 ClientTracer 20 } 21 22 // ServerTracer is a type that starts traces. 23 type ServerTracer interface { 24 StartServerUnary(ctx context.Context, method string) (context.Context, TraceFinisher, error) 25 StartServerStream(ctx context.Context, method string) (context.Context, TraceFinisher, error) 26 } 27 28 // ClientTracer is a type that starts traces. 29 type ClientTracer interface { 30 StartClientUnary(ctx context.Context, remoteAddr, method string) (context.Context, TraceFinisher, error) 31 StartClientStream(ctx context.Context, remoteAddr, method string) (context.Context, TraceFinisher, error) 32 } 33 34 // TraceFinisher is a finisher for traces 35 type TraceFinisher interface { 36 Finish(err error) 37 } 38 39 // TracedServerUnary returns a unary server interceptor. 40 func TracedServerUnary(tracer ServerTracer) grpc.UnaryServerInterceptor { 41 return func(ctx context.Context, args interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (result interface{}, err error) { 42 if tracer == nil { 43 return handler(ctx, args) 44 } 45 var finisher TraceFinisher 46 ctx, finisher, err = tracer.StartServerUnary(ctx, info.FullMethod) 47 if err != nil { 48 return nil, err 49 } 50 defer func() { 51 finisher.Finish(err) 52 }() 53 result, err = handler(ctx, args) 54 return 55 } 56 } 57 58 // TracedServerStream returns a grpc streaming interceptor. 59 func TracedServerStream(tracer ServerTracer) grpc.StreamServerInterceptor { 60 return func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 61 if tracer == nil { 62 return handler(srv, ss) 63 } 64 var finisher TraceFinisher 65 var err error 66 var ctx context.Context 67 ctx, finisher, err = tracer.StartServerStream(ss.Context(), info.FullMethod) 68 if err != nil { 69 return err 70 } 71 defer func() { 72 finisher.Finish(err) 73 }() 74 err = handler(srv, &contextServerStream{ServerStream: ss, ctx: ctx}) 75 return err 76 } 77 } 78 79 // spanServerStream wraps around the embedded grpc.ServerStream, and 80 // intercepts calls to `Context()` returning a context with the span information injected. 81 // 82 // NOTE: you can extend this type to intercept calls to `SendMsg` and `RecvMsg` if you want to 83 // add tracing handling for individual stream calls. 84 type contextServerStream struct { 85 grpc.ServerStream 86 ctx context.Context 87 } 88 89 func (cs *contextServerStream) Context() context.Context { 90 if cs.ctx != nil { 91 return cs.ctx 92 } 93 return cs.ServerStream.Context() 94 } 95 96 // TracedClientUnary implements the unary client interceptor based on a tracer. 97 func TracedClientUnary(tracer ClientTracer) grpc.UnaryClientInterceptor { 98 return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) (err error) { 99 if tracer == nil { 100 err = invoker(ctx, method, req, reply, cc, opts...) 101 return 102 } 103 var finisher TraceFinisher 104 ctx, finisher, err = tracer.StartClientUnary(ctx, cc.Target(), method) 105 if err != nil { 106 return 107 } 108 defer func() { 109 finisher.Finish(err) 110 }() 111 err = invoker(ctx, method, req, reply, cc, opts...) 112 return 113 } 114 } 115 116 // TracedClientStream implements the stream client interceptor based on a tracer. 117 func TracedClientStream(tracer ClientTracer) grpc.StreamClientInterceptor { 118 return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (cs grpc.ClientStream, err error) { 119 if tracer == nil { 120 cs, err = streamer(ctx, desc, cc, method, opts...) 121 return 122 } 123 var finisher TraceFinisher 124 ctx, finisher, err = tracer.StartClientStream(ctx, cc.Target(), method) 125 if err != nil { 126 return 127 } 128 defer func() { 129 finisher.Finish(err) 130 }() 131 cs, err = streamer(ctx, desc, cc, method, opts...) 132 return 133 } 134 }