go.uber.org/yarpc@v1.72.1/transport/grpc/stream.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package grpc
    22  
    23  import (
    24  	"bytes"
    25  	"context"
    26  	"io"
    27  	"io/ioutil"
    28  
    29  	"github.com/gogo/status"
    30  	"github.com/opentracing/opentracing-go"
    31  	"go.uber.org/atomic"
    32  	"go.uber.org/yarpc/api/transport"
    33  	"go.uber.org/yarpc/internal/grpcerrorcodes"
    34  	"go.uber.org/yarpc/yarpcerrors"
    35  	"google.golang.org/grpc"
    36  	"google.golang.org/grpc/metadata"
    37  )
    38  
    39  var (
    40  	_ transport.StreamHeadersSender = (*serverStream)(nil)
    41  	_ transport.StreamHeadersReader = (*clientStream)(nil)
    42  )
    43  
    44  type serverStream struct {
    45  	ctx    context.Context
    46  	req    *transport.StreamRequest
    47  	stream grpc.ServerStream
    48  }
    49  
    50  func newServerStream(ctx context.Context, req *transport.StreamRequest, stream grpc.ServerStream) *serverStream {
    51  	return &serverStream{
    52  		ctx:    ctx,
    53  		req:    req,
    54  		stream: stream,
    55  	}
    56  }
    57  
    58  func (ss *serverStream) Context() context.Context {
    59  	return ss.ctx
    60  }
    61  
    62  func (ss *serverStream) Request() *transport.StreamRequest {
    63  	return ss.req
    64  }
    65  
    66  func (ss *serverStream) SendMessage(_ context.Context, m *transport.StreamMessage) error {
    67  	// TODO pool buffers for performance.
    68  	msg, err := ioutil.ReadAll(m.Body)
    69  	_ = m.Body.Close()
    70  	if err != nil {
    71  		return err
    72  	}
    73  	return toYARPCStreamError(ss.stream.SendMsg(msg))
    74  }
    75  
    76  func (ss *serverStream) ReceiveMessage(_ context.Context) (*transport.StreamMessage, error) {
    77  	var msg []byte
    78  	if err := ss.stream.RecvMsg(&msg); err != nil {
    79  		return nil, toYARPCStreamError(err)
    80  	}
    81  	return &transport.StreamMessage{
    82  		Body:     readCloser{bytes.NewReader(msg)},
    83  		BodySize: len(msg),
    84  	}, nil
    85  }
    86  
    87  type readCloser struct {
    88  	*bytes.Reader
    89  }
    90  
    91  func (r readCloser) Close() error {
    92  	return nil
    93  }
    94  
    95  func (ss *serverStream) SendHeaders(headers transport.Headers) error {
    96  	md := make(metadata.MD, headers.Len())
    97  	for k, v := range headers.Items() {
    98  		md.Set(k, v)
    99  	}
   100  	return ss.stream.SendHeader(md)
   101  }
   102  
   103  type clientStream struct {
   104  	ctx     context.Context
   105  	req     *transport.StreamRequest
   106  	stream  grpc.ClientStream
   107  	span    opentracing.Span
   108  	closed  atomic.Bool
   109  	release func(error)
   110  }
   111  
   112  func newClientStream(ctx context.Context, req *transport.StreamRequest, stream grpc.ClientStream, span opentracing.Span, release func(error)) *clientStream {
   113  	return &clientStream{
   114  		ctx:     ctx,
   115  		req:     req,
   116  		stream:  stream,
   117  		span:    span,
   118  		release: release,
   119  	}
   120  }
   121  
   122  func (cs *clientStream) Context() context.Context {
   123  	return cs.ctx
   124  }
   125  
   126  func (cs *clientStream) Request() *transport.StreamRequest {
   127  	return cs.req
   128  }
   129  
   130  func (cs *clientStream) SendMessage(_ context.Context, m *transport.StreamMessage) error {
   131  	if cs.closed.Load() { // If the stream is closed, we should not be sending messages on it.
   132  		return io.EOF
   133  	}
   134  	// TODO can we make a "Bytes" interface to get direct access to the bytes
   135  	// (instead of resorting to ReadAll (which is not necessarily performant))
   136  	msg, err := ioutil.ReadAll(m.Body)
   137  	_ = m.Body.Close()
   138  	if err != nil {
   139  		return toYARPCStreamError(err)
   140  	}
   141  	if err := cs.stream.SendMsg(msg); err != nil {
   142  		return toYARPCStreamError(cs.closeWithErr(err))
   143  	}
   144  	return nil
   145  }
   146  
   147  func (cs *clientStream) ReceiveMessage(context.Context) (*transport.StreamMessage, error) {
   148  	// TODO use buffers for performance reasons.
   149  	var msg []byte
   150  	if err := cs.stream.RecvMsg(&msg); err != nil {
   151  		return nil, toYARPCStreamError(cs.closeWithErr(err))
   152  	}
   153  	return &transport.StreamMessage{Body: ioutil.NopCloser(bytes.NewReader(msg))}, nil
   154  }
   155  
   156  func (cs *clientStream) Close(context.Context) error {
   157  	_ = cs.closeWithErr(nil)
   158  	return cs.stream.CloseSend()
   159  }
   160  
   161  func (cs *clientStream) Headers() (transport.Headers, error) {
   162  	md, err := cs.stream.Header()
   163  	if err != nil {
   164  		return transport.NewHeaders(), err
   165  	}
   166  	headers := transport.NewHeadersWithCapacity(len(md))
   167  	for k, vs := range md {
   168  		if len(vs) > 0 {
   169  			headers = headers.With(k, vs[0])
   170  		}
   171  	}
   172  	return headers, nil
   173  }
   174  
   175  func (cs *clientStream) closeWithErr(err error) error {
   176  	if !cs.closed.Swap(true) {
   177  		err = transport.UpdateSpanWithErr(cs.span, err)
   178  		cs.span.Finish()
   179  		cs.release(err)
   180  	}
   181  	return err
   182  }
   183  
   184  func toYARPCStreamError(err error) error {
   185  	if err == nil {
   186  		return nil
   187  	}
   188  	if err == io.EOF {
   189  		return err
   190  	}
   191  	if yarpcerrors.IsStatus(err) {
   192  		return err
   193  	}
   194  	status, ok := status.FromError(err)
   195  	// if not a yarpc error or grpc error, just return a wrapped error
   196  	if !ok {
   197  		return yarpcerrors.FromError(err)
   198  	}
   199  	code, ok := grpcerrorcodes.GRPCCodeToYARPCCode[status.Code()]
   200  	if !ok {
   201  		code = yarpcerrors.CodeUnknown
   202  	}
   203  	yarpcerr := yarpcerrors.Newf(code, status.Message())
   204  	details, err := marshalError(status)
   205  	if err != nil {
   206  		return yarpcerrors.FromError(err)
   207  	}
   208  	if details != nil {
   209  		yarpcerr = yarpcerr.WithDetails(details)
   210  	}
   211  	return yarpcerr
   212  }