github.com/cosmos/cosmos-proto@v1.0.0-beta.3/internal/fuzz/message.go (about)

     1  package fuzz
     2  
     3  import (
     4  	"fmt"
     5  	"math"
     6  
     7  	"google.golang.org/protobuf/reflect/protoreflect"
     8  	"pgregory.net/rapid"
     9  )
    10  
    11  const (
    12  	MaxDepthDefault   = 2
    13  	MaxListLength     = 50
    14  	MaxBytesArraySize = 100
    15  )
    16  
    17  func Message(t *rapid.T, typ protoreflect.MessageType) protoreflect.Message {
    18  	g := &generator{
    19  		typ:            typ,
    20  		m:              typ.New(),
    21  		t:              t,
    22  		pickedOneofs:   map[protoreflect.FullName]protoreflect.FullName{},
    23  		fillAll:        true,
    24  		invalidValue:   false,
    25  		depth:          0,
    26  		maxDepth:       MaxDepthDefault,
    27  		maxListLength:  MaxListLength,
    28  		maxMapLength:   MaxListLength,
    29  		maxBytesLength: MaxBytesArraySize,
    30  	}
    31  
    32  	g.generate()
    33  	return g.m
    34  }
    35  
    36  type generator struct {
    37  	typ protoreflect.MessageType
    38  	m   protoreflect.Message
    39  	t   *rapid.T
    40  
    41  	pickedOneofs map[protoreflect.FullName]protoreflect.FullName // maps oneof fullname to picked field descriptor full name
    42  
    43  	fillAll      bool
    44  	invalidValue bool
    45  
    46  	depth    int
    47  	maxDepth int
    48  
    49  	maxListLength  int
    50  	maxMapLength   int
    51  	maxBytesLength int
    52  }
    53  
    54  func (g *generator) generate() {
    55  	if g.depth == g.maxDepth {
    56  		g.m = g.typ.New()
    57  		return
    58  	}
    59  
    60  	// pick oneofs
    61  	g.decideOneofs()
    62  
    63  	for i := 0; i < g.typ.Descriptor().Fields().Len(); i++ {
    64  		fd := g.typ.Descriptor().Fields().Get(i)
    65  
    66  		switch g.fillAll {
    67  		case true:
    68  			g.field(fd)
    69  		default:
    70  			genField := rapid.Bool().Draw(g.t, fmt.Sprintf("skip field generation: %s", fd.FullName()))
    71  			if !genField {
    72  				continue
    73  			}
    74  
    75  			g.field(fd)
    76  		}
    77  	}
    78  }
    79  
    80  // field fill the message with a random value
    81  func (g *generator) field(fd protoreflect.FieldDescriptor) {
    82  	// check if field is part of a oneof and if it is check if it was the picked one
    83  	if isOneof(fd) && !g.chosenOneof(fd) {
    84  		return
    85  	}
    86  	// check if we can set an invalid value
    87  	if g.invalidValue && rapid.Bool().Draw(g.t, fmt.Sprintf("generate invalid value for field %s", fd.FullName())) {
    88  		g.m.Set(fd, protoreflect.Value{})
    89  	}
    90  
    91  	switch {
    92  	case fd.IsList():
    93  		g.list(fd)
    94  	case fd.IsMap():
    95  		g.mapp(fd)
    96  	default:
    97  		g.value(fd)
    98  	}
    99  }
   100  
   101  func isOneof(fd protoreflect.FieldDescriptor) bool {
   102  	return fd.ContainingOneof() != nil
   103  }
   104  
   105  func (g *generator) list(fd protoreflect.FieldDescriptor) {
   106  	list := g.m.NewField(fd).List()
   107  	length := rapid.IntRange(0, g.maxListLength).Draw(g.t, fmt.Sprintf("list length for %s", fd.FullName()))
   108  
   109  	for i := 0; i < length; i++ {
   110  		switch fd.Kind() {
   111  		case protoreflect.MessageKind:
   112  			gen := g.embeddedMessage(list.NewElement().Message().Type())
   113  			list.Append(protoreflect.ValueOfMessage(gen))
   114  		default:
   115  			list.Append(g.valueFor(fd))
   116  		}
   117  	}
   118  
   119  	g.m.Set(fd, protoreflect.ValueOfList(list))
   120  }
   121  
   122  func (g *generator) mapp(fd protoreflect.FieldDescriptor) {
   123  	keyDesc := fd.MapKey()
   124  	valueDesc := fd.MapValue()
   125  
   126  	mapValue := g.m.NewField(fd).Map()
   127  
   128  	length := rapid.IntRange(0, g.maxMapLength).Draw(g.t, "map length for "+string(fd.FullName()))
   129  
   130  	for i := 0; i < length; i++ {
   131  		keyValue := protoreflect.MapKey(g.valueFor(keyDesc))
   132  		var valueValue protoreflect.Value
   133  
   134  		switch valueDesc.Kind() {
   135  		case protoreflect.MessageKind:
   136  			gen := g.embeddedMessage(mapValue.NewValue().Message().Type())
   137  			valueValue = protoreflect.ValueOfMessage(gen)
   138  		default:
   139  			valueValue = g.valueFor(valueDesc)
   140  		}
   141  		mapValue.Set(keyValue, valueValue)
   142  	}
   143  
   144  	g.m.Set(fd, protoreflect.ValueOfMap(mapValue))
   145  }
   146  
   147  func (g *generator) value(fd protoreflect.FieldDescriptor) {
   148  	var value protoreflect.Value
   149  	switch fd.Kind() {
   150  	case protoreflect.MessageKind:
   151  		msg := g.embeddedMessage(g.m.NewField(fd).Message().Type())
   152  		value = protoreflect.ValueOfMessage(msg)
   153  	default:
   154  		value = g.valueFor(fd)
   155  	}
   156  
   157  	g.m.Set(fd, value)
   158  }
   159  
   160  // valueFor generates a random protoreflect.Value which is not of protoreflect.MessageKind
   161  func (g *generator) valueFor(fd protoreflect.FieldDescriptor) protoreflect.Value {
   162  	switch fd.Kind() {
   163  	// bool kind
   164  	case protoreflect.BoolKind:
   165  		value := rapid.Bool().Draw(g.t, label(fd))
   166  		return protoreflect.ValueOfBool(value)
   167  	// int32 kinds
   168  	case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
   169  		value := rapid.Int32().Draw(g.t, label(fd))
   170  		return protoreflect.ValueOfInt32(value)
   171  	// int64 kinds
   172  	case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
   173  		value := rapid.Int64().Draw(g.t, label(fd))
   174  		return protoreflect.ValueOfInt64(value)
   175  	// uint32 kinds
   176  	case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
   177  		value := rapid.Uint32().Draw(g.t, label(fd))
   178  		return protoreflect.ValueOfUint32(value)
   179  	// uint64 kinds
   180  	case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
   181  		value := rapid.Uint64().Draw(g.t, label(fd))
   182  		return protoreflect.ValueOfUint64(value)
   183  	// float32 kind
   184  	case protoreflect.FloatKind:
   185  		value := rapid.Float32Max(math.MaxFloat32).Draw(g.t, label(fd))
   186  		return protoreflect.ValueOfFloat32(value)
   187  	// float64 kind
   188  	case protoreflect.DoubleKind:
   189  		value := rapid.Float64().Draw(g.t, label(fd))
   190  		return protoreflect.ValueOfFloat64(value)
   191  	// string kind
   192  	case protoreflect.StringKind:
   193  		value := rapid.String().Draw(g.t, label(fd))
   194  		return protoreflect.ValueOfString(value)
   195  	// bytes kind
   196  	case protoreflect.BytesKind:
   197  		value := randomBytes(g.t, fd)
   198  		return protoreflect.ValueOfBytes(value)
   199  	// enum kind
   200  	case protoreflect.EnumKind:
   201  		enumIndex := rapid.IntRange(0, fd.Enum().Values().Len()-1).Draw(g.t, "random enum index for "+string(fd.FullName()))
   202  		enum := fd.Enum().Values().Get(enumIndex)
   203  		return protoreflect.ValueOfEnum(enum.Number())
   204  	default:
   205  		panic(fmt.Errorf("cannot handle: %s", fd.Kind()))
   206  	}
   207  }
   208  
   209  // embeddedMessage returns a generator for a message which is contained within the current message
   210  // it is needed mainly to avoid endless cycles on recursive messages
   211  func (g *generator) embeddedMessage(typ protoreflect.MessageType) protoreflect.Message {
   212  	gen := &generator{
   213  		typ:            typ,
   214  		m:              typ.New(),
   215  		t:              g.t,
   216  		fillAll:        g.fillAll,
   217  		invalidValue:   g.invalidValue,
   218  		depth:          g.depth + 1,
   219  		maxDepth:       g.maxDepth,
   220  		maxListLength:  g.maxListLength,
   221  		maxMapLength:   g.maxMapLength,
   222  		maxBytesLength: g.maxBytesLength,
   223  		pickedOneofs:   map[protoreflect.FullName]protoreflect.FullName{},
   224  	}
   225  
   226  	gen.generate()
   227  	return gen.m
   228  }
   229  
   230  // decideOneofs picks the one protoreflect.FieldDescriptor for each oneof
   231  func (g *generator) decideOneofs() {
   232  	md := g.typ.Descriptor()
   233  	for i := 0; i < md.Oneofs().Len(); i++ {
   234  		oneof := md.Oneofs().Get(i)
   235  		index := rapid.IntRange(0, oneof.Fields().Len()-1).Draw(g.t, "deciding oneof field for: "+string(oneof.FullName()))
   236  		decidedFd := oneof.Fields().Get(index)
   237  		g.pickedOneofs[oneof.FullName()] = decidedFd.FullName()
   238  	}
   239  }
   240  
   241  func (g *generator) chosenOneof(fd protoreflect.FieldDescriptor) bool {
   242  	chosenFdName := g.pickedOneofs[fd.ContainingOneof().FullName()]
   243  
   244  	return chosenFdName == fd.FullName()
   245  }
   246  
   247  func label(fd protoreflect.FieldDescriptor) string {
   248  	return fmt.Sprintf("value for %s", fd.FullName())
   249  }
   250  
   251  func randomBytes(t *rapid.T, fd protoreflect.FieldDescriptor) []byte {
   252  	size := rapid.IntRange(0, MaxBytesArraySize).Draw(t, "bytes slice size for %s"+string(fd.FullName()))
   253  	return rapid.SliceOfN(rapid.Byte(), 0, size).Draw(t, "bytes slice for %s"+string(fd.FullName()))
   254  }