github.com/Finschia/finschia-sdk@v0.48.1/codec/unknownproto/unknown_fields.go (about) 1 package unknownproto 2 3 import ( 4 "bytes" 5 "compress/gzip" 6 "errors" 7 "fmt" 8 "io" 9 "reflect" 10 "strings" 11 "sync" 12 13 "github.com/gogo/protobuf/jsonpb" 14 "github.com/gogo/protobuf/proto" 15 "github.com/gogo/protobuf/protoc-gen-gogo/descriptor" 16 "google.golang.org/protobuf/encoding/protowire" 17 18 "github.com/Finschia/finschia-sdk/codec/types" 19 ) 20 21 const bit11NonCritical = 1 << 10 22 23 type descriptorIface interface { 24 Descriptor() ([]byte, []int) 25 } 26 27 // RejectUnknownFieldsStrict rejects any bytes bz with an error that has unknown fields for the provided proto.Message type. 28 // This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message. 29 // An AnyResolver must be provided for traversing inside google.protobuf.Any's. 30 func RejectUnknownFieldsStrict(bz []byte, msg proto.Message, resolver jsonpb.AnyResolver) error { 31 _, err := RejectUnknownFields(bz, msg, false, resolver) 32 return err 33 } 34 35 // RejectUnknownFields rejects any bytes bz with an error that has unknown fields for the provided proto.Message type with an 36 // option to allow non-critical fields (specified as those fields with bit 11) to pass through. In either case, the 37 // hasUnknownNonCriticals will be set to true if non-critical fields were encountered during traversal. This flag can be 38 // used to treat a message with non-critical field different in different security contexts (such as transaction signing). 39 // This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message. 40 // An AnyResolver must be provided for traversing inside google.protobuf.Any's. 41 func RejectUnknownFields(bz []byte, msg proto.Message, allowUnknownNonCriticals bool, resolver jsonpb.AnyResolver) (hasUnknownNonCriticals bool, err error) { 42 if len(bz) == 0 { 43 return hasUnknownNonCriticals, nil 44 } 45 46 desc, ok := msg.(descriptorIface) 47 if !ok { 48 return hasUnknownNonCriticals, fmt.Errorf("%T does not have a Descriptor() method", msg) 49 } 50 51 fieldDescProtoFromTagNum, _, err := getDescriptorInfo(desc, msg) 52 if err != nil { 53 return hasUnknownNonCriticals, err 54 } 55 56 for len(bz) > 0 { 57 tagNum, wireType, m := protowire.ConsumeTag(bz) 58 if m < 0 { 59 return hasUnknownNonCriticals, errors.New("invalid length") 60 } 61 62 fieldDescProto, ok := fieldDescProtoFromTagNum[int32(tagNum)] 63 switch { 64 case ok: 65 // Assert that the wireTypes match. 66 if !canEncodeType(wireType, fieldDescProto.GetType()) { 67 return hasUnknownNonCriticals, &errMismatchedWireType{ 68 Type: reflect.ValueOf(msg).Type().String(), 69 TagNum: tagNum, 70 GotWireType: wireType, 71 WantWireType: protowire.Type(fieldDescProto.WireType()), 72 } 73 } 74 75 default: 76 isCriticalField := tagNum&bit11NonCritical == 0 77 78 if !isCriticalField { 79 hasUnknownNonCriticals = true 80 } 81 82 if isCriticalField || !allowUnknownNonCriticals { 83 // The tag is critical, so report it. 84 return hasUnknownNonCriticals, &errUnknownField{ 85 Type: reflect.ValueOf(msg).Type().String(), 86 TagNum: tagNum, 87 WireType: wireType, 88 } 89 } 90 } 91 92 // Skip over the bytes that store fieldNumber and wireType bytes. 93 bz = bz[m:] 94 n := protowire.ConsumeFieldValue(tagNum, wireType, bz) 95 if n < 0 { 96 err = fmt.Errorf("could not consume field value for tagNum: %d, wireType: %q; %w", 97 tagNum, wireTypeToString(wireType), protowire.ParseError(n)) 98 return hasUnknownNonCriticals, err 99 } 100 fieldBytes := bz[:n] 101 bz = bz[n:] 102 103 // An unknown but non-critical field or just a scalar type (aka *INT and BYTES like). 104 if fieldDescProto == nil || fieldDescProto.IsScalar() { 105 continue 106 } 107 108 protoMessageName := fieldDescProto.GetTypeName() 109 if protoMessageName == "" { 110 switch typ := fieldDescProto.GetType(); typ { 111 case descriptor.FieldDescriptorProto_TYPE_STRING, descriptor.FieldDescriptorProto_TYPE_BYTES: 112 // At this point only TYPE_STRING is expected to be unregistered, since FieldDescriptorProto.IsScalar() returns false for 113 // TYPE_BYTES and TYPE_STRING as per 114 // https://github.com/gogo/protobuf/blob/5628607bb4c51c3157aacc3a50f0ab707582b805/protoc-gen-gogo/descriptor/descriptor.go#L95-L118 115 default: 116 return hasUnknownNonCriticals, fmt.Errorf("failed to get typename for message of type %v, can only be TYPE_STRING or TYPE_BYTES", typ) 117 } 118 continue 119 } 120 121 // Let's recursively traverse and typecheck the field. 122 123 // consume length prefix of nested message 124 _, o := protowire.ConsumeVarint(fieldBytes) 125 fieldBytes = fieldBytes[o:] 126 127 var msg proto.Message 128 var err error 129 130 if protoMessageName == ".google.protobuf.Any" { 131 // Firstly typecheck types.Any to ensure nothing snuck in. 132 hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, (*types.Any)(nil), allowUnknownNonCriticals, resolver) 133 hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild 134 if err != nil { 135 return hasUnknownNonCriticals, err 136 } 137 // And finally we can extract the TypeURL containing the protoMessageName. 138 any := new(types.Any) 139 if err := proto.Unmarshal(fieldBytes, any); err != nil { 140 return hasUnknownNonCriticals, err 141 } 142 protoMessageName = any.TypeUrl 143 fieldBytes = any.Value 144 msg, err = resolver.Resolve(protoMessageName) 145 if err != nil { 146 return hasUnknownNonCriticals, err 147 } 148 } else { 149 msg, err = protoMessageForTypeName(protoMessageName[1:]) 150 if err != nil { 151 return hasUnknownNonCriticals, err 152 } 153 } 154 155 hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, msg, allowUnknownNonCriticals, resolver) 156 hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild 157 if err != nil { 158 return hasUnknownNonCriticals, err 159 } 160 } 161 162 return hasUnknownNonCriticals, nil 163 } 164 165 var ( 166 protoMessageForTypeNameMu sync.RWMutex 167 protoMessageForTypeNameCache = make(map[string]proto.Message) 168 ) 169 170 // protoMessageForTypeName takes in a fully qualified name e.g. testdata.TestVersionFD1 171 // and returns a corresponding empty protobuf message that serves the prototype for typechecking. 172 func protoMessageForTypeName(protoMessageName string) (proto.Message, error) { 173 protoMessageForTypeNameMu.RLock() 174 msg, ok := protoMessageForTypeNameCache[protoMessageName] 175 protoMessageForTypeNameMu.RUnlock() 176 if ok { 177 return msg, nil 178 } 179 180 concreteGoType := proto.MessageType(protoMessageName) 181 if concreteGoType == nil { 182 return nil, fmt.Errorf("failed to retrieve the message of type %q", protoMessageName) 183 } 184 185 value := reflect.New(concreteGoType).Elem() 186 msg, ok = value.Interface().(proto.Message) 187 if !ok { 188 return nil, fmt.Errorf("%q does not implement proto.Message", protoMessageName) 189 } 190 191 // Now cache it. 192 protoMessageForTypeNameMu.Lock() 193 protoMessageForTypeNameCache[protoMessageName] = msg 194 protoMessageForTypeNameMu.Unlock() 195 196 return msg, nil 197 } 198 199 // checks is a mapping of protowire.Type to supported descriptor.FieldDescriptorProto_Type. 200 // it is implemented this way so as to have constant time lookups and avoid the overhead 201 // from O(n) walking of switch. The change to using this mapping boosts throughput by about 200%. 202 var checks = [...]map[descriptor.FieldDescriptorProto_Type]bool{ 203 // "0 Varint: int32, int64, uint32, uint64, sint32, sint64, bool, enum" 204 0: { 205 descriptor.FieldDescriptorProto_TYPE_INT32: true, 206 descriptor.FieldDescriptorProto_TYPE_INT64: true, 207 descriptor.FieldDescriptorProto_TYPE_UINT32: true, 208 descriptor.FieldDescriptorProto_TYPE_UINT64: true, 209 descriptor.FieldDescriptorProto_TYPE_SINT32: true, 210 descriptor.FieldDescriptorProto_TYPE_SINT64: true, 211 descriptor.FieldDescriptorProto_TYPE_BOOL: true, 212 descriptor.FieldDescriptorProto_TYPE_ENUM: true, 213 }, 214 215 // "1 64-bit: fixed64, sfixed64, double" 216 1: { 217 descriptor.FieldDescriptorProto_TYPE_FIXED64: true, 218 descriptor.FieldDescriptorProto_TYPE_SFIXED64: true, 219 descriptor.FieldDescriptorProto_TYPE_DOUBLE: true, 220 }, 221 222 // "2 Length-delimited: string, bytes, embedded messages, packed repeated fields" 223 2: { 224 descriptor.FieldDescriptorProto_TYPE_STRING: true, 225 descriptor.FieldDescriptorProto_TYPE_BYTES: true, 226 descriptor.FieldDescriptorProto_TYPE_MESSAGE: true, 227 // The following types can be packed repeated. 228 // ref: "Only repeated fields of primitive numeric types (types which use the varint, 32-bit, or 64-bit wire types) can be declared "packed"." 229 // ref: https://developers.google.com/protocol-buffers/docs/encoding#packed 230 descriptor.FieldDescriptorProto_TYPE_INT32: true, 231 descriptor.FieldDescriptorProto_TYPE_INT64: true, 232 descriptor.FieldDescriptorProto_TYPE_UINT32: true, 233 descriptor.FieldDescriptorProto_TYPE_UINT64: true, 234 descriptor.FieldDescriptorProto_TYPE_SINT32: true, 235 descriptor.FieldDescriptorProto_TYPE_SINT64: true, 236 descriptor.FieldDescriptorProto_TYPE_BOOL: true, 237 descriptor.FieldDescriptorProto_TYPE_ENUM: true, 238 descriptor.FieldDescriptorProto_TYPE_FIXED64: true, 239 descriptor.FieldDescriptorProto_TYPE_SFIXED64: true, 240 descriptor.FieldDescriptorProto_TYPE_DOUBLE: true, 241 }, 242 243 // "3 Start group: groups (deprecated)" 244 3: { 245 descriptor.FieldDescriptorProto_TYPE_GROUP: true, 246 }, 247 248 // "4 End group: groups (deprecated)" 249 4: { 250 descriptor.FieldDescriptorProto_TYPE_GROUP: true, 251 }, 252 253 // "5 32-bit: fixed32, sfixed32, float" 254 5: { 255 descriptor.FieldDescriptorProto_TYPE_FIXED32: true, 256 descriptor.FieldDescriptorProto_TYPE_SFIXED32: true, 257 descriptor.FieldDescriptorProto_TYPE_FLOAT: true, 258 }, 259 } 260 261 // canEncodeType returns true if the wireType is suitable for encoding the descriptor type. 262 // See https://developers.google.com/protocol-buffers/docs/encoding#structure. 263 func canEncodeType(wireType protowire.Type, descType descriptor.FieldDescriptorProto_Type) bool { 264 if iwt := int(wireType); iwt < 0 || iwt >= len(checks) { 265 return false 266 } 267 return checks[wireType][descType] 268 } 269 270 // errMismatchedWireType describes a mismatch between 271 // expected and got wireTypes for a specific tag number. 272 type errMismatchedWireType struct { 273 Type string 274 GotWireType protowire.Type 275 WantWireType protowire.Type 276 TagNum protowire.Number 277 } 278 279 // String implements fmt.Stringer. 280 func (mwt *errMismatchedWireType) String() string { 281 return fmt.Sprintf("Mismatched %q: {TagNum: %d, GotWireType: %q != WantWireType: %q}", 282 mwt.Type, mwt.TagNum, wireTypeToString(mwt.GotWireType), wireTypeToString(mwt.WantWireType)) 283 } 284 285 // Error implements the error interface. 286 func (mwt *errMismatchedWireType) Error() string { 287 return mwt.String() 288 } 289 290 var _ error = (*errMismatchedWireType)(nil) 291 292 func wireTypeToString(wt protowire.Type) string { 293 switch wt { 294 case 0: 295 return "varint" 296 case 1: 297 return "fixed64" 298 case 2: 299 return "bytes" 300 case 3: 301 return "start_group" 302 case 4: 303 return "end_group" 304 case 5: 305 return "fixed32" 306 default: 307 return fmt.Sprintf("unknown type: %d", wt) 308 } 309 } 310 311 // errUnknownField represents an error indicating that we encountered 312 // a field that isn't available in the target proto.Message. 313 type errUnknownField struct { 314 Type string 315 TagNum protowire.Number 316 WireType protowire.Type 317 } 318 319 // String implements fmt.Stringer. 320 func (twt *errUnknownField) String() string { 321 return fmt.Sprintf("errUnknownField %q: {TagNum: %d, WireType:%q}", 322 twt.Type, twt.TagNum, wireTypeToString(twt.WireType)) 323 } 324 325 // Error implements the error interface. 326 func (twt *errUnknownField) Error() string { 327 return twt.String() 328 } 329 330 var _ error = (*errUnknownField)(nil) 331 332 var ( 333 protoFileToDesc = make(map[string]*descriptor.FileDescriptorProto) 334 protoFileToDescMu sync.RWMutex 335 ) 336 337 func unnestDesc(mdescs []*descriptor.DescriptorProto, indices []int) *descriptor.DescriptorProto { 338 mdesc := mdescs[indices[0]] 339 for _, index := range indices[1:] { 340 mdesc = mdesc.NestedType[index] 341 } 342 return mdesc 343 } 344 345 // Invoking descriptor.ForMessage(proto.Message.(Descriptor).Descriptor()) is incredibly slow 346 // for every single message, thus the need for a hand-rolled custom version that's performant and cacheable. 347 func extractFileDescMessageDesc(desc descriptorIface) (*descriptor.FileDescriptorProto, *descriptor.DescriptorProto, error) { 348 gzippedPb, indices := desc.Descriptor() 349 350 protoFileToDescMu.RLock() 351 cached, ok := protoFileToDesc[string(gzippedPb)] 352 protoFileToDescMu.RUnlock() 353 354 if ok { 355 return cached, unnestDesc(cached.MessageType, indices), nil 356 } 357 358 // Time to gunzip the content of the FileDescriptor and then proto unmarshal them. 359 gzr, err := gzip.NewReader(bytes.NewReader(gzippedPb)) 360 if err != nil { 361 return nil, nil, err 362 } 363 protoBlob, err := io.ReadAll(gzr) 364 if err != nil { 365 return nil, nil, err 366 } 367 368 fdesc := new(descriptor.FileDescriptorProto) 369 if err := proto.Unmarshal(protoBlob, fdesc); err != nil { 370 return nil, nil, err 371 } 372 373 // Now cache the FileDescriptor. 374 protoFileToDescMu.Lock() 375 protoFileToDesc[string(gzippedPb)] = fdesc 376 protoFileToDescMu.Unlock() 377 378 // Unnest the type if necessary. 379 return fdesc, unnestDesc(fdesc.MessageType, indices), nil 380 } 381 382 type descriptorMatch struct { 383 cache map[int32]*descriptor.FieldDescriptorProto 384 desc *descriptor.DescriptorProto 385 } 386 387 var ( 388 descprotoCacheMu sync.RWMutex 389 descprotoCache = make(map[reflect.Type]*descriptorMatch) 390 ) 391 392 // getDescriptorInfo retrieves the mapping of field numbers to their respective field descriptors. 393 func getDescriptorInfo(desc descriptorIface, msg proto.Message) (map[int32]*descriptor.FieldDescriptorProto, *descriptor.DescriptorProto, error) { 394 key := reflect.ValueOf(msg).Type() 395 396 descprotoCacheMu.RLock() 397 got, ok := descprotoCache[key] 398 descprotoCacheMu.RUnlock() 399 400 if ok { 401 return got.cache, got.desc, nil 402 } 403 404 // Now compute and cache the index. 405 _, md, err := extractFileDescMessageDesc(desc) 406 if err != nil { 407 return nil, nil, err 408 } 409 410 tagNumToTypeIndex := make(map[int32]*descriptor.FieldDescriptorProto) 411 for _, field := range md.Field { 412 tagNumToTypeIndex[field.GetNumber()] = field 413 } 414 415 descprotoCacheMu.Lock() 416 descprotoCache[key] = &descriptorMatch{ 417 cache: tagNumToTypeIndex, 418 desc: md, 419 } 420 descprotoCacheMu.Unlock() 421 422 return tagNumToTypeIndex, md, nil 423 } 424 425 // DefaultAnyResolver is a default implementation of AnyResolver which uses 426 // the default encoding of type URLs as specified by the protobuf specification. 427 type DefaultAnyResolver struct{} 428 429 var _ jsonpb.AnyResolver = DefaultAnyResolver{} 430 431 // Resolve is the AnyResolver.Resolve method. 432 func (d DefaultAnyResolver) Resolve(typeURL string) (proto.Message, error) { 433 // Only the part of typeURL after the last slash is relevant. 434 mname := typeURL 435 if slash := strings.LastIndex(mname, "/"); slash >= 0 { 436 mname = mname[slash+1:] 437 } 438 mt := proto.MessageType(mname) 439 if mt == nil { 440 return nil, fmt.Errorf("unknown message type %q", mname) 441 } 442 return reflect.New(mt.Elem()).Interface().(proto.Message), nil 443 }