github.com/fibonacci-chain/fbc@v0.0.0-20231124064014-c7636198c1e9/libs/cosmos-sdk/codec/unknownproto/unknown_fields.go (about) 1 package unknownproto 2 3 import ( 4 "bytes" 5 "compress/gzip" 6 "errors" 7 "fmt" 8 "io/ioutil" 9 "reflect" 10 "strings" 11 "sync" 12 13 "github.com/fibonacci-chain/fbc/libs/cosmos-sdk/codec/types" 14 15 "github.com/gogo/protobuf/jsonpb" 16 "github.com/gogo/protobuf/proto" 17 "github.com/gogo/protobuf/protoc-gen-gogo/descriptor" 18 "google.golang.org/protobuf/encoding/protowire" 19 ) 20 21 const bit11NonCritical = 1 << 10 22 23 type descriptorIface interface { 24 Descriptor() ([]byte, []int) 25 } 26 27 // RejectUnknownFieldsStrict rejects any bytes bz with an error that has unknown fields for the provided proto.Message type. 28 // This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message. 29 // An AnyResolver must be provided for traversing inside google.protobuf.Any's. 30 func RejectUnknownFieldsStrict(bz []byte, msg proto.Message, resolver jsonpb.AnyResolver) error { 31 _, err := RejectUnknownFields(bz, msg, false, resolver) 32 return err 33 } 34 35 // RejectUnknownFields rejects any bytes bz with an error that has unknown fields for the provided proto.Message type with an 36 // option to allow non-critical fields (specified as those fields with bit 11) to pass through. In either case, the 37 // hasUnknownNonCriticals will be set to true if non-critical fields were encountered during traversal. This flag can be 38 // used to treat a message with non-critical field different in different security contexts (such as transaction signing). 39 // This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message. 40 // An AnyResolver must be provided for traversing inside google.protobuf.Any's. 41 func RejectUnknownFields(bz []byte, msg proto.Message, allowUnknownNonCriticals bool, resolver jsonpb.AnyResolver) (hasUnknownNonCriticals bool, err error) { 42 if len(bz) == 0 { 43 return hasUnknownNonCriticals, nil 44 } 45 46 desc, ok := msg.(descriptorIface) 47 if !ok { 48 return hasUnknownNonCriticals, fmt.Errorf("%T does not have a Descriptor() method", msg) 49 } 50 51 fieldDescProtoFromTagNum, _, err := getDescriptorInfo(desc, msg) 52 if err != nil { 53 return hasUnknownNonCriticals, err 54 } 55 56 for len(bz) > 0 { 57 tagNum, wireType, m := protowire.ConsumeTag(bz) 58 if m < 0 { 59 return hasUnknownNonCriticals, errors.New("invalid length") 60 } 61 62 fieldDescProto, ok := fieldDescProtoFromTagNum[int32(tagNum)] 63 switch { 64 case ok: 65 // Assert that the wireTypes match. 66 if !canEncodeType(wireType, fieldDescProto.GetType()) { 67 return hasUnknownNonCriticals, &errMismatchedWireType{ 68 Type: reflect.ValueOf(msg).Type().String(), 69 TagNum: tagNum, 70 GotWireType: wireType, 71 WantWireType: protowire.Type(fieldDescProto.WireType()), 72 } 73 } 74 75 default: 76 isCriticalField := tagNum&bit11NonCritical == 0 77 78 if !isCriticalField { 79 hasUnknownNonCriticals = true 80 } 81 82 if isCriticalField || !allowUnknownNonCriticals { 83 // The tag is critical, so report it. 84 return hasUnknownNonCriticals, &errUnknownField{ 85 Type: reflect.ValueOf(msg).Type().String(), 86 TagNum: tagNum, 87 WireType: wireType, 88 } 89 } 90 } 91 92 // Skip over the bytes that store fieldNumber and wireType bytes. 93 bz = bz[m:] 94 n := protowire.ConsumeFieldValue(tagNum, wireType, bz) 95 if n < 0 { 96 err = fmt.Errorf("could not consume field value for tagNum: %d, wireType: %q; %w", 97 tagNum, wireTypeToString(wireType), protowire.ParseError(n)) 98 return hasUnknownNonCriticals, err 99 } 100 fieldBytes := bz[:n] 101 bz = bz[n:] 102 103 // An unknown but non-critical field or just a scalar type (aka *INT and BYTES like). 104 if fieldDescProto == nil || fieldDescProto.IsScalar() { 105 continue 106 } 107 108 protoMessageName := fieldDescProto.GetTypeName() 109 if protoMessageName == "" { 110 switch typ := fieldDescProto.GetType(); typ { 111 case descriptor.FieldDescriptorProto_TYPE_STRING, descriptor.FieldDescriptorProto_TYPE_BYTES: 112 // At this point only TYPE_STRING is expected to be unregistered, since FieldDescriptorProto.IsScalar() returns false for 113 // TYPE_BYTES and TYPE_STRING as per 114 // https://github.com/gogo/protobuf/blob/5628607bb4c51c3157aacc3a50f0ab707582b805/protoc-gen-gogo/descriptor/descriptor.go#L95-L118 115 default: 116 return hasUnknownNonCriticals, fmt.Errorf("failed to get typename for message of type %v, can only be TYPE_STRING or TYPE_BYTES", typ) 117 } 118 continue 119 } 120 121 // Let's recursively traverse and typecheck the field. 122 123 // consume length prefix of nested message 124 _, o := protowire.ConsumeVarint(fieldBytes) 125 fieldBytes = fieldBytes[o:] 126 127 var msg proto.Message 128 var err error 129 130 if protoMessageName == ".google.protobuf.Any" { 131 // Firstly typecheck types.Any to ensure nothing snuck in. 132 hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, (*types.Any)(nil), allowUnknownNonCriticals, resolver) 133 hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild 134 if err != nil { 135 return hasUnknownNonCriticals, err 136 } 137 // And finally we can extract the TypeURL containing the protoMessageName. 138 any := new(types.Any) 139 if err := proto.Unmarshal(fieldBytes, any); err != nil { 140 return hasUnknownNonCriticals, err 141 } 142 protoMessageName = any.TypeUrl 143 fieldBytes = any.Value 144 msg, err = resolver.Resolve(protoMessageName) 145 if err != nil { 146 return hasUnknownNonCriticals, err 147 } 148 } else { 149 msg, err = protoMessageForTypeName(protoMessageName[1:]) 150 if err != nil { 151 return hasUnknownNonCriticals, err 152 } 153 } 154 155 hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, msg, allowUnknownNonCriticals, resolver) 156 hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild 157 if err != nil { 158 return hasUnknownNonCriticals, err 159 } 160 } 161 162 return hasUnknownNonCriticals, nil 163 } 164 165 var protoMessageForTypeNameMu sync.RWMutex 166 var protoMessageForTypeNameCache = make(map[string]proto.Message) 167 168 // protoMessageForTypeName takes in a fully qualified name e.g. testdata.TestVersionFD1 169 // and returns a corresponding empty protobuf message that serves the prototype for typechecking. 170 func protoMessageForTypeName(protoMessageName string) (proto.Message, error) { 171 protoMessageForTypeNameMu.RLock() 172 msg, ok := protoMessageForTypeNameCache[protoMessageName] 173 protoMessageForTypeNameMu.RUnlock() 174 if ok { 175 return msg, nil 176 } 177 178 concreteGoType := proto.MessageType(protoMessageName) 179 if concreteGoType == nil { 180 return nil, fmt.Errorf("failed to retrieve the message of type %q", protoMessageName) 181 } 182 183 value := reflect.New(concreteGoType).Elem() 184 msg, ok = value.Interface().(proto.Message) 185 if !ok { 186 return nil, fmt.Errorf("%q does not implement proto.Message", protoMessageName) 187 } 188 189 // Now cache it. 190 protoMessageForTypeNameMu.Lock() 191 protoMessageForTypeNameCache[protoMessageName] = msg 192 protoMessageForTypeNameMu.Unlock() 193 194 return msg, nil 195 } 196 197 // checks is a mapping of protowire.Type to supported descriptor.FieldDescriptorProto_Type. 198 // it is implemented this way so as to have constant time lookups and avoid the overhead 199 // from O(n) walking of switch. The change to using this mapping boosts throughput by about 200%. 200 var checks = [...]map[descriptor.FieldDescriptorProto_Type]bool{ 201 // "0 Varint: int32, int64, uint32, uint64, sint32, sint64, bool, enum" 202 0: { 203 descriptor.FieldDescriptorProto_TYPE_INT32: true, 204 descriptor.FieldDescriptorProto_TYPE_INT64: true, 205 descriptor.FieldDescriptorProto_TYPE_UINT32: true, 206 descriptor.FieldDescriptorProto_TYPE_UINT64: true, 207 descriptor.FieldDescriptorProto_TYPE_SINT32: true, 208 descriptor.FieldDescriptorProto_TYPE_SINT64: true, 209 descriptor.FieldDescriptorProto_TYPE_BOOL: true, 210 descriptor.FieldDescriptorProto_TYPE_ENUM: true, 211 }, 212 213 // "1 64-bit: fixed64, sfixed64, double" 214 1: { 215 descriptor.FieldDescriptorProto_TYPE_FIXED64: true, 216 descriptor.FieldDescriptorProto_TYPE_SFIXED64: true, 217 descriptor.FieldDescriptorProto_TYPE_DOUBLE: true, 218 }, 219 220 // "2 Length-delimited: string, bytes, embedded messages, packed repeated fields" 221 2: { 222 descriptor.FieldDescriptorProto_TYPE_STRING: true, 223 descriptor.FieldDescriptorProto_TYPE_BYTES: true, 224 descriptor.FieldDescriptorProto_TYPE_MESSAGE: true, 225 // The following types can be packed repeated. 226 // ref: "Only repeated fields of primitive numeric types (types which use the varint, 32-bit, or 64-bit wire types) can be declared "packed"." 227 // ref: https://developers.google.com/protocol-buffers/docs/encoding#packed 228 descriptor.FieldDescriptorProto_TYPE_INT32: true, 229 descriptor.FieldDescriptorProto_TYPE_INT64: true, 230 descriptor.FieldDescriptorProto_TYPE_UINT32: true, 231 descriptor.FieldDescriptorProto_TYPE_UINT64: true, 232 descriptor.FieldDescriptorProto_TYPE_SINT32: true, 233 descriptor.FieldDescriptorProto_TYPE_SINT64: true, 234 descriptor.FieldDescriptorProto_TYPE_BOOL: true, 235 descriptor.FieldDescriptorProto_TYPE_ENUM: true, 236 descriptor.FieldDescriptorProto_TYPE_FIXED64: true, 237 descriptor.FieldDescriptorProto_TYPE_SFIXED64: true, 238 descriptor.FieldDescriptorProto_TYPE_DOUBLE: true, 239 }, 240 241 // "3 Start group: groups (deprecated)" 242 3: { 243 descriptor.FieldDescriptorProto_TYPE_GROUP: true, 244 }, 245 246 // "4 End group: groups (deprecated)" 247 4: { 248 descriptor.FieldDescriptorProto_TYPE_GROUP: true, 249 }, 250 251 // "5 32-bit: fixed32, sfixed32, float" 252 5: { 253 descriptor.FieldDescriptorProto_TYPE_FIXED32: true, 254 descriptor.FieldDescriptorProto_TYPE_SFIXED32: true, 255 descriptor.FieldDescriptorProto_TYPE_FLOAT: true, 256 }, 257 } 258 259 // canEncodeType returns true if the wireType is suitable for encoding the descriptor type. 260 // See https://developers.google.com/protocol-buffers/docs/encoding#structure. 261 func canEncodeType(wireType protowire.Type, descType descriptor.FieldDescriptorProto_Type) bool { 262 if iwt := int(wireType); iwt < 0 || iwt >= len(checks) { 263 return false 264 } 265 return checks[wireType][descType] 266 } 267 268 // errMismatchedWireType describes a mismatch between 269 // expected and got wireTypes for a specific tag number. 270 type errMismatchedWireType struct { 271 Type string 272 GotWireType protowire.Type 273 WantWireType protowire.Type 274 TagNum protowire.Number 275 } 276 277 // String implements fmt.Stringer. 278 func (mwt *errMismatchedWireType) String() string { 279 return fmt.Sprintf("Mismatched %q: {TagNum: %d, GotWireType: %q != WantWireType: %q}", 280 mwt.Type, mwt.TagNum, wireTypeToString(mwt.GotWireType), wireTypeToString(mwt.WantWireType)) 281 } 282 283 // Error implements the error interface. 284 func (mwt *errMismatchedWireType) Error() string { 285 return mwt.String() 286 } 287 288 var _ error = (*errMismatchedWireType)(nil) 289 290 func wireTypeToString(wt protowire.Type) string { 291 switch wt { 292 case 0: 293 return "varint" 294 case 1: 295 return "fixed64" 296 case 2: 297 return "bytes" 298 case 3: 299 return "start_group" 300 case 4: 301 return "end_group" 302 case 5: 303 return "fixed32" 304 default: 305 return fmt.Sprintf("unknown type: %d", wt) 306 } 307 } 308 309 // errUnknownField represents an error indicating that we encountered 310 // a field that isn't available in the target proto.Message. 311 type errUnknownField struct { 312 Type string 313 TagNum protowire.Number 314 WireType protowire.Type 315 } 316 317 // String implements fmt.Stringer. 318 func (twt *errUnknownField) String() string { 319 return fmt.Sprintf("errUnknownField %q: {TagNum: %d, WireType:%q}", 320 twt.Type, twt.TagNum, wireTypeToString(twt.WireType)) 321 } 322 323 // Error implements the error interface. 324 func (twt *errUnknownField) Error() string { 325 return twt.String() 326 } 327 328 var _ error = (*errUnknownField)(nil) 329 330 var ( 331 protoFileToDesc = make(map[string]*descriptor.FileDescriptorProto) 332 protoFileToDescMu sync.RWMutex 333 ) 334 335 func unnestDesc(mdescs []*descriptor.DescriptorProto, indices []int) *descriptor.DescriptorProto { 336 mdesc := mdescs[indices[0]] 337 for _, index := range indices[1:] { 338 mdesc = mdesc.NestedType[index] 339 } 340 return mdesc 341 } 342 343 // Invoking descriptor.ForMessage(proto.Message.(Descriptor).Descriptor()) is incredibly slow 344 // for every single message, thus the need for a hand-rolled custom version that's performant and cacheable. 345 func extractFileDescMessageDesc(desc descriptorIface) (*descriptor.FileDescriptorProto, *descriptor.DescriptorProto, error) { 346 gzippedPb, indices := desc.Descriptor() 347 348 protoFileToDescMu.RLock() 349 cached, ok := protoFileToDesc[string(gzippedPb)] 350 protoFileToDescMu.RUnlock() 351 352 if ok { 353 return cached, unnestDesc(cached.MessageType, indices), nil 354 } 355 356 // Time to gunzip the content of the FileDescriptor and then proto unmarshal them. 357 gzr, err := gzip.NewReader(bytes.NewReader(gzippedPb)) 358 if err != nil { 359 return nil, nil, err 360 } 361 protoBlob, err := ioutil.ReadAll(gzr) 362 if err != nil { 363 return nil, nil, err 364 } 365 366 fdesc := new(descriptor.FileDescriptorProto) 367 if err := proto.Unmarshal(protoBlob, fdesc); err != nil { 368 return nil, nil, err 369 } 370 371 // Now cache the FileDescriptor. 372 protoFileToDescMu.Lock() 373 protoFileToDesc[string(gzippedPb)] = fdesc 374 protoFileToDescMu.Unlock() 375 376 // Unnest the type if necessary. 377 return fdesc, unnestDesc(fdesc.MessageType, indices), nil 378 } 379 380 type descriptorMatch struct { 381 cache map[int32]*descriptor.FieldDescriptorProto 382 desc *descriptor.DescriptorProto 383 } 384 385 var descprotoCacheMu sync.RWMutex 386 var descprotoCache = make(map[reflect.Type]*descriptorMatch) 387 388 // getDescriptorInfo retrieves the mapping of field numbers to their respective field descriptors. 389 func getDescriptorInfo(desc descriptorIface, msg proto.Message) (map[int32]*descriptor.FieldDescriptorProto, *descriptor.DescriptorProto, error) { 390 key := reflect.ValueOf(msg).Type() 391 392 descprotoCacheMu.RLock() 393 got, ok := descprotoCache[key] 394 descprotoCacheMu.RUnlock() 395 396 if ok { 397 return got.cache, got.desc, nil 398 } 399 400 // Now compute and cache the index. 401 _, md, err := extractFileDescMessageDesc(desc) 402 if err != nil { 403 return nil, nil, err 404 } 405 406 tagNumToTypeIndex := make(map[int32]*descriptor.FieldDescriptorProto) 407 for _, field := range md.Field { 408 tagNumToTypeIndex[field.GetNumber()] = field 409 } 410 411 descprotoCacheMu.Lock() 412 descprotoCache[key] = &descriptorMatch{ 413 cache: tagNumToTypeIndex, 414 desc: md, 415 } 416 descprotoCacheMu.Unlock() 417 418 return tagNumToTypeIndex, md, nil 419 } 420 421 // DefaultAnyResolver is a default implementation of AnyResolver which uses 422 // the default encoding of type URLs as specified by the protobuf specification. 423 type DefaultAnyResolver struct{} 424 425 var _ jsonpb.AnyResolver = DefaultAnyResolver{} 426 427 // Resolve is the AnyResolver.Resolve method. 428 func (d DefaultAnyResolver) Resolve(typeURL string) (proto.Message, error) { 429 // Only the part of typeURL after the last slash is relevant. 430 mname := typeURL 431 if slash := strings.LastIndex(mname, "/"); slash >= 0 { 432 mname = mname[slash+1:] 433 } 434 mt := proto.MessageType(mname) 435 if mt == nil { 436 return nil, fmt.Errorf("unknown message type %q", mname) 437 } 438 return reflect.New(mt.Elem()).Interface().(proto.Message), nil 439 }