github.com/cosmos/cosmos-proto@v1.0.0-beta.3/rapidproto/rapidproto.go (about) 1 package rapidproto 2 3 import ( 4 "fmt" 5 "math" 6 7 cosmos_proto "github.com/cosmos/cosmos-proto" 8 "google.golang.org/protobuf/proto" 9 "google.golang.org/protobuf/reflect/protoreflect" 10 "google.golang.org/protobuf/reflect/protoregistry" 11 "gotest.tools/v3/assert" 12 "pgregory.net/rapid" 13 ) 14 15 func MessageGenerator[T proto.Message](x T, options GeneratorOptions) *rapid.Generator[T] { 16 msgType := x.ProtoReflect().Type() 17 return rapid.Custom(func(t *rapid.T) T { 18 msg := msgType.New() 19 20 options.setFields(t, nil, msg, 0) 21 22 return msg.Interface().(T) 23 }) 24 } 25 26 // FieldMapper is a function that can be used to override the default behavior of the generator for a specific field. 27 // The first argument is the rapid.T, the second is the field descriptor, and the third is the field name. 28 // If the function returns nil, the default behavior will be used. 29 type FieldMapper func(*rapid.T, protoreflect.FieldDescriptor, string) (protoreflect.Value, bool) 30 31 type GeneratorOptions struct { 32 AnyTypeURLs []string 33 InterfaceHints map[string]string 34 Resolver protoregistry.MessageTypeResolver 35 36 // NoEmptyLists will cause the generator to not generate empty lists 37 // Recall that an empty list will marshal (and unmarshal) to null. Some encodings may treat these states 38 // differently. For example, in JSON, an empty list is encoded as [], while null is encoded as null. 39 NoEmptyLists bool 40 41 // DisallowNilMessages will cause the generator to not generate nil messages to protoreflect.MessageKind fields 42 DisallowNilMessages bool 43 44 // FieldMaps is a list of FieldMapper functions that can be used to override the default behavior of the generator 45 // for a specific field. 46 FieldMaps []FieldMapper 47 } 48 49 const depthLimit = 10 50 51 func (opts GeneratorOptions) WithAnyTypes(anyTypes ...proto.Message) GeneratorOptions { 52 for _, a := range anyTypes { 53 opts.AnyTypeURLs = append(opts.AnyTypeURLs, fmt.Sprintf("/%s", a.ProtoReflect().Descriptor().FullName())) 54 } 55 return opts 56 } 57 58 func (opts GeneratorOptions) WithDisallowNil() GeneratorOptions { 59 o := &opts 60 o.DisallowNilMessages = true 61 return *o 62 } 63 64 func (opts GeneratorOptions) WithInterfaceHint(i string, impl proto.Message) GeneratorOptions { 65 if opts.InterfaceHints == nil { 66 opts.InterfaceHints = make(map[string]string) 67 } 68 opts.InterfaceHints[i] = string(impl.ProtoReflect().Descriptor().FullName()) 69 return opts 70 } 71 72 func (opts GeneratorOptions) setFields( 73 t *rapid.T, field protoreflect.FieldDescriptor, msg protoreflect.Message, depth int) bool { 74 // to avoid stack overflow we limit the depth of nested messages 75 if depth > depthLimit { 76 return false 77 } 78 79 descriptor := msg.Descriptor() 80 fullName := descriptor.FullName() 81 switch fullName { 82 case timestampFullName: 83 opts.genTimestamp(t, msg) 84 return true 85 case durationFullName: 86 opts.genDuration(t, msg) 87 return true 88 case anyFullName: 89 opts.genAny(t, field, msg, depth) 90 return true 91 case fieldMaskFullName: 92 opts.genFieldMask(t, msg) 93 return true 94 default: 95 fields := descriptor.Fields() 96 n := fields.Len() 97 for i := 0; i < n; i++ { 98 f := fields.Get(i) 99 if !rapid.Bool().Draw(t, fmt.Sprintf("gen-%s", f.Name())) { 100 if (f.Kind() == protoreflect.MessageKind) && !opts.DisallowNilMessages { 101 continue 102 } 103 } 104 105 opts.setFieldValue(t, msg, f, depth) 106 } 107 return true 108 } 109 } 110 111 const ( 112 timestampFullName = "google.protobuf.Timestamp" 113 durationFullName = "google.protobuf.Duration" 114 anyFullName = "google.protobuf.Any" 115 fieldMaskFullName = "google.protobuf.FieldMask" 116 ) 117 118 func (opts GeneratorOptions) setFieldValue(t *rapid.T, msg protoreflect.Message, field protoreflect.FieldDescriptor, depth int) { 119 name := string(field.Name()) 120 kind := field.Kind() 121 122 switch { 123 case field.IsList(): 124 list := msg.Mutable(field).List() 125 min := 0 126 if opts.NoEmptyLists { 127 min = 1 128 } 129 n := rapid.IntRange(min, 10).Draw(t, fmt.Sprintf("%sN", name)) 130 for i := 0; i < n; i++ { 131 if kind == protoreflect.MessageKind || kind == protoreflect.GroupKind { 132 if !opts.setFields(t, field, list.AppendMutable().Message(), depth+1) { 133 list.Truncate(i) 134 } 135 } else { 136 list.Append(opts.genScalarFieldValue(t, field, fmt.Sprintf("%s%d", name, i))) 137 } 138 } 139 case field.IsMap(): 140 m := msg.Mutable(field).Map() 141 n := rapid.IntRange(0, 10).Draw(t, fmt.Sprintf("%sN", name)) 142 for i := 0; i < n; i++ { 143 keyField := field.MapKey() 144 valueField := field.MapValue() 145 valueKind := valueField.Kind() 146 key := opts.genScalarFieldValue(t, keyField, fmt.Sprintf("%s%d-key", name, i)) 147 if valueKind == protoreflect.MessageKind || valueKind == protoreflect.GroupKind { 148 if !opts.setFields(t, field, m.Mutable(key.MapKey()).Message(), depth+1) { 149 m.Clear(key.MapKey()) 150 } 151 } else { 152 value := opts.genScalarFieldValue(t, valueField, fmt.Sprintf("%s%d-key", name, i)) 153 m.Set(key.MapKey(), value) 154 } 155 } 156 case kind == protoreflect.MessageKind: 157 mutableField := msg.Mutable(field) 158 if mutableField.Message().Descriptor().FullName() == anyFullName { 159 if !opts.genAny(t, field, mutableField.Message(), depth+1) { 160 msg.Clear(field) 161 } 162 } else if !opts.setFields(t, field, mutableField.Message(), depth+1) { 163 msg.Clear(field) 164 } 165 case kind == protoreflect.GroupKind: 166 if !opts.setFields(t, field, msg.Mutable(field).Message(), depth+1) { 167 msg.Clear(field) 168 } 169 default: 170 msg.Set(field, opts.genScalarFieldValue(t, field, name)) 171 } 172 } 173 174 func (opts GeneratorOptions) genScalarFieldValue(t *rapid.T, field protoreflect.FieldDescriptor, name string) protoreflect.Value { 175 for _, fm := range opts.FieldMaps { 176 if v, ok := fm(t, field, name); ok { 177 return v 178 } 179 } 180 181 switch field.Kind() { 182 case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: 183 return protoreflect.ValueOfInt32(rapid.Int32().Draw(t, name)) 184 case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: 185 return protoreflect.ValueOfUint32(rapid.Uint32().Draw(t, name)) 186 case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: 187 return protoreflect.ValueOfInt64(rapid.Int64().Draw(t, name)) 188 case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: 189 return protoreflect.ValueOfUint64(rapid.Uint64().Draw(t, name)) 190 case protoreflect.BoolKind: 191 return protoreflect.ValueOfBool(rapid.Bool().Draw(t, name)) 192 case protoreflect.BytesKind: 193 return protoreflect.ValueOfBytes(rapid.SliceOf(rapid.Byte()).Draw(t, name)) 194 case protoreflect.FloatKind: 195 return protoreflect.ValueOfFloat32(rapid.Float32().Draw(t, name)) 196 case protoreflect.DoubleKind: 197 return protoreflect.ValueOfFloat64(rapid.Float64().Draw(t, name)) 198 case protoreflect.EnumKind: 199 enumValues := field.Enum().Values() 200 val := rapid.Int32Range(0, int32(enumValues.Len()-1)).Draw(t, name) 201 return protoreflect.ValueOfEnum(protoreflect.EnumNumber(val)) 202 case protoreflect.StringKind: 203 return protoreflect.ValueOfString(rapid.String().Draw(t, name)) 204 default: 205 t.Fatalf("unexpected %v", field) 206 return protoreflect.Value{} 207 } 208 } 209 210 const ( 211 // MaxDurationSeconds the maximum number of seconds (when expressed as nanoseconds) which can fit in an int64. 212 // gogoproto encodes google.protobuf.Duration as a time.Duration, which is 64-bit signed integer. 213 MaxDurationSeconds = int64(math.MaxInt64/int(1e9)) - 1 214 secondsName = "seconds" 215 nanosName = "nanos" 216 ) 217 218 func (opts GeneratorOptions) genTimestamp(t *rapid.T, msg protoreflect.Message) { 219 seconds := rapid.Int64Range(-9999999999, 9999999999).Draw(t, "seconds") 220 nanos := rapid.Int32Range(0, 999999999).Draw(t, "nanos") 221 setSecondsNanosFields(t, msg, seconds, nanos) 222 } 223 224 func (opts GeneratorOptions) genDuration(t *rapid.T, msg protoreflect.Message) { 225 seconds := rapid.Int64Range(0, int64(MaxDurationSeconds)).Draw(t, "seconds") 226 nanos := rapid.Int32Range(0, 999999999).Draw(t, "nanos") 227 setSecondsNanosFields(t, msg, seconds, nanos) 228 } 229 230 func setSecondsNanosFields(t *rapid.T, message protoreflect.Message, seconds int64, nanos int32) { 231 fields := message.Descriptor().Fields() 232 233 secondsField := fields.ByName(secondsName) 234 assert.Assert(t, secondsField != nil) 235 message.Set(secondsField, protoreflect.ValueOfInt64(seconds)) 236 237 nanosField := fields.ByName(nanosName) 238 assert.Assert(t, nanosField != nil) 239 message.Set(nanosField, protoreflect.ValueOfInt32(nanos)) 240 } 241 242 const ( 243 typeURLName = "type_url" 244 valueName = "value" 245 ) 246 247 func (opts GeneratorOptions) genAny( 248 t *rapid.T, field protoreflect.FieldDescriptor, msg protoreflect.Message, depth int) bool { 249 if len(opts.AnyTypeURLs) == 0 { 250 return false 251 } 252 253 var typeURL string 254 fopts := field.Options() 255 if proto.HasExtension(fopts, cosmos_proto.E_AcceptsInterface) { 256 ai := proto.GetExtension(fopts, cosmos_proto.E_AcceptsInterface).(string) 257 if impl, found := opts.InterfaceHints[ai]; found { 258 typeURL = fmt.Sprintf("/%s", impl) 259 } else { 260 panic(fmt.Sprintf("no implementation found for interface %s", ai)) 261 } 262 } else { 263 typeURL = rapid.SampledFrom(opts.AnyTypeURLs).Draw(t, "type_url") 264 } 265 266 typ, err := opts.Resolver.FindMessageByURL(typeURL) 267 assert.NilError(t, err) 268 fields := msg.Descriptor().Fields() 269 270 typeURLField := fields.ByName(typeURLName) 271 assert.Assert(t, typeURLField != nil) 272 msg.Set(typeURLField, protoreflect.ValueOfString(typeURL)) 273 274 valueMsg := typ.New() 275 opts.setFields(t, nil, valueMsg, depth+1) 276 valueBz, err := proto.Marshal(valueMsg.Interface()) 277 assert.NilError(t, err) 278 279 valueField := fields.ByName(valueName) 280 assert.Assert(t, valueField != nil) 281 msg.Set(valueField, protoreflect.ValueOfBytes(valueBz)) 282 283 return true 284 } 285 286 const ( 287 pathsName = "paths" 288 ) 289 290 func (opts GeneratorOptions) genFieldMask(t *rapid.T, msg protoreflect.Message) { 291 paths := rapid.SliceOfN(rapid.StringMatching("[a-z]+([.][a-z]+){0,2}"), 1, 5).Draw(t, "paths") 292 pathsField := msg.Descriptor().Fields().ByName(pathsName) 293 assert.Assert(t, pathsField != nil) 294 pathsList := msg.NewField(pathsField).List() 295 for _, path := range paths { 296 pathsList.Append(protoreflect.ValueOfString(path)) 297 } 298 }