trpc.group/trpc-go/trpc-go@v1.0.2/restful/serialize_jsonpb.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package restful
    15  
    16  import (
    17  	"bytes"
    18  	"fmt"
    19  	"reflect"
    20  	"strconv"
    21  
    22  	jsoniter "github.com/json-iterator/go"
    23  	"google.golang.org/protobuf/encoding/protojson"
    24  	"google.golang.org/protobuf/proto"
    25  	"google.golang.org/protobuf/reflect/protoreflect"
    26  )
    27  
    28  func init() {
    29  	RegisterSerializer(&JSONPBSerializer{})
    30  }
    31  
    32  // JSONPBSerializer is used for content-Type: application/json.
    33  // It's based on google.golang.org/protobuf/encoding/protojson.
    34  type JSONPBSerializer struct {
    35  	AllowUnmarshalNil bool // allow unmarshalling nil body
    36  }
    37  
    38  // JSONAPI is a copy of jsoniter.ConfigCompatibleWithStandardLibrary.
    39  // github.com/json-iterator/go is faster than Go's standard json library.
    40  var JSONAPI = jsoniter.ConfigCompatibleWithStandardLibrary
    41  
    42  // Marshaller is a configurable protojson marshaler.
    43  var Marshaller = protojson.MarshalOptions{EmitUnpopulated: true}
    44  
    45  // Unmarshaller is a configurable protojson unmarshaler.
    46  var Unmarshaller = protojson.UnmarshalOptions{DiscardUnknown: true}
    47  
    48  // Marshal implements Serializer.
    49  // Unlike Serializers in trpc-go/codec, Serializers in trpc-go/restful
    50  // could be used to marshal a field of a tRPC message.
    51  func (*JSONPBSerializer) Marshal(v interface{}) ([]byte, error) {
    52  	msg, ok := v.(proto.Message)
    53  	if !ok { // marshal a field of a tRPC message
    54  		return marshal(v)
    55  	}
    56  	// marshal tRPC message
    57  	return Marshaller.Marshal(msg)
    58  }
    59  
    60  // marshal is a helper function that is used to marshal a field of a tRPC message.
    61  func marshal(v interface{}) ([]byte, error) {
    62  	msg, ok := v.(proto.Message)
    63  	if !ok { // marshal none proto field
    64  		return marshalNonProtoField(v)
    65  	}
    66  	// marshal proto field
    67  	return Marshaller.Marshal(msg)
    68  }
    69  
    70  // wrappedEnum is used to get the name of enum.
    71  type wrappedEnum interface {
    72  	protoreflect.Enum
    73  	String() string
    74  }
    75  
    76  // typeOfProtoMessage is used to avoid multiple reflection and check if the object
    77  // implements proto.Message interface.
    78  var typeOfProtoMessage = reflect.TypeOf((*proto.Message)(nil)).Elem()
    79  
    80  // marshalNonProtoField marshals none proto fields.
    81  // Go's standard json lib or github.com/json-iterator/go doesn't support marshaling
    82  // of some types of protobuf message, therefore reflection is needed to support it.
    83  // TODO: performance optimization.
    84  func marshalNonProtoField(v interface{}) ([]byte, error) {
    85  	if v == nil {
    86  		return []byte("null"), nil
    87  	}
    88  
    89  	// reflection
    90  	rv := reflect.ValueOf(v)
    91  
    92  	// get value to which the pointer points
    93  	for rv.Kind() == reflect.Ptr {
    94  		if rv.IsNil() {
    95  			return []byte("null"), nil
    96  		}
    97  		rv = rv.Elem()
    98  	}
    99  
   100  	// marshal name but value of enum
   101  	if enum, ok := rv.Interface().(wrappedEnum); ok && !Marshaller.UseEnumNumbers {
   102  		return JSONAPI.Marshal(enum.String())
   103  	}
   104  	// marshal map proto message
   105  	if rv.Kind() == reflect.Map {
   106  		// make map for marshalling
   107  		m := make(map[string]*jsoniter.RawMessage)
   108  		for _, key := range rv.MapKeys() { // range all keys
   109  			// marshal value
   110  			out, err := marshal(rv.MapIndex(key).Interface())
   111  			if err != nil {
   112  				return out, err
   113  			}
   114  			// assignment
   115  			m[fmt.Sprintf("%v", key.Interface())] = (*jsoniter.RawMessage)(&out)
   116  			if Marshaller.Indent != "" { // 指定 indent
   117  				return JSONAPI.MarshalIndent(v, "", Marshaller.Indent)
   118  			}
   119  			return JSONAPI.Marshal(v)
   120  		}
   121  	}
   122  	// marshal slice proto message
   123  	if rv.Kind() == reflect.Slice {
   124  		if rv.IsNil() { // nil slice
   125  			if Marshaller.EmitUnpopulated {
   126  				return []byte("[]"), nil
   127  			}
   128  			return []byte("null"), nil
   129  		}
   130  
   131  		if rv.Type().Elem().Implements(typeOfProtoMessage) { // type is proto
   132  			var buf bytes.Buffer
   133  			buf.WriteByte('[')
   134  			for i := 0; i < rv.Len(); i++ { // marshal one by one
   135  				out, err := marshal(rv.Index(i).Interface().(proto.Message))
   136  				if err != nil {
   137  					return nil, err
   138  				}
   139  				buf.Write(out)
   140  				if i != rv.Len()-1 {
   141  					buf.WriteByte(',')
   142  				}
   143  			}
   144  			buf.WriteByte(']')
   145  			return buf.Bytes(), nil
   146  		}
   147  	}
   148  
   149  	return JSONAPI.Marshal(v)
   150  }
   151  
   152  // Unmarshal implements Serializer.
   153  func (j *JSONPBSerializer) Unmarshal(data []byte, v interface{}) error {
   154  	if len(data) == 0 && j.AllowUnmarshalNil {
   155  		return nil
   156  	}
   157  	msg, ok := v.(proto.Message)
   158  	if !ok { // unmarshal a field of a tRPC message
   159  		return unmarshal(data, v)
   160  	}
   161  	// unmarshal tRPC message
   162  	return Unmarshaller.Unmarshal(data, msg)
   163  }
   164  
   165  // unmarshal unmarshal a field of a tRPC message.
   166  func unmarshal(data []byte, v interface{}) error {
   167  	msg, ok := v.(proto.Message)
   168  	if !ok { // unmarshal none proto fields
   169  		return unmarshalNonProtoField(data, v)
   170  	}
   171  	// unmarshal proto fields
   172  	return Unmarshaller.Unmarshal(data, msg)
   173  }
   174  
   175  // unmarshalNonProtoField unmarshals none proto fields.
   176  // TODO: performance optimization.
   177  func unmarshalNonProtoField(data []byte, v interface{}) error {
   178  	rv := reflect.ValueOf(v)
   179  	if rv.Kind() != reflect.Ptr { // Must be pointer type.
   180  		return fmt.Errorf("%T is not a pointer", v)
   181  	}
   182  	// get the value to which the pointer points
   183  	for rv.Kind() == reflect.Ptr {
   184  		if rv.IsNil() { // New an object if nil
   185  			rv.Set(reflect.New(rv.Type().Elem()))
   186  		}
   187  		// if the object's type is proto, just unmarshal
   188  		if msg, ok := rv.Interface().(proto.Message); ok {
   189  			return Unmarshaller.Unmarshal(data, msg)
   190  		}
   191  		rv = rv.Elem()
   192  	}
   193  	// can only unmarshal numeric enum
   194  	if _, ok := rv.Interface().(wrappedEnum); ok {
   195  		var x interface{}
   196  		if err := jsoniter.Unmarshal(data, &x); err != nil {
   197  			return err
   198  		}
   199  		switch t := x.(type) {
   200  		case float64:
   201  			rv.Set(reflect.ValueOf(int32(t)).Convert(rv.Type()))
   202  			return nil
   203  		default:
   204  			return fmt.Errorf("unmarshalling of %T into %T is not supported", t, rv.Interface())
   205  		}
   206  	}
   207  	// unmarshal to slice
   208  	if rv.Kind() == reflect.Slice {
   209  		// unmarshal to jsoniter.RawMessage first
   210  		var rms []jsoniter.RawMessage
   211  		if err := JSONAPI.Unmarshal(data, &rms); err != nil {
   212  			return err
   213  		}
   214  		if rms != nil { // rv MakeSlice
   215  			rv.Set(reflect.MakeSlice(rv.Type(), 0, 0))
   216  		}
   217  		// unmarshal one by one
   218  		for _, rm := range rms {
   219  			rn := reflect.New(rv.Type().Elem())
   220  			if err := unmarshal(rm, rn.Interface()); err != nil {
   221  				return err
   222  			}
   223  			rv.Set(reflect.Append(rv, rn.Elem()))
   224  		}
   225  		return nil
   226  	}
   227  	// unmarshal to map
   228  	if rv.Kind() == reflect.Map {
   229  		if rv.IsNil() { // rv MakeMap
   230  			rv.Set(reflect.MakeMap(rv.Type()))
   231  		}
   232  		// unmarshal to map[string]*jsoniter.RawMessage first
   233  		m := make(map[string]*jsoniter.RawMessage)
   234  		if err := JSONAPI.Unmarshal(data, &m); err != nil {
   235  			return err
   236  		}
   237  		kind := rv.Type().Key().Kind()
   238  		for key, value := range m { // unmarshal (k, v) one by one
   239  			convertedKey, err := convert(key, kind) // convert key
   240  			if err != nil {
   241  				return err
   242  			}
   243  			// unmarshal value
   244  			if value == nil {
   245  				rm := jsoniter.RawMessage("null")
   246  				value = &rm
   247  			}
   248  			rn := reflect.New(rv.Type().Elem())
   249  			if err := unmarshal([]byte(*value), rn.Interface()); err != nil {
   250  				return err
   251  			}
   252  			rv.SetMapIndex(reflect.ValueOf(convertedKey), rn.Elem())
   253  		}
   254  	}
   255  	return JSONAPI.Unmarshal(data, v)
   256  }
   257  
   258  // convert converts map key by reflect.Kind.
   259  func convert(key string, kind reflect.Kind) (interface{}, error) {
   260  	switch kind {
   261  	case reflect.String:
   262  		return key, nil
   263  	case reflect.Bool:
   264  		return strconv.ParseBool(key)
   265  	case reflect.Int32:
   266  		v, err := strconv.ParseInt(key, 0, 32)
   267  		if err != nil {
   268  			return nil, err
   269  		}
   270  		return int32(v), nil
   271  	case reflect.Uint32:
   272  		v, err := strconv.ParseUint(key, 0, 32)
   273  		if err != nil {
   274  			return nil, err
   275  		}
   276  		return uint32(v), nil
   277  	case reflect.Int64:
   278  		return strconv.ParseInt(key, 0, 64)
   279  	case reflect.Uint64:
   280  		return strconv.ParseUint(key, 0, 64)
   281  	case reflect.Float32:
   282  		v, err := strconv.ParseFloat(key, 32)
   283  		if err != nil {
   284  			return nil, err
   285  		}
   286  		return float32(v), nil
   287  	case reflect.Float64:
   288  		return strconv.ParseFloat(key, 64)
   289  	default:
   290  		return nil, fmt.Errorf("unsupported kind: %v", kind)
   291  	}
   292  }
   293  
   294  // Name implements Serializer.
   295  func (*JSONPBSerializer) Name() string {
   296  	return "application/json"
   297  }
   298  
   299  // ContentType implements Serializer.
   300  func (*JSONPBSerializer) ContentType() string {
   301  	return "application/json"
   302  }