go.uber.org/yarpc@v1.72.1/transport/grpc/handler.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package grpc
    22  
    23  import (
    24  	"strings"
    25  	"time"
    26  
    27  	"github.com/opentracing/opentracing-go"
    28  	"go.uber.org/yarpc"
    29  	"go.uber.org/yarpc/api/transport"
    30  	"go.uber.org/yarpc/internal/bufferpool"
    31  	"go.uber.org/yarpc/internal/grpcerrorcodes"
    32  	"go.uber.org/yarpc/yarpcerrors"
    33  	"go.uber.org/zap"
    34  	"golang.org/x/net/context"
    35  	"google.golang.org/grpc"
    36  	"google.golang.org/grpc/codes"
    37  	"google.golang.org/grpc/metadata"
    38  	"google.golang.org/grpc/status"
    39  )
    40  
    41  var (
    42  	// errInvalidGRPCStream is applied before yarpc so it's a raw GRPC error
    43  	errInvalidGRPCStream = status.Error(codes.InvalidArgument, "received grpc request with invalid stream")
    44  	errInvalidGRPCMethod = yarpcerrors.Newf(yarpcerrors.CodeInvalidArgument, "invalid stream method name for request")
    45  )
    46  
    47  type handler struct {
    48  	i      *Inbound
    49  	logger *zap.Logger
    50  }
    51  
    52  func newHandler(i *Inbound, l *zap.Logger) *handler {
    53  	return &handler{i: i, logger: l}
    54  }
    55  
    56  func (h *handler) handle(srv interface{}, serverStream grpc.ServerStream) (err error) {
    57  	defer func() { err = toGRPCError(err) }()
    58  
    59  	start := time.Now()
    60  	ctx := serverStream.Context()
    61  	streamMethod, ok := grpc.MethodFromServerStream(serverStream)
    62  	if !ok {
    63  		return errInvalidGRPCStream
    64  	}
    65  
    66  	transportRequest, err := h.getBasicTransportRequest(ctx, streamMethod)
    67  	if err != nil {
    68  		return err
    69  	}
    70  
    71  	handlerSpec, err := h.i.router.Choose(ctx, transportRequest)
    72  	if err != nil {
    73  		return err
    74  	}
    75  	switch handlerSpec.Type() {
    76  	case transport.Unary:
    77  		return h.handleUnary(ctx, transportRequest, serverStream, streamMethod, start, handlerSpec.Unary())
    78  	case transport.Streaming:
    79  		return h.handleStream(ctx, transportRequest, serverStream, start, handlerSpec.Stream())
    80  	}
    81  	return yarpcerrors.Newf(yarpcerrors.CodeUnimplemented, "transport grpc does not handle %s handlers", handlerSpec.Type().String())
    82  }
    83  
    84  // getBasicTransportRequest converts the grpc request metadata into a
    85  // transport.Request without a body field.
    86  func (h *handler) getBasicTransportRequest(ctx context.Context, streamMethod string) (*transport.Request, error) {
    87  	md, ok := metadata.FromIncomingContext(ctx)
    88  	if md == nil || !ok {
    89  		return nil, yarpcerrors.Newf(yarpcerrors.CodeInternal, "cannot get metadata from ctx: %v", ctx)
    90  	}
    91  	transportRequest, err := metadataToTransportRequest(md)
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  	transportRequest.Transport = TransportName
    96  
    97  	procedure, err := procedureFromStreamMethod(streamMethod)
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  
   102  	transportRequest.Procedure = procedure
   103  	if err := transport.ValidateRequest(transportRequest); err != nil {
   104  		return nil, err
   105  	}
   106  	return transportRequest, nil
   107  }
   108  
   109  // procedureFromStreamMethod converts a GRPC stream method into a yarpc
   110  // procedure name.  This is mostly copied from the GRPC-go server processing
   111  // logic here:
   112  // https://github.com/grpc/grpc-go/blob/d6723916d2e73e8824d22a1ba5c52f8e6255e6f8/server.go#L931-L956
   113  func procedureFromStreamMethod(streamMethod string) (string, error) {
   114  	if streamMethod != "" && streamMethod[0] == '/' {
   115  		streamMethod = streamMethod[1:]
   116  	}
   117  	pos := strings.LastIndex(streamMethod, "/")
   118  	if pos == -1 {
   119  		return "", errInvalidGRPCMethod
   120  	}
   121  	service := streamMethod[:pos]
   122  	method := streamMethod[pos+1:]
   123  	return procedureToName(service, method)
   124  }
   125  
   126  func (h *handler) handleStream(
   127  	ctx context.Context,
   128  	transportRequest *transport.Request,
   129  	serverStream grpc.ServerStream,
   130  	start time.Time,
   131  	streamHandler transport.StreamHandler,
   132  ) error {
   133  	tracer := h.i.t.options.tracer
   134  	var parentSpanCtx opentracing.SpanContext
   135  	md, ok := metadata.FromIncomingContext(ctx)
   136  	if ok {
   137  		parentSpanCtx, _ = tracer.Extract(opentracing.HTTPHeaders, mdReadWriter(md))
   138  	}
   139  	extractOpenTracingSpan := &transport.ExtractOpenTracingSpan{
   140  		ParentSpanContext: parentSpanCtx,
   141  		Tracer:            tracer,
   142  		TransportName:     TransportName,
   143  		StartTime:         start,
   144  		ExtraTags:         yarpc.OpentracingTags,
   145  	}
   146  	ctx, span := extractOpenTracingSpan.Do(ctx, transportRequest)
   147  	defer span.Finish()
   148  
   149  	stream := newServerStream(ctx, &transport.StreamRequest{Meta: transportRequest.ToRequestMeta()}, serverStream)
   150  	tServerStream, err := transport.NewServerStream(stream)
   151  	if err != nil {
   152  		return err
   153  	}
   154  	apperr := transport.InvokeStreamHandler(transport.StreamInvokeRequest{
   155  		Stream:  tServerStream,
   156  		Handler: streamHandler,
   157  		Logger:  h.logger,
   158  	})
   159  	apperr = handlerErrorToGRPCError(apperr, nil)
   160  	return transport.UpdateSpanWithErr(span, apperr)
   161  }
   162  
   163  func (h *handler) handleUnary(
   164  	ctx context.Context,
   165  	transportRequest *transport.Request,
   166  	serverStream grpc.ServerStream,
   167  	streamMethod string,
   168  	start time.Time,
   169  	handler transport.UnaryHandler,
   170  ) error {
   171  	var requestData []byte
   172  	if err := serverStream.RecvMsg(&requestData); err != nil {
   173  		return err
   174  	}
   175  	// TODO: avoid redundant buffer copy
   176  	requestBuffer := bufferpool.Get()
   177  	defer bufferpool.Put(requestBuffer)
   178  
   179  	// Buffers are documented to always return a nil error.
   180  	_, _ = requestBuffer.Write(requestData)
   181  	transportRequest.Body = requestBuffer
   182  	transportRequest.BodySize = len(requestData)
   183  
   184  	responseWriter := newResponseWriter()
   185  	defer responseWriter.Close()
   186  
   187  	// Echo accepted rpc-service in response header
   188  	responseWriter.AddSystemHeader(ServiceHeader, transportRequest.Service)
   189  
   190  	err := h.handleUnaryBeforeErrorConversion(ctx, transportRequest, responseWriter, start, handler)
   191  	err = handlerErrorToGRPCError(err, responseWriter)
   192  
   193  	// Send the response attributes back and end the stream.
   194  	//
   195  	// Warning: SendMsg() holds onto these bytes after returning. Therefore, we
   196  	// cannot pool this responseWriter.
   197  	//
   198  	// See https://github.com/yarpc/yarpc-go/pull/1738 for details.
   199  	if sendErr := serverStream.SendMsg(responseWriter.Bytes()); sendErr != nil {
   200  		// We couldn't send the response.
   201  		return sendErr
   202  	}
   203  	if responseWriter.md != nil {
   204  		serverStream.SetTrailer(responseWriter.md)
   205  	}
   206  	return err
   207  }
   208  
   209  func (h *handler) handleUnaryBeforeErrorConversion(
   210  	ctx context.Context,
   211  	transportRequest *transport.Request,
   212  	responseWriter *responseWriter,
   213  	start time.Time,
   214  	handler transport.UnaryHandler,
   215  ) error {
   216  	tracer := h.i.t.options.tracer
   217  	var parentSpanCtx opentracing.SpanContext
   218  	md, ok := metadata.FromIncomingContext(ctx)
   219  	if ok {
   220  		parentSpanCtx, _ = tracer.Extract(opentracing.HTTPHeaders, mdReadWriter(md))
   221  	}
   222  	extractOpenTracingSpan := &transport.ExtractOpenTracingSpan{
   223  		ParentSpanContext: parentSpanCtx,
   224  		Tracer:            tracer,
   225  		TransportName:     TransportName,
   226  		StartTime:         start,
   227  		ExtraTags:         yarpc.OpentracingTags,
   228  	}
   229  	ctx, span := extractOpenTracingSpan.Do(ctx, transportRequest)
   230  	defer span.Finish()
   231  
   232  	err := h.callUnary(ctx, transportRequest, handler, responseWriter)
   233  	return transport.UpdateSpanWithErr(span, err)
   234  }
   235  
   236  func (h *handler) callUnary(ctx context.Context, transportRequest *transport.Request, unaryHandler transport.UnaryHandler, responseWriter *responseWriter) error {
   237  	if err := transport.ValidateRequestContext(ctx); err != nil {
   238  		return err
   239  	}
   240  	return transport.InvokeUnaryHandler(transport.UnaryInvokeRequest{
   241  		Context:        ctx,
   242  		StartTime:      time.Now(),
   243  		Request:        transportRequest,
   244  		ResponseWriter: responseWriter,
   245  		Handler:        unaryHandler,
   246  		Logger:         h.logger,
   247  	})
   248  }
   249  
   250  // handlerErrorToGRPCError converts a yarpcerror to gRPC status error,
   251  // taking into account error details.
   252  //
   253  // This method is used from unary and stream handlers. Stream handler passes
   254  // nil responseWriter
   255  func handlerErrorToGRPCError(err error, responseWriter *responseWriter) error {
   256  	if err == nil {
   257  		return nil
   258  	}
   259  	// if this is an error created from grpc-go, return the error
   260  	if _, ok := status.FromError(err); ok {
   261  		return err
   262  	}
   263  	// if this is not a yarpc error, return the error
   264  	// this will result in the error being a grpc-go error with codes.Unknown
   265  	if !yarpcerrors.IsStatus(err) {
   266  		return err
   267  	}
   268  	// we now know we have a yarpc error
   269  	yarpcStatus := yarpcerrors.FromError(err)
   270  	name := yarpcStatus.Name()
   271  	message := yarpcStatus.Message()
   272  	// if the yarpc error has a name, set the header
   273  	if name != "" {
   274  		if responseWriter != nil {
   275  			responseWriter.AddSystemHeader(ErrorNameHeader, name)
   276  		}
   277  		if message == "" {
   278  			// if the message is empty, set the message to the name for grpc compatibility
   279  			message = name
   280  		} else {
   281  			// else, we set the name as the prefix for grpc compatibility
   282  			// we parse this off the front if the name header is set on the client-side
   283  			message = name + ": " + message
   284  		}
   285  	}
   286  
   287  	if body := yarpcStatus.Details(); body != nil {
   288  		return unmarshalError(body)
   289  	}
   290  
   291  	grpcCode, ok := grpcerrorcodes.YARPCCodeToGRPCCode[yarpcStatus.Code()]
   292  	// should only happen if grpcerrorcodes.YARPCCodeToGRPCCode does not cover all codes
   293  	if !ok {
   294  		grpcCode = codes.Unknown
   295  	}
   296  	return status.Error(grpcCode, message)
   297  }
   298  
   299  // toGRPCError converts errors to gRPC status errors. gRPC status errors are
   300  // returned as is.
   301  //
   302  // This MUST NOT be used for coverting YARPC error details to gRPC error
   303  // details. Use toGRPCErrorWithDetails instead.
   304  func toGRPCError(err error) error {
   305  	if err == nil {
   306  		return nil
   307  	}
   308  
   309  	// if this is not a yarpc error, return the error
   310  	// this will result in the error being a grpc-go error with codes.Unknown
   311  	if !yarpcerrors.IsStatus(err) {
   312  		return err
   313  	}
   314  	// we now know we have a yarpc error
   315  	yarpcStatus := yarpcerrors.FromError(err)
   316  	message := yarpcStatus.Message()
   317  	grpcCode, ok := grpcerrorcodes.YARPCCodeToGRPCCode[yarpcStatus.Code()]
   318  	// should only happen if grpcerrorcodes.YARPCCodeToGRPCCode does not cover all codes
   319  	if !ok {
   320  		grpcCode = codes.Unknown
   321  	}
   322  	return status.Error(grpcCode, message)
   323  }