go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/client/flagpb/unmarshal.go (about)

     1  // Copyright 2016 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package flagpb
    16  
    17  import (
    18  	"bytes"
    19  	"encoding/hex"
    20  	"encoding/json"
    21  	"fmt"
    22  	"strconv"
    23  	"strings"
    24  
    25  	"google.golang.org/protobuf/types/descriptorpb"
    26  
    27  	"go.chromium.org/luci/common/proto/google/descutil"
    28  
    29  	"github.com/golang/protobuf/jsonpb"
    30  	"github.com/golang/protobuf/proto"
    31  )
    32  
    33  // UnmarshalMessage unmarshals the proto message from flags.
    34  //
    35  // The descriptor set should be obtained from the `cproto` compiled packages'
    36  // FileDescriptorSet() method.
    37  func UnmarshalMessage(flags []string, resolver Resolver, msg proto.Message) error {
    38  	// TODO(iannucci): avoid round-trip through parser and jsonpb and populate the
    39  	// message directly. This would involve writing some additional reflection
    40  	// code that may depend on implementation details of proto's generated Go
    41  	// code, which is why this wasn't done initially.
    42  	name := proto.MessageName(msg)
    43  	dproto, ok := resolver.Resolve(name).(*descriptorpb.DescriptorProto)
    44  	if !ok {
    45  		return fmt.Errorf("could not resolve message %q", name)
    46  	}
    47  
    48  	jdata, err := UnmarshalUntyped(flags, dproto, resolver)
    49  	if err != nil {
    50  		return err
    51  	}
    52  
    53  	jtext, err := json.Marshal(jdata)
    54  	if err != nil {
    55  		return err
    56  	}
    57  
    58  	return jsonpb.Unmarshal(bytes.NewReader(jtext), msg)
    59  }
    60  
    61  // UnmarshalUntyped unmarshals a key-value map from flags
    62  // using a protobuf message descriptor.
    63  func UnmarshalUntyped(flags []string, desc *descriptorpb.DescriptorProto, resolver Resolver) (map[string]any, error) {
    64  	p := parser{resolver}
    65  	return p.parse(flags, desc)
    66  }
    67  
    68  type message struct {
    69  	data map[string]any
    70  	desc *descriptorpb.DescriptorProto
    71  }
    72  
    73  type parser struct {
    74  	Resolver Resolver
    75  }
    76  
    77  func (p *parser) parse(flags []string, desc *descriptorpb.DescriptorProto) (map[string]any, error) {
    78  	if desc == nil {
    79  		panic("desc is nil")
    80  	}
    81  	root := message{map[string]any{}, desc}
    82  
    83  	for len(flags) > 0 {
    84  		var err error
    85  		if flags, err = p.parseOneFlag(flags, root); err != nil {
    86  			return nil, err
    87  		}
    88  	}
    89  	return root.data, nil
    90  }
    91  
    92  func (p *parser) parseOneFlag(flags []string, root message) (flagsRest []string, err error) {
    93  	// skip empty flags
    94  	for len(flags) > 0 && strings.TrimSpace(flags[0]) == "" {
    95  		flags = flags[1:]
    96  	}
    97  	if len(flags) == 0 {
    98  		return flags, nil
    99  	}
   100  
   101  	firstArg := flags[0]
   102  	flags = flags[1:]
   103  
   104  	// Prefix returned errors with flag name verbatim.
   105  	defer func() {
   106  		if err != nil {
   107  			err = fmt.Errorf("%s: %s", firstArg, err)
   108  		}
   109  	}()
   110  
   111  	// Trim dashes.
   112  	if !strings.HasPrefix(firstArg, "-") {
   113  		return nil, fmt.Errorf("a flag was expected")
   114  	}
   115  	flagName := strings.TrimPrefix(firstArg, "-") // -foo
   116  	flagName = strings.TrimPrefix(flagName, "-")  // --foo
   117  	if strings.HasPrefix(flagName, "-") {
   118  		// Triple dash is too much.
   119  		return nil, fmt.Errorf("bad flag syntax")
   120  	}
   121  
   122  	// Split key-value pair x=y.
   123  	flagName, valueStr, hasValueStr := p.splitKeyValuePair(flagName)
   124  	if flagName == "" {
   125  		return nil, fmt.Errorf("bad flag syntax")
   126  	}
   127  
   128  	// Split field path "a.b.c" and resolve field names.
   129  	fieldPath := strings.Split(flagName, ".")
   130  	pathMsgs, err := p.subMessages(root, fieldPath[:len(fieldPath)-1])
   131  	if err != nil {
   132  		return nil, err
   133  	}
   134  
   135  	// Where to assign the value?
   136  	target := &root
   137  	if len(pathMsgs) > 0 {
   138  		lastMsg := pathMsgs[len(pathMsgs)-1]
   139  		target = &lastMsg.message
   140  	}
   141  	name := fieldPath[len(fieldPath)-1]
   142  
   143  	// Resolve target field.
   144  	var fieldIndex int
   145  	if target.desc.GetOptions().GetMapEntry() {
   146  		if fieldIndex = descutil.FindField(target.desc, "value"); fieldIndex == -1 {
   147  			return nil, fmt.Errorf("map entry type %s does not have value field", target.desc.GetName())
   148  		}
   149  	} else {
   150  		if fieldIndex = descutil.FindField(target.desc, name); fieldIndex == -1 {
   151  			return nil, fmt.Errorf("field %s not found in message %s", name, target.desc.GetName())
   152  		}
   153  	}
   154  	field := target.desc.Field[fieldIndex]
   155  
   156  	var value any
   157  	hasValue := false
   158  
   159  	if !hasValueStr {
   160  		switch {
   161  		// Boolean and repeated message fields may have no value and ignore
   162  		// next argument.
   163  		case field.GetType() == descriptorpb.FieldDescriptorProto_TYPE_BOOL:
   164  			value = true
   165  			hasValue = true
   166  		case field.GetType() == descriptorpb.FieldDescriptorProto_TYPE_MESSAGE && descutil.Repeated(field):
   167  			value = map[string]any{}
   168  			hasValue = true
   169  
   170  		default:
   171  			// Read next argument as a value.
   172  			if len(flags) == 0 {
   173  				return nil, fmt.Errorf("value was expected")
   174  			}
   175  			valueStr, flags = flags[0], flags[1:]
   176  		}
   177  	}
   178  
   179  	// Check if the value is already set.
   180  	if target.data[name] != nil && !descutil.Repeated(field) {
   181  		repeatedFields := make([]string, 0, len(pathMsgs))
   182  		for _, m := range pathMsgs {
   183  			if m.repeated {
   184  				repeatedFields = append(repeatedFields, "-"+strings.Join(m.path, "."))
   185  			}
   186  		}
   187  		if len(repeatedFields) == 0 {
   188  			return nil, fmt.Errorf("value is already set to %v", target.data[name])
   189  		}
   190  		return nil, fmt.Errorf(
   191  			"value is already set to %v. Did you forgot to insert %s in between to declare a new repeated message?",
   192  			target.data[name], strings.Join(repeatedFields, " or "))
   193  	}
   194  
   195  	if !hasValue {
   196  		value, err = p.parseFieldValue(valueStr, target.desc.GetName(), field)
   197  		if err != nil {
   198  			return nil, err
   199  		}
   200  	}
   201  
   202  	if !descutil.Repeated(field) {
   203  		target.data[name] = value
   204  	} else {
   205  		target.data[name] = append(asSlice(target.data[name]), value)
   206  	}
   207  
   208  	return flags, nil
   209  }
   210  
   211  type subMsg struct {
   212  	message
   213  	path     []string
   214  	repeated bool
   215  }
   216  
   217  // subMessages returns message field values at each component of the path.
   218  // For example, for path ["a", "b", "c"] it will return
   219  // [msg.a, msg.a.b, msg.a.b.c].
   220  // If a field is repeated, returns the last message.
   221  //
   222  // If a field value is nil, initializes it with an empty message or slice.
   223  // If a field is not a message field, returns an error.
   224  func (p *parser) subMessages(root message, path []string) ([]subMsg, error) {
   225  	result := make([]subMsg, 0, len(path))
   226  
   227  	parent := &root
   228  	for i, name := range path {
   229  		curPath := path[:i+1]
   230  
   231  		var fieldIndex int
   232  		if parent.desc.GetOptions().GetMapEntry() {
   233  			if fieldIndex = descutil.FindField(parent.desc, "value"); fieldIndex == -1 {
   234  				return nil, fmt.Errorf("map entry type %s does not have value field", parent.desc.GetName())
   235  			}
   236  		} else {
   237  			if fieldIndex = descutil.FindField(parent.desc, name); fieldIndex == -1 {
   238  				return nil, fmt.Errorf("field %q not found in message %s", name, parent.desc.GetName())
   239  			}
   240  		}
   241  
   242  		f := parent.desc.Field[fieldIndex]
   243  		if f.GetType() != descriptorpb.FieldDescriptorProto_TYPE_MESSAGE {
   244  			return nil, fmt.Errorf("field %s is not a message", strings.Join(curPath, "."))
   245  		}
   246  
   247  		subDescInterface, err := p.resolve(f.GetTypeName())
   248  		if err != nil {
   249  			return nil, err
   250  		}
   251  		subDesc, ok := subDescInterface.(*descriptorpb.DescriptorProto)
   252  		if !ok {
   253  			return nil, fmt.Errorf("%s is not a message", f.GetTypeName())
   254  		}
   255  
   256  		sub := subMsg{
   257  			message:  message{desc: subDesc},
   258  			repeated: descutil.Repeated(f) && !subDesc.GetOptions().GetMapEntry(),
   259  			path:     curPath,
   260  		}
   261  		if value, ok := parent.data[name]; !ok {
   262  			sub.data = map[string]any{}
   263  			if sub.repeated {
   264  				parent.data[name] = []any{sub.data}
   265  			} else {
   266  				parent.data[name] = sub.data
   267  			}
   268  		} else {
   269  			if sub.repeated {
   270  				slice := asSlice(value)
   271  				value = slice[len(slice)-1]
   272  			}
   273  			sub.data = value.(map[string]any)
   274  		}
   275  
   276  		result = append(result, sub)
   277  		parent = &sub.message
   278  	}
   279  	return result, nil
   280  }
   281  
   282  // parseFieldValue parses a field value according to the field type.
   283  // Types: https://developers.google.com/protocol-buffers/docs/proto?hl=en#scalar
   284  func (p *parser) parseFieldValue(s string, msgName string, field *descriptorpb.FieldDescriptorProto) (any, error) {
   285  	switch field.GetType() {
   286  
   287  	case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE:
   288  		return strconv.ParseFloat(s, 64)
   289  
   290  	case descriptorpb.FieldDescriptorProto_TYPE_FLOAT:
   291  		x, err := strconv.ParseFloat(s, 32)
   292  		return float32(x), err
   293  
   294  	case
   295  		descriptorpb.FieldDescriptorProto_TYPE_INT32,
   296  		descriptorpb.FieldDescriptorProto_TYPE_SFIXED32,
   297  		descriptorpb.FieldDescriptorProto_TYPE_SINT32:
   298  
   299  		x, err := strconv.ParseInt(s, 10, 32)
   300  		return int32(x), err
   301  
   302  	case descriptorpb.FieldDescriptorProto_TYPE_INT64,
   303  		descriptorpb.FieldDescriptorProto_TYPE_SFIXED64,
   304  		descriptorpb.FieldDescriptorProto_TYPE_SINT64:
   305  
   306  		return strconv.ParseInt(s, 10, 64)
   307  
   308  	case descriptorpb.FieldDescriptorProto_TYPE_UINT32, descriptorpb.FieldDescriptorProto_TYPE_FIXED32:
   309  		x, err := strconv.ParseUint(s, 10, 32)
   310  		return uint32(x), err
   311  
   312  	case descriptorpb.FieldDescriptorProto_TYPE_UINT64, descriptorpb.FieldDescriptorProto_TYPE_FIXED64:
   313  		return strconv.ParseUint(s, 10, 64)
   314  
   315  	case descriptorpb.FieldDescriptorProto_TYPE_BOOL:
   316  		return strconv.ParseBool(s)
   317  
   318  	case descriptorpb.FieldDescriptorProto_TYPE_STRING:
   319  		return s, nil
   320  
   321  	case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE:
   322  		return nil, fmt.Errorf(
   323  			"%s.%s is a message field. Specify its field values, not the message itself",
   324  			msgName, field.GetName())
   325  
   326  	case descriptorpb.FieldDescriptorProto_TYPE_BYTES:
   327  		return hex.DecodeString(s)
   328  
   329  	case descriptorpb.FieldDescriptorProto_TYPE_ENUM:
   330  		obj, err := p.resolve(field.GetTypeName())
   331  		if err != nil {
   332  			return nil, err
   333  		}
   334  		enum, ok := obj.(*descriptorpb.EnumDescriptorProto)
   335  		if !ok {
   336  			return nil, fmt.Errorf(
   337  				"field %s.%s is declared as of type enum %s, but %s is not an enum",
   338  				msgName, field.GetName(),
   339  				field.GetTypeName(), field.GetTypeName(),
   340  			)
   341  		}
   342  		return parseEnum(enum, s)
   343  
   344  	default:
   345  		return nil, fmt.Errorf("field type %s is not supported", field.GetType())
   346  	}
   347  }
   348  
   349  func (p *parser) resolve(name string) (any, error) {
   350  	if p.Resolver == nil {
   351  		panic(fmt.Errorf("cannot resolve type %q. Resolver is not set", name))
   352  	}
   353  	name = strings.TrimPrefix(name, ".")
   354  	obj := p.Resolver.Resolve(name)
   355  	if obj == nil {
   356  		return nil, fmt.Errorf("cannot resolve type %q", name)
   357  	}
   358  	return obj, nil
   359  }
   360  
   361  // splitKeyValuePair splits a key value pair key=value if there is equals sign.
   362  func (p *parser) splitKeyValuePair(s string) (key, value string, hasValue bool) {
   363  	parts := strings.SplitN(s, "=", 2)
   364  	switch len(parts) {
   365  	case 1:
   366  		key = s
   367  	case 2:
   368  		key = parts[0]
   369  		value = parts[1]
   370  		hasValue = true
   371  	}
   372  	return
   373  }
   374  
   375  // parseEnum returns the number of an enum member, which can be name or number.
   376  func parseEnum(enum *descriptorpb.EnumDescriptorProto, member string) (int32, error) {
   377  	i := descutil.FindEnumValue(enum, member)
   378  	if i < 0 {
   379  		// Is member the number?
   380  		if number, err := strconv.ParseInt(member, 10, 32); err == nil {
   381  			i = descutil.FindValueByNumber(enum, int32(number))
   382  		}
   383  	}
   384  	if i < 0 {
   385  		return 0, fmt.Errorf("invalid value %q for enum %s", member, enum.GetName())
   386  	}
   387  	return enum.Value[i].GetNumber(), nil
   388  }
   389  
   390  func asSlice(x any) []any {
   391  	if x == nil {
   392  		return nil
   393  	}
   394  	return x.([]any)
   395  }