github.com/cosmos/cosmos-sdk@v0.50.10/codec/unknownproto/unknown_fields.go (about) 1 package unknownproto 2 3 import ( 4 "bytes" 5 "compress/gzip" 6 "errors" 7 "fmt" 8 "io" 9 "reflect" 10 "strings" 11 "sync" 12 13 "github.com/cosmos/gogoproto/jsonpb" 14 "github.com/cosmos/gogoproto/proto" 15 "google.golang.org/protobuf/encoding/protowire" 16 protov2 "google.golang.org/protobuf/proto" 17 "google.golang.org/protobuf/types/descriptorpb" 18 19 "github.com/cosmos/cosmos-sdk/codec/types" 20 ) 21 22 const bit11NonCritical = 1 << 10 23 24 type descriptorIface interface { 25 Descriptor() ([]byte, []int) 26 } 27 28 // RejectUnknownFieldsStrict rejects any bytes bz with an error that has unknown fields for the provided proto.Message type. 29 // This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message. 30 // An AnyResolver must be provided for traversing inside google.protobuf.Any's. 31 func RejectUnknownFieldsStrict(bz []byte, msg proto.Message, resolver jsonpb.AnyResolver) error { 32 _, err := RejectUnknownFields(bz, msg, false, resolver) 33 return err 34 } 35 36 // RejectUnknownFields rejects any bytes bz with an error that has unknown fields for the provided proto.Message type with an 37 // option to allow non-critical fields (specified as those fields with bit 11) to pass through. In either case, the 38 // hasUnknownNonCriticals will be set to true if non-critical fields were encountered during traversal. This flag can be 39 // used to treat a message with non-critical field different in different security contexts (such as transaction signing). 40 // This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message. 41 // An AnyResolver must be provided for traversing inside google.protobuf.Any's. 42 func RejectUnknownFields(bz []byte, msg proto.Message, allowUnknownNonCriticals bool, resolver jsonpb.AnyResolver) (hasUnknownNonCriticals bool, err error) { 43 if len(bz) == 0 { 44 return hasUnknownNonCriticals, nil 45 } 46 47 desc, ok := msg.(descriptorIface) 48 if !ok { 49 return hasUnknownNonCriticals, fmt.Errorf("%T does not have a Descriptor() method", msg) 50 } 51 52 fieldDescProtoFromTagNum, _, err := getDescriptorInfo(desc, msg) 53 if err != nil { 54 return hasUnknownNonCriticals, err 55 } 56 57 for len(bz) > 0 { 58 tagNum, wireType, m := protowire.ConsumeTag(bz) 59 if m < 0 { 60 return hasUnknownNonCriticals, errors.New("invalid length") 61 } 62 63 fieldDescProto, ok := fieldDescProtoFromTagNum[int32(tagNum)] 64 switch { 65 case ok: 66 // Assert that the wireTypes match. 67 if !canEncodeType(wireType, fieldDescProto.GetType()) { 68 return hasUnknownNonCriticals, &errMismatchedWireType{ 69 Type: reflect.ValueOf(msg).Type().String(), 70 TagNum: tagNum, 71 GotWireType: wireType, 72 WantWireType: toProtowireType(fieldDescProto.GetType()), 73 } 74 } 75 76 default: 77 isCriticalField := tagNum&bit11NonCritical == 0 78 79 if !isCriticalField { 80 hasUnknownNonCriticals = true 81 } 82 83 if isCriticalField || !allowUnknownNonCriticals { 84 // The tag is critical, so report it. 85 return hasUnknownNonCriticals, &errUnknownField{ 86 Type: reflect.ValueOf(msg).Type().String(), 87 TagNum: tagNum, 88 WireType: wireType, 89 } 90 } 91 } 92 93 // Skip over the bytes that store fieldNumber and wireType bytes. 94 bz = bz[m:] 95 n := protowire.ConsumeFieldValue(tagNum, wireType, bz) 96 if n < 0 { 97 err = fmt.Errorf("could not consume field value for tagNum: %d, wireType: %q; %w", 98 tagNum, wireTypeToString(wireType), protowire.ParseError(n)) 99 return hasUnknownNonCriticals, err 100 } 101 fieldBytes := bz[:n] 102 bz = bz[n:] 103 104 // An unknown but non-critical field or just a scalar type (aka *INT and BYTES like). 105 if fieldDescProto == nil || isScalar(fieldDescProto) { 106 continue 107 } 108 109 protoMessageName := fieldDescProto.GetTypeName() 110 if protoMessageName == "" { 111 switch typ := fieldDescProto.GetType(); typ { 112 case descriptorpb.FieldDescriptorProto_TYPE_STRING, descriptorpb.FieldDescriptorProto_TYPE_BYTES: 113 // At this point only TYPE_STRING is expected to be unregistered, since FieldDescriptorProto.IsScalar() returns false for 114 // TYPE_BYTES and TYPE_STRING as per 115 // https://github.com/cosmos/gogoproto/blob/5628607bb4c51c3157aacc3a50f0ab707582b805/protoc-gen-gogo/descriptor/descriptor.go#L95-L118 116 default: 117 return hasUnknownNonCriticals, fmt.Errorf("failed to get typename for message of type %v, can only be TYPE_STRING or TYPE_BYTES", typ) 118 } 119 continue 120 } 121 122 // Let's recursively traverse and typecheck the field. 123 124 // consume length prefix of nested message 125 _, o := protowire.ConsumeVarint(fieldBytes) 126 fieldBytes = fieldBytes[o:] 127 128 var msg proto.Message 129 var err error 130 131 if protoMessageName == ".google.protobuf.Any" { 132 // Firstly typecheck types.Any to ensure nothing snuck in. 133 hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, (*types.Any)(nil), allowUnknownNonCriticals, resolver) 134 hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild 135 if err != nil { 136 return hasUnknownNonCriticals, err 137 } 138 // And finally we can extract the TypeURL containing the protoMessageName. 139 any := new(types.Any) 140 if err := proto.Unmarshal(fieldBytes, any); err != nil { 141 return hasUnknownNonCriticals, err 142 } 143 protoMessageName = any.TypeUrl 144 fieldBytes = any.Value 145 msg, err = resolver.Resolve(protoMessageName) 146 if err != nil { 147 return hasUnknownNonCriticals, err 148 } 149 } else { 150 msg, err = protoMessageForTypeName(protoMessageName[1:]) 151 if err != nil { 152 return hasUnknownNonCriticals, err 153 } 154 } 155 156 hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, msg, allowUnknownNonCriticals, resolver) 157 hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild 158 if err != nil { 159 return hasUnknownNonCriticals, err 160 } 161 } 162 163 return hasUnknownNonCriticals, nil 164 } 165 166 var ( 167 protoMessageForTypeNameMu sync.RWMutex 168 protoMessageForTypeNameCache = make(map[string]proto.Message) 169 ) 170 171 // protoMessageForTypeName takes in a fully qualified name e.g. testdata.TestVersionFD1 172 // and returns a corresponding empty protobuf message that serves the prototype for typechecking. 173 func protoMessageForTypeName(protoMessageName string) (proto.Message, error) { 174 protoMessageForTypeNameMu.RLock() 175 msg, ok := protoMessageForTypeNameCache[protoMessageName] 176 protoMessageForTypeNameMu.RUnlock() 177 if ok { 178 return msg, nil 179 } 180 181 concreteGoType := proto.MessageType(protoMessageName) 182 if concreteGoType == nil { 183 return nil, fmt.Errorf("failed to retrieve the message of type %q", protoMessageName) 184 } 185 186 value := reflect.New(concreteGoType).Elem() 187 msg, ok = value.Interface().(proto.Message) 188 if !ok { 189 return nil, fmt.Errorf("%q does not implement proto.Message", protoMessageName) 190 } 191 192 // Now cache it. 193 protoMessageForTypeNameMu.Lock() 194 protoMessageForTypeNameCache[protoMessageName] = msg 195 protoMessageForTypeNameMu.Unlock() 196 197 return msg, nil 198 } 199 200 // checks is a mapping of protowire.Type to supported descriptor.FieldDescriptorProto_Type. 201 // it is implemented this way so as to have constant time lookups and avoid the overhead 202 // from O(n) walking of switch. The change to using this mapping boosts throughput by about 200%. 203 var checks = [...]map[descriptorpb.FieldDescriptorProto_Type]bool{ 204 // "0 Varint: int32, int64, uint32, uint64, sint32, sint64, bool, enum" 205 0: { 206 descriptorpb.FieldDescriptorProto_TYPE_INT32: true, 207 descriptorpb.FieldDescriptorProto_TYPE_INT64: true, 208 descriptorpb.FieldDescriptorProto_TYPE_UINT32: true, 209 descriptorpb.FieldDescriptorProto_TYPE_UINT64: true, 210 descriptorpb.FieldDescriptorProto_TYPE_SINT32: true, 211 descriptorpb.FieldDescriptorProto_TYPE_SINT64: true, 212 descriptorpb.FieldDescriptorProto_TYPE_BOOL: true, 213 descriptorpb.FieldDescriptorProto_TYPE_ENUM: true, 214 }, 215 216 // "1 64-bit: fixed64, sfixed64, double" 217 1: { 218 descriptorpb.FieldDescriptorProto_TYPE_FIXED64: true, 219 descriptorpb.FieldDescriptorProto_TYPE_SFIXED64: true, 220 descriptorpb.FieldDescriptorProto_TYPE_DOUBLE: true, 221 }, 222 223 // "2 Length-delimited: string, bytes, embedded messages, packed repeated fields" 224 2: { 225 descriptorpb.FieldDescriptorProto_TYPE_STRING: true, 226 descriptorpb.FieldDescriptorProto_TYPE_BYTES: true, 227 descriptorpb.FieldDescriptorProto_TYPE_MESSAGE: true, 228 // The following types can be packed repeated. 229 // ref: "Only repeated fields of primitive numeric types (types which use the varint, 32-bit, or 64-bit wire types) can be declared "packed"." 230 // ref: https://developers.google.com/protocol-buffers/docs/encoding#packed 231 descriptorpb.FieldDescriptorProto_TYPE_INT32: true, 232 descriptorpb.FieldDescriptorProto_TYPE_INT64: true, 233 descriptorpb.FieldDescriptorProto_TYPE_UINT32: true, 234 descriptorpb.FieldDescriptorProto_TYPE_UINT64: true, 235 descriptorpb.FieldDescriptorProto_TYPE_SINT32: true, 236 descriptorpb.FieldDescriptorProto_TYPE_SINT64: true, 237 descriptorpb.FieldDescriptorProto_TYPE_BOOL: true, 238 descriptorpb.FieldDescriptorProto_TYPE_ENUM: true, 239 descriptorpb.FieldDescriptorProto_TYPE_FIXED64: true, 240 descriptorpb.FieldDescriptorProto_TYPE_SFIXED64: true, 241 descriptorpb.FieldDescriptorProto_TYPE_DOUBLE: true, 242 }, 243 244 // "3 Start group: groups (deprecated)" 245 3: { 246 descriptorpb.FieldDescriptorProto_TYPE_GROUP: true, 247 }, 248 249 // "4 End group: groups (deprecated)" 250 4: { 251 descriptorpb.FieldDescriptorProto_TYPE_GROUP: true, 252 }, 253 254 // "5 32-bit: fixed32, sfixed32, float" 255 5: { 256 descriptorpb.FieldDescriptorProto_TYPE_FIXED32: true, 257 descriptorpb.FieldDescriptorProto_TYPE_SFIXED32: true, 258 descriptorpb.FieldDescriptorProto_TYPE_FLOAT: true, 259 }, 260 } 261 262 // canEncodeType returns true if the wireType is suitable for encoding the descriptor type. 263 // See https://developers.google.com/protocol-buffers/docs/encoding#structure. 264 func canEncodeType(wireType protowire.Type, descType descriptorpb.FieldDescriptorProto_Type) bool { 265 if iwt := int(wireType); iwt < 0 || iwt >= len(checks) { 266 return false 267 } 268 return checks[wireType][descType] 269 } 270 271 // errMismatchedWireType describes a mismatch between 272 // expected and got wireTypes for a specific tag number. 273 type errMismatchedWireType struct { 274 Type string 275 GotWireType protowire.Type 276 WantWireType protowire.Type 277 TagNum protowire.Number 278 } 279 280 // String implements fmt.Stringer. 281 func (mwt *errMismatchedWireType) String() string { 282 return fmt.Sprintf("Mismatched %q: {TagNum: %d, GotWireType: %q != WantWireType: %q}", 283 mwt.Type, mwt.TagNum, wireTypeToString(mwt.GotWireType), wireTypeToString(mwt.WantWireType)) 284 } 285 286 // Error implements the error interface. 287 func (mwt *errMismatchedWireType) Error() string { 288 return mwt.String() 289 } 290 291 var _ error = (*errMismatchedWireType)(nil) 292 293 func wireTypeToString(wt protowire.Type) string { 294 switch wt { 295 case 0: 296 return "varint" 297 case 1: 298 return "fixed64" 299 case 2: 300 return "bytes" 301 case 3: 302 return "start_group" 303 case 4: 304 return "end_group" 305 case 5: 306 return "fixed32" 307 default: 308 return fmt.Sprintf("unknown type: %d", wt) 309 } 310 } 311 312 // errUnknownField represents an error indicating that we encountered 313 // a field that isn't available in the target proto.Message. 314 type errUnknownField struct { 315 Type string 316 TagNum protowire.Number 317 WireType protowire.Type 318 } 319 320 // String implements fmt.Stringer. 321 func (twt *errUnknownField) String() string { 322 return fmt.Sprintf("errUnknownField %q: {TagNum: %d, WireType:%q}", 323 twt.Type, twt.TagNum, wireTypeToString(twt.WireType)) 324 } 325 326 // Error implements the error interface. 327 func (twt *errUnknownField) Error() string { 328 return twt.String() 329 } 330 331 var _ error = (*errUnknownField)(nil) 332 333 var ( 334 protoFileToDesc = make(map[string]*descriptorpb.FileDescriptorProto) 335 protoFileToDescMu sync.RWMutex 336 ) 337 338 func unnestDesc(mdescs []*descriptorpb.DescriptorProto, indices []int) *descriptorpb.DescriptorProto { 339 mdesc := mdescs[indices[0]] 340 for _, index := range indices[1:] { 341 mdesc = mdesc.NestedType[index] 342 } 343 return mdesc 344 } 345 346 // Invoking descriptorpb.ForMessage(proto.Message.(Descriptor).Descriptor()) is incredibly slow 347 // for every single message, thus the need for a hand-rolled custom version that's performant and cacheable. 348 func extractFileDescMessageDesc(desc descriptorIface) (*descriptorpb.FileDescriptorProto, *descriptorpb.DescriptorProto, error) { 349 gzippedPb, indices := desc.Descriptor() 350 351 protoFileToDescMu.RLock() 352 cached, ok := protoFileToDesc[string(gzippedPb)] 353 protoFileToDescMu.RUnlock() 354 355 if ok { 356 return cached, unnestDesc(cached.MessageType, indices), nil 357 } 358 359 // Time to gunzip the content of the FileDescriptor and then proto unmarshal them. 360 gzr, err := gzip.NewReader(bytes.NewReader(gzippedPb)) 361 if err != nil { 362 return nil, nil, err 363 } 364 protoBlob, err := io.ReadAll(gzr) 365 if err != nil { 366 return nil, nil, err 367 } 368 369 fdesc := new(descriptorpb.FileDescriptorProto) 370 if err := protov2.Unmarshal(protoBlob, fdesc); err != nil { 371 return nil, nil, err 372 } 373 374 // Now cache the FileDescriptor. 375 protoFileToDescMu.Lock() 376 protoFileToDesc[string(gzippedPb)] = fdesc 377 protoFileToDescMu.Unlock() 378 379 // Unnest the type if necessary. 380 return fdesc, unnestDesc(fdesc.MessageType, indices), nil 381 } 382 383 type descriptorMatch struct { 384 cache map[int32]*descriptorpb.FieldDescriptorProto 385 desc *descriptorpb.DescriptorProto 386 } 387 388 var ( 389 descprotoCacheMu sync.RWMutex 390 descprotoCache = make(map[reflect.Type]*descriptorMatch) 391 ) 392 393 // getDescriptorInfo retrieves the mapping of field numbers to their respective field descriptors. 394 func getDescriptorInfo(desc descriptorIface, msg proto.Message) (map[int32]*descriptorpb.FieldDescriptorProto, *descriptorpb.DescriptorProto, error) { 395 key := reflect.ValueOf(msg).Type() 396 397 descprotoCacheMu.RLock() 398 got, ok := descprotoCache[key] 399 descprotoCacheMu.RUnlock() 400 401 if ok { 402 return got.cache, got.desc, nil 403 } 404 405 // Now compute and cache the index. 406 _, md, err := extractFileDescMessageDesc(desc) 407 if err != nil { 408 return nil, nil, err 409 } 410 411 tagNumToTypeIndex := make(map[int32]*descriptorpb.FieldDescriptorProto) 412 for _, field := range md.Field { 413 tagNumToTypeIndex[field.GetNumber()] = field 414 } 415 416 descprotoCacheMu.Lock() 417 descprotoCache[key] = &descriptorMatch{ 418 cache: tagNumToTypeIndex, 419 desc: md, 420 } 421 descprotoCacheMu.Unlock() 422 423 return tagNumToTypeIndex, md, nil 424 } 425 426 // DefaultAnyResolver is a default implementation of AnyResolver which uses 427 // the default encoding of type URLs as specified by the protobuf specification. 428 type DefaultAnyResolver struct{} 429 430 var _ jsonpb.AnyResolver = DefaultAnyResolver{} 431 432 // Resolve is the AnyResolver.Resolve method. 433 func (d DefaultAnyResolver) Resolve(typeURL string) (proto.Message, error) { 434 // Only the part of typeURL after the last slash is relevant. 435 mname := typeURL 436 if slash := strings.LastIndex(mname, "/"); slash >= 0 { 437 mname = mname[slash+1:] 438 } 439 mt := proto.MessageType(mname) 440 if mt == nil { 441 return nil, fmt.Errorf("unknown message type %q", mname) 442 } 443 return reflect.New(mt.Elem()).Interface().(proto.Message), nil 444 } 445 446 // toProtowireType converts a descriptorpb.FieldDescriptorProto_Type to a protowire.Type. 447 func toProtowireType(fieldType descriptorpb.FieldDescriptorProto_Type) protowire.Type { 448 switch fieldType { 449 // varint encoded 450 case descriptorpb.FieldDescriptorProto_TYPE_INT64, 451 descriptorpb.FieldDescriptorProto_TYPE_UINT64, 452 descriptorpb.FieldDescriptorProto_TYPE_INT32, 453 descriptorpb.FieldDescriptorProto_TYPE_UINT32, 454 descriptorpb.FieldDescriptorProto_TYPE_BOOL, 455 descriptorpb.FieldDescriptorProto_TYPE_ENUM, 456 descriptorpb.FieldDescriptorProto_TYPE_SINT32, 457 descriptorpb.FieldDescriptorProto_TYPE_SINT64: 458 return protowire.VarintType 459 460 // fixed64 encoded 461 case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE, 462 descriptorpb.FieldDescriptorProto_TYPE_FIXED64, 463 descriptorpb.FieldDescriptorProto_TYPE_SFIXED64: 464 return protowire.Fixed64Type 465 466 // fixed32 encoded 467 case descriptorpb.FieldDescriptorProto_TYPE_FLOAT, 468 descriptorpb.FieldDescriptorProto_TYPE_FIXED32, 469 descriptorpb.FieldDescriptorProto_TYPE_SFIXED32: 470 return protowire.Fixed32Type 471 472 // bytes encoded 473 case descriptorpb.FieldDescriptorProto_TYPE_STRING, 474 descriptorpb.FieldDescriptorProto_TYPE_BYTES, 475 descriptorpb.FieldDescriptorProto_TYPE_MESSAGE, 476 descriptorpb.FieldDescriptorProto_TYPE_GROUP: 477 return protowire.BytesType 478 default: 479 panic(fmt.Sprintf("unknown field type %s", fieldType)) 480 } 481 } 482 483 // isScalar defines whether a field is a scalar type. 484 // Copied from gogo/protobuf/protoc-gen-gogo 485 // https://github.com/gogo/protobuf/blob/b03c65ea87cdc3521ede29f62fe3ce239267c1bc/protoc-gen-gogo/descriptor/descriptor.go#L95 486 func isScalar(field *descriptorpb.FieldDescriptorProto) bool { 487 if field.Type == nil { 488 return false 489 } 490 switch *field.Type { 491 case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE, 492 descriptorpb.FieldDescriptorProto_TYPE_FLOAT, 493 descriptorpb.FieldDescriptorProto_TYPE_INT64, 494 descriptorpb.FieldDescriptorProto_TYPE_UINT64, 495 descriptorpb.FieldDescriptorProto_TYPE_INT32, 496 descriptorpb.FieldDescriptorProto_TYPE_FIXED64, 497 descriptorpb.FieldDescriptorProto_TYPE_FIXED32, 498 descriptorpb.FieldDescriptorProto_TYPE_BOOL, 499 descriptorpb.FieldDescriptorProto_TYPE_UINT32, 500 descriptorpb.FieldDescriptorProto_TYPE_ENUM, 501 descriptorpb.FieldDescriptorProto_TYPE_SFIXED32, 502 descriptorpb.FieldDescriptorProto_TYPE_SFIXED64, 503 descriptorpb.FieldDescriptorProto_TYPE_SINT32, 504 descriptorpb.FieldDescriptorProto_TYPE_SINT64: 505 return true 506 default: 507 return false 508 } 509 }