go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/common/proto/msgpackpb/unmarshal.go (about)

     1  // Copyright 2022 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package msgpackpb
    16  
    17  import (
    18  	"bytes"
    19  	"io"
    20  
    21  	"github.com/vmihailenco/msgpack/v5"
    22  	"github.com/vmihailenco/msgpack/v5/msgpcode"
    23  
    24  	"google.golang.org/protobuf/encoding/protowire"
    25  	"google.golang.org/protobuf/proto"
    26  	"google.golang.org/protobuf/reflect/protoreflect"
    27  
    28  	"go.chromium.org/luci/common/errors"
    29  )
    30  
    31  func numericMapKey(key int32, kind protoreflect.Kind) (protoreflect.Value, error) {
    32  	switch kind {
    33  	case protoreflect.BoolKind:
    34  		return protoreflect.ValueOfBool(key != 0), nil
    35  
    36  	case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
    37  		return protoreflect.ValueOfInt32(key), nil
    38  
    39  	case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
    40  		return protoreflect.ValueOfInt64(int64(key)), nil
    41  
    42  	case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
    43  		return protoreflect.ValueOfUint32(uint32(key)), nil
    44  
    45  	case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
    46  		return protoreflect.ValueOfUint64(uint64(key)), nil
    47  	}
    48  
    49  	return protoreflect.Value{}, errors.New("cannot convert numeric map key")
    50  }
    51  
    52  // unmarshalScalar will decode a value from the Decoder and return it as a Value,
    53  // using arproximate protobuf decoding compatibility rules (i.e. Go numeric casts...
    54  // official proto rules state that the casts should be "C++" style, and from my
    55  // cursory read of the Golang spec, Go uses the same numeric conversion rules).
    56  //
    57  // NOTE: I considered the possibility where lua has encoded large int values
    58  // with floats. However, inspecting the lua C msgpack library (all versions), it
    59  // looks like it will already do the work to avoid using a float where possible.
    60  // This means that if we get a float in a field which is supposed to have an
    61  // integer type, we can treat it as a hard error.
    62  func (o *options) unmarshalScalar(dec *msgpack.Decoder, fd protoreflect.FieldDescriptor) (ret protoreflect.Value, err error) {
    63  	// DecodeInterfaceLoose will return:
    64  	//   - int8, int16, and int32 are converted to int64,
    65  	//   - uint8, uint16, and uint32 are converted to uint64,
    66  	//   - float32 is converted to float64.
    67  	//   - []byte is converted to string.
    68  	val, err := dec.DecodeInterfaceLoose()
    69  	if err != nil {
    70  		err = errors.Annotate(err, "decoding scalar").Err()
    71  		return
    72  	}
    73  
    74  	switch fd.Kind() {
    75  	case protoreflect.BoolKind:
    76  		switch x := val.(type) {
    77  		case bool:
    78  			return protoreflect.ValueOfBool(x), nil
    79  		case uint64:
    80  			return protoreflect.ValueOfBool(x != 0), nil
    81  		case int64:
    82  			return protoreflect.ValueOfBool(x != 0), nil
    83  		}
    84  
    85  	case protoreflect.EnumKind:
    86  		switch x := val.(type) {
    87  		case bool:
    88  			if x {
    89  				return protoreflect.ValueOfEnum(1), nil
    90  			}
    91  			return protoreflect.ValueOfEnum(0), nil
    92  		case uint64:
    93  			return protoreflect.ValueOfEnum(protoreflect.EnumNumber(x)), nil
    94  		case int64:
    95  			return protoreflect.ValueOfEnum(protoreflect.EnumNumber(x)), nil
    96  		}
    97  
    98  	case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
    99  		switch x := val.(type) {
   100  		case bool:
   101  			if x {
   102  				return protoreflect.ValueOfInt32(1), nil
   103  			}
   104  			return protoreflect.ValueOfInt32(0), nil
   105  		case uint64:
   106  			return protoreflect.ValueOfInt32(int32(x)), nil
   107  		case int64:
   108  			return protoreflect.ValueOfInt32(int32(x)), nil
   109  		}
   110  
   111  	case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
   112  		switch x := val.(type) {
   113  		case bool:
   114  			if x {
   115  				return protoreflect.ValueOfInt64(1), nil
   116  			}
   117  			return protoreflect.ValueOfInt64(0), nil
   118  		case uint64:
   119  			return protoreflect.ValueOfInt64(int64(x)), nil
   120  		case int64:
   121  			return protoreflect.ValueOfInt64(x), nil
   122  		}
   123  
   124  	case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
   125  		switch x := val.(type) {
   126  		case bool:
   127  			if x {
   128  				return protoreflect.ValueOfUint32(1), nil
   129  			}
   130  			return protoreflect.ValueOfUint32(0), nil
   131  		case uint64:
   132  			return protoreflect.ValueOfUint32(uint32(x)), nil
   133  		case int64:
   134  			return protoreflect.ValueOfUint32(uint32(x)), nil
   135  		}
   136  
   137  	case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
   138  		switch x := val.(type) {
   139  		case bool:
   140  			if x {
   141  				return protoreflect.ValueOfUint64(1), nil
   142  			}
   143  			return protoreflect.ValueOfUint64(0), nil
   144  		case uint64:
   145  			return protoreflect.ValueOfUint64(x), nil
   146  		case int64:
   147  			return protoreflect.ValueOfUint64(uint64(x)), nil
   148  		}
   149  
   150  	case protoreflect.FloatKind:
   151  		switch x := val.(type) {
   152  		case uint64:
   153  			// allowed, because lua will encode non-floatlike numbers as integers.
   154  			return protoreflect.ValueOfFloat32(float32(x)), nil
   155  		case int64:
   156  			// allowed, because lua will encode non-floatlike negative numbers as integers.
   157  			return protoreflect.ValueOfFloat32(float32(x)), nil
   158  		case float32:
   159  			return protoreflect.ValueOfFloat32(x), nil
   160  		case float64:
   161  			return protoreflect.ValueOfFloat32(float32(x)), nil
   162  		}
   163  
   164  	case protoreflect.DoubleKind:
   165  		switch x := val.(type) {
   166  		case uint64:
   167  			// allowed, because lua will encode non-floatlike numbers as integers.
   168  			return protoreflect.ValueOfFloat64(float64(x)), nil
   169  		case int64:
   170  			// allowed, because lua will encode non-floatlike negative numbers as integers.
   171  			return protoreflect.ValueOfFloat64(float64(x)), nil
   172  		case float32:
   173  			return protoreflect.ValueOfFloat64(float64(x)), nil
   174  		case float64:
   175  			return protoreflect.ValueOfFloat64(x), nil
   176  		}
   177  
   178  	case protoreflect.StringKind, protoreflect.BytesKind:
   179  		var checkIntern bool
   180  		var internIdx int
   181  		switch x := val.(type) {
   182  		case string:
   183  			return protoreflect.ValueOf(val), nil
   184  		case uint64:
   185  			checkIntern = true
   186  			internIdx = int(x)
   187  		case int64:
   188  			checkIntern = true
   189  			internIdx = int(x)
   190  		}
   191  
   192  		if checkIntern {
   193  			if internIdx < len(o.internUnmarshalTable) {
   194  				return protoreflect.ValueOfString(o.internUnmarshalTable[internIdx]), nil
   195  			}
   196  			err = errors.Reason("interned string has index out of bounds: %d", internIdx).Err()
   197  			return
   198  		}
   199  	}
   200  
   201  	err = errors.Reason("bad type: expected %s, got %T", fd.Kind(), val).Err()
   202  	return
   203  }
   204  
   205  func isMap(dec *msgpack.Decoder) (bool, error) {
   206  	c, err := dec.PeekCode()
   207  	if err != nil {
   208  		return false, err
   209  	}
   210  
   211  	if msgpcode.IsFixedMap(c) || c == msgpcode.Map16 || c == msgpcode.Map32 {
   212  		return true, nil
   213  	}
   214  	return false, nil
   215  }
   216  
   217  // Because lua tables are used for both maps and lists, we can't reliably encode
   218  // a map as a map, because if it HAPPENS to have numeric indexes which are all
   219  // 1..N, cmsgpack will consider this to be a list and encode just the values in
   220  // sequence. Fortunately, in this case, the list is guaranteed to be already
   221  // sorted (by definition)!
   222  func getMapLen(dec *msgpack.Decoder) (n int, nextKey func() int32, err error) {
   223  	ism, err := isMap(dec)
   224  	if err != nil {
   225  		return
   226  	}
   227  
   228  	if ism {
   229  		n, err = dec.DecodeMapLen()
   230  		return
   231  	}
   232  
   233  	n, err = dec.DecodeArrayLen()
   234  	var idx int32 // remember; lua indexes are 1 based, so we ++ and then return
   235  	nextKey = func() int32 { idx++; return idx }
   236  	return
   237  }
   238  
   239  func getNextMsgTag(dec *msgpack.Decoder, nextKey func() int32) (tag int32, err error) {
   240  	if nextKey != nil {
   241  		tag = nextKey()
   242  	} else {
   243  		if tag, err = dec.DecodeInt32(); err != nil {
   244  			return
   245  		}
   246  	}
   247  	return
   248  }
   249  
   250  func (o *options) unmarshalMessage(dec *msgpack.Decoder, to protoreflect.Message) error {
   251  	msgItemLen, nextKey, err := getMapLen(dec)
   252  	if err != nil {
   253  		return errors.Annotate(err, "expected message length").Err()
   254  	}
   255  
   256  	d := to.Descriptor()
   257  	fieldsD := d.Fields()
   258  
   259  	var unknownFields map[int32]msgpack.RawMessage
   260  
   261  	for i := 0; i < msgItemLen; i++ {
   262  		tag, err := getNextMsgTag(dec, nextKey)
   263  		if err != nil {
   264  			return errors.Annotate(err, "reading message tag").Err()
   265  		}
   266  
   267  		fd := fieldsD.ByNumber(protowire.Number(tag))
   268  		if fd == nil {
   269  			switch o.unknownFieldBehavior {
   270  			case ignoreUnknownFields:
   271  				//pass
   272  			case disallowUnknownFields:
   273  				return errors.Reason("unknown field tag %d on decoded field %d", tag, i).Err()
   274  			case preserveUnknownFields:
   275  				if unknownFields == nil {
   276  					unknownFields = map[int32]msgpack.RawMessage{}
   277  				}
   278  				if unknownFields[tag], err = dec.DecodeRaw(); err != nil {
   279  					return errors.Reason("unknown field tag %d on decoded field %d: cannot decode msgpack", tag, i).Err()
   280  				}
   281  			default:
   282  				panic("unknown value of o.unknownFieldBehavior")
   283  			}
   284  			continue
   285  		}
   286  		name := fd.Name()
   287  
   288  		// now we check that the encoded thing is the thing we expect to find.
   289  		if fd.IsList() {
   290  			// note that if the input array was `sparse` (contained nil values), it MAY
   291  			// be encoded as a map.
   292  			ism, err := isMap(dec)
   293  			if err != nil {
   294  				return errors.Annotate(err, "%s: expected list or map", name).Err()
   295  			}
   296  
   297  			lst := to.Mutable(fd).List()
   298  
   299  			var mapLen int
   300  			var decodeIdx func() (int, error)
   301  			var addValue func(i int, v protoreflect.Value)
   302  			var postProcess func()
   303  			if ism {
   304  				if mapLen, err = dec.DecodeMapLen(); err != nil {
   305  					return errors.Annotate(err, "%s: expected sparse list", name).Err()
   306  				}
   307  
   308  				maxIdx := 0
   309  				decodeIdx = func() (int, error) {
   310  					ret, err := dec.DecodeInt()
   311  					if err != nil {
   312  						return ret, err
   313  					}
   314  					if ret > maxIdx {
   315  						maxIdx = ret
   316  					}
   317  					return ret, err
   318  				}
   319  				sparse := make(map[int]protoreflect.Value, mapLen)
   320  				addValue = func(i int, v protoreflect.Value) { sparse[i] = v }
   321  				zero := lst.NewElement()
   322  				postProcess = func() {
   323  					for i := 0; i <= maxIdx; i++ {
   324  						if val, ok := sparse[i]; ok {
   325  							lst.Append(val)
   326  						} else {
   327  							lst.Append(zero)
   328  						}
   329  					}
   330  				}
   331  			} else {
   332  				if mapLen, err = dec.DecodeArrayLen(); err != nil {
   333  					return errors.Annotate(err, "%s: expected list", name).Err()
   334  				}
   335  
   336  				addValue = func(_ int, v protoreflect.Value) { lst.Append(v) }
   337  				decodeIdx = func() (int, error) { return 0, nil }
   338  			}
   339  
   340  			for i := 0; i < mapLen; i++ {
   341  				idx, err := decodeIdx()
   342  				if err != nil {
   343  					return errors.Annotate(err, "%s[%d]: expected int key", name, i).Err()
   344  				}
   345  
   346  				var el protoreflect.Value
   347  				if fd.Kind() == protoreflect.MessageKind {
   348  					el = lst.NewElement()
   349  					if err = o.unmarshalMessage(dec, el.Message()); err != nil {
   350  						return errors.Annotate(err, "%s[%d]", name, i).Err()
   351  					}
   352  				} else {
   353  					if el, err = o.unmarshalScalar(dec, fd); err != nil {
   354  						return errors.Annotate(err, "%s[%d]", name, i).Err()
   355  					}
   356  				}
   357  				addValue(idx, el)
   358  			}
   359  			if postProcess != nil {
   360  				postProcess()
   361  			}
   362  			continue
   363  		}
   364  
   365  		if fd.IsMap() {
   366  			mapLen, nextKey, err := getMapLen(dec)
   367  			if err != nil {
   368  				return errors.Annotate(err, "%s: expected map", name).Err()
   369  			}
   370  
   371  			valFD := fd.MapValue()
   372  
   373  			// ok, we're a map and they're a map, do the decode
   374  			keyFD := fd.MapKey()
   375  			mapp := to.Mutable(fd).Map()
   376  			for i := 0; i < mapLen; i++ {
   377  				var key protoreflect.Value
   378  				if nextKey == nil {
   379  					if key, err = o.unmarshalScalar(dec, keyFD); err != nil {
   380  						return errors.Annotate(err, "%s[idx:%d]: bad map key", name, i).Err()
   381  					}
   382  				} else {
   383  					if key, err = numericMapKey(nextKey(), keyFD.Kind()); err != nil {
   384  						return errors.Annotate(err, "%s[idx:%d]: bad map key", name, i).Err()
   385  					}
   386  				}
   387  
   388  				if valFD.Kind() == protoreflect.MessageKind {
   389  					if err := o.unmarshalMessage(dec, mapp.Mutable(key.MapKey()).Message()); err != nil {
   390  						return errors.Annotate(err, "%s[%s]", name, key).Err()
   391  					}
   392  				} else {
   393  					val, err := o.unmarshalScalar(dec, valFD)
   394  					if err != nil {
   395  						return errors.Annotate(err, "%s[%s]", name, key).Err()
   396  					}
   397  					mapp.Set(key.MapKey(), val)
   398  				}
   399  			}
   400  			continue
   401  		}
   402  
   403  		// singular field
   404  		if fd.Kind() == protoreflect.MessageKind {
   405  			if err := o.unmarshalMessage(dec, to.Mutable(fd).Message()); err != nil {
   406  				return errors.Annotate(err, "%s", name).Err()
   407  			}
   408  		} else {
   409  			val, err := o.unmarshalScalar(dec, fd)
   410  			if err != nil {
   411  				return errors.Annotate(err, "%s", name).Err()
   412  			}
   413  			to.Set(fd, val)
   414  		}
   415  	}
   416  
   417  	if len(unknownFields) > 0 {
   418  		unknownBuf := bytes.Buffer{}
   419  		unknownEnc := msgpack.GetEncoder()
   420  		defer msgpack.PutEncoder(unknownEnc)
   421  
   422  		unknownEnc.Reset(&unknownBuf)
   423  		unknownEnc.UseCompactFloats(true)
   424  		unknownEnc.UseCompactInts(true)
   425  		if err := unknownEnc.Encode(unknownFields); err != nil {
   426  			panic(err)
   427  		}
   428  		protoEncUnknown, err := proto.Marshal(&UnknownFields{MsgpackpbData: unknownBuf.Bytes()})
   429  		if err != nil {
   430  			panic(err)
   431  		}
   432  		to.SetUnknown(protoEncUnknown)
   433  	}
   434  	return nil
   435  }
   436  
   437  // UnmarshalStream is like Unmarshal but takes an io.Reader instead of accepting
   438  // a string.
   439  //
   440  // If the reader contains multiple msgpackpb messages, this function will stop
   441  // exactly at where the next message in the stream begins (i.e. you could call
   442  // this in a loop until the reader is exhausted to merge the messages together).
   443  func UnmarshalStream(reader io.Reader, to proto.Message, opts ...Option) (err error) {
   444  	o := &options{}
   445  	for _, fn := range opts {
   446  		fn(o)
   447  	}
   448  
   449  	dec := msgpack.GetDecoder()
   450  	defer msgpack.PutDecoder(dec)
   451  
   452  	dec.Reset(reader)
   453  
   454  	return o.unmarshalMessage(dec, to.ProtoReflect())
   455  }
   456  
   457  // Unmarshal parses the encoded msgpack into the given proto message.
   458  //
   459  // This does NOT reset the Message; if it is partially populated, this will
   460  // effectively do a proto.Merge on top of it.
   461  //
   462  // By default, this will output unknown fields in the Message, but this will
   463  // only be usable by the corresponding Marshal function in this package. Pass
   464  // IgnoreUnknownFields or DisallowUnknownFields to affect this behavior.
   465  func Unmarshal(msg msgpack.RawMessage, to proto.Message, opts ...Option) (err error) {
   466  	return UnmarshalStream(bytes.NewReader(msg), to, opts...)
   467  }