go.uber.org/yarpc@v1.72.1/encoding/protobuf/marshal.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package protobuf
    22  
    23  import (
    24  	"bytes"
    25  	"io"
    26  	"sync"
    27  
    28  	"github.com/gogo/protobuf/jsonpb"
    29  	"github.com/gogo/protobuf/proto"
    30  	"go.uber.org/yarpc/api/transport"
    31  	"go.uber.org/yarpc/internal/bufferpool"
    32  	"go.uber.org/yarpc/yarpcerrors"
    33  )
    34  
    35  var (
    36  	_bufferPool = sync.Pool{
    37  		New: func() interface{} {
    38  			return proto.NewBuffer(make([]byte, 1024))
    39  		},
    40  	}
    41  )
    42  
    43  // codec is a private helper struct used to hold custom marshling behavior.
    44  type codec struct {
    45  	jsonMarshaler   *jsonpb.Marshaler
    46  	jsonUnmarshaler *jsonpb.Unmarshaler
    47  }
    48  
    49  func newCodec(anyResolver jsonpb.AnyResolver) *codec {
    50  	return &codec{
    51  		jsonMarshaler:   &jsonpb.Marshaler{AnyResolver: anyResolver},
    52  		jsonUnmarshaler: &jsonpb.Unmarshaler{AnyResolver: anyResolver, AllowUnknownFields: true},
    53  	}
    54  }
    55  
    56  func unmarshal(encoding transport.Encoding, reader io.Reader, message proto.Message, codec *codec) error {
    57  	buf := bufferpool.Get()
    58  	defer bufferpool.Put(buf)
    59  	if _, err := buf.ReadFrom(reader); err != nil {
    60  		return err
    61  	}
    62  	body := buf.Bytes()
    63  	if len(body) == 0 {
    64  		return nil
    65  	}
    66  	return unmarshalBytes(encoding, body, message, codec)
    67  }
    68  
    69  func unmarshalBytes(encoding transport.Encoding, body []byte, message proto.Message, codec *codec) error {
    70  	switch encoding {
    71  	case Encoding:
    72  		return unmarshalProto(body, message, codec)
    73  	case JSONEncoding:
    74  		return unmarshalJSON(body, message, codec)
    75  	default:
    76  		return yarpcerrors.Newf(yarpcerrors.CodeInternal, "encoding.Expect should have handled encoding %q but did not", encoding)
    77  	}
    78  }
    79  
    80  func unmarshalProto(body []byte, message proto.Message, _ *codec) error {
    81  	return proto.Unmarshal(body, message)
    82  }
    83  
    84  func unmarshalJSON(body []byte, message proto.Message, codec *codec) error {
    85  	return codec.jsonUnmarshaler.Unmarshal(bytes.NewReader(body), message)
    86  }
    87  
    88  func marshal(encoding transport.Encoding, message proto.Message, codec *codec) ([]byte, func(), error) {
    89  	switch encoding {
    90  	case Encoding:
    91  		return marshalProto(message, codec)
    92  	case JSONEncoding:
    93  		return marshalJSON(message, codec)
    94  	default:
    95  		return nil, nil, yarpcerrors.Newf(yarpcerrors.CodeInternal, "encoding.Expect should have handled encoding %q but did not", encoding)
    96  	}
    97  }
    98  
    99  func marshalProto(message proto.Message, _ *codec) ([]byte, func(), error) {
   100  	protoBuffer := getBuffer()
   101  	cleanup := func() { putBuffer(protoBuffer) }
   102  	if err := protoBuffer.Marshal(message); err != nil {
   103  		cleanup()
   104  		return nil, nil, err
   105  	}
   106  	return protoBuffer.Bytes(), cleanup, nil
   107  }
   108  
   109  func marshalJSON(message proto.Message, codec *codec) ([]byte, func(), error) {
   110  	buf := bufferpool.Get()
   111  	cleanup := func() { bufferpool.Put(buf) }
   112  	if err := codec.jsonMarshaler.Marshal(buf, message); err != nil {
   113  		cleanup()
   114  		return nil, nil, err
   115  	}
   116  	return buf.Bytes(), cleanup, nil
   117  }
   118  
   119  func getBuffer() *proto.Buffer {
   120  	buf := _bufferPool.Get().(*proto.Buffer)
   121  	buf.Reset()
   122  	return buf
   123  }
   124  
   125  func putBuffer(buf *proto.Buffer) {
   126  	_bufferPool.Put(buf)
   127  }