cosmossdk.io/client/v2@v2.0.0-beta.1/autocli/flag/builder.go (about)

     1  package flag
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"strconv"
     7  
     8  	cosmos_proto "github.com/cosmos/cosmos-proto"
     9  	"github.com/spf13/cobra"
    10  	"github.com/spf13/pflag"
    11  	"google.golang.org/protobuf/proto"
    12  	"google.golang.org/protobuf/reflect/protodesc"
    13  	"google.golang.org/protobuf/reflect/protoreflect"
    14  	"google.golang.org/protobuf/reflect/protoregistry"
    15  
    16  	autocliv1 "cosmossdk.io/api/cosmos/autocli/v1"
    17  	msgv1 "cosmossdk.io/api/cosmos/msg/v1"
    18  	"cosmossdk.io/client/v2/autocli/keyring"
    19  	"cosmossdk.io/client/v2/internal/flags"
    20  	"cosmossdk.io/client/v2/internal/util"
    21  	"cosmossdk.io/core/address"
    22  
    23  	"github.com/cosmos/cosmos-sdk/runtime"
    24  )
    25  
    26  const (
    27  	AddressStringScalarType          = "cosmos.AddressString"
    28  	ValidatorAddressStringScalarType = "cosmos.ValidatorAddressString"
    29  	ConsensusAddressStringScalarType = "cosmos.ConsensusAddressString"
    30  )
    31  
    32  // Builder manages options for building pflag flags for protobuf messages.
    33  type Builder struct {
    34  	// TypeResolver specifies how protobuf types will be resolved. If it is
    35  	// nil protoregistry.GlobalTypes will be used.
    36  	TypeResolver interface {
    37  		protoregistry.MessageTypeResolver
    38  		protoregistry.ExtensionTypeResolver
    39  	}
    40  
    41  	// FileResolver specifies how protobuf file descriptors will be resolved. If it is
    42  	// nil protoregistry.GlobalFiles will be used.
    43  	FileResolver interface {
    44  		protodesc.Resolver
    45  		RangeFiles(func(protoreflect.FileDescriptor) bool)
    46  	}
    47  
    48  	messageFlagTypes map[protoreflect.FullName]Type
    49  	scalarFlagTypes  map[string]Type
    50  
    51  	// Keyring is the keyring to use for client/v2.
    52  	Keyring keyring.Keyring
    53  
    54  	// Address Codecs are the address codecs to use for client/v2.
    55  	AddressCodec          address.Codec
    56  	ValidatorAddressCodec runtime.ValidatorAddressCodec
    57  	ConsensusAddressCodec runtime.ConsensusAddressCodec
    58  }
    59  
    60  func (b *Builder) init() {
    61  	if b.messageFlagTypes == nil {
    62  		b.messageFlagTypes = map[protoreflect.FullName]Type{}
    63  		b.messageFlagTypes["google.protobuf.Timestamp"] = timestampType{}
    64  		b.messageFlagTypes["google.protobuf.Duration"] = durationType{}
    65  		b.messageFlagTypes["cosmos.base.v1beta1.Coin"] = coinType{}
    66  	}
    67  
    68  	if b.scalarFlagTypes == nil {
    69  		b.scalarFlagTypes = map[string]Type{}
    70  		b.scalarFlagTypes[AddressStringScalarType] = addressStringType{}
    71  		b.scalarFlagTypes[ValidatorAddressStringScalarType] = validatorAddressStringType{}
    72  		b.scalarFlagTypes[ConsensusAddressStringScalarType] = consensusAddressStringType{}
    73  	}
    74  }
    75  
    76  // DefineMessageFlagType allows to extend custom protobuf message type handling for flags (and positional arguments).
    77  func (b *Builder) DefineMessageFlagType(messageName protoreflect.FullName, flagType Type) {
    78  	b.init()
    79  	b.messageFlagTypes[messageName] = flagType
    80  }
    81  
    82  // DefineScalarFlagType allows to extend custom scalar type handling for flags (and positional arguments).
    83  func (b *Builder) DefineScalarFlagType(scalarName string, flagType Type) {
    84  	b.init()
    85  	b.scalarFlagTypes[scalarName] = flagType
    86  }
    87  
    88  // AddMessageFlags adds flags for each field in the message to the flag set.
    89  func (b *Builder) AddMessageFlags(ctx context.Context, flagSet *pflag.FlagSet, messageType protoreflect.MessageType, commandOptions *autocliv1.RpcCommandOptions) (*MessageBinder, error) {
    90  	return b.addMessageFlags(ctx, flagSet, messageType, commandOptions, namingOptions{})
    91  }
    92  
    93  // addMessageFlags adds flags for each field in the message to the flag set.
    94  func (b *Builder) addMessageFlags(ctx context.Context, flagSet *pflag.FlagSet, messageType protoreflect.MessageType, commandOptions *autocliv1.RpcCommandOptions, options namingOptions) (*MessageBinder, error) {
    95  	messageBinder := &MessageBinder{
    96  		messageType: messageType,
    97  		// positional args are also parsed using a FlagSet so that we can reuse all the same parsers
    98  		positionalFlagSet: pflag.NewFlagSet("positional", pflag.ContinueOnError),
    99  	}
   100  
   101  	fields := messageType.Descriptor().Fields()
   102  	signerFieldName := GetSignerFieldName(messageType.Descriptor())
   103  
   104  	isPositional := map[string]bool{}
   105  
   106  	lengthPositionalArgsOptions := len(commandOptions.PositionalArgs)
   107  	for i, arg := range commandOptions.PositionalArgs {
   108  		isPositional[arg.ProtoField] = true
   109  
   110  		// verify if a positional field is a signer field
   111  		if arg.ProtoField == signerFieldName {
   112  			messageBinder.SignerInfo = SignerInfo{
   113  				PositionalArgIndex: i,
   114  				FieldName:          arg.ProtoField,
   115  			}
   116  		}
   117  
   118  		field := fields.ByName(protoreflect.Name(arg.ProtoField))
   119  		if field == nil {
   120  			return nil, fmt.Errorf("can't find field %s on %s", arg.ProtoField, messageType.Descriptor().FullName())
   121  		}
   122  
   123  		if arg.Optional && arg.Varargs {
   124  			return nil, fmt.Errorf("positional argument %s can't be both optional and varargs", arg.ProtoField)
   125  		}
   126  
   127  		if arg.Varargs {
   128  			if i != lengthPositionalArgsOptions-1 {
   129  				return nil, fmt.Errorf("varargs positional argument %s must be the last argument", arg.ProtoField)
   130  			}
   131  
   132  			messageBinder.hasVarargs = true
   133  		}
   134  
   135  		if arg.Optional {
   136  			if i != lengthPositionalArgsOptions-1 {
   137  				return nil, fmt.Errorf("optional positional argument %s must be the last argument", arg.ProtoField)
   138  			}
   139  
   140  			messageBinder.hasOptional = true
   141  		}
   142  
   143  		_, hasValue, err := b.addFieldFlag(
   144  			ctx,
   145  			messageBinder.positionalFlagSet,
   146  			field,
   147  			&autocliv1.FlagOptions{Name: fmt.Sprintf("%d", i)},
   148  			namingOptions{},
   149  		)
   150  		if err != nil {
   151  			return nil, err
   152  		}
   153  
   154  		messageBinder.positionalArgs = append(messageBinder.positionalArgs, fieldBinding{
   155  			field:    field,
   156  			hasValue: hasValue,
   157  		})
   158  	}
   159  
   160  	if messageBinder.hasVarargs {
   161  		messageBinder.CobraArgs = cobra.MinimumNArgs(lengthPositionalArgsOptions - 1)
   162  		messageBinder.mandatoryArgUntil = lengthPositionalArgsOptions - 1
   163  	} else if messageBinder.hasOptional {
   164  		messageBinder.CobraArgs = cobra.RangeArgs(lengthPositionalArgsOptions-1, lengthPositionalArgsOptions)
   165  		messageBinder.mandatoryArgUntil = lengthPositionalArgsOptions - 1
   166  	} else {
   167  		messageBinder.CobraArgs = cobra.ExactArgs(lengthPositionalArgsOptions)
   168  		messageBinder.mandatoryArgUntil = lengthPositionalArgsOptions
   169  	}
   170  
   171  	// validate flag options
   172  	for name := range commandOptions.FlagOptions {
   173  		if fields.ByName(protoreflect.Name(name)) == nil {
   174  			return nil, fmt.Errorf("can't find field %s on %s specified as a flag", name, messageType.Descriptor().FullName())
   175  		}
   176  
   177  		// verify if a flag is a signer field
   178  		if name == signerFieldName {
   179  			messageBinder.SignerInfo = SignerInfo{
   180  				FieldName: name,
   181  				IsFlag:    false,
   182  			}
   183  		}
   184  	}
   185  
   186  	// if signer has not been specified as positional arguments,
   187  	// add it as `--from` flag (instead of --field-name flags)
   188  	if signerFieldName != "" && messageBinder.SignerInfo.FieldName == "" {
   189  		if commandOptions.FlagOptions == nil {
   190  			commandOptions.FlagOptions = make(map[string]*autocliv1.FlagOptions)
   191  		}
   192  
   193  		commandOptions.FlagOptions[signerFieldName] = &autocliv1.FlagOptions{
   194  			Name:      flags.FlagFrom,
   195  			Usage:     "Name or address with which to sign the message",
   196  			Shorthand: "f",
   197  		}
   198  
   199  		messageBinder.SignerInfo = SignerInfo{
   200  			FieldName: flags.FlagFrom,
   201  			IsFlag:    true,
   202  		}
   203  	}
   204  
   205  	// define all other fields as flags
   206  	flagOptsByFlagName := map[string]*autocliv1.FlagOptions{}
   207  	for i := 0; i < fields.Len(); i++ {
   208  		field := fields.Get(i)
   209  		// skips positional args and signer field if already set
   210  		if isPositional[string(field.Name())] ||
   211  			(string(field.Name()) == signerFieldName && messageBinder.SignerInfo.FieldName == flags.FlagFrom) {
   212  			continue
   213  		}
   214  
   215  		flagOpts := commandOptions.FlagOptions[string(field.Name())]
   216  		name, hasValue, err := b.addFieldFlag(ctx, flagSet, field, flagOpts, options)
   217  		flagOptsByFlagName[name] = flagOpts
   218  		if err != nil {
   219  			return nil, err
   220  		}
   221  
   222  		messageBinder.flagBindings = append(messageBinder.flagBindings, fieldBinding{
   223  			hasValue: hasValue,
   224  			field:    field,
   225  		})
   226  	}
   227  
   228  	flagSet.VisitAll(func(flag *pflag.Flag) {
   229  		opts := flagOptsByFlagName[flag.Name]
   230  		if opts != nil {
   231  			// This is a bit of hacking around the pflag API, but
   232  			// we need to set these options here using Flag.VisitAll because the flag
   233  			// constructors that pflag gives us (StringP, Int32P, etc.) do not
   234  			// actually return the *Flag instance
   235  			flag.Deprecated = opts.Deprecated
   236  			flag.ShorthandDeprecated = opts.ShorthandDeprecated
   237  			flag.Hidden = opts.Hidden
   238  		}
   239  	})
   240  
   241  	return messageBinder, nil
   242  }
   243  
   244  // bindPageRequest create a flag for pagination
   245  func (b *Builder) bindPageRequest(ctx context.Context, flagSet *pflag.FlagSet, field protoreflect.FieldDescriptor) (HasValue, error) {
   246  	return b.addMessageFlags(
   247  		ctx,
   248  		flagSet,
   249  		util.ResolveMessageType(b.TypeResolver, field.Message()),
   250  		&autocliv1.RpcCommandOptions{},
   251  		namingOptions{Prefix: "page-"},
   252  	)
   253  }
   254  
   255  // namingOptions specifies internal naming options for flags.
   256  type namingOptions struct {
   257  	// Prefix is a prefix to prepend to all flags.
   258  	Prefix string
   259  }
   260  
   261  // addFieldFlag adds a flag for the provided field to the flag set.
   262  func (b *Builder) addFieldFlag(ctx context.Context, flagSet *pflag.FlagSet, field protoreflect.FieldDescriptor, opts *autocliv1.FlagOptions, options namingOptions) (name string, hasValue HasValue, err error) {
   263  	if opts == nil {
   264  		opts = &autocliv1.FlagOptions{}
   265  	}
   266  
   267  	if field.Kind() == protoreflect.MessageKind && field.Message().FullName() == "cosmos.base.query.v1beta1.PageRequest" {
   268  		hasValue, err := b.bindPageRequest(ctx, flagSet, field)
   269  		return "", hasValue, err
   270  	}
   271  
   272  	name = opts.Name
   273  	if name == "" {
   274  		name = options.Prefix + util.DescriptorKebabName(field)
   275  	}
   276  
   277  	usage := opts.Usage
   278  	if usage == "" {
   279  		usage = util.DescriptorDocs(field)
   280  	}
   281  
   282  	shorthand := opts.Shorthand
   283  	defaultValue := opts.DefaultValue
   284  
   285  	if typ := b.resolveFlagType(field); typ != nil {
   286  		if defaultValue == "" {
   287  			defaultValue = typ.DefaultValue()
   288  		}
   289  
   290  		val := typ.NewValue(ctx, b)
   291  		flagSet.AddFlag(&pflag.Flag{
   292  			Name:      name,
   293  			Shorthand: shorthand,
   294  			Usage:     usage,
   295  			DefValue:  defaultValue,
   296  			Value:     val,
   297  		})
   298  		return name, val, nil
   299  	}
   300  
   301  	// use the built-in pflag StringP, Int32P, etc. functions
   302  	var val HasValue
   303  
   304  	if field.IsList() {
   305  		val = bindSimpleListFlag(flagSet, field.Kind(), name, shorthand, usage)
   306  	} else if field.IsMap() {
   307  		keyKind := field.MapKey().Kind()
   308  		valKind := field.MapValue().Kind()
   309  		val = bindSimpleMapFlag(flagSet, keyKind, valKind, name, shorthand, usage)
   310  	} else {
   311  		val = bindSimpleFlag(flagSet, field.Kind(), name, shorthand, usage)
   312  	}
   313  
   314  	// This is a bit of hacking around the pflag API, but the
   315  	// defaultValue is set in this way because this is much easier than trying
   316  	// to parse the string into the types that StringSliceP, Int32P, etc.
   317  	if defaultValue != "" {
   318  		err = flagSet.Set(name, defaultValue)
   319  	}
   320  
   321  	return name, val, err
   322  }
   323  
   324  func (b *Builder) resolveFlagType(field protoreflect.FieldDescriptor) Type {
   325  	typ := b.resolveFlagTypeBasic(field)
   326  	if field.IsList() {
   327  		if typ != nil {
   328  			return compositeListType{simpleType: typ}
   329  		}
   330  		return nil
   331  	}
   332  	if field.IsMap() {
   333  		keyKind := field.MapKey().Kind()
   334  		valType := b.resolveFlagType(field.MapValue())
   335  		if valType != nil {
   336  			switch keyKind {
   337  			case protoreflect.StringKind:
   338  				ct := new(compositeMapType[string])
   339  				ct.keyValueResolver = func(s string) (string, error) { return s, nil }
   340  				ct.valueType = valType
   341  				ct.keyType = "string"
   342  				return ct
   343  			case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
   344  				ct := new(compositeMapType[int32])
   345  				ct.keyValueResolver = func(s string) (int32, error) {
   346  					i, err := strconv.ParseInt(s, 10, 32)
   347  					return int32(i), err
   348  				}
   349  				ct.valueType = valType
   350  				ct.keyType = "int32"
   351  				return ct
   352  			case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
   353  				ct := new(compositeMapType[int64])
   354  				ct.keyValueResolver = func(s string) (int64, error) {
   355  					i, err := strconv.ParseInt(s, 10, 64)
   356  					return i, err
   357  				}
   358  				ct.valueType = valType
   359  				ct.keyType = "int64"
   360  				return ct
   361  			case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
   362  				ct := new(compositeMapType[uint32])
   363  				ct.keyValueResolver = func(s string) (uint32, error) {
   364  					i, err := strconv.ParseUint(s, 10, 32)
   365  					return uint32(i), err
   366  				}
   367  				ct.valueType = valType
   368  				ct.keyType = "uint32"
   369  				return ct
   370  			case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
   371  				ct := new(compositeMapType[uint64])
   372  				ct.keyValueResolver = func(s string) (uint64, error) {
   373  					i, err := strconv.ParseUint(s, 10, 64)
   374  					return i, err
   375  				}
   376  				ct.valueType = valType
   377  				ct.keyType = "uint64"
   378  				return ct
   379  			case protoreflect.BoolKind:
   380  				ct := new(compositeMapType[bool])
   381  				ct.keyValueResolver = strconv.ParseBool
   382  				ct.valueType = valType
   383  				ct.keyType = "bool"
   384  				return ct
   385  			}
   386  			return nil
   387  
   388  		}
   389  		return nil
   390  	}
   391  
   392  	return typ
   393  }
   394  
   395  func (b *Builder) resolveFlagTypeBasic(field protoreflect.FieldDescriptor) Type {
   396  	scalar, ok := GetScalarType(field)
   397  	if ok {
   398  		b.init()
   399  		if typ, ok := b.scalarFlagTypes[scalar]; ok {
   400  			return typ
   401  		}
   402  	}
   403  
   404  	switch field.Kind() {
   405  	case protoreflect.BytesKind:
   406  		return binaryType{}
   407  	case protoreflect.EnumKind:
   408  		return enumType{enum: field.Enum()}
   409  	case protoreflect.MessageKind:
   410  		b.init()
   411  		if flagType, ok := b.messageFlagTypes[field.Message().FullName()]; ok {
   412  			return flagType
   413  		}
   414  		return jsonMessageFlagType{
   415  			messageDesc: field.Message(),
   416  		}
   417  	default:
   418  		return nil
   419  	}
   420  }
   421  
   422  // GetScalarType gets scalar type of a field.
   423  func GetScalarType(field protoreflect.FieldDescriptor) (string, bool) {
   424  	scalar := proto.GetExtension(field.Options(), cosmos_proto.E_Scalar)
   425  	scalarStr, ok := scalar.(string)
   426  	return scalarStr, ok
   427  }
   428  
   429  // GetSignerFieldName gets signer field name of a message.
   430  // AutoCLI supports only one signer field per message.
   431  func GetSignerFieldName(descriptor protoreflect.MessageDescriptor) string {
   432  	signersFields := proto.GetExtension(descriptor.Options(), msgv1.E_Signer).([]string)
   433  	if len(signersFields) == 0 {
   434  		return ""
   435  	}
   436  
   437  	return signersFields[0]
   438  }