github.com/sandwich-go/boost@v1.3.29/xencoding/pbjson/pbjson.go (about)

     1  package pbjson
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"github.com/golang/protobuf/jsonpb"
     8  	"github.com/golang/protobuf/proto"
     9  	"github.com/sandwich-go/boost/xencoding"
    10  	"io"
    11  )
    12  
    13  var (
    14  	errCodecParam = errors.New("pbjson codec marshal/unmarshal must be proto message")
    15  )
    16  
    17  const (
    18  	// CodecName pbjson 加解码名称,可以通过 encoding2.GetCodec(CodecName) 获取对应的 Codec
    19  	CodecName = "pbjson"
    20  )
    21  
    22  var Codec = codec{}
    23  
    24  var (
    25  	marshaler   = &jsonpb.Marshaler{EnumsAsInts: true}
    26  	unmarshaler = &jsonpb.Unmarshaler{}
    27  )
    28  
    29  // EmitUnpopulated 指定是否使用零值渲染字段
    30  func EmitUnpopulated(emit bool) { marshaler.EmitDefaults = emit }
    31  
    32  // UseEnumNumbers 设置是否将 enum 序列化为数字,默认开启功能
    33  func UseEnumNumbers(b bool) { marshaler.EnumsAsInts = b }
    34  
    35  func init() {
    36  	xencoding.RegisterCodec(Codec)
    37  }
    38  
    39  // codec is a Codec implementation with json
    40  type codec struct{}
    41  
    42  // Name 返回 Codec 名
    43  func (codec) Name() string { return CodecName }
    44  
    45  // Marshal 编码
    46  func (codec) Marshal(_ context.Context, obj interface{}) ([]byte, error) {
    47  	if pm, ok := obj.(proto.Message); ok {
    48  		var buf bytes.Buffer
    49  		err := marshaler.Marshal(&buf, pm)
    50  		return buf.Bytes(), err
    51  	}
    52  	return nil, errCodecParam
    53  }
    54  
    55  // Unmarshal 解码
    56  func (codec) Unmarshal(_ context.Context, data []byte, v interface{}) error {
    57  	if pm, ok := v.(proto.Message); ok {
    58  		err := unmarshaler.Unmarshal(bytes.NewBuffer(data), pm)
    59  		if err == io.EOF {
    60  			err = nil
    61  		}
    62  		return err
    63  	}
    64  	return errCodecParam
    65  }