github.com/cosmos/cosmos-sdk@v0.50.10/codec/unknownproto/unknown_fields.go (about)

     1  package unknownproto
     2  
     3  import (
     4  	"bytes"
     5  	"compress/gzip"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"reflect"
    10  	"strings"
    11  	"sync"
    12  
    13  	"github.com/cosmos/gogoproto/jsonpb"
    14  	"github.com/cosmos/gogoproto/proto"
    15  	"google.golang.org/protobuf/encoding/protowire"
    16  	protov2 "google.golang.org/protobuf/proto"
    17  	"google.golang.org/protobuf/types/descriptorpb"
    18  
    19  	"github.com/cosmos/cosmos-sdk/codec/types"
    20  )
    21  
    22  const bit11NonCritical = 1 << 10
    23  
    24  type descriptorIface interface {
    25  	Descriptor() ([]byte, []int)
    26  }
    27  
    28  // RejectUnknownFieldsStrict rejects any bytes bz with an error that has unknown fields for the provided proto.Message type.
    29  // This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message.
    30  // An AnyResolver must be provided for traversing inside google.protobuf.Any's.
    31  func RejectUnknownFieldsStrict(bz []byte, msg proto.Message, resolver jsonpb.AnyResolver) error {
    32  	_, err := RejectUnknownFields(bz, msg, false, resolver)
    33  	return err
    34  }
    35  
    36  // RejectUnknownFields rejects any bytes bz with an error that has unknown fields for the provided proto.Message type with an
    37  // option to allow non-critical fields (specified as those fields with bit 11) to pass through. In either case, the
    38  // hasUnknownNonCriticals will be set to true if non-critical fields were encountered during traversal. This flag can be
    39  // used to treat a message with non-critical field different in different security contexts (such as transaction signing).
    40  // This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message.
    41  // An AnyResolver must be provided for traversing inside google.protobuf.Any's.
    42  func RejectUnknownFields(bz []byte, msg proto.Message, allowUnknownNonCriticals bool, resolver jsonpb.AnyResolver) (hasUnknownNonCriticals bool, err error) {
    43  	if len(bz) == 0 {
    44  		return hasUnknownNonCriticals, nil
    45  	}
    46  
    47  	desc, ok := msg.(descriptorIface)
    48  	if !ok {
    49  		return hasUnknownNonCriticals, fmt.Errorf("%T does not have a Descriptor() method", msg)
    50  	}
    51  
    52  	fieldDescProtoFromTagNum, _, err := getDescriptorInfo(desc, msg)
    53  	if err != nil {
    54  		return hasUnknownNonCriticals, err
    55  	}
    56  
    57  	for len(bz) > 0 {
    58  		tagNum, wireType, m := protowire.ConsumeTag(bz)
    59  		if m < 0 {
    60  			return hasUnknownNonCriticals, errors.New("invalid length")
    61  		}
    62  
    63  		fieldDescProto, ok := fieldDescProtoFromTagNum[int32(tagNum)]
    64  		switch {
    65  		case ok:
    66  			// Assert that the wireTypes match.
    67  			if !canEncodeType(wireType, fieldDescProto.GetType()) {
    68  				return hasUnknownNonCriticals, &errMismatchedWireType{
    69  					Type:         reflect.ValueOf(msg).Type().String(),
    70  					TagNum:       tagNum,
    71  					GotWireType:  wireType,
    72  					WantWireType: toProtowireType(fieldDescProto.GetType()),
    73  				}
    74  			}
    75  
    76  		default:
    77  			isCriticalField := tagNum&bit11NonCritical == 0
    78  
    79  			if !isCriticalField {
    80  				hasUnknownNonCriticals = true
    81  			}
    82  
    83  			if isCriticalField || !allowUnknownNonCriticals {
    84  				// The tag is critical, so report it.
    85  				return hasUnknownNonCriticals, &errUnknownField{
    86  					Type:     reflect.ValueOf(msg).Type().String(),
    87  					TagNum:   tagNum,
    88  					WireType: wireType,
    89  				}
    90  			}
    91  		}
    92  
    93  		// Skip over the bytes that store fieldNumber and wireType bytes.
    94  		bz = bz[m:]
    95  		n := protowire.ConsumeFieldValue(tagNum, wireType, bz)
    96  		if n < 0 {
    97  			err = fmt.Errorf("could not consume field value for tagNum: %d, wireType: %q; %w",
    98  				tagNum, wireTypeToString(wireType), protowire.ParseError(n))
    99  			return hasUnknownNonCriticals, err
   100  		}
   101  		fieldBytes := bz[:n]
   102  		bz = bz[n:]
   103  
   104  		// An unknown but non-critical field or just a scalar type (aka *INT and BYTES like).
   105  		if fieldDescProto == nil || isScalar(fieldDescProto) {
   106  			continue
   107  		}
   108  
   109  		protoMessageName := fieldDescProto.GetTypeName()
   110  		if protoMessageName == "" {
   111  			switch typ := fieldDescProto.GetType(); typ {
   112  			case descriptorpb.FieldDescriptorProto_TYPE_STRING, descriptorpb.FieldDescriptorProto_TYPE_BYTES:
   113  				// At this point only TYPE_STRING is expected to be unregistered, since FieldDescriptorProto.IsScalar() returns false for
   114  				// TYPE_BYTES and TYPE_STRING as per
   115  				// https://github.com/cosmos/gogoproto/blob/5628607bb4c51c3157aacc3a50f0ab707582b805/protoc-gen-gogo/descriptor/descriptor.go#L95-L118
   116  			default:
   117  				return hasUnknownNonCriticals, fmt.Errorf("failed to get typename for message of type %v, can only be TYPE_STRING or TYPE_BYTES", typ)
   118  			}
   119  			continue
   120  		}
   121  
   122  		// Let's recursively traverse and typecheck the field.
   123  
   124  		// consume length prefix of nested message
   125  		_, o := protowire.ConsumeVarint(fieldBytes)
   126  		fieldBytes = fieldBytes[o:]
   127  
   128  		var msg proto.Message
   129  		var err error
   130  
   131  		if protoMessageName == ".google.protobuf.Any" {
   132  			// Firstly typecheck types.Any to ensure nothing snuck in.
   133  			hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, (*types.Any)(nil), allowUnknownNonCriticals, resolver)
   134  			hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild
   135  			if err != nil {
   136  				return hasUnknownNonCriticals, err
   137  			}
   138  			// And finally we can extract the TypeURL containing the protoMessageName.
   139  			any := new(types.Any)
   140  			if err := proto.Unmarshal(fieldBytes, any); err != nil {
   141  				return hasUnknownNonCriticals, err
   142  			}
   143  			protoMessageName = any.TypeUrl
   144  			fieldBytes = any.Value
   145  			msg, err = resolver.Resolve(protoMessageName)
   146  			if err != nil {
   147  				return hasUnknownNonCriticals, err
   148  			}
   149  		} else {
   150  			msg, err = protoMessageForTypeName(protoMessageName[1:])
   151  			if err != nil {
   152  				return hasUnknownNonCriticals, err
   153  			}
   154  		}
   155  
   156  		hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, msg, allowUnknownNonCriticals, resolver)
   157  		hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild
   158  		if err != nil {
   159  			return hasUnknownNonCriticals, err
   160  		}
   161  	}
   162  
   163  	return hasUnknownNonCriticals, nil
   164  }
   165  
   166  var (
   167  	protoMessageForTypeNameMu    sync.RWMutex
   168  	protoMessageForTypeNameCache = make(map[string]proto.Message)
   169  )
   170  
   171  // protoMessageForTypeName takes in a fully qualified name e.g. testdata.TestVersionFD1
   172  // and returns a corresponding empty protobuf message that serves the prototype for typechecking.
   173  func protoMessageForTypeName(protoMessageName string) (proto.Message, error) {
   174  	protoMessageForTypeNameMu.RLock()
   175  	msg, ok := protoMessageForTypeNameCache[protoMessageName]
   176  	protoMessageForTypeNameMu.RUnlock()
   177  	if ok {
   178  		return msg, nil
   179  	}
   180  
   181  	concreteGoType := proto.MessageType(protoMessageName)
   182  	if concreteGoType == nil {
   183  		return nil, fmt.Errorf("failed to retrieve the message of type %q", protoMessageName)
   184  	}
   185  
   186  	value := reflect.New(concreteGoType).Elem()
   187  	msg, ok = value.Interface().(proto.Message)
   188  	if !ok {
   189  		return nil, fmt.Errorf("%q does not implement proto.Message", protoMessageName)
   190  	}
   191  
   192  	// Now cache it.
   193  	protoMessageForTypeNameMu.Lock()
   194  	protoMessageForTypeNameCache[protoMessageName] = msg
   195  	protoMessageForTypeNameMu.Unlock()
   196  
   197  	return msg, nil
   198  }
   199  
   200  // checks is a mapping of protowire.Type to supported descriptor.FieldDescriptorProto_Type.
   201  // it is implemented this way so as to have constant time lookups and avoid the overhead
   202  // from O(n) walking of switch. The change to using this mapping boosts throughput by about 200%.
   203  var checks = [...]map[descriptorpb.FieldDescriptorProto_Type]bool{
   204  	// "0	Varint: int32, int64, uint32, uint64, sint32, sint64, bool, enum"
   205  	0: {
   206  		descriptorpb.FieldDescriptorProto_TYPE_INT32:  true,
   207  		descriptorpb.FieldDescriptorProto_TYPE_INT64:  true,
   208  		descriptorpb.FieldDescriptorProto_TYPE_UINT32: true,
   209  		descriptorpb.FieldDescriptorProto_TYPE_UINT64: true,
   210  		descriptorpb.FieldDescriptorProto_TYPE_SINT32: true,
   211  		descriptorpb.FieldDescriptorProto_TYPE_SINT64: true,
   212  		descriptorpb.FieldDescriptorProto_TYPE_BOOL:   true,
   213  		descriptorpb.FieldDescriptorProto_TYPE_ENUM:   true,
   214  	},
   215  
   216  	// "1	64-bit:	fixed64, sfixed64, double"
   217  	1: {
   218  		descriptorpb.FieldDescriptorProto_TYPE_FIXED64:  true,
   219  		descriptorpb.FieldDescriptorProto_TYPE_SFIXED64: true,
   220  		descriptorpb.FieldDescriptorProto_TYPE_DOUBLE:   true,
   221  	},
   222  
   223  	// "2	Length-delimited: string, bytes, embedded messages, packed repeated fields"
   224  	2: {
   225  		descriptorpb.FieldDescriptorProto_TYPE_STRING:  true,
   226  		descriptorpb.FieldDescriptorProto_TYPE_BYTES:   true,
   227  		descriptorpb.FieldDescriptorProto_TYPE_MESSAGE: true,
   228  		// The following types can be packed repeated.
   229  		// ref: "Only repeated fields of primitive numeric types (types which use the varint, 32-bit, or 64-bit wire types) can be declared "packed"."
   230  		// ref: https://developers.google.com/protocol-buffers/docs/encoding#packed
   231  		descriptorpb.FieldDescriptorProto_TYPE_INT32:    true,
   232  		descriptorpb.FieldDescriptorProto_TYPE_INT64:    true,
   233  		descriptorpb.FieldDescriptorProto_TYPE_UINT32:   true,
   234  		descriptorpb.FieldDescriptorProto_TYPE_UINT64:   true,
   235  		descriptorpb.FieldDescriptorProto_TYPE_SINT32:   true,
   236  		descriptorpb.FieldDescriptorProto_TYPE_SINT64:   true,
   237  		descriptorpb.FieldDescriptorProto_TYPE_BOOL:     true,
   238  		descriptorpb.FieldDescriptorProto_TYPE_ENUM:     true,
   239  		descriptorpb.FieldDescriptorProto_TYPE_FIXED64:  true,
   240  		descriptorpb.FieldDescriptorProto_TYPE_SFIXED64: true,
   241  		descriptorpb.FieldDescriptorProto_TYPE_DOUBLE:   true,
   242  	},
   243  
   244  	// "3	Start group:	groups (deprecated)"
   245  	3: {
   246  		descriptorpb.FieldDescriptorProto_TYPE_GROUP: true,
   247  	},
   248  
   249  	// "4	End group:	groups (deprecated)"
   250  	4: {
   251  		descriptorpb.FieldDescriptorProto_TYPE_GROUP: true,
   252  	},
   253  
   254  	// "5	32-bit:	fixed32, sfixed32, float"
   255  	5: {
   256  		descriptorpb.FieldDescriptorProto_TYPE_FIXED32:  true,
   257  		descriptorpb.FieldDescriptorProto_TYPE_SFIXED32: true,
   258  		descriptorpb.FieldDescriptorProto_TYPE_FLOAT:    true,
   259  	},
   260  }
   261  
   262  // canEncodeType returns true if the wireType is suitable for encoding the descriptor type.
   263  // See https://developers.google.com/protocol-buffers/docs/encoding#structure.
   264  func canEncodeType(wireType protowire.Type, descType descriptorpb.FieldDescriptorProto_Type) bool {
   265  	if iwt := int(wireType); iwt < 0 || iwt >= len(checks) {
   266  		return false
   267  	}
   268  	return checks[wireType][descType]
   269  }
   270  
   271  // errMismatchedWireType describes a mismatch between
   272  // expected and got wireTypes for a specific tag number.
   273  type errMismatchedWireType struct {
   274  	Type         string
   275  	GotWireType  protowire.Type
   276  	WantWireType protowire.Type
   277  	TagNum       protowire.Number
   278  }
   279  
   280  // String implements fmt.Stringer.
   281  func (mwt *errMismatchedWireType) String() string {
   282  	return fmt.Sprintf("Mismatched %q: {TagNum: %d, GotWireType: %q != WantWireType: %q}",
   283  		mwt.Type, mwt.TagNum, wireTypeToString(mwt.GotWireType), wireTypeToString(mwt.WantWireType))
   284  }
   285  
   286  // Error implements the error interface.
   287  func (mwt *errMismatchedWireType) Error() string {
   288  	return mwt.String()
   289  }
   290  
   291  var _ error = (*errMismatchedWireType)(nil)
   292  
   293  func wireTypeToString(wt protowire.Type) string {
   294  	switch wt {
   295  	case 0:
   296  		return "varint"
   297  	case 1:
   298  		return "fixed64"
   299  	case 2:
   300  		return "bytes"
   301  	case 3:
   302  		return "start_group"
   303  	case 4:
   304  		return "end_group"
   305  	case 5:
   306  		return "fixed32"
   307  	default:
   308  		return fmt.Sprintf("unknown type: %d", wt)
   309  	}
   310  }
   311  
   312  // errUnknownField represents an error indicating that we encountered
   313  // a field that isn't available in the target proto.Message.
   314  type errUnknownField struct {
   315  	Type     string
   316  	TagNum   protowire.Number
   317  	WireType protowire.Type
   318  }
   319  
   320  // String implements fmt.Stringer.
   321  func (twt *errUnknownField) String() string {
   322  	return fmt.Sprintf("errUnknownField %q: {TagNum: %d, WireType:%q}",
   323  		twt.Type, twt.TagNum, wireTypeToString(twt.WireType))
   324  }
   325  
   326  // Error implements the error interface.
   327  func (twt *errUnknownField) Error() string {
   328  	return twt.String()
   329  }
   330  
   331  var _ error = (*errUnknownField)(nil)
   332  
   333  var (
   334  	protoFileToDesc   = make(map[string]*descriptorpb.FileDescriptorProto)
   335  	protoFileToDescMu sync.RWMutex
   336  )
   337  
   338  func unnestDesc(mdescs []*descriptorpb.DescriptorProto, indices []int) *descriptorpb.DescriptorProto {
   339  	mdesc := mdescs[indices[0]]
   340  	for _, index := range indices[1:] {
   341  		mdesc = mdesc.NestedType[index]
   342  	}
   343  	return mdesc
   344  }
   345  
   346  // Invoking descriptorpb.ForMessage(proto.Message.(Descriptor).Descriptor()) is incredibly slow
   347  // for every single message, thus the need for a hand-rolled custom version that's performant and cacheable.
   348  func extractFileDescMessageDesc(desc descriptorIface) (*descriptorpb.FileDescriptorProto, *descriptorpb.DescriptorProto, error) {
   349  	gzippedPb, indices := desc.Descriptor()
   350  
   351  	protoFileToDescMu.RLock()
   352  	cached, ok := protoFileToDesc[string(gzippedPb)]
   353  	protoFileToDescMu.RUnlock()
   354  
   355  	if ok {
   356  		return cached, unnestDesc(cached.MessageType, indices), nil
   357  	}
   358  
   359  	// Time to gunzip the content of the FileDescriptor and then proto unmarshal them.
   360  	gzr, err := gzip.NewReader(bytes.NewReader(gzippedPb))
   361  	if err != nil {
   362  		return nil, nil, err
   363  	}
   364  	protoBlob, err := io.ReadAll(gzr)
   365  	if err != nil {
   366  		return nil, nil, err
   367  	}
   368  
   369  	fdesc := new(descriptorpb.FileDescriptorProto)
   370  	if err := protov2.Unmarshal(protoBlob, fdesc); err != nil {
   371  		return nil, nil, err
   372  	}
   373  
   374  	// Now cache the FileDescriptor.
   375  	protoFileToDescMu.Lock()
   376  	protoFileToDesc[string(gzippedPb)] = fdesc
   377  	protoFileToDescMu.Unlock()
   378  
   379  	// Unnest the type if necessary.
   380  	return fdesc, unnestDesc(fdesc.MessageType, indices), nil
   381  }
   382  
   383  type descriptorMatch struct {
   384  	cache map[int32]*descriptorpb.FieldDescriptorProto
   385  	desc  *descriptorpb.DescriptorProto
   386  }
   387  
   388  var (
   389  	descprotoCacheMu sync.RWMutex
   390  	descprotoCache   = make(map[reflect.Type]*descriptorMatch)
   391  )
   392  
   393  // getDescriptorInfo retrieves the mapping of field numbers to their respective field descriptors.
   394  func getDescriptorInfo(desc descriptorIface, msg proto.Message) (map[int32]*descriptorpb.FieldDescriptorProto, *descriptorpb.DescriptorProto, error) {
   395  	key := reflect.ValueOf(msg).Type()
   396  
   397  	descprotoCacheMu.RLock()
   398  	got, ok := descprotoCache[key]
   399  	descprotoCacheMu.RUnlock()
   400  
   401  	if ok {
   402  		return got.cache, got.desc, nil
   403  	}
   404  
   405  	// Now compute and cache the index.
   406  	_, md, err := extractFileDescMessageDesc(desc)
   407  	if err != nil {
   408  		return nil, nil, err
   409  	}
   410  
   411  	tagNumToTypeIndex := make(map[int32]*descriptorpb.FieldDescriptorProto)
   412  	for _, field := range md.Field {
   413  		tagNumToTypeIndex[field.GetNumber()] = field
   414  	}
   415  
   416  	descprotoCacheMu.Lock()
   417  	descprotoCache[key] = &descriptorMatch{
   418  		cache: tagNumToTypeIndex,
   419  		desc:  md,
   420  	}
   421  	descprotoCacheMu.Unlock()
   422  
   423  	return tagNumToTypeIndex, md, nil
   424  }
   425  
   426  // DefaultAnyResolver is a default implementation of AnyResolver which uses
   427  // the default encoding of type URLs as specified by the protobuf specification.
   428  type DefaultAnyResolver struct{}
   429  
   430  var _ jsonpb.AnyResolver = DefaultAnyResolver{}
   431  
   432  // Resolve is the AnyResolver.Resolve method.
   433  func (d DefaultAnyResolver) Resolve(typeURL string) (proto.Message, error) {
   434  	// Only the part of typeURL after the last slash is relevant.
   435  	mname := typeURL
   436  	if slash := strings.LastIndex(mname, "/"); slash >= 0 {
   437  		mname = mname[slash+1:]
   438  	}
   439  	mt := proto.MessageType(mname)
   440  	if mt == nil {
   441  		return nil, fmt.Errorf("unknown message type %q", mname)
   442  	}
   443  	return reflect.New(mt.Elem()).Interface().(proto.Message), nil
   444  }
   445  
   446  // toProtowireType converts a descriptorpb.FieldDescriptorProto_Type to a protowire.Type.
   447  func toProtowireType(fieldType descriptorpb.FieldDescriptorProto_Type) protowire.Type {
   448  	switch fieldType {
   449  	// varint encoded
   450  	case descriptorpb.FieldDescriptorProto_TYPE_INT64,
   451  		descriptorpb.FieldDescriptorProto_TYPE_UINT64,
   452  		descriptorpb.FieldDescriptorProto_TYPE_INT32,
   453  		descriptorpb.FieldDescriptorProto_TYPE_UINT32,
   454  		descriptorpb.FieldDescriptorProto_TYPE_BOOL,
   455  		descriptorpb.FieldDescriptorProto_TYPE_ENUM,
   456  		descriptorpb.FieldDescriptorProto_TYPE_SINT32,
   457  		descriptorpb.FieldDescriptorProto_TYPE_SINT64:
   458  		return protowire.VarintType
   459  
   460  	// fixed64 encoded
   461  	case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE,
   462  		descriptorpb.FieldDescriptorProto_TYPE_FIXED64,
   463  		descriptorpb.FieldDescriptorProto_TYPE_SFIXED64:
   464  		return protowire.Fixed64Type
   465  
   466  	// fixed32 encoded
   467  	case descriptorpb.FieldDescriptorProto_TYPE_FLOAT,
   468  		descriptorpb.FieldDescriptorProto_TYPE_FIXED32,
   469  		descriptorpb.FieldDescriptorProto_TYPE_SFIXED32:
   470  		return protowire.Fixed32Type
   471  
   472  	// bytes encoded
   473  	case descriptorpb.FieldDescriptorProto_TYPE_STRING,
   474  		descriptorpb.FieldDescriptorProto_TYPE_BYTES,
   475  		descriptorpb.FieldDescriptorProto_TYPE_MESSAGE,
   476  		descriptorpb.FieldDescriptorProto_TYPE_GROUP:
   477  		return protowire.BytesType
   478  	default:
   479  		panic(fmt.Sprintf("unknown field type %s", fieldType))
   480  	}
   481  }
   482  
   483  // isScalar defines whether a field is a scalar type.
   484  // Copied from gogo/protobuf/protoc-gen-gogo
   485  // https://github.com/gogo/protobuf/blob/b03c65ea87cdc3521ede29f62fe3ce239267c1bc/protoc-gen-gogo/descriptor/descriptor.go#L95
   486  func isScalar(field *descriptorpb.FieldDescriptorProto) bool {
   487  	if field.Type == nil {
   488  		return false
   489  	}
   490  	switch *field.Type {
   491  	case descriptorpb.FieldDescriptorProto_TYPE_DOUBLE,
   492  		descriptorpb.FieldDescriptorProto_TYPE_FLOAT,
   493  		descriptorpb.FieldDescriptorProto_TYPE_INT64,
   494  		descriptorpb.FieldDescriptorProto_TYPE_UINT64,
   495  		descriptorpb.FieldDescriptorProto_TYPE_INT32,
   496  		descriptorpb.FieldDescriptorProto_TYPE_FIXED64,
   497  		descriptorpb.FieldDescriptorProto_TYPE_FIXED32,
   498  		descriptorpb.FieldDescriptorProto_TYPE_BOOL,
   499  		descriptorpb.FieldDescriptorProto_TYPE_UINT32,
   500  		descriptorpb.FieldDescriptorProto_TYPE_ENUM,
   501  		descriptorpb.FieldDescriptorProto_TYPE_SFIXED32,
   502  		descriptorpb.FieldDescriptorProto_TYPE_SFIXED64,
   503  		descriptorpb.FieldDescriptorProto_TYPE_SINT32,
   504  		descriptorpb.FieldDescriptorProto_TYPE_SINT64:
   505  		return true
   506  	default:
   507  		return false
   508  	}
   509  }