github.com/Finschia/finschia-sdk@v0.48.1/codec/unknownproto/unknown_fields.go (about)

     1  package unknownproto
     2  
     3  import (
     4  	"bytes"
     5  	"compress/gzip"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"reflect"
    10  	"strings"
    11  	"sync"
    12  
    13  	"github.com/gogo/protobuf/jsonpb"
    14  	"github.com/gogo/protobuf/proto"
    15  	"github.com/gogo/protobuf/protoc-gen-gogo/descriptor"
    16  	"google.golang.org/protobuf/encoding/protowire"
    17  
    18  	"github.com/Finschia/finschia-sdk/codec/types"
    19  )
    20  
    21  const bit11NonCritical = 1 << 10
    22  
    23  type descriptorIface interface {
    24  	Descriptor() ([]byte, []int)
    25  }
    26  
    27  // RejectUnknownFieldsStrict rejects any bytes bz with an error that has unknown fields for the provided proto.Message type.
    28  // This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message.
    29  // An AnyResolver must be provided for traversing inside google.protobuf.Any's.
    30  func RejectUnknownFieldsStrict(bz []byte, msg proto.Message, resolver jsonpb.AnyResolver) error {
    31  	_, err := RejectUnknownFields(bz, msg, false, resolver)
    32  	return err
    33  }
    34  
    35  // RejectUnknownFields rejects any bytes bz with an error that has unknown fields for the provided proto.Message type with an
    36  // option to allow non-critical fields (specified as those fields with bit 11) to pass through. In either case, the
    37  // hasUnknownNonCriticals will be set to true if non-critical fields were encountered during traversal. This flag can be
    38  // used to treat a message with non-critical field different in different security contexts (such as transaction signing).
    39  // This function traverses inside of messages nested via google.protobuf.Any. It does not do any deserialization of the proto.Message.
    40  // An AnyResolver must be provided for traversing inside google.protobuf.Any's.
    41  func RejectUnknownFields(bz []byte, msg proto.Message, allowUnknownNonCriticals bool, resolver jsonpb.AnyResolver) (hasUnknownNonCriticals bool, err error) {
    42  	if len(bz) == 0 {
    43  		return hasUnknownNonCriticals, nil
    44  	}
    45  
    46  	desc, ok := msg.(descriptorIface)
    47  	if !ok {
    48  		return hasUnknownNonCriticals, fmt.Errorf("%T does not have a Descriptor() method", msg)
    49  	}
    50  
    51  	fieldDescProtoFromTagNum, _, err := getDescriptorInfo(desc, msg)
    52  	if err != nil {
    53  		return hasUnknownNonCriticals, err
    54  	}
    55  
    56  	for len(bz) > 0 {
    57  		tagNum, wireType, m := protowire.ConsumeTag(bz)
    58  		if m < 0 {
    59  			return hasUnknownNonCriticals, errors.New("invalid length")
    60  		}
    61  
    62  		fieldDescProto, ok := fieldDescProtoFromTagNum[int32(tagNum)]
    63  		switch {
    64  		case ok:
    65  			// Assert that the wireTypes match.
    66  			if !canEncodeType(wireType, fieldDescProto.GetType()) {
    67  				return hasUnknownNonCriticals, &errMismatchedWireType{
    68  					Type:         reflect.ValueOf(msg).Type().String(),
    69  					TagNum:       tagNum,
    70  					GotWireType:  wireType,
    71  					WantWireType: protowire.Type(fieldDescProto.WireType()),
    72  				}
    73  			}
    74  
    75  		default:
    76  			isCriticalField := tagNum&bit11NonCritical == 0
    77  
    78  			if !isCriticalField {
    79  				hasUnknownNonCriticals = true
    80  			}
    81  
    82  			if isCriticalField || !allowUnknownNonCriticals {
    83  				// The tag is critical, so report it.
    84  				return hasUnknownNonCriticals, &errUnknownField{
    85  					Type:     reflect.ValueOf(msg).Type().String(),
    86  					TagNum:   tagNum,
    87  					WireType: wireType,
    88  				}
    89  			}
    90  		}
    91  
    92  		// Skip over the bytes that store fieldNumber and wireType bytes.
    93  		bz = bz[m:]
    94  		n := protowire.ConsumeFieldValue(tagNum, wireType, bz)
    95  		if n < 0 {
    96  			err = fmt.Errorf("could not consume field value for tagNum: %d, wireType: %q; %w",
    97  				tagNum, wireTypeToString(wireType), protowire.ParseError(n))
    98  			return hasUnknownNonCriticals, err
    99  		}
   100  		fieldBytes := bz[:n]
   101  		bz = bz[n:]
   102  
   103  		// An unknown but non-critical field or just a scalar type (aka *INT and BYTES like).
   104  		if fieldDescProto == nil || fieldDescProto.IsScalar() {
   105  			continue
   106  		}
   107  
   108  		protoMessageName := fieldDescProto.GetTypeName()
   109  		if protoMessageName == "" {
   110  			switch typ := fieldDescProto.GetType(); typ {
   111  			case descriptor.FieldDescriptorProto_TYPE_STRING, descriptor.FieldDescriptorProto_TYPE_BYTES:
   112  				// At this point only TYPE_STRING is expected to be unregistered, since FieldDescriptorProto.IsScalar() returns false for
   113  				// TYPE_BYTES and TYPE_STRING as per
   114  				// https://github.com/gogo/protobuf/blob/5628607bb4c51c3157aacc3a50f0ab707582b805/protoc-gen-gogo/descriptor/descriptor.go#L95-L118
   115  			default:
   116  				return hasUnknownNonCriticals, fmt.Errorf("failed to get typename for message of type %v, can only be TYPE_STRING or TYPE_BYTES", typ)
   117  			}
   118  			continue
   119  		}
   120  
   121  		// Let's recursively traverse and typecheck the field.
   122  
   123  		// consume length prefix of nested message
   124  		_, o := protowire.ConsumeVarint(fieldBytes)
   125  		fieldBytes = fieldBytes[o:]
   126  
   127  		var msg proto.Message
   128  		var err error
   129  
   130  		if protoMessageName == ".google.protobuf.Any" {
   131  			// Firstly typecheck types.Any to ensure nothing snuck in.
   132  			hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, (*types.Any)(nil), allowUnknownNonCriticals, resolver)
   133  			hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild
   134  			if err != nil {
   135  				return hasUnknownNonCriticals, err
   136  			}
   137  			// And finally we can extract the TypeURL containing the protoMessageName.
   138  			any := new(types.Any)
   139  			if err := proto.Unmarshal(fieldBytes, any); err != nil {
   140  				return hasUnknownNonCriticals, err
   141  			}
   142  			protoMessageName = any.TypeUrl
   143  			fieldBytes = any.Value
   144  			msg, err = resolver.Resolve(protoMessageName)
   145  			if err != nil {
   146  				return hasUnknownNonCriticals, err
   147  			}
   148  		} else {
   149  			msg, err = protoMessageForTypeName(protoMessageName[1:])
   150  			if err != nil {
   151  				return hasUnknownNonCriticals, err
   152  			}
   153  		}
   154  
   155  		hasUnknownNonCriticalsChild, err := RejectUnknownFields(fieldBytes, msg, allowUnknownNonCriticals, resolver)
   156  		hasUnknownNonCriticals = hasUnknownNonCriticals || hasUnknownNonCriticalsChild
   157  		if err != nil {
   158  			return hasUnknownNonCriticals, err
   159  		}
   160  	}
   161  
   162  	return hasUnknownNonCriticals, nil
   163  }
   164  
   165  var (
   166  	protoMessageForTypeNameMu    sync.RWMutex
   167  	protoMessageForTypeNameCache = make(map[string]proto.Message)
   168  )
   169  
   170  // protoMessageForTypeName takes in a fully qualified name e.g. testdata.TestVersionFD1
   171  // and returns a corresponding empty protobuf message that serves the prototype for typechecking.
   172  func protoMessageForTypeName(protoMessageName string) (proto.Message, error) {
   173  	protoMessageForTypeNameMu.RLock()
   174  	msg, ok := protoMessageForTypeNameCache[protoMessageName]
   175  	protoMessageForTypeNameMu.RUnlock()
   176  	if ok {
   177  		return msg, nil
   178  	}
   179  
   180  	concreteGoType := proto.MessageType(protoMessageName)
   181  	if concreteGoType == nil {
   182  		return nil, fmt.Errorf("failed to retrieve the message of type %q", protoMessageName)
   183  	}
   184  
   185  	value := reflect.New(concreteGoType).Elem()
   186  	msg, ok = value.Interface().(proto.Message)
   187  	if !ok {
   188  		return nil, fmt.Errorf("%q does not implement proto.Message", protoMessageName)
   189  	}
   190  
   191  	// Now cache it.
   192  	protoMessageForTypeNameMu.Lock()
   193  	protoMessageForTypeNameCache[protoMessageName] = msg
   194  	protoMessageForTypeNameMu.Unlock()
   195  
   196  	return msg, nil
   197  }
   198  
   199  // checks is a mapping of protowire.Type to supported descriptor.FieldDescriptorProto_Type.
   200  // it is implemented this way so as to have constant time lookups and avoid the overhead
   201  // from O(n) walking of switch. The change to using this mapping boosts throughput by about 200%.
   202  var checks = [...]map[descriptor.FieldDescriptorProto_Type]bool{
   203  	// "0	Varint: int32, int64, uint32, uint64, sint32, sint64, bool, enum"
   204  	0: {
   205  		descriptor.FieldDescriptorProto_TYPE_INT32:  true,
   206  		descriptor.FieldDescriptorProto_TYPE_INT64:  true,
   207  		descriptor.FieldDescriptorProto_TYPE_UINT32: true,
   208  		descriptor.FieldDescriptorProto_TYPE_UINT64: true,
   209  		descriptor.FieldDescriptorProto_TYPE_SINT32: true,
   210  		descriptor.FieldDescriptorProto_TYPE_SINT64: true,
   211  		descriptor.FieldDescriptorProto_TYPE_BOOL:   true,
   212  		descriptor.FieldDescriptorProto_TYPE_ENUM:   true,
   213  	},
   214  
   215  	// "1	64-bit:	fixed64, sfixed64, double"
   216  	1: {
   217  		descriptor.FieldDescriptorProto_TYPE_FIXED64:  true,
   218  		descriptor.FieldDescriptorProto_TYPE_SFIXED64: true,
   219  		descriptor.FieldDescriptorProto_TYPE_DOUBLE:   true,
   220  	},
   221  
   222  	// "2	Length-delimited: string, bytes, embedded messages, packed repeated fields"
   223  	2: {
   224  		descriptor.FieldDescriptorProto_TYPE_STRING:  true,
   225  		descriptor.FieldDescriptorProto_TYPE_BYTES:   true,
   226  		descriptor.FieldDescriptorProto_TYPE_MESSAGE: true,
   227  		// The following types can be packed repeated.
   228  		// ref: "Only repeated fields of primitive numeric types (types which use the varint, 32-bit, or 64-bit wire types) can be declared "packed"."
   229  		// ref: https://developers.google.com/protocol-buffers/docs/encoding#packed
   230  		descriptor.FieldDescriptorProto_TYPE_INT32:    true,
   231  		descriptor.FieldDescriptorProto_TYPE_INT64:    true,
   232  		descriptor.FieldDescriptorProto_TYPE_UINT32:   true,
   233  		descriptor.FieldDescriptorProto_TYPE_UINT64:   true,
   234  		descriptor.FieldDescriptorProto_TYPE_SINT32:   true,
   235  		descriptor.FieldDescriptorProto_TYPE_SINT64:   true,
   236  		descriptor.FieldDescriptorProto_TYPE_BOOL:     true,
   237  		descriptor.FieldDescriptorProto_TYPE_ENUM:     true,
   238  		descriptor.FieldDescriptorProto_TYPE_FIXED64:  true,
   239  		descriptor.FieldDescriptorProto_TYPE_SFIXED64: true,
   240  		descriptor.FieldDescriptorProto_TYPE_DOUBLE:   true,
   241  	},
   242  
   243  	// "3	Start group:	groups (deprecated)"
   244  	3: {
   245  		descriptor.FieldDescriptorProto_TYPE_GROUP: true,
   246  	},
   247  
   248  	// "4	End group:	groups (deprecated)"
   249  	4: {
   250  		descriptor.FieldDescriptorProto_TYPE_GROUP: true,
   251  	},
   252  
   253  	// "5	32-bit:	fixed32, sfixed32, float"
   254  	5: {
   255  		descriptor.FieldDescriptorProto_TYPE_FIXED32:  true,
   256  		descriptor.FieldDescriptorProto_TYPE_SFIXED32: true,
   257  		descriptor.FieldDescriptorProto_TYPE_FLOAT:    true,
   258  	},
   259  }
   260  
   261  // canEncodeType returns true if the wireType is suitable for encoding the descriptor type.
   262  // See https://developers.google.com/protocol-buffers/docs/encoding#structure.
   263  func canEncodeType(wireType protowire.Type, descType descriptor.FieldDescriptorProto_Type) bool {
   264  	if iwt := int(wireType); iwt < 0 || iwt >= len(checks) {
   265  		return false
   266  	}
   267  	return checks[wireType][descType]
   268  }
   269  
   270  // errMismatchedWireType describes a mismatch between
   271  // expected and got wireTypes for a specific tag number.
   272  type errMismatchedWireType struct {
   273  	Type         string
   274  	GotWireType  protowire.Type
   275  	WantWireType protowire.Type
   276  	TagNum       protowire.Number
   277  }
   278  
   279  // String implements fmt.Stringer.
   280  func (mwt *errMismatchedWireType) String() string {
   281  	return fmt.Sprintf("Mismatched %q: {TagNum: %d, GotWireType: %q != WantWireType: %q}",
   282  		mwt.Type, mwt.TagNum, wireTypeToString(mwt.GotWireType), wireTypeToString(mwt.WantWireType))
   283  }
   284  
   285  // Error implements the error interface.
   286  func (mwt *errMismatchedWireType) Error() string {
   287  	return mwt.String()
   288  }
   289  
   290  var _ error = (*errMismatchedWireType)(nil)
   291  
   292  func wireTypeToString(wt protowire.Type) string {
   293  	switch wt {
   294  	case 0:
   295  		return "varint"
   296  	case 1:
   297  		return "fixed64"
   298  	case 2:
   299  		return "bytes"
   300  	case 3:
   301  		return "start_group"
   302  	case 4:
   303  		return "end_group"
   304  	case 5:
   305  		return "fixed32"
   306  	default:
   307  		return fmt.Sprintf("unknown type: %d", wt)
   308  	}
   309  }
   310  
   311  // errUnknownField represents an error indicating that we encountered
   312  // a field that isn't available in the target proto.Message.
   313  type errUnknownField struct {
   314  	Type     string
   315  	TagNum   protowire.Number
   316  	WireType protowire.Type
   317  }
   318  
   319  // String implements fmt.Stringer.
   320  func (twt *errUnknownField) String() string {
   321  	return fmt.Sprintf("errUnknownField %q: {TagNum: %d, WireType:%q}",
   322  		twt.Type, twt.TagNum, wireTypeToString(twt.WireType))
   323  }
   324  
   325  // Error implements the error interface.
   326  func (twt *errUnknownField) Error() string {
   327  	return twt.String()
   328  }
   329  
   330  var _ error = (*errUnknownField)(nil)
   331  
   332  var (
   333  	protoFileToDesc   = make(map[string]*descriptor.FileDescriptorProto)
   334  	protoFileToDescMu sync.RWMutex
   335  )
   336  
   337  func unnestDesc(mdescs []*descriptor.DescriptorProto, indices []int) *descriptor.DescriptorProto {
   338  	mdesc := mdescs[indices[0]]
   339  	for _, index := range indices[1:] {
   340  		mdesc = mdesc.NestedType[index]
   341  	}
   342  	return mdesc
   343  }
   344  
   345  // Invoking descriptor.ForMessage(proto.Message.(Descriptor).Descriptor()) is incredibly slow
   346  // for every single message, thus the need for a hand-rolled custom version that's performant and cacheable.
   347  func extractFileDescMessageDesc(desc descriptorIface) (*descriptor.FileDescriptorProto, *descriptor.DescriptorProto, error) {
   348  	gzippedPb, indices := desc.Descriptor()
   349  
   350  	protoFileToDescMu.RLock()
   351  	cached, ok := protoFileToDesc[string(gzippedPb)]
   352  	protoFileToDescMu.RUnlock()
   353  
   354  	if ok {
   355  		return cached, unnestDesc(cached.MessageType, indices), nil
   356  	}
   357  
   358  	// Time to gunzip the content of the FileDescriptor and then proto unmarshal them.
   359  	gzr, err := gzip.NewReader(bytes.NewReader(gzippedPb))
   360  	if err != nil {
   361  		return nil, nil, err
   362  	}
   363  	protoBlob, err := io.ReadAll(gzr)
   364  	if err != nil {
   365  		return nil, nil, err
   366  	}
   367  
   368  	fdesc := new(descriptor.FileDescriptorProto)
   369  	if err := proto.Unmarshal(protoBlob, fdesc); err != nil {
   370  		return nil, nil, err
   371  	}
   372  
   373  	// Now cache the FileDescriptor.
   374  	protoFileToDescMu.Lock()
   375  	protoFileToDesc[string(gzippedPb)] = fdesc
   376  	protoFileToDescMu.Unlock()
   377  
   378  	// Unnest the type if necessary.
   379  	return fdesc, unnestDesc(fdesc.MessageType, indices), nil
   380  }
   381  
   382  type descriptorMatch struct {
   383  	cache map[int32]*descriptor.FieldDescriptorProto
   384  	desc  *descriptor.DescriptorProto
   385  }
   386  
   387  var (
   388  	descprotoCacheMu sync.RWMutex
   389  	descprotoCache   = make(map[reflect.Type]*descriptorMatch)
   390  )
   391  
   392  // getDescriptorInfo retrieves the mapping of field numbers to their respective field descriptors.
   393  func getDescriptorInfo(desc descriptorIface, msg proto.Message) (map[int32]*descriptor.FieldDescriptorProto, *descriptor.DescriptorProto, error) {
   394  	key := reflect.ValueOf(msg).Type()
   395  
   396  	descprotoCacheMu.RLock()
   397  	got, ok := descprotoCache[key]
   398  	descprotoCacheMu.RUnlock()
   399  
   400  	if ok {
   401  		return got.cache, got.desc, nil
   402  	}
   403  
   404  	// Now compute and cache the index.
   405  	_, md, err := extractFileDescMessageDesc(desc)
   406  	if err != nil {
   407  		return nil, nil, err
   408  	}
   409  
   410  	tagNumToTypeIndex := make(map[int32]*descriptor.FieldDescriptorProto)
   411  	for _, field := range md.Field {
   412  		tagNumToTypeIndex[field.GetNumber()] = field
   413  	}
   414  
   415  	descprotoCacheMu.Lock()
   416  	descprotoCache[key] = &descriptorMatch{
   417  		cache: tagNumToTypeIndex,
   418  		desc:  md,
   419  	}
   420  	descprotoCacheMu.Unlock()
   421  
   422  	return tagNumToTypeIndex, md, nil
   423  }
   424  
   425  // DefaultAnyResolver is a default implementation of AnyResolver which uses
   426  // the default encoding of type URLs as specified by the protobuf specification.
   427  type DefaultAnyResolver struct{}
   428  
   429  var _ jsonpb.AnyResolver = DefaultAnyResolver{}
   430  
   431  // Resolve is the AnyResolver.Resolve method.
   432  func (d DefaultAnyResolver) Resolve(typeURL string) (proto.Message, error) {
   433  	// Only the part of typeURL after the last slash is relevant.
   434  	mname := typeURL
   435  	if slash := strings.LastIndex(mname, "/"); slash >= 0 {
   436  		mname = mname[slash+1:]
   437  	}
   438  	mt := proto.MessageType(mname)
   439  	if mt == nil {
   440  		return nil, fmt.Errorf("unknown message type %q", mname)
   441  	}
   442  	return reflect.New(mt.Elem()).Interface().(proto.Message), nil
   443  }