github.com/hoveychen/protoreflect@v1.4.7-0.20221103114119-0b4b3385ec76/codec/decode.go (about) 1 package codec 2 3 import ( 4 "errors" 5 "fmt" 6 "io" 7 "math" 8 9 "github.com/golang/protobuf/proto" 10 "github.com/golang/protobuf/protoc-gen-go/descriptor" 11 ) 12 13 // ErrOverflow is returned when an integer is too large to be represented. 14 var ErrOverflow = errors.New("proto: integer overflow") 15 16 // ErrBadWireType is returned when decoding a wire-type from a buffer that 17 // is not valid. 18 var ErrBadWireType = errors.New("proto: bad wiretype") 19 20 var varintTypes = map[descriptor.FieldDescriptorProto_Type]bool{} 21 var fixed32Types = map[descriptor.FieldDescriptorProto_Type]bool{} 22 var fixed64Types = map[descriptor.FieldDescriptorProto_Type]bool{} 23 24 func init() { 25 varintTypes[descriptor.FieldDescriptorProto_TYPE_BOOL] = true 26 varintTypes[descriptor.FieldDescriptorProto_TYPE_INT32] = true 27 varintTypes[descriptor.FieldDescriptorProto_TYPE_INT64] = true 28 varintTypes[descriptor.FieldDescriptorProto_TYPE_UINT32] = true 29 varintTypes[descriptor.FieldDescriptorProto_TYPE_UINT64] = true 30 varintTypes[descriptor.FieldDescriptorProto_TYPE_SINT32] = true 31 varintTypes[descriptor.FieldDescriptorProto_TYPE_SINT64] = true 32 varintTypes[descriptor.FieldDescriptorProto_TYPE_ENUM] = true 33 34 fixed32Types[descriptor.FieldDescriptorProto_TYPE_FIXED32] = true 35 fixed32Types[descriptor.FieldDescriptorProto_TYPE_SFIXED32] = true 36 fixed32Types[descriptor.FieldDescriptorProto_TYPE_FLOAT] = true 37 38 fixed64Types[descriptor.FieldDescriptorProto_TYPE_FIXED64] = true 39 fixed64Types[descriptor.FieldDescriptorProto_TYPE_SFIXED64] = true 40 fixed64Types[descriptor.FieldDescriptorProto_TYPE_DOUBLE] = true 41 } 42 43 func (cb *Buffer) decodeVarintSlow() (x uint64, err error) { 44 i := cb.index 45 l := len(cb.buf) 46 47 for shift := uint(0); shift < 64; shift += 7 { 48 if i >= l { 49 err = io.ErrUnexpectedEOF 50 return 51 } 52 b := cb.buf[i] 53 i++ 54 x |= (uint64(b) & 0x7F) << shift 55 if b < 0x80 { 56 cb.index = i 57 return 58 } 59 } 60 61 // The number is too large to represent in a 64-bit value. 62 err = ErrOverflow 63 return 64 } 65 66 // DecodeVarint reads a varint-encoded integer from the Buffer. 67 // This is the format for the 68 // int32, int64, uint32, uint64, bool, and enum 69 // protocol buffer types. 70 func (cb *Buffer) DecodeVarint() (uint64, error) { 71 i := cb.index 72 buf := cb.buf 73 74 if i >= len(buf) { 75 return 0, io.ErrUnexpectedEOF 76 } else if buf[i] < 0x80 { 77 cb.index++ 78 return uint64(buf[i]), nil 79 } else if len(buf)-i < 10 { 80 return cb.decodeVarintSlow() 81 } 82 83 var b uint64 84 // we already checked the first byte 85 x := uint64(buf[i]) - 0x80 86 i++ 87 88 b = uint64(buf[i]) 89 i++ 90 x += b << 7 91 if b&0x80 == 0 { 92 goto done 93 } 94 x -= 0x80 << 7 95 96 b = uint64(buf[i]) 97 i++ 98 x += b << 14 99 if b&0x80 == 0 { 100 goto done 101 } 102 x -= 0x80 << 14 103 104 b = uint64(buf[i]) 105 i++ 106 x += b << 21 107 if b&0x80 == 0 { 108 goto done 109 } 110 x -= 0x80 << 21 111 112 b = uint64(buf[i]) 113 i++ 114 x += b << 28 115 if b&0x80 == 0 { 116 goto done 117 } 118 x -= 0x80 << 28 119 120 b = uint64(buf[i]) 121 i++ 122 x += b << 35 123 if b&0x80 == 0 { 124 goto done 125 } 126 x -= 0x80 << 35 127 128 b = uint64(buf[i]) 129 i++ 130 x += b << 42 131 if b&0x80 == 0 { 132 goto done 133 } 134 x -= 0x80 << 42 135 136 b = uint64(buf[i]) 137 i++ 138 x += b << 49 139 if b&0x80 == 0 { 140 goto done 141 } 142 x -= 0x80 << 49 143 144 b = uint64(buf[i]) 145 i++ 146 x += b << 56 147 if b&0x80 == 0 { 148 goto done 149 } 150 x -= 0x80 << 56 151 152 b = uint64(buf[i]) 153 i++ 154 x += b << 63 155 if b&0x80 == 0 { 156 goto done 157 } 158 // x -= 0x80 << 63 // Always zero. 159 160 return 0, ErrOverflow 161 162 done: 163 cb.index = i 164 return x, nil 165 } 166 167 // DecodeTagAndWireType decodes a field tag and wire type from input. 168 // This reads a varint and then extracts the two fields from the varint 169 // value read. 170 func (cb *Buffer) DecodeTagAndWireType() (tag int32, wireType int8, err error) { 171 var v uint64 172 v, err = cb.DecodeVarint() 173 if err != nil { 174 return 175 } 176 // low 7 bits is wire type 177 wireType = int8(v & 7) 178 // rest is int32 tag number 179 v = v >> 3 180 if v > math.MaxInt32 { 181 err = fmt.Errorf("tag number out of range: %d", v) 182 return 183 } 184 tag = int32(v) 185 return 186 } 187 188 // DecodeFixed64 reads a 64-bit integer from the Buffer. 189 // This is the format for the 190 // fixed64, sfixed64, and double protocol buffer types. 191 func (cb *Buffer) DecodeFixed64() (x uint64, err error) { 192 // x, err already 0 193 i := cb.index + 8 194 if i < 0 || i > len(cb.buf) { 195 err = io.ErrUnexpectedEOF 196 return 197 } 198 cb.index = i 199 200 x = uint64(cb.buf[i-8]) 201 x |= uint64(cb.buf[i-7]) << 8 202 x |= uint64(cb.buf[i-6]) << 16 203 x |= uint64(cb.buf[i-5]) << 24 204 x |= uint64(cb.buf[i-4]) << 32 205 x |= uint64(cb.buf[i-3]) << 40 206 x |= uint64(cb.buf[i-2]) << 48 207 x |= uint64(cb.buf[i-1]) << 56 208 return 209 } 210 211 // DecodeFixed32 reads a 32-bit integer from the Buffer. 212 // This is the format for the 213 // fixed32, sfixed32, and float protocol buffer types. 214 func (cb *Buffer) DecodeFixed32() (x uint64, err error) { 215 // x, err already 0 216 i := cb.index + 4 217 if i < 0 || i > len(cb.buf) { 218 err = io.ErrUnexpectedEOF 219 return 220 } 221 cb.index = i 222 223 x = uint64(cb.buf[i-4]) 224 x |= uint64(cb.buf[i-3]) << 8 225 x |= uint64(cb.buf[i-2]) << 16 226 x |= uint64(cb.buf[i-1]) << 24 227 return 228 } 229 230 // DecodeZigZag32 decodes a signed 32-bit integer from the given 231 // zig-zag encoded value. 232 func DecodeZigZag32(v uint64) int32 { 233 return int32((uint32(v) >> 1) ^ uint32((int32(v&1)<<31)>>31)) 234 } 235 236 // DecodeZigZag64 decodes a signed 64-bit integer from the given 237 // zig-zag encoded value. 238 func DecodeZigZag64(v uint64) int64 { 239 return int64((v >> 1) ^ uint64((int64(v&1)<<63)>>63)) 240 } 241 242 // DecodeRawBytes reads a count-delimited byte buffer from the Buffer. 243 // This is the format used for the bytes protocol buffer 244 // type and for embedded messages. 245 func (cb *Buffer) DecodeRawBytes(alloc bool) (buf []byte, err error) { 246 n, err := cb.DecodeVarint() 247 if err != nil { 248 return nil, err 249 } 250 251 nb := int(n) 252 if nb < 0 { 253 return nil, fmt.Errorf("proto: bad byte length %d", nb) 254 } 255 end := cb.index + nb 256 if end < cb.index || end > len(cb.buf) { 257 return nil, io.ErrUnexpectedEOF 258 } 259 260 if !alloc { 261 buf = cb.buf[cb.index:end] 262 cb.index = end 263 return 264 } 265 266 buf = make([]byte, nb) 267 copy(buf, cb.buf[cb.index:]) 268 cb.index = end 269 return 270 } 271 272 // ReadGroup reads the input until a "group end" tag is found 273 // and returns the data up to that point. Subsequent reads from 274 // the buffer will read data after the group end tag. If alloc 275 // is true, the data is copied to a new slice before being returned. 276 // Otherwise, the returned slice is a view into the buffer's 277 // underlying byte slice. 278 // 279 // This function correctly handles nested groups: if a "group start" 280 // tag is found, then that group's end tag will be included in the 281 // returned data. 282 func (cb *Buffer) ReadGroup(alloc bool) ([]byte, error) { 283 var groupEnd, dataEnd int 284 groupEnd, dataEnd, err := cb.findGroupEnd() 285 if err != nil { 286 return nil, err 287 } 288 var results []byte 289 if !alloc { 290 results = cb.buf[cb.index:dataEnd] 291 } else { 292 results = make([]byte, dataEnd-cb.index) 293 copy(results, cb.buf[cb.index:]) 294 } 295 cb.index = groupEnd 296 return results, nil 297 } 298 299 // SkipGroup is like ReadGroup, except that it discards the 300 // data and just advances the buffer to point to the input 301 // right *after* the "group end" tag. 302 func (cb *Buffer) SkipGroup() error { 303 groupEnd, _, err := cb.findGroupEnd() 304 if err != nil { 305 return err 306 } 307 cb.index = groupEnd 308 return nil 309 } 310 311 func (cb *Buffer) findGroupEnd() (groupEnd int, dataEnd int, err error) { 312 bs := cb.buf 313 start := cb.index 314 defer func() { 315 cb.index = start 316 }() 317 for { 318 fieldStart := cb.index 319 // read a field tag 320 _, wireType, err := cb.DecodeTagAndWireType() 321 if err != nil { 322 return 0, 0, err 323 } 324 // skip past the field's data 325 switch wireType { 326 case proto.WireFixed32: 327 if err := cb.Skip(4); err != nil { 328 return 0, 0, err 329 } 330 case proto.WireFixed64: 331 if err := cb.Skip(8); err != nil { 332 return 0, 0, err 333 } 334 case proto.WireVarint: 335 // skip varint by finding last byte (has high bit unset) 336 i := cb.index 337 limit := i + 10 // varint cannot be >10 bytes 338 for { 339 if i >= limit { 340 return 0, 0, ErrOverflow 341 } 342 if i >= len(bs) { 343 return 0, 0, io.ErrUnexpectedEOF 344 } 345 if bs[i]&0x80 == 0 { 346 break 347 } 348 i++ 349 } 350 // TODO: This would only overflow if buffer length was MaxInt and we 351 // read the last byte. This is not a real/feasible concern on 64-bit 352 // systems. Something to worry about for 32-bit systems? Do we care? 353 cb.index = i + 1 354 case proto.WireBytes: 355 l, err := cb.DecodeVarint() 356 if err != nil { 357 return 0, 0, err 358 } 359 if err := cb.Skip(int(l)); err != nil { 360 return 0, 0, err 361 } 362 case proto.WireStartGroup: 363 if err := cb.SkipGroup(); err != nil { 364 return 0, 0, err 365 } 366 case proto.WireEndGroup: 367 return cb.index, fieldStart, nil 368 default: 369 return 0, 0, ErrBadWireType 370 } 371 } 372 }