trpc.group/trpc-go/trpc-go@v1.0.3/restful/populate_util.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package restful
    15  
    16  import (
    17  	"encoding/base64"
    18  	"errors"
    19  	"fmt"
    20  	"strconv"
    21  	"strings"
    22  	"time"
    23  
    24  	"google.golang.org/protobuf/proto"
    25  	"google.golang.org/protobuf/reflect/protoreflect"
    26  	"google.golang.org/protobuf/reflect/protoregistry"
    27  	"google.golang.org/protobuf/types/known/durationpb"
    28  	"google.golang.org/protobuf/types/known/fieldmaskpb"
    29  	"google.golang.org/protobuf/types/known/timestamppb"
    30  	"google.golang.org/protobuf/types/known/wrapperspb"
    31  )
    32  
    33  var (
    34  	// ErrTraverseNotFound is the error which indicates the field is
    35  	// not found after traversing the proto message.
    36  	ErrTraverseNotFound = errors.New("field not found")
    37  )
    38  
    39  // PopulateMessage populates a proto message.
    40  func PopulateMessage(msg proto.Message, fieldPath []string, values []string) error {
    41  	// empty check
    42  	if len(fieldPath) == 0 || len(values) == 0 {
    43  		return fmt.Errorf("fieldPath: %v or values: %v is empty", fieldPath, values)
    44  	}
    45  
    46  	// proto reflection
    47  	message := msg.ProtoReflect()
    48  
    49  	// traverse for leaf field by field path
    50  	message, fd, err := traverse(message, fieldPath)
    51  	if err != nil {
    52  		return fmt.Errorf("failed to traverse for leaf field by fieldPath %v: %w", fieldPath, err)
    53  	}
    54  
    55  	// populate the field
    56  	switch {
    57  	case fd.IsList(): // repeated field
    58  		return populateRepeatedField(fd, message.Mutable(fd).List(), values)
    59  	case fd.IsMap(): // map field
    60  		return populateMapField(fd, message.Mutable(fd).Map(), values)
    61  	default: // normal field
    62  		return populateField(fd, message, values)
    63  	}
    64  }
    65  
    66  // fdByName returns field descriptor by field name.
    67  func fdByName(message protoreflect.Message, name string) (protoreflect.FieldDescriptor, error) {
    68  	if message == nil {
    69  		return nil, errors.New("get field descriptor from nil message")
    70  	}
    71  
    72  	field := message.Descriptor().Fields().ByJSONName(name)
    73  	if field == nil {
    74  		field = message.Descriptor().Fields().ByName(protoreflect.Name(name))
    75  	}
    76  	if field == nil {
    77  		return nil, fmt.Errorf("%w: %v", ErrTraverseNotFound, name)
    78  	}
    79  	return field, nil
    80  }
    81  
    82  // traverse traverses the nested proto message by names and returns the descriptor of the leaf field.
    83  func traverse(
    84  	message protoreflect.Message,
    85  	fieldPath []string,
    86  ) (protoreflect.Message, protoreflect.FieldDescriptor, error) {
    87  	field, err := fdByName(message, fieldPath[0])
    88  	if err != nil {
    89  		return nil, nil, err
    90  	}
    91  
    92  	// leaf field
    93  	if len(fieldPath) == 1 {
    94  		return message, field, nil
    95  	}
    96  
    97  	// haven't reached the leaf field, need to continue traversing,
    98  	// and type of current field must be proto message
    99  	if field.Message() == nil || field.Cardinality() == protoreflect.Repeated {
   100  		return nil, nil, fmt.Errorf("type of field %s is not proto message", fieldPath[0])
   101  	}
   102  
   103  	// recursion
   104  	return traverse(message.Mutable(field).Message(), fieldPath[1:])
   105  }
   106  
   107  // populateField populates normal fields.
   108  func populateField(fd protoreflect.FieldDescriptor, msg protoreflect.Message, values []string) error {
   109  	// len of values should be 1
   110  	if len(values) != 1 {
   111  		return fmt.Errorf("tried to populate field %s with values %v", fd.FullName().Name(), values)
   112  	}
   113  
   114  	// parse value into protoreflect.Value
   115  	v, err := parseField(fd, values[0])
   116  	if err != nil {
   117  		return fmt.Errorf("failed to parse field %s: %w", fd.FullName().Name(), err)
   118  	}
   119  
   120  	// do the population
   121  	msg.Set(fd, v)
   122  	return nil
   123  }
   124  
   125  // populateRepeatedField populates repeated fields.
   126  func populateRepeatedField(fd protoreflect.FieldDescriptor, list protoreflect.List, values []string) error {
   127  	for _, value := range values {
   128  		// parse value into protoreflect.Value
   129  		v, err := parseField(fd, value)
   130  		if err != nil {
   131  			return fmt.Errorf("failed to parse repeated field %s: %w", fd.FullName().Name(), err)
   132  		}
   133  		// do the population
   134  		list.Append(v)
   135  	}
   136  	return nil
   137  }
   138  
   139  // populateMapField populates map fields.
   140  func populateMapField(fd protoreflect.FieldDescriptor, m protoreflect.Map, values []string) error {
   141  	// len of values should be 2
   142  	if len(values) != 2 {
   143  		return fmt.Errorf("tried to populate map field %s with values %v", fd.FullName().Name(), values)
   144  	}
   145  
   146  	// parse map key into protoreflect.Value
   147  	key, err := parseField(fd.MapKey(), values[0])
   148  	if err != nil {
   149  		return fmt.Errorf("failed to parse key of map field %s: %w", fd.FullName().Name(), err)
   150  	}
   151  
   152  	// parse map value into protoreflect.Value
   153  	value, err := parseField(fd.MapValue(), values[1])
   154  	if err != nil {
   155  		return fmt.Errorf("failed to parse value of map field %s: %w", fd.FullName().Name(), err)
   156  	}
   157  
   158  	// do the population
   159  	m.Set(key.MapKey(), value)
   160  	return nil
   161  }
   162  
   163  // parseField parses string value into protoreflect.Value by protoreflect.FieldDescriptor.
   164  func parseField(fd protoreflect.FieldDescriptor, value string) (protoreflect.Value, error) {
   165  	switch kind := fd.Kind(); kind {
   166  	case protoreflect.BoolKind:
   167  		v, err := strconv.ParseBool(value)
   168  		if err != nil {
   169  			return protoreflect.Value{}, err
   170  		}
   171  		return protoreflect.ValueOfBool(v), nil
   172  	case protoreflect.EnumKind:
   173  		return parseEnumField(fd, value)
   174  	case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
   175  		v, err := strconv.ParseInt(value, 10, 32)
   176  		if err != nil {
   177  			return protoreflect.Value{}, err
   178  		}
   179  		return protoreflect.ValueOfInt32(int32(v)), nil
   180  	case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
   181  		v, err := strconv.ParseInt(value, 10, 64)
   182  		if err != nil {
   183  			return protoreflect.Value{}, err
   184  		}
   185  		return protoreflect.ValueOfInt64(v), nil
   186  	case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
   187  		v, err := strconv.ParseUint(value, 10, 32)
   188  		if err != nil {
   189  			return protoreflect.Value{}, err
   190  		}
   191  		return protoreflect.ValueOfUint32(uint32(v)), nil
   192  	case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
   193  		v, err := strconv.ParseUint(value, 10, 64)
   194  		if err != nil {
   195  			return protoreflect.Value{}, err
   196  		}
   197  		return protoreflect.ValueOfUint64(v), nil
   198  	case protoreflect.FloatKind:
   199  		v, err := strconv.ParseFloat(value, 32)
   200  		if err != nil {
   201  			return protoreflect.Value{}, err
   202  		}
   203  		return protoreflect.ValueOfFloat32(float32(v)), nil
   204  	case protoreflect.DoubleKind:
   205  		v, err := strconv.ParseFloat(value, 64)
   206  		if err != nil {
   207  			return protoreflect.Value{}, err
   208  		}
   209  		return protoreflect.ValueOfFloat64(v), nil
   210  	case protoreflect.StringKind:
   211  		return protoreflect.ValueOfString(value), nil
   212  	case protoreflect.BytesKind:
   213  		v, err := base64.URLEncoding.DecodeString(value)
   214  		if err != nil {
   215  			return protoreflect.Value{}, err
   216  		}
   217  		return protoreflect.ValueOfBytes(v), nil
   218  	case protoreflect.MessageKind, protoreflect.GroupKind:
   219  		return parseMessage(fd.Message(), value)
   220  	default:
   221  		return protoreflect.Value{}, fmt.Errorf("unsupported field kind: %v", kind)
   222  	}
   223  }
   224  
   225  // parseEnumField parses enum fields.
   226  func parseEnumField(fd protoreflect.FieldDescriptor, value string) (protoreflect.Value, error) {
   227  	enum, err := protoregistry.GlobalTypes.FindEnumByName(fd.Enum().FullName())
   228  	switch {
   229  	case errors.Is(err, protoregistry.NotFound):
   230  		return protoreflect.Value{}, fmt.Errorf("enum %s is not registered", fd.Enum().FullName())
   231  	case err != nil:
   232  		return protoreflect.Value{}, fmt.Errorf("failed to look up enum: %w", err)
   233  	}
   234  	v := enum.Descriptor().Values().ByName(protoreflect.Name(value))
   235  	if v == nil {
   236  		i, err := strconv.Atoi(value)
   237  		if err != nil {
   238  			return protoreflect.Value{}, fmt.Errorf("%s is not a valid value", value)
   239  		}
   240  		v = enum.Descriptor().Values().ByNumber(protoreflect.EnumNumber(i))
   241  		if v == nil {
   242  			return protoreflect.Value{}, fmt.Errorf("%s is not a valid value", value)
   243  		}
   244  	}
   245  	return protoreflect.ValueOfEnum(v.Number()), nil
   246  }
   247  
   248  // parseMessage parses string value into protoreflect.Value by protoreflect.MessageDescriptor.
   249  // It's used to parse google.protobuf.xxx.
   250  func parseMessage(md protoreflect.MessageDescriptor, value string) (protoreflect.Value, error) {
   251  	switch md.FullName() {
   252  	case "google.protobuf.Timestamp":
   253  		return parseTimestampMessage(value)
   254  	case "google.protobuf.Duration":
   255  		return parseDurationMessage(value)
   256  	case "google.protobuf.DoubleValue":
   257  		return parseDoubleValueMessage(value)
   258  	case "google.protobuf.FloatValue":
   259  		return parseFloatValueMessage(value)
   260  	case "google.protobuf.Int64Value":
   261  		return parseInt64ValueMessage(value)
   262  	case "google.protobuf.Int32Value":
   263  		return parseInt32ValueMessage(value)
   264  	case "google.protobuf.UInt64Value":
   265  		return parseUInt64ValueMessage(value)
   266  	case "google.protobuf.UInt32Value":
   267  		return parseUInt32ValueMessage(value)
   268  	case "google.protobuf.BoolValue":
   269  		return parseBoolValueMessage(value)
   270  	case "google.protobuf.StringValue":
   271  		sv := &wrapperspb.StringValue{Value: value}
   272  		return protoreflect.ValueOfMessage(sv.ProtoReflect()), nil
   273  	case "google.protobuf.BytesValue":
   274  		return parseBytesValueMessage(value)
   275  	case "google.protobuf.FieldMask":
   276  		fm := &fieldmaskpb.FieldMask{}
   277  		fm.Paths = append(fm.Paths, strings.Split(value, ",")...)
   278  		return protoreflect.ValueOfMessage(fm.ProtoReflect()), nil
   279  	default:
   280  		return protoreflect.Value{}, fmt.Errorf("unsupported message type: %s", string(md.FullName()))
   281  	}
   282  }
   283  
   284  // parseTimestampMessage parses google.protobuf.Timestamp.
   285  func parseTimestampMessage(value string) (protoreflect.Value, error) {
   286  	var msg proto.Message
   287  	if value != "null" {
   288  		t, err := time.Parse(time.RFC3339Nano, value)
   289  		if err != nil {
   290  			return protoreflect.Value{}, err
   291  		}
   292  		msg = timestamppb.New(t)
   293  	}
   294  	return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil
   295  }
   296  
   297  // parseDurationMessage parses google.protobuf.Duration.
   298  func parseDurationMessage(value string) (protoreflect.Value, error) {
   299  	var msg proto.Message
   300  	if value != "null" {
   301  		d, err := time.ParseDuration(value)
   302  		if err != nil {
   303  			return protoreflect.Value{}, err
   304  		}
   305  		msg = durationpb.New(d)
   306  	}
   307  	return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil
   308  }
   309  
   310  // parseDoubleValueMessage parses google.protobuf.DoubleValue.
   311  func parseDoubleValueMessage(value string) (protoreflect.Value, error) {
   312  	v, err := strconv.ParseFloat(value, 64)
   313  	if err != nil {
   314  		return protoreflect.Value{}, err
   315  	}
   316  	msg := &wrapperspb.DoubleValue{Value: v}
   317  	return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil
   318  }
   319  
   320  // parseFloatValueMessage parses google.protobuf.FloatValue.
   321  func parseFloatValueMessage(value string) (protoreflect.Value, error) {
   322  	v, err := strconv.ParseFloat(value, 32)
   323  	if err != nil {
   324  		return protoreflect.Value{}, err
   325  	}
   326  	msg := &wrapperspb.FloatValue{Value: float32(v)}
   327  	return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil
   328  }
   329  
   330  // parseInt64ValueMessage parses google.protobuf.Int64Value.
   331  func parseInt64ValueMessage(value string) (protoreflect.Value, error) {
   332  	v, err := strconv.ParseInt(value, 10, 64)
   333  	if err != nil {
   334  		return protoreflect.Value{}, err
   335  	}
   336  	msg := &wrapperspb.Int64Value{Value: v}
   337  	return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil
   338  }
   339  
   340  // parseInt32ValueMessage parses google.protobuf.Int32Value.
   341  func parseInt32ValueMessage(value string) (protoreflect.Value, error) {
   342  	v, err := strconv.ParseInt(value, 10, 32)
   343  	if err != nil {
   344  		return protoreflect.Value{}, err
   345  	}
   346  	msg := &wrapperspb.Int32Value{Value: int32(v)}
   347  	return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil
   348  }
   349  
   350  // parseUInt64ValueMessage parses google.protobuf.UInt64Value.
   351  func parseUInt64ValueMessage(value string) (protoreflect.Value, error) {
   352  	v, err := strconv.ParseUint(value, 10, 64)
   353  	if err != nil {
   354  		return protoreflect.Value{}, err
   355  	}
   356  	msg := &wrapperspb.UInt64Value{Value: v}
   357  	return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil
   358  }
   359  
   360  // parseUInt32ValueMessage parses google.protobuf.UInt32Value.
   361  func parseUInt32ValueMessage(value string) (protoreflect.Value, error) {
   362  	v, err := strconv.ParseUint(value, 10, 32)
   363  	if err != nil {
   364  		return protoreflect.Value{}, err
   365  	}
   366  	msg := &wrapperspb.UInt32Value{Value: uint32(v)}
   367  	return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil
   368  }
   369  
   370  // parseBoolValueMessage parses google.protobuf.BoolValue.
   371  func parseBoolValueMessage(value string) (protoreflect.Value, error) {
   372  	v, err := strconv.ParseBool(value)
   373  	if err != nil {
   374  		return protoreflect.Value{}, err
   375  	}
   376  	msg := &wrapperspb.BoolValue{Value: v}
   377  	return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil
   378  }
   379  
   380  // parseBytesValueMessage parses google.protobuf.BytesValue.
   381  func parseBytesValueMessage(value string) (protoreflect.Value, error) {
   382  	v, err := base64.URLEncoding.DecodeString(value)
   383  	if err != nil {
   384  		return protoreflect.Value{}, err
   385  	}
   386  	msg := &wrapperspb.BytesValue{Value: v}
   387  	return protoreflect.ValueOfMessage(msg.ProtoReflect()), nil
   388  }
   389  
   390  // setFieldMask sets field mask for the field.
   391  func setFieldMask(message protoreflect.Message, fieldPath string) error {
   392  	maskFd := theMaskField(message)
   393  	if maskFd == nil {
   394  		return nil
   395  	}
   396  
   397  	partiallyUpdated, err := fdByName(message, fieldPath)
   398  	if err != nil {
   399  		return fmt.Errorf("failed to find partially updated field %s, err: %w", fieldPath, err)
   400  	}
   401  	if !isPlainMessage(partiallyUpdated) {
   402  		return fmt.Errorf("with FieldMask enabled, partially updated field must be a plain message")
   403  	}
   404  	message.Set(maskFd, protoreflect.ValueOfMessage((&fieldmaskpb.FieldMask{
   405  		Paths: getPopulatedFieldPaths(message.Get(partiallyUpdated).Message()),
   406  	}).ProtoReflect()))
   407  	return nil
   408  }
   409  
   410  // theMaskField returns the only field whose type is googleProtobufFieldMaskFullName, otherwise, returns nil.
   411  func theMaskField(message protoreflect.Message) protoreflect.FieldDescriptor {
   412  	var count int
   413  	var theFd protoreflect.FieldDescriptor
   414  	message.Descriptor().Fields()
   415  	for i, fds := 0, message.Descriptor().Fields(); i < fds.Len(); i++ {
   416  		fd := fds.Get(i)
   417  		if isPlainMessage(fd) && fd.Message().FullName() == googleProtobufFieldMaskFullName {
   418  			count++
   419  			theFd = fd
   420  		}
   421  	}
   422  
   423  	if count == 1 {
   424  		return theFd
   425  	}
   426  	return nil
   427  }
   428  
   429  var googleProtobufFieldMaskFullName = (*fieldmaskpb.FieldMask)(nil).ProtoReflect().Descriptor().FullName()
   430  
   431  func isPlainMessage(fd protoreflect.FieldDescriptor) bool {
   432  	return fd.Message() != nil && !fd.IsList() && !fd.IsMap()
   433  }
   434  
   435  // getPopulatedFieldPaths returns all populated field paths.
   436  func getPopulatedFieldPaths(message protoreflect.Message) []string {
   437  	var res []string
   438  	dfs(message, []string{}, &res)
   439  	return res
   440  }
   441  
   442  // dfs performs the Depth-first search algorithm.
   443  func dfs(message protoreflect.Message, paths []string, res *[]string) {
   444  	message.Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
   445  		name := string(fd.FullName().Name())
   446  		if isPlainMessage(fd) {
   447  			dfs(v.Message(), append(paths, name), res)
   448  		} else {
   449  			*res = append(*res, strings.Join(append(paths, name), "."))
   450  		}
   451  		return true
   452  	})
   453  }