trpc.group/trpc-go/trpc-go@v1.0.3/restful/serialize_jsonpb.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package restful
    15  
    16  import (
    17  	"bytes"
    18  	"fmt"
    19  	"reflect"
    20  	"strconv"
    21  
    22  	jsoniter "github.com/json-iterator/go"
    23  	"google.golang.org/protobuf/encoding/protojson"
    24  	"google.golang.org/protobuf/proto"
    25  	"google.golang.org/protobuf/reflect/protoreflect"
    26  )
    27  
    28  func init() {
    29  	RegisterSerializer(&JSONPBSerializer{})
    30  }
    31  
    32  // JSONPBSerializer is used for content-Type: application/json.
    33  // It's based on google.golang.org/protobuf/encoding/protojson.
    34  //
    35  // This serializer will firstly try jsonpb's serialization. If object does not
    36  // conform to protobuf proto.Message interface, the serialization will switch to
    37  // json-iterator.
    38  type JSONPBSerializer struct {
    39  	AllowUnmarshalNil bool // allow unmarshalling nil body
    40  }
    41  
    42  // JSONAPI is a copy of jsoniter.ConfigCompatibleWithStandardLibrary.
    43  // github.com/json-iterator/go is faster than Go's standard json library.
    44  //
    45  // Deprecated: This global variable is exportable due to backward comparability issue but
    46  // should not be modified. If users want to change the default behavior of
    47  // internal JSON serialization, please use register your customized serializer
    48  // function like:
    49  //
    50  //	restful.RegisterSerializer(yourOwnJSONSerializer)
    51  var JSONAPI = jsoniter.ConfigCompatibleWithStandardLibrary
    52  
    53  // Marshaller is a configurable protojson marshaler.
    54  var Marshaller = protojson.MarshalOptions{EmitUnpopulated: true}
    55  
    56  // Unmarshaller is a configurable protojson unmarshaler.
    57  var Unmarshaller = protojson.UnmarshalOptions{DiscardUnknown: true}
    58  
    59  // Marshal implements Serializer.
    60  // Unlike Serializers in trpc-go/codec, Serializers in trpc-go/restful
    61  // could be used to marshal a field of a tRPC message.
    62  func (*JSONPBSerializer) Marshal(v interface{}) ([]byte, error) {
    63  	msg, ok := v.(proto.Message)
    64  	if !ok { // marshal a field of a tRPC message
    65  		return marshal(v)
    66  	}
    67  	// marshal tRPC message
    68  	return Marshaller.Marshal(msg)
    69  }
    70  
    71  // marshal is a helper function that is used to marshal a field of a tRPC message.
    72  func marshal(v interface{}) ([]byte, error) {
    73  	msg, ok := v.(proto.Message)
    74  	if !ok { // marshal none proto field
    75  		return marshalNonProtoField(v)
    76  	}
    77  	// marshal proto field
    78  	return Marshaller.Marshal(msg)
    79  }
    80  
    81  // wrappedEnum is used to get the name of enum.
    82  type wrappedEnum interface {
    83  	protoreflect.Enum
    84  	String() string
    85  }
    86  
    87  // typeOfProtoMessage is used to avoid multiple reflection and check if the object
    88  // implements proto.Message interface.
    89  var typeOfProtoMessage = reflect.TypeOf((*proto.Message)(nil)).Elem()
    90  
    91  // marshalNonProtoField marshals none proto fields.
    92  // Go's standard json lib or github.com/json-iterator/go doesn't support marshaling
    93  // of some types of protobuf message, therefore reflection is needed to support it.
    94  // TODO: performance optimization.
    95  func marshalNonProtoField(v interface{}) ([]byte, error) {
    96  	if v == nil {
    97  		return []byte("null"), nil
    98  	}
    99  
   100  	// reflection
   101  	rv := reflect.ValueOf(v)
   102  
   103  	// get value to which the pointer points
   104  	for rv.Kind() == reflect.Ptr {
   105  		if rv.IsNil() {
   106  			return []byte("null"), nil
   107  		}
   108  		rv = rv.Elem()
   109  	}
   110  
   111  	// marshal name but value of enum
   112  	if enum, ok := rv.Interface().(wrappedEnum); ok && !Marshaller.UseEnumNumbers {
   113  		return JSONAPI.Marshal(enum.String())
   114  	}
   115  	// marshal map proto message
   116  	if rv.Kind() == reflect.Map {
   117  		// make map for marshalling
   118  		m := make(map[string]*jsoniter.RawMessage)
   119  		for _, key := range rv.MapKeys() { // range all keys
   120  			// marshal value
   121  			out, err := marshal(rv.MapIndex(key).Interface())
   122  			if err != nil {
   123  				return out, err
   124  			}
   125  			// assignment
   126  			m[fmt.Sprintf("%v", key.Interface())] = (*jsoniter.RawMessage)(&out)
   127  			if Marshaller.Indent != "" { // 指定 indent
   128  				return JSONAPI.MarshalIndent(v, "", Marshaller.Indent)
   129  			}
   130  			return JSONAPI.Marshal(v)
   131  		}
   132  	}
   133  	// marshal slice proto message
   134  	if rv.Kind() == reflect.Slice {
   135  		if rv.IsNil() { // nil slice
   136  			if Marshaller.EmitUnpopulated {
   137  				return []byte("[]"), nil
   138  			}
   139  			return []byte("null"), nil
   140  		}
   141  
   142  		if rv.Type().Elem().Implements(typeOfProtoMessage) { // type is proto
   143  			var buf bytes.Buffer
   144  			buf.WriteByte('[')
   145  			for i := 0; i < rv.Len(); i++ { // marshal one by one
   146  				out, err := marshal(rv.Index(i).Interface().(proto.Message))
   147  				if err != nil {
   148  					return nil, err
   149  				}
   150  				buf.Write(out)
   151  				if i != rv.Len()-1 {
   152  					buf.WriteByte(',')
   153  				}
   154  			}
   155  			buf.WriteByte(']')
   156  			return buf.Bytes(), nil
   157  		}
   158  	}
   159  
   160  	return JSONAPI.Marshal(v)
   161  }
   162  
   163  // Unmarshal implements Serializer.
   164  func (j *JSONPBSerializer) Unmarshal(data []byte, v interface{}) error {
   165  	if len(data) == 0 && j.AllowUnmarshalNil {
   166  		return nil
   167  	}
   168  	msg, ok := v.(proto.Message)
   169  	if !ok { // unmarshal a field of a tRPC message
   170  		return unmarshal(data, v)
   171  	}
   172  	// unmarshal tRPC message
   173  	return Unmarshaller.Unmarshal(data, msg)
   174  }
   175  
   176  // unmarshal unmarshal a field of a tRPC message.
   177  func unmarshal(data []byte, v interface{}) error {
   178  	msg, ok := v.(proto.Message)
   179  	if !ok { // unmarshal none proto fields
   180  		return unmarshalNonProtoField(data, v)
   181  	}
   182  	// unmarshal proto fields
   183  	return Unmarshaller.Unmarshal(data, msg)
   184  }
   185  
   186  // unmarshalNonProtoField unmarshals none proto fields.
   187  // TODO: performance optimization.
   188  func unmarshalNonProtoField(data []byte, v interface{}) error {
   189  	rv := reflect.ValueOf(v)
   190  	if rv.Kind() != reflect.Ptr { // Must be pointer type.
   191  		return fmt.Errorf("%T is not a pointer", v)
   192  	}
   193  	// get the value to which the pointer points
   194  	for rv.Kind() == reflect.Ptr {
   195  		if rv.IsNil() { // New an object if nil
   196  			rv.Set(reflect.New(rv.Type().Elem()))
   197  		}
   198  		// if the object's type is proto, just unmarshal
   199  		if msg, ok := rv.Interface().(proto.Message); ok {
   200  			return Unmarshaller.Unmarshal(data, msg)
   201  		}
   202  		rv = rv.Elem()
   203  	}
   204  	// can only unmarshal numeric enum
   205  	if _, ok := rv.Interface().(wrappedEnum); ok {
   206  		var x interface{}
   207  		if err := jsoniter.Unmarshal(data, &x); err != nil {
   208  			return err
   209  		}
   210  		switch t := x.(type) {
   211  		case float64:
   212  			rv.Set(reflect.ValueOf(int32(t)).Convert(rv.Type()))
   213  			return nil
   214  		default:
   215  			return fmt.Errorf("unmarshalling of %T into %T is not supported", t, rv.Interface())
   216  		}
   217  	}
   218  	// unmarshal to slice
   219  	if rv.Kind() == reflect.Slice {
   220  		// unmarshal to jsoniter.RawMessage first
   221  		var rms []jsoniter.RawMessage
   222  		if err := JSONAPI.Unmarshal(data, &rms); err != nil {
   223  			return err
   224  		}
   225  		if rms != nil { // rv MakeSlice
   226  			rv.Set(reflect.MakeSlice(rv.Type(), 0, 0))
   227  		}
   228  		// unmarshal one by one
   229  		for _, rm := range rms {
   230  			rn := reflect.New(rv.Type().Elem())
   231  			if err := unmarshal(rm, rn.Interface()); err != nil {
   232  				return err
   233  			}
   234  			rv.Set(reflect.Append(rv, rn.Elem()))
   235  		}
   236  		return nil
   237  	}
   238  	// unmarshal to map
   239  	if rv.Kind() == reflect.Map {
   240  		if rv.IsNil() { // rv MakeMap
   241  			rv.Set(reflect.MakeMap(rv.Type()))
   242  		}
   243  		// unmarshal to map[string]*jsoniter.RawMessage first
   244  		m := make(map[string]*jsoniter.RawMessage)
   245  		if err := JSONAPI.Unmarshal(data, &m); err != nil {
   246  			return err
   247  		}
   248  		kind := rv.Type().Key().Kind()
   249  		for key, value := range m { // unmarshal (k, v) one by one
   250  			convertedKey, err := convert(key, kind) // convert key
   251  			if err != nil {
   252  				return err
   253  			}
   254  			// unmarshal value
   255  			if value == nil {
   256  				rm := jsoniter.RawMessage("null")
   257  				value = &rm
   258  			}
   259  			rn := reflect.New(rv.Type().Elem())
   260  			if err := unmarshal([]byte(*value), rn.Interface()); err != nil {
   261  				return err
   262  			}
   263  			rv.SetMapIndex(reflect.ValueOf(convertedKey), rn.Elem())
   264  		}
   265  	}
   266  	return JSONAPI.Unmarshal(data, v)
   267  }
   268  
   269  // convert converts map key by reflect.Kind.
   270  func convert(key string, kind reflect.Kind) (interface{}, error) {
   271  	switch kind {
   272  	case reflect.String:
   273  		return key, nil
   274  	case reflect.Bool:
   275  		return strconv.ParseBool(key)
   276  	case reflect.Int32:
   277  		v, err := strconv.ParseInt(key, 0, 32)
   278  		if err != nil {
   279  			return nil, err
   280  		}
   281  		return int32(v), nil
   282  	case reflect.Uint32:
   283  		v, err := strconv.ParseUint(key, 0, 32)
   284  		if err != nil {
   285  			return nil, err
   286  		}
   287  		return uint32(v), nil
   288  	case reflect.Int64:
   289  		return strconv.ParseInt(key, 0, 64)
   290  	case reflect.Uint64:
   291  		return strconv.ParseUint(key, 0, 64)
   292  	case reflect.Float32:
   293  		v, err := strconv.ParseFloat(key, 32)
   294  		if err != nil {
   295  			return nil, err
   296  		}
   297  		return float32(v), nil
   298  	case reflect.Float64:
   299  		return strconv.ParseFloat(key, 64)
   300  	default:
   301  		return nil, fmt.Errorf("unsupported kind: %v", kind)
   302  	}
   303  }
   304  
   305  // Name implements Serializer.
   306  func (*JSONPBSerializer) Name() string {
   307  	return "application/json"
   308  }
   309  
   310  // ContentType implements Serializer.
   311  func (*JSONPBSerializer) ContentType() string {
   312  	return "application/json"
   313  }