github.com/sandwich-go/boost@v1.3.29/xencoding/protobuf/protobuf.go (about)

     1  package protobuf
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"github.com/sandwich-go/boost/xencoding"
     8  	"github.com/sandwich-go/boost/xerror"
     9  	"math"
    10  	"reflect"
    11  	"sync"
    12  
    13  	"github.com/golang/protobuf/jsonpb"
    14  	"github.com/golang/protobuf/proto"
    15  )
    16  
    17  var (
    18  	Codec          = codec{usingPool: false, name: CodecName}
    19  	CodecUsingPool = codec{usingPool: true, name: UsingPoolCodecName}
    20  )
    21  
    22  const (
    23  	// CodecName proto 压缩效果名称,可以通过 encoding2.GetCodec(CodecName) 获取对应的 Codec
    24  	CodecName = "proto"
    25  	// UsingPoolCodecName 带对象池的 proto 压缩效果名称,可以通过 encoding2.GetCodec(UsingPoolCodecName) 获取对应的 Codec
    26  	UsingPoolCodecName = "proto_using_pool"
    27  )
    28  
    29  func init() {
    30  	xencoding.RegisterCodec(Codec)
    31  	xencoding.RegisterCodec(CodecUsingPool)
    32  }
    33  
    34  // codec is a Codec implementation with protobuf. It is the default codec.
    35  type codec struct {
    36  	usingPool bool
    37  	name      string
    38  }
    39  
    40  // Name 返回 Codec 名
    41  func (p codec) Name() string { return p.name }
    42  
    43  // Marshal 编码
    44  func (p codec) Marshal(_ context.Context, v interface{}) ([]byte, error) {
    45  	if pm, ok := v.(proto.Marshaler); ok {
    46  		// object can marshal itself, no need for buffer
    47  		return pm.Marshal()
    48  	}
    49  	if pm, ok := v.(proto.Message); ok {
    50  		if p.usingPool {
    51  			cb := protoBufferPool.Get().(*cachedProtoBuffer)
    52  			out, err := marshal(pm, cb)
    53  			// put back buffer and lose the ref to the slice
    54  			cb.SetBuf(nil)
    55  			protoBufferPool.Put(cb)
    56  			return out, err
    57  		}
    58  		return proto.Marshal(pm)
    59  	}
    60  	return nil, xerror.NewText("%T is not a proto.Marshaler", v)
    61  }
    62  
    63  // Uri 获取 Message Name
    64  func (codec) Uri(t interface{}) string { return proto.MessageName(t.(proto.Message)) }
    65  
    66  // Type 获取 Message Type
    67  func (codec) Type(uri string) reflect.Type { return proto.MessageType(uri) }
    68  
    69  // Unmarshal 解码
    70  func (p codec) Unmarshal(ctx context.Context, data []byte, v interface{}) error {
    71  	if pu, ok := v.(proto.Unmarshaler); ok {
    72  		// object can unmarshal itself, no need for buffer
    73  		return pu.Unmarshal(data)
    74  	}
    75  
    76  	if m, ok := v.(proto.Message); ok {
    77  		m.Reset()
    78  		if p.usingPool {
    79  			cb := protoBufferPool.Get().(*cachedProtoBuffer)
    80  			cb.SetBuf(data)
    81  			err := cb.Unmarshal(m)
    82  			cb.SetBuf(nil)
    83  			protoBufferPool.Put(cb)
    84  			return err
    85  		}
    86  		return proto.Unmarshal(data, m)
    87  	}
    88  
    89  	return xerror.NewText("%T is not a proto.Unmarshaler", v)
    90  }
    91  
    92  func (codec) JSONMarshal(obj interface{}) ([]byte, error) {
    93  	if pm, ok := obj.(proto.Message); ok {
    94  		m := jsonpb.Marshaler{EmitDefaults: false}
    95  		var buf bytes.Buffer
    96  		return buf.Bytes(), m.Marshal(&buf, pm)
    97  	}
    98  	return nil, errors.New("not proto message")
    99  }
   100  
   101  func marshal(pm proto.Message, cb *cachedProtoBuffer) ([]byte, error) {
   102  	newSlice := make([]byte, 0, cb.lastMarshaledSize)
   103  
   104  	cb.SetBuf(newSlice)
   105  	cb.Reset()
   106  	if err := cb.Marshal(pm); err != nil {
   107  		return nil, err
   108  	}
   109  	out := cb.Bytes()
   110  	cb.lastMarshaledSize = capToMaxInt32(len(out))
   111  	return out, nil
   112  }
   113  
   114  func capToMaxInt32(val int) uint32 {
   115  	if val > math.MaxInt32 {
   116  		return uint32(math.MaxInt32)
   117  	}
   118  	return uint32(val)
   119  }
   120  
   121  type cachedProtoBuffer struct {
   122  	lastMarshaledSize uint32
   123  	proto.Buffer
   124  }
   125  
   126  var protoBufferPool = &sync.Pool{
   127  	New: func() interface{} {
   128  		return &cachedProtoBuffer{
   129  			Buffer:            proto.Buffer{},
   130  			lastMarshaledSize: 16,
   131  		}
   132  	},
   133  }