github.com/cosmos/cosmos-proto@v1.0.0-beta.3/rapidproto/rapidproto.go (about)

     1  package rapidproto
     2  
     3  import (
     4  	"fmt"
     5  	"math"
     6  
     7  	cosmos_proto "github.com/cosmos/cosmos-proto"
     8  	"google.golang.org/protobuf/proto"
     9  	"google.golang.org/protobuf/reflect/protoreflect"
    10  	"google.golang.org/protobuf/reflect/protoregistry"
    11  	"gotest.tools/v3/assert"
    12  	"pgregory.net/rapid"
    13  )
    14  
    15  func MessageGenerator[T proto.Message](x T, options GeneratorOptions) *rapid.Generator[T] {
    16  	msgType := x.ProtoReflect().Type()
    17  	return rapid.Custom(func(t *rapid.T) T {
    18  		msg := msgType.New()
    19  
    20  		options.setFields(t, nil, msg, 0)
    21  
    22  		return msg.Interface().(T)
    23  	})
    24  }
    25  
    26  // FieldMapper is a function that can be used to override the default behavior of the generator for a specific field.
    27  // The first argument is the rapid.T, the second is the field descriptor, and the third is the field name.
    28  // If the function returns nil, the default behavior will be used.
    29  type FieldMapper func(*rapid.T, protoreflect.FieldDescriptor, string) (protoreflect.Value, bool)
    30  
    31  type GeneratorOptions struct {
    32  	AnyTypeURLs    []string
    33  	InterfaceHints map[string]string
    34  	Resolver       protoregistry.MessageTypeResolver
    35  
    36  	// NoEmptyLists will cause the generator to not generate empty lists
    37  	// Recall that an empty list will marshal (and unmarshal) to null. Some encodings may treat these states
    38  	// differently.  For example, in JSON, an empty list is encoded as [], while null is encoded as null.
    39  	NoEmptyLists bool
    40  
    41  	// DisallowNilMessages will cause the generator to not generate nil messages to protoreflect.MessageKind fields
    42  	DisallowNilMessages bool
    43  
    44  	// FieldMaps is a list of FieldMapper functions that can be used to override the default behavior of the generator
    45  	// for a specific field.
    46  	FieldMaps []FieldMapper
    47  }
    48  
    49  const depthLimit = 10
    50  
    51  func (opts GeneratorOptions) WithAnyTypes(anyTypes ...proto.Message) GeneratorOptions {
    52  	for _, a := range anyTypes {
    53  		opts.AnyTypeURLs = append(opts.AnyTypeURLs, fmt.Sprintf("/%s", a.ProtoReflect().Descriptor().FullName()))
    54  	}
    55  	return opts
    56  }
    57  
    58  func (opts GeneratorOptions) WithDisallowNil() GeneratorOptions {
    59  	o := &opts
    60  	o.DisallowNilMessages = true
    61  	return *o
    62  }
    63  
    64  func (opts GeneratorOptions) WithInterfaceHint(i string, impl proto.Message) GeneratorOptions {
    65  	if opts.InterfaceHints == nil {
    66  		opts.InterfaceHints = make(map[string]string)
    67  	}
    68  	opts.InterfaceHints[i] = string(impl.ProtoReflect().Descriptor().FullName())
    69  	return opts
    70  }
    71  
    72  func (opts GeneratorOptions) setFields(
    73  	t *rapid.T, field protoreflect.FieldDescriptor, msg protoreflect.Message, depth int) bool {
    74  	// to avoid stack overflow we limit the depth of nested messages
    75  	if depth > depthLimit {
    76  		return false
    77  	}
    78  
    79  	descriptor := msg.Descriptor()
    80  	fullName := descriptor.FullName()
    81  	switch fullName {
    82  	case timestampFullName:
    83  		opts.genTimestamp(t, msg)
    84  		return true
    85  	case durationFullName:
    86  		opts.genDuration(t, msg)
    87  		return true
    88  	case anyFullName:
    89  		opts.genAny(t, field, msg, depth)
    90  		return true
    91  	case fieldMaskFullName:
    92  		opts.genFieldMask(t, msg)
    93  		return true
    94  	default:
    95  		fields := descriptor.Fields()
    96  		n := fields.Len()
    97  		for i := 0; i < n; i++ {
    98  			f := fields.Get(i)
    99  			if !rapid.Bool().Draw(t, fmt.Sprintf("gen-%s", f.Name())) {
   100  				if (f.Kind() == protoreflect.MessageKind) && !opts.DisallowNilMessages {
   101  					continue
   102  				}
   103  			}
   104  
   105  			opts.setFieldValue(t, msg, f, depth)
   106  		}
   107  		return true
   108  	}
   109  }
   110  
   111  const (
   112  	timestampFullName = "google.protobuf.Timestamp"
   113  	durationFullName  = "google.protobuf.Duration"
   114  	anyFullName       = "google.protobuf.Any"
   115  	fieldMaskFullName = "google.protobuf.FieldMask"
   116  )
   117  
   118  func (opts GeneratorOptions) setFieldValue(t *rapid.T, msg protoreflect.Message, field protoreflect.FieldDescriptor, depth int) {
   119  	name := string(field.Name())
   120  	kind := field.Kind()
   121  
   122  	switch {
   123  	case field.IsList():
   124  		list := msg.Mutable(field).List()
   125  		min := 0
   126  		if opts.NoEmptyLists {
   127  			min = 1
   128  		}
   129  		n := rapid.IntRange(min, 10).Draw(t, fmt.Sprintf("%sN", name))
   130  		for i := 0; i < n; i++ {
   131  			if kind == protoreflect.MessageKind || kind == protoreflect.GroupKind {
   132  				if !opts.setFields(t, field, list.AppendMutable().Message(), depth+1) {
   133  					list.Truncate(i)
   134  				}
   135  			} else {
   136  				list.Append(opts.genScalarFieldValue(t, field, fmt.Sprintf("%s%d", name, i)))
   137  			}
   138  		}
   139  	case field.IsMap():
   140  		m := msg.Mutable(field).Map()
   141  		n := rapid.IntRange(0, 10).Draw(t, fmt.Sprintf("%sN", name))
   142  		for i := 0; i < n; i++ {
   143  			keyField := field.MapKey()
   144  			valueField := field.MapValue()
   145  			valueKind := valueField.Kind()
   146  			key := opts.genScalarFieldValue(t, keyField, fmt.Sprintf("%s%d-key", name, i))
   147  			if valueKind == protoreflect.MessageKind || valueKind == protoreflect.GroupKind {
   148  				if !opts.setFields(t, field, m.Mutable(key.MapKey()).Message(), depth+1) {
   149  					m.Clear(key.MapKey())
   150  				}
   151  			} else {
   152  				value := opts.genScalarFieldValue(t, valueField, fmt.Sprintf("%s%d-key", name, i))
   153  				m.Set(key.MapKey(), value)
   154  			}
   155  		}
   156  	case kind == protoreflect.MessageKind:
   157  		mutableField := msg.Mutable(field)
   158  		if mutableField.Message().Descriptor().FullName() == anyFullName {
   159  			if !opts.genAny(t, field, mutableField.Message(), depth+1) {
   160  				msg.Clear(field)
   161  			}
   162  		} else if !opts.setFields(t, field, mutableField.Message(), depth+1) {
   163  			msg.Clear(field)
   164  		}
   165  	case kind == protoreflect.GroupKind:
   166  		if !opts.setFields(t, field, msg.Mutable(field).Message(), depth+1) {
   167  			msg.Clear(field)
   168  		}
   169  	default:
   170  		msg.Set(field, opts.genScalarFieldValue(t, field, name))
   171  	}
   172  }
   173  
   174  func (opts GeneratorOptions) genScalarFieldValue(t *rapid.T, field protoreflect.FieldDescriptor, name string) protoreflect.Value {
   175  	for _, fm := range opts.FieldMaps {
   176  		if v, ok := fm(t, field, name); ok {
   177  			return v
   178  		}
   179  	}
   180  
   181  	switch field.Kind() {
   182  	case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
   183  		return protoreflect.ValueOfInt32(rapid.Int32().Draw(t, name))
   184  	case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
   185  		return protoreflect.ValueOfUint32(rapid.Uint32().Draw(t, name))
   186  	case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
   187  		return protoreflect.ValueOfInt64(rapid.Int64().Draw(t, name))
   188  	case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
   189  		return protoreflect.ValueOfUint64(rapid.Uint64().Draw(t, name))
   190  	case protoreflect.BoolKind:
   191  		return protoreflect.ValueOfBool(rapid.Bool().Draw(t, name))
   192  	case protoreflect.BytesKind:
   193  		return protoreflect.ValueOfBytes(rapid.SliceOf(rapid.Byte()).Draw(t, name))
   194  	case protoreflect.FloatKind:
   195  		return protoreflect.ValueOfFloat32(rapid.Float32().Draw(t, name))
   196  	case protoreflect.DoubleKind:
   197  		return protoreflect.ValueOfFloat64(rapid.Float64().Draw(t, name))
   198  	case protoreflect.EnumKind:
   199  		enumValues := field.Enum().Values()
   200  		val := rapid.Int32Range(0, int32(enumValues.Len()-1)).Draw(t, name)
   201  		return protoreflect.ValueOfEnum(protoreflect.EnumNumber(val))
   202  	case protoreflect.StringKind:
   203  		return protoreflect.ValueOfString(rapid.String().Draw(t, name))
   204  	default:
   205  		t.Fatalf("unexpected %v", field)
   206  		return protoreflect.Value{}
   207  	}
   208  }
   209  
   210  const (
   211  	// MaxDurationSeconds the maximum number of seconds (when expressed as nanoseconds) which can fit in an int64.
   212  	// gogoproto encodes google.protobuf.Duration as a time.Duration, which is 64-bit signed integer.
   213  	MaxDurationSeconds = int64(math.MaxInt64/int(1e9)) - 1
   214  	secondsName        = "seconds"
   215  	nanosName          = "nanos"
   216  )
   217  
   218  func (opts GeneratorOptions) genTimestamp(t *rapid.T, msg protoreflect.Message) {
   219  	seconds := rapid.Int64Range(-9999999999, 9999999999).Draw(t, "seconds")
   220  	nanos := rapid.Int32Range(0, 999999999).Draw(t, "nanos")
   221  	setSecondsNanosFields(t, msg, seconds, nanos)
   222  }
   223  
   224  func (opts GeneratorOptions) genDuration(t *rapid.T, msg protoreflect.Message) {
   225  	seconds := rapid.Int64Range(0, int64(MaxDurationSeconds)).Draw(t, "seconds")
   226  	nanos := rapid.Int32Range(0, 999999999).Draw(t, "nanos")
   227  	setSecondsNanosFields(t, msg, seconds, nanos)
   228  }
   229  
   230  func setSecondsNanosFields(t *rapid.T, message protoreflect.Message, seconds int64, nanos int32) {
   231  	fields := message.Descriptor().Fields()
   232  
   233  	secondsField := fields.ByName(secondsName)
   234  	assert.Assert(t, secondsField != nil)
   235  	message.Set(secondsField, protoreflect.ValueOfInt64(seconds))
   236  
   237  	nanosField := fields.ByName(nanosName)
   238  	assert.Assert(t, nanosField != nil)
   239  	message.Set(nanosField, protoreflect.ValueOfInt32(nanos))
   240  }
   241  
   242  const (
   243  	typeURLName = "type_url"
   244  	valueName   = "value"
   245  )
   246  
   247  func (opts GeneratorOptions) genAny(
   248  	t *rapid.T, field protoreflect.FieldDescriptor, msg protoreflect.Message, depth int) bool {
   249  	if len(opts.AnyTypeURLs) == 0 {
   250  		return false
   251  	}
   252  
   253  	var typeURL string
   254  	fopts := field.Options()
   255  	if proto.HasExtension(fopts, cosmos_proto.E_AcceptsInterface) {
   256  		ai := proto.GetExtension(fopts, cosmos_proto.E_AcceptsInterface).(string)
   257  		if impl, found := opts.InterfaceHints[ai]; found {
   258  			typeURL = fmt.Sprintf("/%s", impl)
   259  		} else {
   260  			panic(fmt.Sprintf("no implementation found for interface %s", ai))
   261  		}
   262  	} else {
   263  		typeURL = rapid.SampledFrom(opts.AnyTypeURLs).Draw(t, "type_url")
   264  	}
   265  
   266  	typ, err := opts.Resolver.FindMessageByURL(typeURL)
   267  	assert.NilError(t, err)
   268  	fields := msg.Descriptor().Fields()
   269  
   270  	typeURLField := fields.ByName(typeURLName)
   271  	assert.Assert(t, typeURLField != nil)
   272  	msg.Set(typeURLField, protoreflect.ValueOfString(typeURL))
   273  
   274  	valueMsg := typ.New()
   275  	opts.setFields(t, nil, valueMsg, depth+1)
   276  	valueBz, err := proto.Marshal(valueMsg.Interface())
   277  	assert.NilError(t, err)
   278  
   279  	valueField := fields.ByName(valueName)
   280  	assert.Assert(t, valueField != nil)
   281  	msg.Set(valueField, protoreflect.ValueOfBytes(valueBz))
   282  
   283  	return true
   284  }
   285  
   286  const (
   287  	pathsName = "paths"
   288  )
   289  
   290  func (opts GeneratorOptions) genFieldMask(t *rapid.T, msg protoreflect.Message) {
   291  	paths := rapid.SliceOfN(rapid.StringMatching("[a-z]+([.][a-z]+){0,2}"), 1, 5).Draw(t, "paths")
   292  	pathsField := msg.Descriptor().Fields().ByName(pathsName)
   293  	assert.Assert(t, pathsField != nil)
   294  	pathsList := msg.NewField(pathsField).List()
   295  	for _, path := range paths {
   296  		pathsList.Append(protoreflect.ValueOfString(path))
   297  	}
   298  }