github.com/lingyao2333/mo-zero@v1.4.1/zrpc/internal/clientinterceptors/tracinginterceptor.go (about)

     1  package clientinterceptors
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  
     7  	ztrace "github.com/lingyao2333/mo-zero/core/trace"
     8  	"go.opentelemetry.io/otel"
     9  	"go.opentelemetry.io/otel/codes"
    10  	"go.opentelemetry.io/otel/trace"
    11  	"google.golang.org/grpc"
    12  	gcodes "google.golang.org/grpc/codes"
    13  	"google.golang.org/grpc/metadata"
    14  	"google.golang.org/grpc/status"
    15  )
    16  
    17  const (
    18  	receiveEndEvent streamEventType = iota
    19  	errorEvent
    20  )
    21  
    22  // UnaryTracingInterceptor returns a grpc.UnaryClientInterceptor for opentelemetry.
    23  func UnaryTracingInterceptor(ctx context.Context, method string, req, reply interface{},
    24  	cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
    25  	ctx, span := startSpan(ctx, method, cc.Target())
    26  	defer span.End()
    27  
    28  	ztrace.MessageSent.Event(ctx, 1, req)
    29  	err := invoker(ctx, method, req, reply, cc, opts...)
    30  	ztrace.MessageReceived.Event(ctx, 1, reply)
    31  	if err != nil {
    32  		s, ok := status.FromError(err)
    33  		if ok {
    34  			span.SetStatus(codes.Error, s.Message())
    35  			span.SetAttributes(ztrace.StatusCodeAttr(s.Code()))
    36  		} else {
    37  			span.SetStatus(codes.Error, err.Error())
    38  		}
    39  		return err
    40  	}
    41  
    42  	span.SetAttributes(ztrace.StatusCodeAttr(gcodes.OK))
    43  	return nil
    44  }
    45  
    46  // StreamTracingInterceptor returns a grpc.StreamClientInterceptor for opentelemetry.
    47  func StreamTracingInterceptor(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn,
    48  	method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
    49  	ctx, span := startSpan(ctx, method, cc.Target())
    50  	s, err := streamer(ctx, desc, cc, method, opts...)
    51  	if err != nil {
    52  		st, ok := status.FromError(err)
    53  		if ok {
    54  			span.SetStatus(codes.Error, st.Message())
    55  			span.SetAttributes(ztrace.StatusCodeAttr(st.Code()))
    56  		} else {
    57  			span.SetStatus(codes.Error, err.Error())
    58  		}
    59  		span.End()
    60  		return s, err
    61  	}
    62  
    63  	stream := wrapClientStream(ctx, s, desc)
    64  
    65  	go func() {
    66  		if err := <-stream.Finished; err != nil {
    67  			s, ok := status.FromError(err)
    68  			if ok {
    69  				span.SetStatus(codes.Error, s.Message())
    70  				span.SetAttributes(ztrace.StatusCodeAttr(s.Code()))
    71  			} else {
    72  				span.SetStatus(codes.Error, err.Error())
    73  			}
    74  		} else {
    75  			span.SetAttributes(ztrace.StatusCodeAttr(gcodes.OK))
    76  		}
    77  
    78  		span.End()
    79  	}()
    80  
    81  	return stream, nil
    82  }
    83  
    84  type (
    85  	streamEventType int
    86  
    87  	streamEvent struct {
    88  		Type streamEventType
    89  		Err  error
    90  	}
    91  
    92  	clientStream struct {
    93  		grpc.ClientStream
    94  		Finished          chan error
    95  		desc              *grpc.StreamDesc
    96  		events            chan streamEvent
    97  		eventsDone        chan struct{}
    98  		receivedMessageID int
    99  		sentMessageID     int
   100  	}
   101  )
   102  
   103  func (w *clientStream) CloseSend() error {
   104  	err := w.ClientStream.CloseSend()
   105  	if err != nil {
   106  		w.sendStreamEvent(errorEvent, err)
   107  	}
   108  
   109  	return err
   110  }
   111  
   112  func (w *clientStream) Header() (metadata.MD, error) {
   113  	md, err := w.ClientStream.Header()
   114  	if err != nil {
   115  		w.sendStreamEvent(errorEvent, err)
   116  	}
   117  
   118  	return md, err
   119  }
   120  
   121  func (w *clientStream) RecvMsg(m interface{}) error {
   122  	err := w.ClientStream.RecvMsg(m)
   123  	if err == nil && !w.desc.ServerStreams {
   124  		w.sendStreamEvent(receiveEndEvent, nil)
   125  	} else if err == io.EOF {
   126  		w.sendStreamEvent(receiveEndEvent, nil)
   127  	} else if err != nil {
   128  		w.sendStreamEvent(errorEvent, err)
   129  	} else {
   130  		w.receivedMessageID++
   131  		ztrace.MessageReceived.Event(w.Context(), w.receivedMessageID, m)
   132  	}
   133  
   134  	return err
   135  }
   136  
   137  func (w *clientStream) SendMsg(m interface{}) error {
   138  	err := w.ClientStream.SendMsg(m)
   139  	w.sentMessageID++
   140  	ztrace.MessageSent.Event(w.Context(), w.sentMessageID, m)
   141  	if err != nil {
   142  		w.sendStreamEvent(errorEvent, err)
   143  	}
   144  
   145  	return err
   146  }
   147  
   148  func (w *clientStream) sendStreamEvent(eventType streamEventType, err error) {
   149  	select {
   150  	case <-w.eventsDone:
   151  	case w.events <- streamEvent{Type: eventType, Err: err}:
   152  	}
   153  }
   154  
   155  func startSpan(ctx context.Context, method, target string) (context.Context, trace.Span) {
   156  	md, ok := metadata.FromOutgoingContext(ctx)
   157  	if !ok {
   158  		md = metadata.MD{}
   159  	}
   160  	tr := otel.Tracer(ztrace.TraceName)
   161  	name, attr := ztrace.SpanInfo(method, target)
   162  	ctx, span := tr.Start(ctx, name, trace.WithSpanKind(trace.SpanKindClient),
   163  		trace.WithAttributes(attr...))
   164  	ztrace.Inject(ctx, otel.GetTextMapPropagator(), &md)
   165  	ctx = metadata.NewOutgoingContext(ctx, md)
   166  
   167  	return ctx, span
   168  }
   169  
   170  // wrapClientStream wraps s with given ctx and desc.
   171  func wrapClientStream(ctx context.Context, s grpc.ClientStream, desc *grpc.StreamDesc) *clientStream {
   172  	events := make(chan streamEvent)
   173  	eventsDone := make(chan struct{})
   174  	finished := make(chan error)
   175  
   176  	go func() {
   177  		defer close(eventsDone)
   178  
   179  		for {
   180  			select {
   181  			case event := <-events:
   182  				switch event.Type {
   183  				case receiveEndEvent:
   184  					finished <- nil
   185  					return
   186  				case errorEvent:
   187  					finished <- event.Err
   188  					return
   189  				}
   190  			case <-ctx.Done():
   191  				finished <- ctx.Err()
   192  				return
   193  			}
   194  		}
   195  	}()
   196  
   197  	return &clientStream{
   198  		ClientStream: s,
   199  		desc:         desc,
   200  		events:       events,
   201  		eventsDone:   eventsDone,
   202  		Finished:     finished,
   203  	}
   204  }