github.com/fibonacci-chain/fbc@v0.0.0-20231124064014-c7636198c1e9/libs/cosmos-sdk/codec/unknownproto/unknown_fields.go (about)

     1  package unknownproto
     2  
     3  import (
     4  	"bytes"
     5  	"compress/gzip"
     6  	"errors"
     7  	"fmt"
     8  	"io/ioutil"
     9  	"reflect"
    10  	"strings"
    11  	"sync"
    12  
    13  	"github.com/fibonacci-chain/fbc/libs/cosmos-sdk/codec/types"
    14  
    15  	"github.com/gogo/protobuf/jsonpb"
    16  	"github.com/gogo/protobuf/proto"
    17  	"github.com/gogo/protobuf/protoc-gen-gogo/descriptor"
    18  	"google.golang.org/protobuf/encoding/protowire"
    19  )
    20  
    21  const bit11NonCritical = 1 << 10
    22  
    23  type descriptorIface interface {
    24  	Descriptor() ([]byte, []int)
    25  }
    26  
    27  // RejectUnknownFieldsStrict rejects any bytes bz with an error that has unknown fields for the provided proto.Message type.
    28  // This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message.
    29  // An AnyResolver must be provided for traversing inside google.protobuf.Any's.
    30  func RejectUnknownFieldsStrict(bz []byte, msg proto.Message, resolver jsonpb.AnyResolver) error {
    31  	_, err := RejectUnknownFields(bz, msg, false, resolver)
    32  	return err
    33  }
    34  
    35  // RejectUnknownFields rejects any bytes bz with an error that has unknown fields for the provided proto.Message type with an
    36  // option to allow non-critical fields (specified as those fields with bit 11) to pass through. In either case, the
    37  // hasUnknownNonCriticals will be set to true if non-critical fields were encountered during traversal. This flag can be
    38  // used to treat a message with non-critical field different in different security contexts (such as transaction signing).
    39  // This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message.
    40  // An AnyResolver must be provided for traversing inside google.protobuf.Any's.
    41  func RejectUnknownFields(bz []byte, msg proto.Message, allowUnknownNonCriticals bool, resolver jsonpb.AnyResolver) (hasUnknownNonCriticals bool, err error) {
    42  	if len(bz) == 0 {
    43  		return hasUnknownNonCriticals, nil
    44  	}
    45  
    46  	desc, ok := msg.(descriptorIface)
    47  	if !ok {
    48  		return hasUnknownNonCriticals, fmt.Errorf("%T does not have a Descriptor() method", msg)
    49  	}
    50  
    51  	fieldDescProtoFromTagNum, _, err := getDescriptorInfo(desc, msg)
    52  	if err != nil {
    53  		return hasUnknownNonCriticals, err
    54  	}
    55  
    56  	for len(bz) > 0 {
    57  		tagNum, wireType, m := protowire.ConsumeTag(bz)
    58  		if m < 0 {
    59  			return hasUnknownNonCriticals, errors.New("invalid length")
    60  		}
    61  
    62  		fieldDescProto, ok := fieldDescProtoFromTagNum[int32(tagNum)]
    63  		switch {
    64  		case ok:
    65  			// Assert that the wireTypes match.
    66  			if !canEncodeType(wireType, fieldDescProto.GetType()) {
    67  				return hasUnknownNonCriticals, &errMismatchedWireType{
    68  					Type:         reflect.ValueOf(msg).Type().String(),
    69  					TagNum:       tagNum,
    70  					GotWireType:  wireType,
    71  					WantWireType: protowire.Type(fieldDescProto.WireType()),
    72  				}
    73  			}
    74  
    75  		default:
    76  			isCriticalField := tagNum&bit11NonCritical == 0
    77  
    78  			if !isCriticalField {
    79  				hasUnknownNonCriticals = true
    80  			}
    81  
    82  			if isCriticalField || !allowUnknownNonCriticals {
    83  				// The tag is critical, so report it.
    84  				return hasUnknownNonCriticals, &errUnknownField{
    85  					Type:     reflect.ValueOf(msg).Type().String(),
    86  					TagNum:   tagNum,
    87  					WireType: wireType,
    88  				}
    89  			}
    90  		}
    91  
    92  		// Skip over the bytes that store fieldNumber and wireType bytes.
    93  		bz = bz[m:]
    94  		n := protowire.ConsumeFieldValue(tagNum, wireType, bz)
    95  		if n < 0 {
    96  			err = fmt.Errorf("could not consume field value for tagNum: %d, wireType: %q; %w",
    97  				tagNum, wireTypeToString(wireType), protowire.ParseError(n))
    98  			return hasUnknownNonCriticals, err
    99  		}
   100  		fieldBytes := bz[:n]
   101  		bz = bz[n:]
   102  
   103  		// An unknown but non-critical field or just a scalar type (aka *INT and BYTES like).
   104  		if fieldDescProto == nil || fieldDescProto.IsScalar() {
   105  			continue
   106  		}
   107  
   108  		protoMessageName := fieldDescProto.GetTypeName()
   109  		if protoMessageName == "" {
   110  			switch typ := fieldDescProto.GetType(); typ {
   111  			case descriptor.FieldDescriptorProto_TYPE_STRING, descriptor.FieldDescriptorProto_TYPE_BYTES:
   112  				// At this point only TYPE_STRING is expected to be unregistered, since FieldDescriptorProto.IsScalar() returns false for
   113  				// TYPE_BYTES and TYPE_STRING as per
   114  				// https://github.com/gogo/protobuf/blob/5628607bb4c51c3157aacc3a50f0ab707582b805/protoc-gen-gogo/descriptor/descriptor.go#L95-L118
   115  			default:
   116  				return hasUnknownNonCriticals, fmt.Errorf("failed to get typename for message of type %v, can only be TYPE_STRING or TYPE_BYTES", typ)
   117  			}
   118  			continue
   119  		}
   120  
   121  		// Let's recursively traverse and typecheck the field.
   122  
   123  		// consume length prefix of nested message
   124  		_, o := protowire.ConsumeVarint(fieldBytes)
   125  		fieldBytes = fieldBytes[o:]
   126  
   127  		var msg proto.Message
   128  		var err error
   129  
   130  		if protoMessageName == ".google.protobuf.Any" {
   131  			// Firstly typecheck types.Any to ensure nothing snuck in.
   132  			hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, (*types.Any)(nil), allowUnknownNonCriticals, resolver)
   133  			hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild
   134  			if err != nil {
   135  				return hasUnknownNonCriticals, err
   136  			}
   137  			// And finally we can extract the TypeURL containing the protoMessageName.
   138  			any := new(types.Any)
   139  			if err := proto.Unmarshal(fieldBytes, any); err != nil {
   140  				return hasUnknownNonCriticals, err
   141  			}
   142  			protoMessageName = any.TypeUrl
   143  			fieldBytes = any.Value
   144  			msg, err = resolver.Resolve(protoMessageName)
   145  			if err != nil {
   146  				return hasUnknownNonCriticals, err
   147  			}
   148  		} else {
   149  			msg, err = protoMessageForTypeName(protoMessageName[1:])
   150  			if err != nil {
   151  				return hasUnknownNonCriticals, err
   152  			}
   153  		}
   154  
   155  		hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, msg, allowUnknownNonCriticals, resolver)
   156  		hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild
   157  		if err != nil {
   158  			return hasUnknownNonCriticals, err
   159  		}
   160  	}
   161  
   162  	return hasUnknownNonCriticals, nil
   163  }
   164  
   165  var protoMessageForTypeNameMu sync.RWMutex
   166  var protoMessageForTypeNameCache = make(map[string]proto.Message)
   167  
   168  // protoMessageForTypeName takes in a fully qualified name e.g. testdata.TestVersionFD1
   169  // and returns a corresponding empty protobuf message that serves the prototype for typechecking.
   170  func protoMessageForTypeName(protoMessageName string) (proto.Message, error) {
   171  	protoMessageForTypeNameMu.RLock()
   172  	msg, ok := protoMessageForTypeNameCache[protoMessageName]
   173  	protoMessageForTypeNameMu.RUnlock()
   174  	if ok {
   175  		return msg, nil
   176  	}
   177  
   178  	concreteGoType := proto.MessageType(protoMessageName)
   179  	if concreteGoType == nil {
   180  		return nil, fmt.Errorf("failed to retrieve the message of type %q", protoMessageName)
   181  	}
   182  
   183  	value := reflect.New(concreteGoType).Elem()
   184  	msg, ok = value.Interface().(proto.Message)
   185  	if !ok {
   186  		return nil, fmt.Errorf("%q does not implement proto.Message", protoMessageName)
   187  	}
   188  
   189  	// Now cache it.
   190  	protoMessageForTypeNameMu.Lock()
   191  	protoMessageForTypeNameCache[protoMessageName] = msg
   192  	protoMessageForTypeNameMu.Unlock()
   193  
   194  	return msg, nil
   195  }
   196  
   197  // checks is a mapping of protowire.Type to supported descriptor.FieldDescriptorProto_Type.
   198  // it is implemented this way so as to have constant time lookups and avoid the overhead
   199  // from O(n) walking of switch. The change to using this mapping boosts throughput by about 200%.
   200  var checks = [...]map[descriptor.FieldDescriptorProto_Type]bool{
   201  	// "0	Varint: int32, int64, uint32, uint64, sint32, sint64, bool, enum"
   202  	0: {
   203  		descriptor.FieldDescriptorProto_TYPE_INT32:  true,
   204  		descriptor.FieldDescriptorProto_TYPE_INT64:  true,
   205  		descriptor.FieldDescriptorProto_TYPE_UINT32: true,
   206  		descriptor.FieldDescriptorProto_TYPE_UINT64: true,
   207  		descriptor.FieldDescriptorProto_TYPE_SINT32: true,
   208  		descriptor.FieldDescriptorProto_TYPE_SINT64: true,
   209  		descriptor.FieldDescriptorProto_TYPE_BOOL:   true,
   210  		descriptor.FieldDescriptorProto_TYPE_ENUM:   true,
   211  	},
   212  
   213  	// "1	64-bit:	fixed64, sfixed64, double"
   214  	1: {
   215  		descriptor.FieldDescriptorProto_TYPE_FIXED64:  true,
   216  		descriptor.FieldDescriptorProto_TYPE_SFIXED64: true,
   217  		descriptor.FieldDescriptorProto_TYPE_DOUBLE:   true,
   218  	},
   219  
   220  	// "2	Length-delimited: string, bytes, embedded messages, packed repeated fields"
   221  	2: {
   222  		descriptor.FieldDescriptorProto_TYPE_STRING:  true,
   223  		descriptor.FieldDescriptorProto_TYPE_BYTES:   true,
   224  		descriptor.FieldDescriptorProto_TYPE_MESSAGE: true,
   225  		// The following types can be packed repeated.
   226  		// ref: "Only repeated fields of primitive numeric types (types which use the varint, 32-bit, or 64-bit wire types) can be declared "packed"."
   227  		// ref: https://developers.google.com/protocol-buffers/docs/encoding#packed
   228  		descriptor.FieldDescriptorProto_TYPE_INT32:    true,
   229  		descriptor.FieldDescriptorProto_TYPE_INT64:    true,
   230  		descriptor.FieldDescriptorProto_TYPE_UINT32:   true,
   231  		descriptor.FieldDescriptorProto_TYPE_UINT64:   true,
   232  		descriptor.FieldDescriptorProto_TYPE_SINT32:   true,
   233  		descriptor.FieldDescriptorProto_TYPE_SINT64:   true,
   234  		descriptor.FieldDescriptorProto_TYPE_BOOL:     true,
   235  		descriptor.FieldDescriptorProto_TYPE_ENUM:     true,
   236  		descriptor.FieldDescriptorProto_TYPE_FIXED64:  true,
   237  		descriptor.FieldDescriptorProto_TYPE_SFIXED64: true,
   238  		descriptor.FieldDescriptorProto_TYPE_DOUBLE:   true,
   239  	},
   240  
   241  	// "3	Start group:	groups (deprecated)"
   242  	3: {
   243  		descriptor.FieldDescriptorProto_TYPE_GROUP: true,
   244  	},
   245  
   246  	// "4	End group:	groups (deprecated)"
   247  	4: {
   248  		descriptor.FieldDescriptorProto_TYPE_GROUP: true,
   249  	},
   250  
   251  	// "5	32-bit:	fixed32, sfixed32, float"
   252  	5: {
   253  		descriptor.FieldDescriptorProto_TYPE_FIXED32:  true,
   254  		descriptor.FieldDescriptorProto_TYPE_SFIXED32: true,
   255  		descriptor.FieldDescriptorProto_TYPE_FLOAT:    true,
   256  	},
   257  }
   258  
   259  // canEncodeType returns true if the wireType is suitable for encoding the descriptor type.
   260  // See https://developers.google.com/protocol-buffers/docs/encoding#structure.
   261  func canEncodeType(wireType protowire.Type, descType descriptor.FieldDescriptorProto_Type) bool {
   262  	if iwt := int(wireType); iwt < 0 || iwt >= len(checks) {
   263  		return false
   264  	}
   265  	return checks[wireType][descType]
   266  }
   267  
   268  // errMismatchedWireType describes a mismatch between
   269  // expected and got wireTypes for a specific tag number.
   270  type errMismatchedWireType struct {
   271  	Type         string
   272  	GotWireType  protowire.Type
   273  	WantWireType protowire.Type
   274  	TagNum       protowire.Number
   275  }
   276  
   277  // String implements fmt.Stringer.
   278  func (mwt *errMismatchedWireType) String() string {
   279  	return fmt.Sprintf("Mismatched %q: {TagNum: %d, GotWireType: %q != WantWireType: %q}",
   280  		mwt.Type, mwt.TagNum, wireTypeToString(mwt.GotWireType), wireTypeToString(mwt.WantWireType))
   281  }
   282  
   283  // Error implements the error interface.
   284  func (mwt *errMismatchedWireType) Error() string {
   285  	return mwt.String()
   286  }
   287  
   288  var _ error = (*errMismatchedWireType)(nil)
   289  
   290  func wireTypeToString(wt protowire.Type) string {
   291  	switch wt {
   292  	case 0:
   293  		return "varint"
   294  	case 1:
   295  		return "fixed64"
   296  	case 2:
   297  		return "bytes"
   298  	case 3:
   299  		return "start_group"
   300  	case 4:
   301  		return "end_group"
   302  	case 5:
   303  		return "fixed32"
   304  	default:
   305  		return fmt.Sprintf("unknown type: %d", wt)
   306  	}
   307  }
   308  
   309  // errUnknownField represents an error indicating that we encountered
   310  // a field that isn't available in the target proto.Message.
   311  type errUnknownField struct {
   312  	Type     string
   313  	TagNum   protowire.Number
   314  	WireType protowire.Type
   315  }
   316  
   317  // String implements fmt.Stringer.
   318  func (twt *errUnknownField) String() string {
   319  	return fmt.Sprintf("errUnknownField %q: {TagNum: %d, WireType:%q}",
   320  		twt.Type, twt.TagNum, wireTypeToString(twt.WireType))
   321  }
   322  
   323  // Error implements the error interface.
   324  func (twt *errUnknownField) Error() string {
   325  	return twt.String()
   326  }
   327  
   328  var _ error = (*errUnknownField)(nil)
   329  
   330  var (
   331  	protoFileToDesc   = make(map[string]*descriptor.FileDescriptorProto)
   332  	protoFileToDescMu sync.RWMutex
   333  )
   334  
   335  func unnestDesc(mdescs []*descriptor.DescriptorProto, indices []int) *descriptor.DescriptorProto {
   336  	mdesc := mdescs[indices[0]]
   337  	for _, index := range indices[1:] {
   338  		mdesc = mdesc.NestedType[index]
   339  	}
   340  	return mdesc
   341  }
   342  
   343  // Invoking descriptor.ForMessage(proto.Message.(Descriptor).Descriptor()) is incredibly slow
   344  // for every single message, thus the need for a hand-rolled custom version that's performant and cacheable.
   345  func extractFileDescMessageDesc(desc descriptorIface) (*descriptor.FileDescriptorProto, *descriptor.DescriptorProto, error) {
   346  	gzippedPb, indices := desc.Descriptor()
   347  
   348  	protoFileToDescMu.RLock()
   349  	cached, ok := protoFileToDesc[string(gzippedPb)]
   350  	protoFileToDescMu.RUnlock()
   351  
   352  	if ok {
   353  		return cached, unnestDesc(cached.MessageType, indices), nil
   354  	}
   355  
   356  	// Time to gunzip the content of the FileDescriptor and then proto unmarshal them.
   357  	gzr, err := gzip.NewReader(bytes.NewReader(gzippedPb))
   358  	if err != nil {
   359  		return nil, nil, err
   360  	}
   361  	protoBlob, err := ioutil.ReadAll(gzr)
   362  	if err != nil {
   363  		return nil, nil, err
   364  	}
   365  
   366  	fdesc := new(descriptor.FileDescriptorProto)
   367  	if err := proto.Unmarshal(protoBlob, fdesc); err != nil {
   368  		return nil, nil, err
   369  	}
   370  
   371  	// Now cache the FileDescriptor.
   372  	protoFileToDescMu.Lock()
   373  	protoFileToDesc[string(gzippedPb)] = fdesc
   374  	protoFileToDescMu.Unlock()
   375  
   376  	// Unnest the type if necessary.
   377  	return fdesc, unnestDesc(fdesc.MessageType, indices), nil
   378  }
   379  
   380  type descriptorMatch struct {
   381  	cache map[int32]*descriptor.FieldDescriptorProto
   382  	desc  *descriptor.DescriptorProto
   383  }
   384  
   385  var descprotoCacheMu sync.RWMutex
   386  var descprotoCache = make(map[reflect.Type]*descriptorMatch)
   387  
   388  // getDescriptorInfo retrieves the mapping of field numbers to their respective field descriptors.
   389  func getDescriptorInfo(desc descriptorIface, msg proto.Message) (map[int32]*descriptor.FieldDescriptorProto, *descriptor.DescriptorProto, error) {
   390  	key := reflect.ValueOf(msg).Type()
   391  
   392  	descprotoCacheMu.RLock()
   393  	got, ok := descprotoCache[key]
   394  	descprotoCacheMu.RUnlock()
   395  
   396  	if ok {
   397  		return got.cache, got.desc, nil
   398  	}
   399  
   400  	// Now compute and cache the index.
   401  	_, md, err := extractFileDescMessageDesc(desc)
   402  	if err != nil {
   403  		return nil, nil, err
   404  	}
   405  
   406  	tagNumToTypeIndex := make(map[int32]*descriptor.FieldDescriptorProto)
   407  	for _, field := range md.Field {
   408  		tagNumToTypeIndex[field.GetNumber()] = field
   409  	}
   410  
   411  	descprotoCacheMu.Lock()
   412  	descprotoCache[key] = &descriptorMatch{
   413  		cache: tagNumToTypeIndex,
   414  		desc:  md,
   415  	}
   416  	descprotoCacheMu.Unlock()
   417  
   418  	return tagNumToTypeIndex, md, nil
   419  }
   420  
   421  // DefaultAnyResolver is a default implementation of AnyResolver which uses
   422  // the default encoding of type URLs as specified by the protobuf specification.
   423  type DefaultAnyResolver struct{}
   424  
   425  var _ jsonpb.AnyResolver = DefaultAnyResolver{}
   426  
   427  // Resolve is the AnyResolver.Resolve method.
   428  func (d DefaultAnyResolver) Resolve(typeURL string) (proto.Message, error) {
   429  	// Only the part of typeURL after the last slash is relevant.
   430  	mname := typeURL
   431  	if slash := strings.LastIndex(mname, "/"); slash >= 0 {
   432  		mname = mname[slash+1:]
   433  	}
   434  	mt := proto.MessageType(mname)
   435  	if mt == nil {
   436  		return nil, fmt.Errorf("unknown message type %q", mname)
   437  	}
   438  	return reflect.New(mt.Elem()).Interface().(proto.Message), nil
   439  }