go.uber.org/yarpc@v1.72.1/encoding/protobuf/v2/marshal.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package v2
    22  
    23  import (
    24  	"go.uber.org/yarpc/api/transport"
    25  	"go.uber.org/yarpc/internal/bufferpool"
    26  	"go.uber.org/yarpc/yarpcerrors"
    27  	"google.golang.org/protobuf/encoding/protojson"
    28  	"google.golang.org/protobuf/proto"
    29  	"google.golang.org/protobuf/reflect/protoregistry"
    30  	"io"
    31  	"sync"
    32  )
    33  
    34  var (
    35  	_bufferPool = sync.Pool{
    36  		New: func() interface{} {
    37  			newBuf := make([]byte, 1024)
    38  			return &newBuf
    39  		},
    40  	}
    41  )
    42  
    43  const (
    44  	// Encoding is the name of this encoding.
    45  	Encoding transport.Encoding = "proto"
    46  
    47  	// JSONEncoding is the name of the JSON encoding.
    48  	// Protobuf handlers are able to handle both Encoding and JSONEncoding encodings.
    49  	JSONEncoding transport.Encoding = "json"
    50  )
    51  
    52  // AnyResolver provides interface for looking up or iterating over descriptor types.
    53  type AnyResolver interface {
    54  	protoregistry.ExtensionTypeResolver
    55  	protoregistry.MessageTypeResolver
    56  }
    57  
    58  // codec is a private helper struct used to hold custom marshling behavior of golang protobuf messages.
    59  type codec struct {
    60  	jsonMarshaler   *protojson.MarshalOptions
    61  	jsonUnmarshaler *protojson.UnmarshalOptions
    62  }
    63  
    64  func newCodec(anyResolver AnyResolver) *codec {
    65  	return &codec{
    66  		jsonMarshaler:   &protojson.MarshalOptions{Resolver: anyResolver},
    67  		jsonUnmarshaler: &protojson.UnmarshalOptions{Resolver: anyResolver, DiscardUnknown: true},
    68  	}
    69  }
    70  
    71  func unmarshal(encoding transport.Encoding, reader io.Reader, message proto.Message, codec *codec) error {
    72  	buf := bufferpool.Get()
    73  	defer bufferpool.Put(buf)
    74  	if _, err := buf.ReadFrom(reader); err != nil {
    75  		return err
    76  	}
    77  	body := buf.Bytes()
    78  	if len(body) == 0 {
    79  		return nil
    80  	}
    81  	return unmarshalBytes(encoding, body, message, codec)
    82  }
    83  
    84  func unmarshalBytes(encoding transport.Encoding, body []byte, message proto.Message, codec *codec) error {
    85  	switch encoding {
    86  	case Encoding:
    87  		return unmarshalProto(body, message, codec)
    88  	case JSONEncoding:
    89  		return unmarshalJSON(body, message, codec)
    90  	default:
    91  		return yarpcerrors.Newf(yarpcerrors.CodeInternal, "encoding.Expect should have handled encoding %q but did not", encoding)
    92  	}
    93  }
    94  
    95  func unmarshalProto(body []byte, message proto.Message, _ *codec) error {
    96  	return proto.Unmarshal(body, message)
    97  }
    98  
    99  func unmarshalJSON(body []byte, message proto.Message, codec *codec) error {
   100  	return codec.jsonUnmarshaler.Unmarshal(body, message)
   101  }
   102  
   103  func marshal(encoding transport.Encoding, message proto.Message, codec *codec) ([]byte, func(), error) {
   104  	switch encoding {
   105  	case Encoding:
   106  		return marshalProto(message, codec)
   107  	case JSONEncoding:
   108  		return marshalJSON(message, codec)
   109  	default:
   110  		return nil, nil, yarpcerrors.Newf(yarpcerrors.CodeInternal, "encoding.Expect should have handled encoding %q but did not", encoding)
   111  	}
   112  }
   113  
   114  func marshalProto(message proto.Message, _ *codec) ([]byte, func(), error) {
   115  	buf := getBuffer()
   116  	cleanup := func() { putBuffer(buf) }
   117  	data, err := proto.MarshalOptions{}.MarshalAppend(*buf, message)
   118  	if err != nil {
   119  		cleanup()
   120  		return nil, nil, err
   121  	}
   122  	*buf = data
   123  	return data, cleanup, nil
   124  }
   125  
   126  func marshalJSON(message proto.Message, codec *codec) ([]byte, func(), error) {
   127  	data, err := codec.jsonMarshaler.Marshal(message)
   128  	if err != nil {
   129  		return nil, nil, err
   130  	}
   131  	return data, func() {}, nil
   132  }
   133  
   134  func getBuffer() *[]byte {
   135  	newbuf := _bufferPool.Get().(*[]byte)
   136  	resetBuf(newbuf)
   137  	return newbuf
   138  }
   139  
   140  func putBuffer(buf *[]byte) {
   141  	_bufferPool.Put(buf)
   142  }
   143  
   144  func resetBuf(buf *[]byte) {
   145  	*buf = (*buf)[:0]
   146  }