github.com/m3db/m3@v1.5.0/src/dbnode/encoding/proto/custom_unmarshal.go (about) 1 // Copyright (c) 2019 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package proto 22 23 import ( 24 "errors" 25 "fmt" 26 "math" 27 "sort" 28 29 "github.com/golang/protobuf/proto" 30 dpb "github.com/golang/protobuf/protoc-gen-go/descriptor" 31 "github.com/jhump/protoreflect/desc" 32 ) 33 34 var ( 35 // Groups in the Protobuf wire format are deprecated, so simplify the code significantly by 36 // not supporting them. 37 errGroupsAreNotSupported = errors.New("use of groups in proto wire format is not supported") 38 zeroValue unmarshalValue 39 ) 40 41 type customFieldUnmarshaller interface { 42 sortedCustomFieldValues() sortedCustomFieldValues 43 sortedNonCustomFieldValues() sortedMarshalledFields 44 numNonCustomValues() int 45 resetAndUnmarshal(schema *desc.MessageDescriptor, buf []byte) error 46 } 47 48 type customUnmarshallerOptions struct { 49 skipUnknownFields bool 50 } 51 52 type customUnmarshaller struct { 53 schema *desc.MessageDescriptor 54 decodeBuf *buffer 55 customValues sortedCustomFieldValues 56 57 nonCustomValues sortedMarshalledFields 58 numNonCustom int 59 60 opts customUnmarshallerOptions 61 } 62 63 func newCustomFieldUnmarshaller(opts customUnmarshallerOptions) customFieldUnmarshaller { 64 return &customUnmarshaller{ 65 decodeBuf: newCodedBuffer(nil), 66 opts: opts, 67 } 68 } 69 70 func (u *customUnmarshaller) sortedCustomFieldValues() sortedCustomFieldValues { 71 return u.customValues 72 } 73 74 func (u *customUnmarshaller) numNonCustomValues() int { 75 return u.numNonCustom 76 } 77 78 func (u *customUnmarshaller) sortedNonCustomFieldValues() sortedMarshalledFields { 79 return u.nonCustomValues 80 } 81 82 func (u *customUnmarshaller) unmarshal() error { 83 u.resetCustomAndNonCustomValues() 84 85 var ( 86 areCustomValuesSorted = true 87 areNonCustomValuesSorted = true 88 ) 89 for !u.decodeBuf.eof() { 90 tagAndWireTypeStartOffset := u.decodeBuf.index 91 fieldNum, wireType, err := u.decodeBuf.decodeTagAndWireType() 92 if err != nil { 93 return err 94 } 95 96 fd := u.schema.FindFieldByNumber(fieldNum) 97 if fd == nil { 98 if !u.opts.skipUnknownFields { 99 return fmt.Errorf("encountered unknown field with field number: %d", fieldNum) 100 } 101 102 if _, err := u.skip(wireType); err != nil { 103 return err 104 } 105 continue 106 } 107 108 if !u.isCustomField(fd) { 109 _, err = u.skip(wireType) 110 if err != nil { 111 return err 112 } 113 114 var ( 115 startIdx = tagAndWireTypeStartOffset 116 endIdx = u.decodeBuf.index 117 marshalled = u.decodeBuf.buf[startIdx:endIdx] 118 ) 119 // A marshalled Protobuf message consists of a stream of <fieldNumber, wireType, value> 120 // tuples, all of which are optional, with no additional header or footer information. 121 // This means that each tuple within the stream can be thought of as its own complete 122 // marshalled message and as a result we can build up the []marshalledField one field at 123 // a time. 124 updatedExisting := false 125 if fd.IsRepeated() { 126 // If the fd is a repeated type and not using `packed` encoding then their could be multiple 127 // entries in the stream with the same field number so their marshalled bytes needs to be all 128 // concatenated together. 129 // 130 // NB(rartoul): This will have an adverse impact on the compression of map types because the 131 // key/val pairs can be encoded in any order. This means that its possible for two equivalent 132 // maps to have different byte streams which will force the encoder to re-encode the field into 133 // the stream even though it hasn't changed. This naive solution should be good enough for now, 134 // but if it proves problematic in the future the issue could be resolved by accumulating the 135 // marshalled tuples into a slice and then sorting by field number to produce a deterministic 136 // result such that equivalent maps always result in equivalent marshalled bytes slices. 137 for i, val := range u.nonCustomValues { 138 if fieldNum == val.fieldNum { 139 u.nonCustomValues[i].marshalled = append(u.nonCustomValues[i].marshalled, marshalled...) 140 updatedExisting = true 141 break 142 } 143 } 144 } 145 if !updatedExisting { 146 u.nonCustomValues = append(u.nonCustomValues, marshalledField{ 147 fieldNum: fieldNum, 148 marshalled: marshalled, 149 }) 150 } 151 152 if areNonCustomValuesSorted && len(u.nonCustomValues) > 1 { 153 // Check if the slice is sorted as it's built to avoid resorting 154 // unnecessarily at the end. 155 lastFieldNum := u.nonCustomValues[len(u.nonCustomValues)-1].fieldNum 156 if fieldNum < lastFieldNum { 157 areNonCustomValuesSorted = false 158 } 159 } 160 161 u.numNonCustom++ 162 continue 163 } 164 165 value, err := u.unmarshalCustomField(fd, wireType) 166 if err != nil { 167 return err 168 } 169 170 if areCustomValuesSorted && len(u.customValues) > 1 { 171 // Check if the slice is sorted as it's built to avoid resorting 172 // unnecessarily at the end. 173 lastFieldNum := u.customValues[len(u.customValues)-1].fieldNumber 174 if fieldNum < lastFieldNum { 175 areCustomValuesSorted = false 176 } 177 } 178 179 u.customValues = append(u.customValues, value) 180 } 181 182 u.decodeBuf.reset(u.decodeBuf.buf) 183 184 // Avoid resorting if possible. 185 if !areCustomValuesSorted { 186 sort.Sort(u.customValues) 187 } 188 if !areNonCustomValuesSorted { 189 sort.Sort(u.nonCustomValues) 190 } 191 192 return nil 193 } 194 195 // isCustomField checks whether the encoder would have custom encoded this field or left 196 // it up to the `jhump/dynamic` package to handle the encoding. This is important because 197 // it allows us to use the efficient unmarshal path only for fields that the encoder can 198 // actually take advantage of. 199 func (u *customUnmarshaller) isCustomField(fd *desc.FieldDescriptor) bool { 200 if fd.IsRepeated() || fd.IsMap() { 201 // Map should always be repeated but include the guard just in case. 202 return false 203 } 204 205 if fd.GetMessageType() != nil { 206 // Skip nested messages. 207 return false 208 } 209 210 return true 211 } 212 213 // skip will skip over the next value in the encoded stream (given that the tag and 214 // wiretype have already been decoded). 215 func (u *customUnmarshaller) skip(wireType int8) (int, error) { 216 switch wireType { 217 case proto.WireFixed32: 218 bytesSkipped := 4 219 u.decodeBuf.index += bytesSkipped 220 return bytesSkipped, nil 221 222 case proto.WireFixed64: 223 bytesSkipped := 8 224 u.decodeBuf.index += bytesSkipped 225 return bytesSkipped, nil 226 227 case proto.WireVarint: 228 var ( 229 bytesSkipped = 0 230 offsetBeforeDecodeVarInt = u.decodeBuf.index 231 ) 232 _, err := u.decodeBuf.decodeVarint() 233 if err != nil { 234 return 0, err 235 } 236 bytesSkipped += u.decodeBuf.index - offsetBeforeDecodeVarInt 237 return bytesSkipped, nil 238 239 case proto.WireBytes: 240 var ( 241 bytesSkipped = 0 242 offsetBeforeDecodeRawBytes = u.decodeBuf.index 243 ) 244 // Bytes aren't copied because they're just being skipped over so 245 // copying would be wasteful. 246 _, err := u.decodeBuf.decodeRawBytes(false) 247 if err != nil { 248 return 0, err 249 } 250 bytesSkipped += u.decodeBuf.index - offsetBeforeDecodeRawBytes 251 return bytesSkipped, nil 252 253 case proto.WireStartGroup: 254 return 0, errGroupsAreNotSupported 255 256 case proto.WireEndGroup: 257 return 0, errGroupsAreNotSupported 258 259 default: 260 return 0, proto.ErrInternalBadWireType 261 } 262 } 263 264 func (u *customUnmarshaller) unmarshalCustomField(fd *desc.FieldDescriptor, wireType int8) (unmarshalValue, error) { 265 switch wireType { 266 case proto.WireFixed32: 267 num, err := u.decodeBuf.decodeFixed32() 268 if err != nil { 269 return zeroValue, err 270 } 271 return unmarshalSimpleField(fd, num) 272 273 case proto.WireFixed64: 274 num, err := u.decodeBuf.decodeFixed64() 275 if err != nil { 276 return zeroValue, err 277 } 278 return unmarshalSimpleField(fd, num) 279 280 case proto.WireVarint: 281 num, err := u.decodeBuf.decodeVarint() 282 if err != nil { 283 return zeroValue, err 284 } 285 return unmarshalSimpleField(fd, num) 286 287 case proto.WireBytes: 288 if t := fd.GetType(); t != dpb.FieldDescriptorProto_TYPE_BYTES && 289 t != dpb.FieldDescriptorProto_TYPE_STRING { 290 // This should never happen since it means the skipping logic is not working 291 // correctly or the message is malformed since proto.WireBytes should only be 292 // used for fields of type bytes, string, group, or message. Groups/messages 293 // should be handled by the skipping logic (for now). 294 return zeroValue, fmt.Errorf( 295 "tried to unmarshal field with wire type: bytes and proto field type: %s", 296 fd.GetType().String()) 297 } 298 299 // Don't bother copying the bytes now because the encoder has exclusive ownership 300 // of them until the call to Encode() completes and they will get "copied" anyways 301 // once they're written into the OStream. 302 raw, err := u.decodeBuf.decodeRawBytes(false) 303 if err != nil { 304 return zeroValue, err 305 } 306 307 val := unmarshalValue{fieldNumber: fd.GetNumber(), bytes: raw} 308 return val, nil 309 310 case proto.WireStartGroup: 311 return zeroValue, errGroupsAreNotSupported 312 313 default: 314 return zeroValue, proto.ErrInternalBadWireType 315 } 316 } 317 318 func unmarshalSimpleField(fd *desc.FieldDescriptor, v uint64) (unmarshalValue, error) { 319 fieldNum := fd.GetNumber() 320 val := unmarshalValue{fieldNumber: fieldNum, v: v} 321 switch fd.GetType() { 322 case dpb.FieldDescriptorProto_TYPE_BOOL, 323 dpb.FieldDescriptorProto_TYPE_UINT64, 324 dpb.FieldDescriptorProto_TYPE_FIXED64, 325 dpb.FieldDescriptorProto_TYPE_INT64, 326 dpb.FieldDescriptorProto_TYPE_SFIXED64, 327 dpb.FieldDescriptorProto_TYPE_DOUBLE: 328 return val, nil 329 330 case dpb.FieldDescriptorProto_TYPE_UINT32, 331 dpb.FieldDescriptorProto_TYPE_FIXED32: 332 if v > math.MaxUint32 { 333 return zeroValue, fmt.Errorf("%d (field num %d) overflows uint32", v, fieldNum) 334 } 335 return val, nil 336 337 case dpb.FieldDescriptorProto_TYPE_INT32, 338 dpb.FieldDescriptorProto_TYPE_ENUM: 339 s := int64(v) 340 if s > math.MaxInt32 { 341 return zeroValue, fmt.Errorf("%d (field num %d) overflows int32", v, fieldNum) 342 } 343 if s < math.MinInt32 { 344 return zeroValue, fmt.Errorf("%d (field num %d) underflows int32", v, fieldNum) 345 } 346 return val, nil 347 348 case dpb.FieldDescriptorProto_TYPE_SFIXED32: 349 if v > math.MaxUint32 { 350 return zeroValue, fmt.Errorf("%d (field num %d) overflows int32", v, fieldNum) 351 } 352 return val, nil 353 354 case dpb.FieldDescriptorProto_TYPE_SINT32: 355 if v > math.MaxUint32 { 356 return zeroValue, fmt.Errorf("%d (field num %d) overflows int32", v, fieldNum) 357 } 358 val.v = uint64(decodeZigZag32(v)) 359 return val, nil 360 361 case dpb.FieldDescriptorProto_TYPE_SINT64: 362 val.v = uint64(decodeZigZag64(v)) 363 return val, nil 364 365 case dpb.FieldDescriptorProto_TYPE_FLOAT: 366 if v > math.MaxUint32 { 367 return zeroValue, fmt.Errorf("%d (field num %d) overflows uint32", v, fieldNum) 368 } 369 float32Val := math.Float32frombits(uint32(v)) 370 float64Bits := math.Float64bits(float64(float32Val)) 371 val.v = float64Bits 372 return val, nil 373 374 default: 375 // bytes, string, message, and group cannot be represented as a simple numeric value. 376 return zeroValue, fmt.Errorf("bad input; field %s requires length-delimited wire type", fd.GetFullyQualifiedName()) 377 } 378 } 379 380 func (u *customUnmarshaller) resetAndUnmarshal(schema *desc.MessageDescriptor, buf []byte) error { 381 u.schema = schema 382 u.numNonCustom = 0 383 u.resetCustomAndNonCustomValues() 384 u.decodeBuf.reset(buf) 385 return u.unmarshal() 386 } 387 388 func (u *customUnmarshaller) resetCustomAndNonCustomValues() { 389 for i := range u.customValues { 390 u.customValues[i] = unmarshalValue{} 391 } 392 u.customValues = u.customValues[:0] 393 394 for i := range u.nonCustomValues { 395 u.nonCustomValues[i] = marshalledField{} 396 } 397 u.nonCustomValues = u.nonCustomValues[:0] 398 } 399 400 type sortedCustomFieldValues []unmarshalValue 401 402 func (s sortedCustomFieldValues) Len() int { 403 return len(s) 404 } 405 406 func (s sortedCustomFieldValues) Less(i, j int) bool { 407 return s[i].fieldNumber < s[j].fieldNumber 408 } 409 410 func (s sortedCustomFieldValues) Swap(i, j int) { 411 s[i], s[j] = s[j], s[i] 412 } 413 414 type unmarshalValue struct { 415 fieldNumber int32 416 v uint64 417 bytes []byte 418 } 419 420 func (v *unmarshalValue) asBool() bool { 421 return v.v != 0 422 } 423 424 func (v *unmarshalValue) asUint64() uint64 { 425 return v.v 426 } 427 428 func (v *unmarshalValue) asInt64() int64 { 429 return int64(v.v) 430 } 431 432 func (v *unmarshalValue) asFloat64() float64 { 433 return math.Float64frombits(v.v) 434 } 435 436 func (v *unmarshalValue) asBytes() []byte { 437 return v.bytes 438 }