go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/common/proto/msgpackpb/unmarshal.go (about) 1 // Copyright 2022 The LUCI Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package msgpackpb 16 17 import ( 18 "bytes" 19 "io" 20 21 "github.com/vmihailenco/msgpack/v5" 22 "github.com/vmihailenco/msgpack/v5/msgpcode" 23 24 "google.golang.org/protobuf/encoding/protowire" 25 "google.golang.org/protobuf/proto" 26 "google.golang.org/protobuf/reflect/protoreflect" 27 28 "go.chromium.org/luci/common/errors" 29 ) 30 31 func numericMapKey(key int32, kind protoreflect.Kind) (protoreflect.Value, error) { 32 switch kind { 33 case protoreflect.BoolKind: 34 return protoreflect.ValueOfBool(key != 0), nil 35 36 case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: 37 return protoreflect.ValueOfInt32(key), nil 38 39 case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: 40 return protoreflect.ValueOfInt64(int64(key)), nil 41 42 case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: 43 return protoreflect.ValueOfUint32(uint32(key)), nil 44 45 case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: 46 return protoreflect.ValueOfUint64(uint64(key)), nil 47 } 48 49 return protoreflect.Value{}, errors.New("cannot convert numeric map key") 50 } 51 52 // unmarshalScalar will decode a value from the Decoder and return it as a Value, 53 // using arproximate protobuf decoding compatibility rules (i.e. Go numeric casts... 54 // official proto rules state that the casts should be "C++" style, and from my 55 // cursory read of the Golang spec, Go uses the same numeric conversion rules). 56 // 57 // NOTE: I considered the possibility where lua has encoded large int values 58 // with floats. However, inspecting the lua C msgpack library (all versions), it 59 // looks like it will already do the work to avoid using a float where possible. 60 // This means that if we get a float in a field which is supposed to have an 61 // integer type, we can treat it as a hard error. 62 func (o *options) unmarshalScalar(dec *msgpack.Decoder, fd protoreflect.FieldDescriptor) (ret protoreflect.Value, err error) { 63 // DecodeInterfaceLoose will return: 64 // - int8, int16, and int32 are converted to int64, 65 // - uint8, uint16, and uint32 are converted to uint64, 66 // - float32 is converted to float64. 67 // - []byte is converted to string. 68 val, err := dec.DecodeInterfaceLoose() 69 if err != nil { 70 err = errors.Annotate(err, "decoding scalar").Err() 71 return 72 } 73 74 switch fd.Kind() { 75 case protoreflect.BoolKind: 76 switch x := val.(type) { 77 case bool: 78 return protoreflect.ValueOfBool(x), nil 79 case uint64: 80 return protoreflect.ValueOfBool(x != 0), nil 81 case int64: 82 return protoreflect.ValueOfBool(x != 0), nil 83 } 84 85 case protoreflect.EnumKind: 86 switch x := val.(type) { 87 case bool: 88 if x { 89 return protoreflect.ValueOfEnum(1), nil 90 } 91 return protoreflect.ValueOfEnum(0), nil 92 case uint64: 93 return protoreflect.ValueOfEnum(protoreflect.EnumNumber(x)), nil 94 case int64: 95 return protoreflect.ValueOfEnum(protoreflect.EnumNumber(x)), nil 96 } 97 98 case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: 99 switch x := val.(type) { 100 case bool: 101 if x { 102 return protoreflect.ValueOfInt32(1), nil 103 } 104 return protoreflect.ValueOfInt32(0), nil 105 case uint64: 106 return protoreflect.ValueOfInt32(int32(x)), nil 107 case int64: 108 return protoreflect.ValueOfInt32(int32(x)), nil 109 } 110 111 case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: 112 switch x := val.(type) { 113 case bool: 114 if x { 115 return protoreflect.ValueOfInt64(1), nil 116 } 117 return protoreflect.ValueOfInt64(0), nil 118 case uint64: 119 return protoreflect.ValueOfInt64(int64(x)), nil 120 case int64: 121 return protoreflect.ValueOfInt64(x), nil 122 } 123 124 case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: 125 switch x := val.(type) { 126 case bool: 127 if x { 128 return protoreflect.ValueOfUint32(1), nil 129 } 130 return protoreflect.ValueOfUint32(0), nil 131 case uint64: 132 return protoreflect.ValueOfUint32(uint32(x)), nil 133 case int64: 134 return protoreflect.ValueOfUint32(uint32(x)), nil 135 } 136 137 case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: 138 switch x := val.(type) { 139 case bool: 140 if x { 141 return protoreflect.ValueOfUint64(1), nil 142 } 143 return protoreflect.ValueOfUint64(0), nil 144 case uint64: 145 return protoreflect.ValueOfUint64(x), nil 146 case int64: 147 return protoreflect.ValueOfUint64(uint64(x)), nil 148 } 149 150 case protoreflect.FloatKind: 151 switch x := val.(type) { 152 case uint64: 153 // allowed, because lua will encode non-floatlike numbers as integers. 154 return protoreflect.ValueOfFloat32(float32(x)), nil 155 case int64: 156 // allowed, because lua will encode non-floatlike negative numbers as integers. 157 return protoreflect.ValueOfFloat32(float32(x)), nil 158 case float32: 159 return protoreflect.ValueOfFloat32(x), nil 160 case float64: 161 return protoreflect.ValueOfFloat32(float32(x)), nil 162 } 163 164 case protoreflect.DoubleKind: 165 switch x := val.(type) { 166 case uint64: 167 // allowed, because lua will encode non-floatlike numbers as integers. 168 return protoreflect.ValueOfFloat64(float64(x)), nil 169 case int64: 170 // allowed, because lua will encode non-floatlike negative numbers as integers. 171 return protoreflect.ValueOfFloat64(float64(x)), nil 172 case float32: 173 return protoreflect.ValueOfFloat64(float64(x)), nil 174 case float64: 175 return protoreflect.ValueOfFloat64(x), nil 176 } 177 178 case protoreflect.StringKind, protoreflect.BytesKind: 179 var checkIntern bool 180 var internIdx int 181 switch x := val.(type) { 182 case string: 183 return protoreflect.ValueOf(val), nil 184 case uint64: 185 checkIntern = true 186 internIdx = int(x) 187 case int64: 188 checkIntern = true 189 internIdx = int(x) 190 } 191 192 if checkIntern { 193 if internIdx < len(o.internUnmarshalTable) { 194 return protoreflect.ValueOfString(o.internUnmarshalTable[internIdx]), nil 195 } 196 err = errors.Reason("interned string has index out of bounds: %d", internIdx).Err() 197 return 198 } 199 } 200 201 err = errors.Reason("bad type: expected %s, got %T", fd.Kind(), val).Err() 202 return 203 } 204 205 func isMap(dec *msgpack.Decoder) (bool, error) { 206 c, err := dec.PeekCode() 207 if err != nil { 208 return false, err 209 } 210 211 if msgpcode.IsFixedMap(c) || c == msgpcode.Map16 || c == msgpcode.Map32 { 212 return true, nil 213 } 214 return false, nil 215 } 216 217 // Because lua tables are used for both maps and lists, we can't reliably encode 218 // a map as a map, because if it HAPPENS to have numeric indexes which are all 219 // 1..N, cmsgpack will consider this to be a list and encode just the values in 220 // sequence. Fortunately, in this case, the list is guaranteed to be already 221 // sorted (by definition)! 222 func getMapLen(dec *msgpack.Decoder) (n int, nextKey func() int32, err error) { 223 ism, err := isMap(dec) 224 if err != nil { 225 return 226 } 227 228 if ism { 229 n, err = dec.DecodeMapLen() 230 return 231 } 232 233 n, err = dec.DecodeArrayLen() 234 var idx int32 // remember; lua indexes are 1 based, so we ++ and then return 235 nextKey = func() int32 { idx++; return idx } 236 return 237 } 238 239 func getNextMsgTag(dec *msgpack.Decoder, nextKey func() int32) (tag int32, err error) { 240 if nextKey != nil { 241 tag = nextKey() 242 } else { 243 if tag, err = dec.DecodeInt32(); err != nil { 244 return 245 } 246 } 247 return 248 } 249 250 func (o *options) unmarshalMessage(dec *msgpack.Decoder, to protoreflect.Message) error { 251 msgItemLen, nextKey, err := getMapLen(dec) 252 if err != nil { 253 return errors.Annotate(err, "expected message length").Err() 254 } 255 256 d := to.Descriptor() 257 fieldsD := d.Fields() 258 259 var unknownFields map[int32]msgpack.RawMessage 260 261 for i := 0; i < msgItemLen; i++ { 262 tag, err := getNextMsgTag(dec, nextKey) 263 if err != nil { 264 return errors.Annotate(err, "reading message tag").Err() 265 } 266 267 fd := fieldsD.ByNumber(protowire.Number(tag)) 268 if fd == nil { 269 switch o.unknownFieldBehavior { 270 case ignoreUnknownFields: 271 //pass 272 case disallowUnknownFields: 273 return errors.Reason("unknown field tag %d on decoded field %d", tag, i).Err() 274 case preserveUnknownFields: 275 if unknownFields == nil { 276 unknownFields = map[int32]msgpack.RawMessage{} 277 } 278 if unknownFields[tag], err = dec.DecodeRaw(); err != nil { 279 return errors.Reason("unknown field tag %d on decoded field %d: cannot decode msgpack", tag, i).Err() 280 } 281 default: 282 panic("unknown value of o.unknownFieldBehavior") 283 } 284 continue 285 } 286 name := fd.Name() 287 288 // now we check that the encoded thing is the thing we expect to find. 289 if fd.IsList() { 290 // note that if the input array was `sparse` (contained nil values), it MAY 291 // be encoded as a map. 292 ism, err := isMap(dec) 293 if err != nil { 294 return errors.Annotate(err, "%s: expected list or map", name).Err() 295 } 296 297 lst := to.Mutable(fd).List() 298 299 var mapLen int 300 var decodeIdx func() (int, error) 301 var addValue func(i int, v protoreflect.Value) 302 var postProcess func() 303 if ism { 304 if mapLen, err = dec.DecodeMapLen(); err != nil { 305 return errors.Annotate(err, "%s: expected sparse list", name).Err() 306 } 307 308 maxIdx := 0 309 decodeIdx = func() (int, error) { 310 ret, err := dec.DecodeInt() 311 if err != nil { 312 return ret, err 313 } 314 if ret > maxIdx { 315 maxIdx = ret 316 } 317 return ret, err 318 } 319 sparse := make(map[int]protoreflect.Value, mapLen) 320 addValue = func(i int, v protoreflect.Value) { sparse[i] = v } 321 zero := lst.NewElement() 322 postProcess = func() { 323 for i := 0; i <= maxIdx; i++ { 324 if val, ok := sparse[i]; ok { 325 lst.Append(val) 326 } else { 327 lst.Append(zero) 328 } 329 } 330 } 331 } else { 332 if mapLen, err = dec.DecodeArrayLen(); err != nil { 333 return errors.Annotate(err, "%s: expected list", name).Err() 334 } 335 336 addValue = func(_ int, v protoreflect.Value) { lst.Append(v) } 337 decodeIdx = func() (int, error) { return 0, nil } 338 } 339 340 for i := 0; i < mapLen; i++ { 341 idx, err := decodeIdx() 342 if err != nil { 343 return errors.Annotate(err, "%s[%d]: expected int key", name, i).Err() 344 } 345 346 var el protoreflect.Value 347 if fd.Kind() == protoreflect.MessageKind { 348 el = lst.NewElement() 349 if err = o.unmarshalMessage(dec, el.Message()); err != nil { 350 return errors.Annotate(err, "%s[%d]", name, i).Err() 351 } 352 } else { 353 if el, err = o.unmarshalScalar(dec, fd); err != nil { 354 return errors.Annotate(err, "%s[%d]", name, i).Err() 355 } 356 } 357 addValue(idx, el) 358 } 359 if postProcess != nil { 360 postProcess() 361 } 362 continue 363 } 364 365 if fd.IsMap() { 366 mapLen, nextKey, err := getMapLen(dec) 367 if err != nil { 368 return errors.Annotate(err, "%s: expected map", name).Err() 369 } 370 371 valFD := fd.MapValue() 372 373 // ok, we're a map and they're a map, do the decode 374 keyFD := fd.MapKey() 375 mapp := to.Mutable(fd).Map() 376 for i := 0; i < mapLen; i++ { 377 var key protoreflect.Value 378 if nextKey == nil { 379 if key, err = o.unmarshalScalar(dec, keyFD); err != nil { 380 return errors.Annotate(err, "%s[idx:%d]: bad map key", name, i).Err() 381 } 382 } else { 383 if key, err = numericMapKey(nextKey(), keyFD.Kind()); err != nil { 384 return errors.Annotate(err, "%s[idx:%d]: bad map key", name, i).Err() 385 } 386 } 387 388 if valFD.Kind() == protoreflect.MessageKind { 389 if err := o.unmarshalMessage(dec, mapp.Mutable(key.MapKey()).Message()); err != nil { 390 return errors.Annotate(err, "%s[%s]", name, key).Err() 391 } 392 } else { 393 val, err := o.unmarshalScalar(dec, valFD) 394 if err != nil { 395 return errors.Annotate(err, "%s[%s]", name, key).Err() 396 } 397 mapp.Set(key.MapKey(), val) 398 } 399 } 400 continue 401 } 402 403 // singular field 404 if fd.Kind() == protoreflect.MessageKind { 405 if err := o.unmarshalMessage(dec, to.Mutable(fd).Message()); err != nil { 406 return errors.Annotate(err, "%s", name).Err() 407 } 408 } else { 409 val, err := o.unmarshalScalar(dec, fd) 410 if err != nil { 411 return errors.Annotate(err, "%s", name).Err() 412 } 413 to.Set(fd, val) 414 } 415 } 416 417 if len(unknownFields) > 0 { 418 unknownBuf := bytes.Buffer{} 419 unknownEnc := msgpack.GetEncoder() 420 defer msgpack.PutEncoder(unknownEnc) 421 422 unknownEnc.Reset(&unknownBuf) 423 unknownEnc.UseCompactFloats(true) 424 unknownEnc.UseCompactInts(true) 425 if err := unknownEnc.Encode(unknownFields); err != nil { 426 panic(err) 427 } 428 protoEncUnknown, err := proto.Marshal(&UnknownFields{MsgpackpbData: unknownBuf.Bytes()}) 429 if err != nil { 430 panic(err) 431 } 432 to.SetUnknown(protoEncUnknown) 433 } 434 return nil 435 } 436 437 // UnmarshalStream is like Unmarshal but takes an io.Reader instead of accepting 438 // a string. 439 // 440 // If the reader contains multiple msgpackpb messages, this function will stop 441 // exactly at where the next message in the stream begins (i.e. you could call 442 // this in a loop until the reader is exhausted to merge the messages together). 443 func UnmarshalStream(reader io.Reader, to proto.Message, opts ...Option) (err error) { 444 o := &options{} 445 for _, fn := range opts { 446 fn(o) 447 } 448 449 dec := msgpack.GetDecoder() 450 defer msgpack.PutDecoder(dec) 451 452 dec.Reset(reader) 453 454 return o.unmarshalMessage(dec, to.ProtoReflect()) 455 } 456 457 // Unmarshal parses the encoded msgpack into the given proto message. 458 // 459 // This does NOT reset the Message; if it is partially populated, this will 460 // effectively do a proto.Merge on top of it. 461 // 462 // By default, this will output unknown fields in the Message, but this will 463 // only be usable by the corresponding Marshal function in this package. Pass 464 // IgnoreUnknownFields or DisallowUnknownFields to affect this behavior. 465 func Unmarshal(msg msgpack.RawMessage, to proto.Message, opts ...Option) (err error) { 466 return UnmarshalStream(bytes.NewReader(msg), to, opts...) 467 }