go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/client/flagpb/unmarshal.go (about) 1 // Copyright 2016 The LUCI Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package flagpb 16 17 import ( 18 "bytes" 19 "encoding/hex" 20 "encoding/json" 21 "fmt" 22 "strconv" 23 "strings" 24 25 "google.golang.org/protobuf/types/descriptorpb" 26 27 "go.chromium.org/luci/common/proto/google/descutil" 28 29 "github.com/golang/protobuf/jsonpb" 30 "github.com/golang/protobuf/proto" 31 ) 32 33 // UnmarshalMessage unmarshals the proto message from flags. 34 // 35 // The descriptor set should be obtained from the `cproto` compiled packages' 36 // FileDescriptorSet() method. 37 func UnmarshalMessage(flags []string, resolver Resolver, msg proto.Message) error { 38 // TODO(iannucci): avoid round-trip through parser and jsonpb and populate the 39 // message directly. This would involve writing some additional reflection 40 // code that may depend on implementation details of proto's generated Go 41 // code, which is why this wasn't done initially. 42 name := proto.MessageName(msg) 43 dproto, ok := resolver.Resolve(name).(*descriptorpb.DescriptorProto) 44 if !ok { 45 return fmt.Errorf("could not resolve message %q", name) 46 } 47 48 jdata, err := UnmarshalUntyped(flags, dproto, resolver) 49 if err != nil { 50 return err 51 } 52 53 jtext, err := json.Marshal(jdata) 54 if err != nil { 55 return err 56 } 57 58 return jsonpb.Unmarshal(bytes.NewReader(jtext), msg) 59 } 60 61 // UnmarshalUntyped unmarshals a key-value map from flags 62 // using a protobuf message descriptor. 63 func UnmarshalUntyped(flags []string, desc *descriptorpb.DescriptorProto, resolver Resolver) (map[string]any, error) { 64 p := parser{resolver} 65 return p.parse(flags, desc) 66 } 67 68 type message struct { 69 data map[string]any 70 desc *descriptorpb.DescriptorProto 71 } 72 73 type parser struct { 74 Resolver Resolver 75 } 76 77 func (p *parser) parse(flags []string, desc *descriptorpb.DescriptorProto) (map[string]any, error) { 78 if desc == nil { 79 panic("desc is nil") 80 } 81 root := message{map[string]any{}, desc} 82 83 for len(flags) > 0 { 84 var err error 85 if flags, err = p.parseOneFlag(flags, root); err != nil { 86 return nil, err 87 } 88 } 89 return root.data, nil 90 } 91 92 func (p *parser) parseOneFlag(flags []string, root message) (flagsRest []string, err error) { 93 // skip empty flags 94 for len(flags) > 0 && strings.TrimSpace(flags[0]) == "" { 95 flags = flags[1:] 96 } 97 if len(flags) == 0 { 98 return flags, nil 99 } 100 101 firstArg := flags[0] 102 flags = flags[1:] 103 104 // Prefix returned errors with flag name verbatim. 105 defer func() { 106 if err != nil { 107 err = fmt.Errorf("%s: %s", firstArg, err) 108 } 109 }() 110 111 // Trim dashes. 112 if !strings.HasPrefix(firstArg, "-") { 113 return nil, fmt.Errorf("a flag was expected") 114 } 115 flagName := strings.TrimPrefix(firstArg, "-") // -foo 116 flagName = strings.TrimPrefix(flagName, "-") // --foo 117 if strings.HasPrefix(flagName, "-") { 118 // Triple dash is too much. 119 return nil, fmt.Errorf("bad flag syntax") 120 } 121 122 // Split key-value pair x=y. 123 flagName, valueStr, hasValueStr := p.splitKeyValuePair(flagName) 124 if flagName == "" { 125 return nil, fmt.Errorf("bad flag syntax") 126 } 127 128 // Split field path "a.b.c" and resolve field names. 129 fieldPath := strings.Split(flagName, ".") 130 pathMsgs, err := p.subMessages(root, fieldPath[:len(fieldPath)-1]) 131 if err != nil { 132 return nil, err 133 } 134 135 // Where to assign the value? 136 target := &root 137 if len(pathMsgs) > 0 { 138 lastMsg := pathMsgs[len(pathMsgs)-1] 139 target = &lastMsg.message 140 } 141 name := fieldPath[len(fieldPath)-1] 142 143 // Resolve target field. 144 var fieldIndex int 145 if target.desc.GetOptions().GetMapEntry() { 146 if fieldIndex = descutil.FindField(target.desc, "value"); fieldIndex == -1 { 147 return nil, fmt.Errorf("map entry type %s does not have value field", target.desc.GetName()) 148 } 149 } else { 150 if fieldIndex = descutil.FindField(target.desc, name); fieldIndex == -1 { 151 return nil, fmt.Errorf("field %s not found in message %s", name, target.desc.GetName()) 152 } 153 } 154 field := target.desc.Field[fieldIndex] 155 156 var value any 157 hasValue := false 158 159 if !hasValueStr { 160 switch { 161 // Boolean and repeated message fields may have no value and ignore 162 // next argument. 163 case field.GetType() == descriptorpb.FieldDescriptorProto_TYPE_BOOL: 164 value = true 165 hasValue = true 166 case field.GetType() == descriptorpb.FieldDescriptorProto_TYPE_MESSAGE && descutil.Repeated(field): 167 value = map[string]any{} 168 hasValue = true 169 170 default: 171 // Read next argument as a value. 172 if len(flags) == 0 { 173 return nil, fmt.Errorf("value was expected") 174 } 175 valueStr, flags = flags[0], flags[1:] 176 } 177 } 178 179 // Check if the value is already set. 180 if target.data[name] != nil && !descutil.Repeated(field) { 181 repeatedFields := make([]string, 0, len(pathMsgs)) 182 for _, m := range pathMsgs { 183 if m.repeated { 184 repeatedFields = append(repeatedFields, "-"+strings.Join(m.path, ".")) 185 } 186 } 187 if len(repeatedFields) == 0 { 188 return nil, fmt.Errorf("value is already set to %v", target.data[name]) 189 } 190 return nil, fmt.Errorf( 191 "value is already set to %v. Did you forgot to insert %s in between to declare a new repeated message?", 192 target.data[name], strings.Join(repeatedFields, " or ")) 193 } 194 195 if !hasValue { 196 value, err = p.parseFieldValue(valueStr, target.desc.GetName(), field) 197 if err != nil { 198 return nil, err 199 } 200 } 201 202 if !descutil.Repeated(field) { 203 target.data[name] = value 204 } else { 205 target.data[name] = append(asSlice(target.data[name]), value) 206 } 207 208 return flags, nil 209 } 210 211 type subMsg struct { 212 message 213 path []string 214 repeated bool 215 } 216 217 // subMessages returns message field values at each component of the path. 218 // For example, for path ["a", "b", "c"] it will return 219 // [msg.a, msg.a.b, msg.a.b.c]. 220 // If a field is repeated, returns the last message. 221 // 222 // If a field value is nil, initializes it with an empty message or slice. 223 // If a field is not a message field, returns an error. 224 func (p *parser) subMessages(root message, path []string) ([]subMsg, error) { 225 result := make([]subMsg, 0, len(path)) 226 227 parent := &root 228 for i, name := range path { 229 curPath := path[:i+1] 230 231 var fieldIndex int 232 if parent.desc.GetOptions().GetMapEntry() { 233 if fieldIndex = descutil.FindField(parent.desc, "value"); fieldIndex == -1 { 234 return nil, fmt.Errorf("map entry type %s does not have value field", parent.desc.GetName()) 235 } 236 } else { 237 if fieldIndex = descutil.FindField(parent.desc, name); fieldIndex == -1 { 238 return nil, fmt.Errorf("field %q not found in message %s", name, parent.desc.GetName()) 239 } 240 } 241 242 f := parent.desc.Field[fieldIndex] 243 if f.GetType() != descriptorpb.FieldDescriptorProto_TYPE_MESSAGE { 244 return nil, fmt.Errorf("field %s is not a message", strings.Join(curPath, ".")) 245 } 246 247 subDescInterface, err := p.resolve(f.GetTypeName()) 248 if err != nil { 249 return nil, err 250 } 251 subDesc, ok := subDescInterface.(*descriptorpb.DescriptorProto) 252 if !ok { 253 return nil, fmt.Errorf("%s is not a message", f.GetTypeName()) 254 } 255 256 sub := subMsg{ 257 message: message{desc: subDesc}, 258 repeated: descutil.Repeated(f) && !subDesc.GetOptions().GetMapEntry(), 259 path: curPath, 260 } 261 if value, ok := parent.data[name]; !ok { 262 sub.data = map[string]any{} 263 if sub.repeated { 264 parent.data[name] = []any{sub.data} 265 } else { 266 parent.data[name] = sub.data 267 } 268 } else { 269 if sub.repeated { 270 slice := asSlice(value) 271 value = slice[len(slice)-1] 272 } 273 sub.data = value.(map[string]any) 274 } 275 276 result = append(result, sub) 277 parent = &sub.message 278 } 279 return result, nil 280 } 281 282 // parseFieldValue parses a field value according to the field type. 283 // Types: https://developers.google.com/protocol-buffers/docs/proto?hl=en#scalar 284 func (p *parser) parseFieldValue(s string, msgName string, field *descriptorpb.FieldDescriptorProto) (any, error) { 285 switch field.GetType() { 286 287 case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE: 288 return strconv.ParseFloat(s, 64) 289 290 case descriptorpb.FieldDescriptorProto_TYPE_FLOAT: 291 x, err := strconv.ParseFloat(s, 32) 292 return float32(x), err 293 294 case 295 descriptorpb.FieldDescriptorProto_TYPE_INT32, 296 descriptorpb.FieldDescriptorProto_TYPE_SFIXED32, 297 descriptorpb.FieldDescriptorProto_TYPE_SINT32: 298 299 x, err := strconv.ParseInt(s, 10, 32) 300 return int32(x), err 301 302 case descriptorpb.FieldDescriptorProto_TYPE_INT64, 303 descriptorpb.FieldDescriptorProto_TYPE_SFIXED64, 304 descriptorpb.FieldDescriptorProto_TYPE_SINT64: 305 306 return strconv.ParseInt(s, 10, 64) 307 308 case descriptorpb.FieldDescriptorProto_TYPE_UINT32, descriptorpb.FieldDescriptorProto_TYPE_FIXED32: 309 x, err := strconv.ParseUint(s, 10, 32) 310 return uint32(x), err 311 312 case descriptorpb.FieldDescriptorProto_TYPE_UINT64, descriptorpb.FieldDescriptorProto_TYPE_FIXED64: 313 return strconv.ParseUint(s, 10, 64) 314 315 case descriptorpb.FieldDescriptorProto_TYPE_BOOL: 316 return strconv.ParseBool(s) 317 318 case descriptorpb.FieldDescriptorProto_TYPE_STRING: 319 return s, nil 320 321 case descriptorpb.FieldDescriptorProto_TYPE_MESSAGE: 322 return nil, fmt.Errorf( 323 "%s.%s is a message field. Specify its field values, not the message itself", 324 msgName, field.GetName()) 325 326 case descriptorpb.FieldDescriptorProto_TYPE_BYTES: 327 return hex.DecodeString(s) 328 329 case descriptorpb.FieldDescriptorProto_TYPE_ENUM: 330 obj, err := p.resolve(field.GetTypeName()) 331 if err != nil { 332 return nil, err 333 } 334 enum, ok := obj.(*descriptorpb.EnumDescriptorProto) 335 if !ok { 336 return nil, fmt.Errorf( 337 "field %s.%s is declared as of type enum %s, but %s is not an enum", 338 msgName, field.GetName(), 339 field.GetTypeName(), field.GetTypeName(), 340 ) 341 } 342 return parseEnum(enum, s) 343 344 default: 345 return nil, fmt.Errorf("field type %s is not supported", field.GetType()) 346 } 347 } 348 349 func (p *parser) resolve(name string) (any, error) { 350 if p.Resolver == nil { 351 panic(fmt.Errorf("cannot resolve type %q. Resolver is not set", name)) 352 } 353 name = strings.TrimPrefix(name, ".") 354 obj := p.Resolver.Resolve(name) 355 if obj == nil { 356 return nil, fmt.Errorf("cannot resolve type %q", name) 357 } 358 return obj, nil 359 } 360 361 // splitKeyValuePair splits a key value pair key=value if there is equals sign. 362 func (p *parser) splitKeyValuePair(s string) (key, value string, hasValue bool) { 363 parts := strings.SplitN(s, "=", 2) 364 switch len(parts) { 365 case 1: 366 key = s 367 case 2: 368 key = parts[0] 369 value = parts[1] 370 hasValue = true 371 } 372 return 373 } 374 375 // parseEnum returns the number of an enum member, which can be name or number. 376 func parseEnum(enum *descriptorpb.EnumDescriptorProto, member string) (int32, error) { 377 i := descutil.FindEnumValue(enum, member) 378 if i < 0 { 379 // Is member the number? 380 if number, err := strconv.ParseInt(member, 10, 32); err == nil { 381 i = descutil.FindValueByNumber(enum, int32(number)) 382 } 383 } 384 if i < 0 { 385 return 0, fmt.Errorf("invalid value %q for enum %s", member, enum.GetName()) 386 } 387 return enum.Value[i].GetNumber(), nil 388 } 389 390 func asSlice(x any) []any { 391 if x == nil { 392 return nil 393 } 394 return x.([]any) 395 }