trpc.group/trpc-go/trpc-go@v1.0.2/restful/serialize_jsonpb.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package restful 15 16 import ( 17 "bytes" 18 "fmt" 19 "reflect" 20 "strconv" 21 22 jsoniter "github.com/json-iterator/go" 23 "google.golang.org/protobuf/encoding/protojson" 24 "google.golang.org/protobuf/proto" 25 "google.golang.org/protobuf/reflect/protoreflect" 26 ) 27 28 func init() { 29 RegisterSerializer(&JSONPBSerializer{}) 30 } 31 32 // JSONPBSerializer is used for content-Type: application/json. 33 // It's based on google.golang.org/protobuf/encoding/protojson. 34 type JSONPBSerializer struct { 35 AllowUnmarshalNil bool // allow unmarshalling nil body 36 } 37 38 // JSONAPI is a copy of jsoniter.ConfigCompatibleWithStandardLibrary. 39 // github.com/json-iterator/go is faster than Go's standard json library. 40 var JSONAPI = jsoniter.ConfigCompatibleWithStandardLibrary 41 42 // Marshaller is a configurable protojson marshaler. 43 var Marshaller = protojson.MarshalOptions{EmitUnpopulated: true} 44 45 // Unmarshaller is a configurable protojson unmarshaler. 46 var Unmarshaller = protojson.UnmarshalOptions{DiscardUnknown: true} 47 48 // Marshal implements Serializer. 49 // Unlike Serializers in trpc-go/codec, Serializers in trpc-go/restful 50 // could be used to marshal a field of a tRPC message. 51 func (*JSONPBSerializer) Marshal(v interface{}) ([]byte, error) { 52 msg, ok := v.(proto.Message) 53 if !ok { // marshal a field of a tRPC message 54 return marshal(v) 55 } 56 // marshal tRPC message 57 return Marshaller.Marshal(msg) 58 } 59 60 // marshal is a helper function that is used to marshal a field of a tRPC message. 61 func marshal(v interface{}) ([]byte, error) { 62 msg, ok := v.(proto.Message) 63 if !ok { // marshal none proto field 64 return marshalNonProtoField(v) 65 } 66 // marshal proto field 67 return Marshaller.Marshal(msg) 68 } 69 70 // wrappedEnum is used to get the name of enum. 71 type wrappedEnum interface { 72 protoreflect.Enum 73 String() string 74 } 75 76 // typeOfProtoMessage is used to avoid multiple reflection and check if the object 77 // implements proto.Message interface. 78 var typeOfProtoMessage = reflect.TypeOf((*proto.Message)(nil)).Elem() 79 80 // marshalNonProtoField marshals none proto fields. 81 // Go's standard json lib or github.com/json-iterator/go doesn't support marshaling 82 // of some types of protobuf message, therefore reflection is needed to support it. 83 // TODO: performance optimization. 84 func marshalNonProtoField(v interface{}) ([]byte, error) { 85 if v == nil { 86 return []byte("null"), nil 87 } 88 89 // reflection 90 rv := reflect.ValueOf(v) 91 92 // get value to which the pointer points 93 for rv.Kind() == reflect.Ptr { 94 if rv.IsNil() { 95 return []byte("null"), nil 96 } 97 rv = rv.Elem() 98 } 99 100 // marshal name but value of enum 101 if enum, ok := rv.Interface().(wrappedEnum); ok && !Marshaller.UseEnumNumbers { 102 return JSONAPI.Marshal(enum.String()) 103 } 104 // marshal map proto message 105 if rv.Kind() == reflect.Map { 106 // make map for marshalling 107 m := make(map[string]*jsoniter.RawMessage) 108 for _, key := range rv.MapKeys() { // range all keys 109 // marshal value 110 out, err := marshal(rv.MapIndex(key).Interface()) 111 if err != nil { 112 return out, err 113 } 114 // assignment 115 m[fmt.Sprintf("%v", key.Interface())] = (*jsoniter.RawMessage)(&out) 116 if Marshaller.Indent != "" { // 指定 indent 117 return JSONAPI.MarshalIndent(v, "", Marshaller.Indent) 118 } 119 return JSONAPI.Marshal(v) 120 } 121 } 122 // marshal slice proto message 123 if rv.Kind() == reflect.Slice { 124 if rv.IsNil() { // nil slice 125 if Marshaller.EmitUnpopulated { 126 return []byte("[]"), nil 127 } 128 return []byte("null"), nil 129 } 130 131 if rv.Type().Elem().Implements(typeOfProtoMessage) { // type is proto 132 var buf bytes.Buffer 133 buf.WriteByte('[') 134 for i := 0; i < rv.Len(); i++ { // marshal one by one 135 out, err := marshal(rv.Index(i).Interface().(proto.Message)) 136 if err != nil { 137 return nil, err 138 } 139 buf.Write(out) 140 if i != rv.Len()-1 { 141 buf.WriteByte(',') 142 } 143 } 144 buf.WriteByte(']') 145 return buf.Bytes(), nil 146 } 147 } 148 149 return JSONAPI.Marshal(v) 150 } 151 152 // Unmarshal implements Serializer. 153 func (j *JSONPBSerializer) Unmarshal(data []byte, v interface{}) error { 154 if len(data) == 0 && j.AllowUnmarshalNil { 155 return nil 156 } 157 msg, ok := v.(proto.Message) 158 if !ok { // unmarshal a field of a tRPC message 159 return unmarshal(data, v) 160 } 161 // unmarshal tRPC message 162 return Unmarshaller.Unmarshal(data, msg) 163 } 164 165 // unmarshal unmarshal a field of a tRPC message. 166 func unmarshal(data []byte, v interface{}) error { 167 msg, ok := v.(proto.Message) 168 if !ok { // unmarshal none proto fields 169 return unmarshalNonProtoField(data, v) 170 } 171 // unmarshal proto fields 172 return Unmarshaller.Unmarshal(data, msg) 173 } 174 175 // unmarshalNonProtoField unmarshals none proto fields. 176 // TODO: performance optimization. 177 func unmarshalNonProtoField(data []byte, v interface{}) error { 178 rv := reflect.ValueOf(v) 179 if rv.Kind() != reflect.Ptr { // Must be pointer type. 180 return fmt.Errorf("%T is not a pointer", v) 181 } 182 // get the value to which the pointer points 183 for rv.Kind() == reflect.Ptr { 184 if rv.IsNil() { // New an object if nil 185 rv.Set(reflect.New(rv.Type().Elem())) 186 } 187 // if the object's type is proto, just unmarshal 188 if msg, ok := rv.Interface().(proto.Message); ok { 189 return Unmarshaller.Unmarshal(data, msg) 190 } 191 rv = rv.Elem() 192 } 193 // can only unmarshal numeric enum 194 if _, ok := rv.Interface().(wrappedEnum); ok { 195 var x interface{} 196 if err := jsoniter.Unmarshal(data, &x); err != nil { 197 return err 198 } 199 switch t := x.(type) { 200 case float64: 201 rv.Set(reflect.ValueOf(int32(t)).Convert(rv.Type())) 202 return nil 203 default: 204 return fmt.Errorf("unmarshalling of %T into %T is not supported", t, rv.Interface()) 205 } 206 } 207 // unmarshal to slice 208 if rv.Kind() == reflect.Slice { 209 // unmarshal to jsoniter.RawMessage first 210 var rms []jsoniter.RawMessage 211 if err := JSONAPI.Unmarshal(data, &rms); err != nil { 212 return err 213 } 214 if rms != nil { // rv MakeSlice 215 rv.Set(reflect.MakeSlice(rv.Type(), 0, 0)) 216 } 217 // unmarshal one by one 218 for _, rm := range rms { 219 rn := reflect.New(rv.Type().Elem()) 220 if err := unmarshal(rm, rn.Interface()); err != nil { 221 return err 222 } 223 rv.Set(reflect.Append(rv, rn.Elem())) 224 } 225 return nil 226 } 227 // unmarshal to map 228 if rv.Kind() == reflect.Map { 229 if rv.IsNil() { // rv MakeMap 230 rv.Set(reflect.MakeMap(rv.Type())) 231 } 232 // unmarshal to map[string]*jsoniter.RawMessage first 233 m := make(map[string]*jsoniter.RawMessage) 234 if err := JSONAPI.Unmarshal(data, &m); err != nil { 235 return err 236 } 237 kind := rv.Type().Key().Kind() 238 for key, value := range m { // unmarshal (k, v) one by one 239 convertedKey, err := convert(key, kind) // convert key 240 if err != nil { 241 return err 242 } 243 // unmarshal value 244 if value == nil { 245 rm := jsoniter.RawMessage("null") 246 value = &rm 247 } 248 rn := reflect.New(rv.Type().Elem()) 249 if err := unmarshal([]byte(*value), rn.Interface()); err != nil { 250 return err 251 } 252 rv.SetMapIndex(reflect.ValueOf(convertedKey), rn.Elem()) 253 } 254 } 255 return JSONAPI.Unmarshal(data, v) 256 } 257 258 // convert converts map key by reflect.Kind. 259 func convert(key string, kind reflect.Kind) (interface{}, error) { 260 switch kind { 261 case reflect.String: 262 return key, nil 263 case reflect.Bool: 264 return strconv.ParseBool(key) 265 case reflect.Int32: 266 v, err := strconv.ParseInt(key, 0, 32) 267 if err != nil { 268 return nil, err 269 } 270 return int32(v), nil 271 case reflect.Uint32: 272 v, err := strconv.ParseUint(key, 0, 32) 273 if err != nil { 274 return nil, err 275 } 276 return uint32(v), nil 277 case reflect.Int64: 278 return strconv.ParseInt(key, 0, 64) 279 case reflect.Uint64: 280 return strconv.ParseUint(key, 0, 64) 281 case reflect.Float32: 282 v, err := strconv.ParseFloat(key, 32) 283 if err != nil { 284 return nil, err 285 } 286 return float32(v), nil 287 case reflect.Float64: 288 return strconv.ParseFloat(key, 64) 289 default: 290 return nil, fmt.Errorf("unsupported kind: %v", kind) 291 } 292 } 293 294 // Name implements Serializer. 295 func (*JSONPBSerializer) Name() string { 296 return "application/json" 297 } 298 299 // ContentType implements Serializer. 300 func (*JSONPBSerializer) ContentType() string { 301 return "application/json" 302 }