trpc.group/trpc-go/trpc-go@v1.0.3/restful/populate_util.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package restful 15 16 import ( 17 "encoding/base64" 18 "errors" 19 "fmt" 20 "strconv" 21 "strings" 22 "time" 23 24 "google.golang.org/protobuf/proto" 25 "google.golang.org/protobuf/reflect/protoreflect" 26 "google.golang.org/protobuf/reflect/protoregistry" 27 "google.golang.org/protobuf/types/known/durationpb" 28 "google.golang.org/protobuf/types/known/fieldmaskpb" 29 "google.golang.org/protobuf/types/known/timestamppb" 30 "google.golang.org/protobuf/types/known/wrapperspb" 31 ) 32 33 var ( 34 // ErrTraverseNotFound is the error which indicates the field is 35 // not found after traversing the proto message. 36 ErrTraverseNotFound = errors.New("field not found") 37 ) 38 39 // PopulateMessage populates a proto message. 40 func PopulateMessage(msg proto.Message, fieldPath []string, values []string) error { 41 // empty check 42 if len(fieldPath) == 0 || len(values) == 0 { 43 return fmt.Errorf("fieldPath: %v or values: %v is empty", fieldPath, values) 44 } 45 46 // proto reflection 47 message := msg.ProtoReflect() 48 49 // traverse for leaf field by field path 50 message, fd, err := traverse(message, fieldPath) 51 if err != nil { 52 return fmt.Errorf("failed to traverse for leaf field by fieldPath %v: %w", fieldPath, err) 53 } 54 55 // populate the field 56 switch { 57 case fd.IsList(): // repeated field 58 return populateRepeatedField(fd, message.Mutable(fd).List(), values) 59 case fd.IsMap(): // map field 60 return populateMapField(fd, message.Mutable(fd).Map(), values) 61 default: // normal field 62 return populateField(fd, message, values) 63 } 64 } 65 66 // fdByName returns field descriptor by field name. 67 func fdByName(message protoreflect.Message, name string) (protoreflect.FieldDescriptor, error) { 68 if message == nil { 69 return nil, errors.New("get field descriptor from nil message") 70 } 71 72 field := message.Descriptor().Fields().ByJSONName(name) 73 if field == nil { 74 field = message.Descriptor().Fields().ByName(protoreflect.Name(name)) 75 } 76 if field == nil { 77 return nil, fmt.Errorf("%w: %v", ErrTraverseNotFound, name) 78 } 79 return field, nil 80 } 81 82 // traverse traverses the nested proto message by names and returns the descriptor of the leaf field. 83 func traverse( 84 message protoreflect.Message, 85 fieldPath []string, 86 ) (protoreflect.Message, protoreflect.FieldDescriptor, error) { 87 field, err := fdByName(message, fieldPath[0]) 88 if err != nil { 89 return nil, nil, err 90 } 91 92 // leaf field 93 if len(fieldPath) == 1 { 94 return message, field, nil 95 } 96 97 // haven't reached the leaf field, need to continue traversing, 98 // and type of current field must be proto message 99 if field.Message() == nil || field.Cardinality() == protoreflect.Repeated { 100 return nil, nil, fmt.Errorf("type of field %s is not proto message", fieldPath[0]) 101 } 102 103 // recursion 104 return traverse(message.Mutable(field).Message(), fieldPath[1:]) 105 } 106 107 // populateField populates normal fields. 108 func populateField(fd protoreflect.FieldDescriptor, msg protoreflect.Message, values []string) error { 109 // len of values should be 1 110 if len(values) != 1 { 111 return fmt.Errorf("tried to populate field %s with values %v", fd.FullName().Name(), values) 112 } 113 114 // parse value into protoreflect.Value 115 v, err := parseField(fd, values[0]) 116 if err != nil { 117 return fmt.Errorf("failed to parse field %s: %w", fd.FullName().Name(), err) 118 } 119 120 // do the population 121 msg.Set(fd, v) 122 return nil 123 } 124 125 // populateRepeatedField populates repeated fields. 126 func populateRepeatedField(fd protoreflect.FieldDescriptor, list protoreflect.List, values []string) error { 127 for _, value := range values { 128 // parse value into protoreflect.Value 129 v, err := parseField(fd, value) 130 if err != nil { 131 return fmt.Errorf("failed to parse repeated field %s: %w", fd.FullName().Name(), err) 132 } 133 // do the population 134 list.Append(v) 135 } 136 return nil 137 } 138 139 // populateMapField populates map fields. 140 func populateMapField(fd protoreflect.FieldDescriptor, m protoreflect.Map, values []string) error { 141 // len of values should be 2 142 if len(values) != 2 { 143 return fmt.Errorf("tried to populate map field %s with values %v", fd.FullName().Name(), values) 144 } 145 146 // parse map key into protoreflect.Value 147 key, err := parseField(fd.MapKey(), values[0]) 148 if err != nil { 149 return fmt.Errorf("failed to parse key of map field %s: %w", fd.FullName().Name(), err) 150 } 151 152 // parse map value into protoreflect.Value 153 value, err := parseField(fd.MapValue(), values[1]) 154 if err != nil { 155 return fmt.Errorf("failed to parse value of map field %s: %w", fd.FullName().Name(), err) 156 } 157 158 // do the population 159 m.Set(key.MapKey(), value) 160 return nil 161 } 162 163 // parseField parses string value into protoreflect.Value by protoreflect.FieldDescriptor. 164 func parseField(fd protoreflect.FieldDescriptor, value string) (protoreflect.Value, error) { 165 switch kind := fd.Kind(); kind { 166 case protoreflect.BoolKind: 167 v, err := strconv.ParseBool(value) 168 if err != nil { 169 return protoreflect.Value{}, err 170 } 171 return protoreflect.ValueOfBool(v), nil 172 case protoreflect.EnumKind: 173 return parseEnumField(fd, value) 174 case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: 175 v, err := strconv.ParseInt(value, 10, 32) 176 if err != nil { 177 return protoreflect.Value{}, err 178 } 179 return protoreflect.ValueOfInt32(int32(v)), nil 180 case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: 181 v, err := strconv.ParseInt(value, 10, 64) 182 if err != nil { 183 return protoreflect.Value{}, err 184 } 185 return protoreflect.ValueOfInt64(v), nil 186 case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: 187 v, err := strconv.ParseUint(value, 10, 32) 188 if err != nil { 189 return protoreflect.Value{}, err 190 } 191 return protoreflect.ValueOfUint32(uint32(v)), nil 192 case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: 193 v, err := strconv.ParseUint(value, 10, 64) 194 if err != nil { 195 return protoreflect.Value{}, err 196 } 197 return protoreflect.ValueOfUint64(v), nil 198 case protoreflect.FloatKind: 199 v, err := strconv.ParseFloat(value, 32) 200 if err != nil { 201 return protoreflect.Value{}, err 202 } 203 return protoreflect.ValueOfFloat32(float32(v)), nil 204 case protoreflect.DoubleKind: 205 v, err := strconv.ParseFloat(value, 64) 206 if err != nil { 207 return protoreflect.Value{}, err 208 } 209 return protoreflect.ValueOfFloat64(v), nil 210 case protoreflect.StringKind: 211 return protoreflect.ValueOfString(value), nil 212 case protoreflect.BytesKind: 213 v, err := base64.URLEncoding.DecodeString(value) 214 if err != nil { 215 return protoreflect.Value{}, err 216 } 217 return protoreflect.ValueOfBytes(v), nil 218 case protoreflect.MessageKind, protoreflect.GroupKind: 219 return parseMessage(fd.Message(), value) 220 default: 221 return protoreflect.Value{}, fmt.Errorf("unsupported field kind: %v", kind) 222 } 223 } 224 225 // parseEnumField parses enum fields. 226 func parseEnumField(fd protoreflect.FieldDescriptor, value string) (protoreflect.Value, error) { 227 enum, err := protoregistry.GlobalTypes.FindEnumByName(fd.Enum().FullName()) 228 switch { 229 case errors.Is(err, protoregistry.NotFound): 230 return protoreflect.Value{}, fmt.Errorf("enum %s is not registered", fd.Enum().FullName()) 231 case err != nil: 232 return protoreflect.Value{}, fmt.Errorf("failed to look up enum: %w", err) 233 } 234 v := enum.Descriptor().Values().ByName(protoreflect.Name(value)) 235 if v == nil { 236 i, err := strconv.Atoi(value) 237 if err != nil { 238 return protoreflect.Value{}, fmt.Errorf("%s is not a valid value", value) 239 } 240 v = enum.Descriptor().Values().ByNumber(protoreflect.EnumNumber(i)) 241 if v == nil { 242 return protoreflect.Value{}, fmt.Errorf("%s is not a valid value", value) 243 } 244 } 245 return protoreflect.ValueOfEnum(v.Number()), nil 246 } 247 248 // parseMessage parses string value into protoreflect.Value by protoreflect.MessageDescriptor. 249 // It's used to parse google.protobuf.xxx. 250 func parseMessage(md protoreflect.MessageDescriptor, value string) (protoreflect.Value, error) { 251 switch md.FullName() { 252 case "google.protobuf.Timestamp": 253 return parseTimestampMessage(value) 254 case "google.protobuf.Duration": 255 return parseDurationMessage(value) 256 case "google.protobuf.DoubleValue": 257 return parseDoubleValueMessage(value) 258 case "google.protobuf.FloatValue": 259 return parseFloatValueMessage(value) 260 case "google.protobuf.Int64Value": 261 return parseInt64ValueMessage(value) 262 case "google.protobuf.Int32Value": 263 return parseInt32ValueMessage(value) 264 case "google.protobuf.UInt64Value": 265 return parseUInt64ValueMessage(value) 266 case "google.protobuf.UInt32Value": 267 return parseUInt32ValueMessage(value) 268 case "google.protobuf.BoolValue": 269 return parseBoolValueMessage(value) 270 case "google.protobuf.StringValue": 271 sv := &wrapperspb.StringValue{Value: value} 272 return protoreflect.ValueOfMessage(sv.ProtoReflect()), nil 273 case "google.protobuf.BytesValue": 274 return parseBytesValueMessage(value) 275 case "google.protobuf.FieldMask": 276 fm := &fieldmaskpb.FieldMask{} 277 fm.Paths = append(fm.Paths, strings.Split(value, ",")...) 278 return protoreflect.ValueOfMessage(fm.ProtoReflect()), nil 279 default: 280 return protoreflect.Value{}, fmt.Errorf("unsupported message type: %s", string(md.FullName())) 281 } 282 } 283 284 // parseTimestampMessage parses google.protobuf.Timestamp. 285 func parseTimestampMessage(value string) (protoreflect.Value, error) { 286 var msg proto.Message 287 if value != "null" { 288 t, err := time.Parse(time.RFC3339Nano, value) 289 if err != nil { 290 return protoreflect.Value{}, err 291 } 292 msg = timestamppb.New(t) 293 } 294 return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil 295 } 296 297 // parseDurationMessage parses google.protobuf.Duration. 298 func parseDurationMessage(value string) (protoreflect.Value, error) { 299 var msg proto.Message 300 if value != "null" { 301 d, err := time.ParseDuration(value) 302 if err != nil { 303 return protoreflect.Value{}, err 304 } 305 msg = durationpb.New(d) 306 } 307 return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil 308 } 309 310 // parseDoubleValueMessage parses google.protobuf.DoubleValue. 311 func parseDoubleValueMessage(value string) (protoreflect.Value, error) { 312 v, err := strconv.ParseFloat(value, 64) 313 if err != nil { 314 return protoreflect.Value{}, err 315 } 316 msg := &wrapperspb.DoubleValue{Value: v} 317 return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil 318 } 319 320 // parseFloatValueMessage parses google.protobuf.FloatValue. 321 func parseFloatValueMessage(value string) (protoreflect.Value, error) { 322 v, err := strconv.ParseFloat(value, 32) 323 if err != nil { 324 return protoreflect.Value{}, err 325 } 326 msg := &wrapperspb.FloatValue{Value: float32(v)} 327 return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil 328 } 329 330 // parseInt64ValueMessage parses google.protobuf.Int64Value. 331 func parseInt64ValueMessage(value string) (protoreflect.Value, error) { 332 v, err := strconv.ParseInt(value, 10, 64) 333 if err != nil { 334 return protoreflect.Value{}, err 335 } 336 msg := &wrapperspb.Int64Value{Value: v} 337 return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil 338 } 339 340 // parseInt32ValueMessage parses google.protobuf.Int32Value. 341 func parseInt32ValueMessage(value string) (protoreflect.Value, error) { 342 v, err := strconv.ParseInt(value, 10, 32) 343 if err != nil { 344 return protoreflect.Value{}, err 345 } 346 msg := &wrapperspb.Int32Value{Value: int32(v)} 347 return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil 348 } 349 350 // parseUInt64ValueMessage parses google.protobuf.UInt64Value. 351 func parseUInt64ValueMessage(value string) (protoreflect.Value, error) { 352 v, err := strconv.ParseUint(value, 10, 64) 353 if err != nil { 354 return protoreflect.Value{}, err 355 } 356 msg := &wrapperspb.UInt64Value{Value: v} 357 return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil 358 } 359 360 // parseUInt32ValueMessage parses google.protobuf.UInt32Value. 361 func parseUInt32ValueMessage(value string) (protoreflect.Value, error) { 362 v, err := strconv.ParseUint(value, 10, 32) 363 if err != nil { 364 return protoreflect.Value{}, err 365 } 366 msg := &wrapperspb.UInt32Value{Value: uint32(v)} 367 return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil 368 } 369 370 // parseBoolValueMessage parses google.protobuf.BoolValue. 371 func parseBoolValueMessage(value string) (protoreflect.Value, error) { 372 v, err := strconv.ParseBool(value) 373 if err != nil { 374 return protoreflect.Value{}, err 375 } 376 msg := &wrapperspb.BoolValue{Value: v} 377 return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil 378 } 379 380 // parseBytesValueMessage parses google.protobuf.BytesValue. 381 func parseBytesValueMessage(value string) (protoreflect.Value, error) { 382 v, err := base64.URLEncoding.DecodeString(value) 383 if err != nil { 384 return protoreflect.Value{}, err 385 } 386 msg := &wrapperspb.BytesValue{Value: v} 387 return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil 388 } 389 390 // setFieldMask sets field mask for the field. 391 func setFieldMask(message protoreflect.Message, fieldPath string) error { 392 maskFd := theMaskField(message) 393 if maskFd == nil { 394 return nil 395 } 396 397 partiallyUpdated, err := fdByName(message, fieldPath) 398 if err != nil { 399 return fmt.Errorf("failed to find partially updated field %s, err: %w", fieldPath, err) 400 } 401 if !isPlainMessage(partiallyUpdated) { 402 return fmt.Errorf("with FieldMask enabled, partially updated field must be a plain message") 403 } 404 message.Set(maskFd, protoreflect.ValueOfMessage((&fieldmaskpb.FieldMask{ 405 Paths: getPopulatedFieldPaths(message.Get(partiallyUpdated).Message()), 406 }).ProtoReflect())) 407 return nil 408 } 409 410 // theMaskField returns the only field whose type is googleProtobufFieldMaskFullName, otherwise, returns nil. 411 func theMaskField(message protoreflect.Message) protoreflect.FieldDescriptor { 412 var count int 413 var theFd protoreflect.FieldDescriptor 414 message.Descriptor().Fields() 415 for i, fds := 0, message.Descriptor().Fields(); i < fds.Len(); i++ { 416 fd := fds.Get(i) 417 if isPlainMessage(fd) && fd.Message().FullName() == googleProtobufFieldMaskFullName { 418 count++ 419 theFd = fd 420 } 421 } 422 423 if count == 1 { 424 return theFd 425 } 426 return nil 427 } 428 429 var googleProtobufFieldMaskFullName = (*fieldmaskpb.FieldMask)(nil).ProtoReflect().Descriptor().FullName() 430 431 func isPlainMessage(fd protoreflect.FieldDescriptor) bool { 432 return fd.Message() != nil && !fd.IsList() && !fd.IsMap() 433 } 434 435 // getPopulatedFieldPaths returns all populated field paths. 436 func getPopulatedFieldPaths(message protoreflect.Message) []string { 437 var res []string 438 dfs(message, []string{}, &res) 439 return res 440 } 441 442 // dfs performs the Depth-first search algorithm. 443 func dfs(message protoreflect.Message, paths []string, res *[]string) { 444 message.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool { 445 name := string(fd.FullName().Name()) 446 if isPlainMessage(fd) { 447 dfs(v.Message(), append(paths, name), res) 448 } else { 449 *res = append(*res, strings.Join(append(paths, name), ".")) 450 } 451 return true 452 }) 453 }