cosmossdk.io/client/v2@v2.0.0-beta.1/autocli/flag/builder.go (about) 1 package flag 2 3 import ( 4 "context" 5 "fmt" 6 "strconv" 7 8 cosmos_proto "github.com/cosmos/cosmos-proto" 9 "github.com/spf13/cobra" 10 "github.com/spf13/pflag" 11 "google.golang.org/protobuf/proto" 12 "google.golang.org/protobuf/reflect/protodesc" 13 "google.golang.org/protobuf/reflect/protoreflect" 14 "google.golang.org/protobuf/reflect/protoregistry" 15 16 autocliv1 "cosmossdk.io/api/cosmos/autocli/v1" 17 msgv1 "cosmossdk.io/api/cosmos/msg/v1" 18 "cosmossdk.io/client/v2/autocli/keyring" 19 "cosmossdk.io/client/v2/internal/flags" 20 "cosmossdk.io/client/v2/internal/util" 21 "cosmossdk.io/core/address" 22 23 "github.com/cosmos/cosmos-sdk/runtime" 24 ) 25 26 const ( 27 AddressStringScalarType = "cosmos.AddressString" 28 ValidatorAddressStringScalarType = "cosmos.ValidatorAddressString" 29 ConsensusAddressStringScalarType = "cosmos.ConsensusAddressString" 30 ) 31 32 // Builder manages options for building pflag flags for protobuf messages. 33 type Builder struct { 34 // TypeResolver specifies how protobuf types will be resolved. If it is 35 // nil protoregistry.GlobalTypes will be used. 36 TypeResolver interface { 37 protoregistry.MessageTypeResolver 38 protoregistry.ExtensionTypeResolver 39 } 40 41 // FileResolver specifies how protobuf file descriptors will be resolved. If it is 42 // nil protoregistry.GlobalFiles will be used. 43 FileResolver interface { 44 protodesc.Resolver 45 RangeFiles(func(protoreflect.FileDescriptor) bool) 46 } 47 48 messageFlagTypes map[protoreflect.FullName]Type 49 scalarFlagTypes map[string]Type 50 51 // Keyring is the keyring to use for client/v2. 52 Keyring keyring.Keyring 53 54 // Address Codecs are the address codecs to use for client/v2. 55 AddressCodec address.Codec 56 ValidatorAddressCodec runtime.ValidatorAddressCodec 57 ConsensusAddressCodec runtime.ConsensusAddressCodec 58 } 59 60 func (b *Builder) init() { 61 if b.messageFlagTypes == nil { 62 b.messageFlagTypes = map[protoreflect.FullName]Type{} 63 b.messageFlagTypes["google.protobuf.Timestamp"] = timestampType{} 64 b.messageFlagTypes["google.protobuf.Duration"] = durationType{} 65 b.messageFlagTypes["cosmos.base.v1beta1.Coin"] = coinType{} 66 } 67 68 if b.scalarFlagTypes == nil { 69 b.scalarFlagTypes = map[string]Type{} 70 b.scalarFlagTypes[AddressStringScalarType] = addressStringType{} 71 b.scalarFlagTypes[ValidatorAddressStringScalarType] = validatorAddressStringType{} 72 b.scalarFlagTypes[ConsensusAddressStringScalarType] = consensusAddressStringType{} 73 } 74 } 75 76 // DefineMessageFlagType allows to extend custom protobuf message type handling for flags (and positional arguments). 77 func (b *Builder) DefineMessageFlagType(messageName protoreflect.FullName, flagType Type) { 78 b.init() 79 b.messageFlagTypes[messageName] = flagType 80 } 81 82 // DefineScalarFlagType allows to extend custom scalar type handling for flags (and positional arguments). 83 func (b *Builder) DefineScalarFlagType(scalarName string, flagType Type) { 84 b.init() 85 b.scalarFlagTypes[scalarName] = flagType 86 } 87 88 // AddMessageFlags adds flags for each field in the message to the flag set. 89 func (b *Builder) AddMessageFlags(ctx context.Context, flagSet *pflag.FlagSet, messageType protoreflect.MessageType, commandOptions *autocliv1.RpcCommandOptions) (*MessageBinder, error) { 90 return b.addMessageFlags(ctx, flagSet, messageType, commandOptions, namingOptions{}) 91 } 92 93 // addMessageFlags adds flags for each field in the message to the flag set. 94 func (b *Builder) addMessageFlags(ctx context.Context, flagSet *pflag.FlagSet, messageType protoreflect.MessageType, commandOptions *autocliv1.RpcCommandOptions, options namingOptions) (*MessageBinder, error) { 95 messageBinder := &MessageBinder{ 96 messageType: messageType, 97 // positional args are also parsed using a FlagSet so that we can reuse all the same parsers 98 positionalFlagSet: pflag.NewFlagSet("positional", pflag.ContinueOnError), 99 } 100 101 fields := messageType.Descriptor().Fields() 102 signerFieldName := GetSignerFieldName(messageType.Descriptor()) 103 104 isPositional := map[string]bool{} 105 106 lengthPositionalArgsOptions := len(commandOptions.PositionalArgs) 107 for i, arg := range commandOptions.PositionalArgs { 108 isPositional[arg.ProtoField] = true 109 110 // verify if a positional field is a signer field 111 if arg.ProtoField == signerFieldName { 112 messageBinder.SignerInfo = SignerInfo{ 113 PositionalArgIndex: i, 114 FieldName: arg.ProtoField, 115 } 116 } 117 118 field := fields.ByName(protoreflect.Name(arg.ProtoField)) 119 if field == nil { 120 return nil, fmt.Errorf("can't find field %s on %s", arg.ProtoField, messageType.Descriptor().FullName()) 121 } 122 123 if arg.Optional && arg.Varargs { 124 return nil, fmt.Errorf("positional argument %s can't be both optional and varargs", arg.ProtoField) 125 } 126 127 if arg.Varargs { 128 if i != lengthPositionalArgsOptions-1 { 129 return nil, fmt.Errorf("varargs positional argument %s must be the last argument", arg.ProtoField) 130 } 131 132 messageBinder.hasVarargs = true 133 } 134 135 if arg.Optional { 136 if i != lengthPositionalArgsOptions-1 { 137 return nil, fmt.Errorf("optional positional argument %s must be the last argument", arg.ProtoField) 138 } 139 140 messageBinder.hasOptional = true 141 } 142 143 _, hasValue, err := b.addFieldFlag( 144 ctx, 145 messageBinder.positionalFlagSet, 146 field, 147 &autocliv1.FlagOptions{Name: fmt.Sprintf("%d", i)}, 148 namingOptions{}, 149 ) 150 if err != nil { 151 return nil, err 152 } 153 154 messageBinder.positionalArgs = append(messageBinder.positionalArgs, fieldBinding{ 155 field: field, 156 hasValue: hasValue, 157 }) 158 } 159 160 if messageBinder.hasVarargs { 161 messageBinder.CobraArgs = cobra.MinimumNArgs(lengthPositionalArgsOptions - 1) 162 messageBinder.mandatoryArgUntil = lengthPositionalArgsOptions - 1 163 } else if messageBinder.hasOptional { 164 messageBinder.CobraArgs = cobra.RangeArgs(lengthPositionalArgsOptions-1, lengthPositionalArgsOptions) 165 messageBinder.mandatoryArgUntil = lengthPositionalArgsOptions - 1 166 } else { 167 messageBinder.CobraArgs = cobra.ExactArgs(lengthPositionalArgsOptions) 168 messageBinder.mandatoryArgUntil = lengthPositionalArgsOptions 169 } 170 171 // validate flag options 172 for name := range commandOptions.FlagOptions { 173 if fields.ByName(protoreflect.Name(name)) == nil { 174 return nil, fmt.Errorf("can't find field %s on %s specified as a flag", name, messageType.Descriptor().FullName()) 175 } 176 177 // verify if a flag is a signer field 178 if name == signerFieldName { 179 messageBinder.SignerInfo = SignerInfo{ 180 FieldName: name, 181 IsFlag: false, 182 } 183 } 184 } 185 186 // if signer has not been specified as positional arguments, 187 // add it as `--from` flag (instead of --field-name flags) 188 if signerFieldName != "" && messageBinder.SignerInfo.FieldName == "" { 189 if commandOptions.FlagOptions == nil { 190 commandOptions.FlagOptions = make(map[string]*autocliv1.FlagOptions) 191 } 192 193 commandOptions.FlagOptions[signerFieldName] = &autocliv1.FlagOptions{ 194 Name: flags.FlagFrom, 195 Usage: "Name or address with which to sign the message", 196 Shorthand: "f", 197 } 198 199 messageBinder.SignerInfo = SignerInfo{ 200 FieldName: flags.FlagFrom, 201 IsFlag: true, 202 } 203 } 204 205 // define all other fields as flags 206 flagOptsByFlagName := map[string]*autocliv1.FlagOptions{} 207 for i := 0; i < fields.Len(); i++ { 208 field := fields.Get(i) 209 // skips positional args and signer field if already set 210 if isPositional[string(field.Name())] || 211 (string(field.Name()) == signerFieldName && messageBinder.SignerInfo.FieldName == flags.FlagFrom) { 212 continue 213 } 214 215 flagOpts := commandOptions.FlagOptions[string(field.Name())] 216 name, hasValue, err := b.addFieldFlag(ctx, flagSet, field, flagOpts, options) 217 flagOptsByFlagName[name] = flagOpts 218 if err != nil { 219 return nil, err 220 } 221 222 messageBinder.flagBindings = append(messageBinder.flagBindings, fieldBinding{ 223 hasValue: hasValue, 224 field: field, 225 }) 226 } 227 228 flagSet.VisitAll(func(flag *pflag.Flag) { 229 opts := flagOptsByFlagName[flag.Name] 230 if opts != nil { 231 // This is a bit of hacking around the pflag API, but 232 // we need to set these options here using Flag.VisitAll because the flag 233 // constructors that pflag gives us (StringP, Int32P, etc.) do not 234 // actually return the *Flag instance 235 flag.Deprecated = opts.Deprecated 236 flag.ShorthandDeprecated = opts.ShorthandDeprecated 237 flag.Hidden = opts.Hidden 238 } 239 }) 240 241 return messageBinder, nil 242 } 243 244 // bindPageRequest create a flag for pagination 245 func (b *Builder) bindPageRequest(ctx context.Context, flagSet *pflag.FlagSet, field protoreflect.FieldDescriptor) (HasValue, error) { 246 return b.addMessageFlags( 247 ctx, 248 flagSet, 249 util.ResolveMessageType(b.TypeResolver, field.Message()), 250 &autocliv1.RpcCommandOptions{}, 251 namingOptions{Prefix: "page-"}, 252 ) 253 } 254 255 // namingOptions specifies internal naming options for flags. 256 type namingOptions struct { 257 // Prefix is a prefix to prepend to all flags. 258 Prefix string 259 } 260 261 // addFieldFlag adds a flag for the provided field to the flag set. 262 func (b *Builder) addFieldFlag(ctx context.Context, flagSet *pflag.FlagSet, field protoreflect.FieldDescriptor, opts *autocliv1.FlagOptions, options namingOptions) (name string, hasValue HasValue, err error) { 263 if opts == nil { 264 opts = &autocliv1.FlagOptions{} 265 } 266 267 if field.Kind() == protoreflect.MessageKind && field.Message().FullName() == "cosmos.base.query.v1beta1.PageRequest" { 268 hasValue, err := b.bindPageRequest(ctx, flagSet, field) 269 return "", hasValue, err 270 } 271 272 name = opts.Name 273 if name == "" { 274 name = options.Prefix + util.DescriptorKebabName(field) 275 } 276 277 usage := opts.Usage 278 if usage == "" { 279 usage = util.DescriptorDocs(field) 280 } 281 282 shorthand := opts.Shorthand 283 defaultValue := opts.DefaultValue 284 285 if typ := b.resolveFlagType(field); typ != nil { 286 if defaultValue == "" { 287 defaultValue = typ.DefaultValue() 288 } 289 290 val := typ.NewValue(ctx, b) 291 flagSet.AddFlag(&pflag.Flag{ 292 Name: name, 293 Shorthand: shorthand, 294 Usage: usage, 295 DefValue: defaultValue, 296 Value: val, 297 }) 298 return name, val, nil 299 } 300 301 // use the built-in pflag StringP, Int32P, etc. functions 302 var val HasValue 303 304 if field.IsList() { 305 val = bindSimpleListFlag(flagSet, field.Kind(), name, shorthand, usage) 306 } else if field.IsMap() { 307 keyKind := field.MapKey().Kind() 308 valKind := field.MapValue().Kind() 309 val = bindSimpleMapFlag(flagSet, keyKind, valKind, name, shorthand, usage) 310 } else { 311 val = bindSimpleFlag(flagSet, field.Kind(), name, shorthand, usage) 312 } 313 314 // This is a bit of hacking around the pflag API, but the 315 // defaultValue is set in this way because this is much easier than trying 316 // to parse the string into the types that StringSliceP, Int32P, etc. 317 if defaultValue != "" { 318 err = flagSet.Set(name, defaultValue) 319 } 320 321 return name, val, err 322 } 323 324 func (b *Builder) resolveFlagType(field protoreflect.FieldDescriptor) Type { 325 typ := b.resolveFlagTypeBasic(field) 326 if field.IsList() { 327 if typ != nil { 328 return compositeListType{simpleType: typ} 329 } 330 return nil 331 } 332 if field.IsMap() { 333 keyKind := field.MapKey().Kind() 334 valType := b.resolveFlagType(field.MapValue()) 335 if valType != nil { 336 switch keyKind { 337 case protoreflect.StringKind: 338 ct := new(compositeMapType[string]) 339 ct.keyValueResolver = func(s string) (string, error) { return s, nil } 340 ct.valueType = valType 341 ct.keyType = "string" 342 return ct 343 case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: 344 ct := new(compositeMapType[int32]) 345 ct.keyValueResolver = func(s string) (int32, error) { 346 i, err := strconv.ParseInt(s, 10, 32) 347 return int32(i), err 348 } 349 ct.valueType = valType 350 ct.keyType = "int32" 351 return ct 352 case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: 353 ct := new(compositeMapType[int64]) 354 ct.keyValueResolver = func(s string) (int64, error) { 355 i, err := strconv.ParseInt(s, 10, 64) 356 return i, err 357 } 358 ct.valueType = valType 359 ct.keyType = "int64" 360 return ct 361 case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: 362 ct := new(compositeMapType[uint32]) 363 ct.keyValueResolver = func(s string) (uint32, error) { 364 i, err := strconv.ParseUint(s, 10, 32) 365 return uint32(i), err 366 } 367 ct.valueType = valType 368 ct.keyType = "uint32" 369 return ct 370 case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: 371 ct := new(compositeMapType[uint64]) 372 ct.keyValueResolver = func(s string) (uint64, error) { 373 i, err := strconv.ParseUint(s, 10, 64) 374 return i, err 375 } 376 ct.valueType = valType 377 ct.keyType = "uint64" 378 return ct 379 case protoreflect.BoolKind: 380 ct := new(compositeMapType[bool]) 381 ct.keyValueResolver = strconv.ParseBool 382 ct.valueType = valType 383 ct.keyType = "bool" 384 return ct 385 } 386 return nil 387 388 } 389 return nil 390 } 391 392 return typ 393 } 394 395 func (b *Builder) resolveFlagTypeBasic(field protoreflect.FieldDescriptor) Type { 396 scalar, ok := GetScalarType(field) 397 if ok { 398 b.init() 399 if typ, ok := b.scalarFlagTypes[scalar]; ok { 400 return typ 401 } 402 } 403 404 switch field.Kind() { 405 case protoreflect.BytesKind: 406 return binaryType{} 407 case protoreflect.EnumKind: 408 return enumType{enum: field.Enum()} 409 case protoreflect.MessageKind: 410 b.init() 411 if flagType, ok := b.messageFlagTypes[field.Message().FullName()]; ok { 412 return flagType 413 } 414 return jsonMessageFlagType{ 415 messageDesc: field.Message(), 416 } 417 default: 418 return nil 419 } 420 } 421 422 // GetScalarType gets scalar type of a field. 423 func GetScalarType(field protoreflect.FieldDescriptor) (string, bool) { 424 scalar := proto.GetExtension(field.Options(), cosmos_proto.E_Scalar) 425 scalarStr, ok := scalar.(string) 426 return scalarStr, ok 427 } 428 429 // GetSignerFieldName gets signer field name of a message. 430 // AutoCLI supports only one signer field per message. 431 func GetSignerFieldName(descriptor protoreflect.MessageDescriptor) string { 432 signersFields := proto.GetExtension(descriptor.Options(), msgv1.E_Signer).([]string) 433 if len(signersFields) == 0 { 434 return "" 435 } 436 437 return signersFields[0] 438 }