github.com/jhump/protoreflect@v1.16.0/internal/codec/decode.go (about) 1 package codec 2 3 import ( 4 "errors" 5 "fmt" 6 "io" 7 "math" 8 9 "github.com/golang/protobuf/proto" 10 ) 11 12 // ErrOverflow is returned when an integer is too large to be represented. 13 var ErrOverflow = errors.New("proto: integer overflow") 14 15 // ErrBadWireType is returned when decoding a wire-type from a buffer that 16 // is not valid. 17 var ErrBadWireType = errors.New("proto: bad wiretype") 18 19 func (cb *Buffer) decodeVarintSlow() (x uint64, err error) { 20 i := cb.index 21 l := len(cb.buf) 22 23 for shift := uint(0); shift < 64; shift += 7 { 24 if i >= l { 25 err = io.ErrUnexpectedEOF 26 return 27 } 28 b := cb.buf[i] 29 i++ 30 x |= (uint64(b) & 0x7F) << shift 31 if b < 0x80 { 32 cb.index = i 33 return 34 } 35 } 36 37 // The number is too large to represent in a 64-bit value. 38 err = ErrOverflow 39 return 40 } 41 42 // DecodeVarint reads a varint-encoded integer from the Buffer. 43 // This is the format for the 44 // int32, int64, uint32, uint64, bool, and enum 45 // protocol buffer types. 46 func (cb *Buffer) DecodeVarint() (uint64, error) { 47 i := cb.index 48 buf := cb.buf 49 50 if i >= len(buf) { 51 return 0, io.ErrUnexpectedEOF 52 } else if buf[i] < 0x80 { 53 cb.index++ 54 return uint64(buf[i]), nil 55 } else if len(buf)-i < 10 { 56 return cb.decodeVarintSlow() 57 } 58 59 var b uint64 60 // we already checked the first byte 61 x := uint64(buf[i]) - 0x80 62 i++ 63 64 b = uint64(buf[i]) 65 i++ 66 x += b << 7 67 if b&0x80 == 0 { 68 goto done 69 } 70 x -= 0x80 << 7 71 72 b = uint64(buf[i]) 73 i++ 74 x += b << 14 75 if b&0x80 == 0 { 76 goto done 77 } 78 x -= 0x80 << 14 79 80 b = uint64(buf[i]) 81 i++ 82 x += b << 21 83 if b&0x80 == 0 { 84 goto done 85 } 86 x -= 0x80 << 21 87 88 b = uint64(buf[i]) 89 i++ 90 x += b << 28 91 if b&0x80 == 0 { 92 goto done 93 } 94 x -= 0x80 << 28 95 96 b = uint64(buf[i]) 97 i++ 98 x += b << 35 99 if b&0x80 == 0 { 100 goto done 101 } 102 x -= 0x80 << 35 103 104 b = uint64(buf[i]) 105 i++ 106 x += b << 42 107 if b&0x80 == 0 { 108 goto done 109 } 110 x -= 0x80 << 42 111 112 b = uint64(buf[i]) 113 i++ 114 x += b << 49 115 if b&0x80 == 0 { 116 goto done 117 } 118 x -= 0x80 << 49 119 120 b = uint64(buf[i]) 121 i++ 122 x += b << 56 123 if b&0x80 == 0 { 124 goto done 125 } 126 x -= 0x80 << 56 127 128 b = uint64(buf[i]) 129 i++ 130 x += b << 63 131 if b&0x80 == 0 { 132 goto done 133 } 134 // x -= 0x80 << 63 // Always zero. 135 136 return 0, ErrOverflow 137 138 done: 139 cb.index = i 140 return x, nil 141 } 142 143 // DecodeTagAndWireType decodes a field tag and wire type from input. 144 // This reads a varint and then extracts the two fields from the varint 145 // value read. 146 func (cb *Buffer) DecodeTagAndWireType() (tag int32, wireType int8, err error) { 147 var v uint64 148 v, err = cb.DecodeVarint() 149 if err != nil { 150 return 151 } 152 // low 7 bits is wire type 153 wireType = int8(v & 7) 154 // rest is int32 tag number 155 v = v >> 3 156 if v > math.MaxInt32 { 157 err = fmt.Errorf("tag number out of range: %d", v) 158 return 159 } 160 tag = int32(v) 161 return 162 } 163 164 // DecodeFixed64 reads a 64-bit integer from the Buffer. 165 // This is the format for the 166 // fixed64, sfixed64, and double protocol buffer types. 167 func (cb *Buffer) DecodeFixed64() (x uint64, err error) { 168 // x, err already 0 169 i := cb.index + 8 170 if i < 0 || i > len(cb.buf) { 171 err = io.ErrUnexpectedEOF 172 return 173 } 174 cb.index = i 175 176 x = uint64(cb.buf[i-8]) 177 x |= uint64(cb.buf[i-7]) << 8 178 x |= uint64(cb.buf[i-6]) << 16 179 x |= uint64(cb.buf[i-5]) << 24 180 x |= uint64(cb.buf[i-4]) << 32 181 x |= uint64(cb.buf[i-3]) << 40 182 x |= uint64(cb.buf[i-2]) << 48 183 x |= uint64(cb.buf[i-1]) << 56 184 return 185 } 186 187 // DecodeFixed32 reads a 32-bit integer from the Buffer. 188 // This is the format for the 189 // fixed32, sfixed32, and float protocol buffer types. 190 func (cb *Buffer) DecodeFixed32() (x uint64, err error) { 191 // x, err already 0 192 i := cb.index + 4 193 if i < 0 || i > len(cb.buf) { 194 err = io.ErrUnexpectedEOF 195 return 196 } 197 cb.index = i 198 199 x = uint64(cb.buf[i-4]) 200 x |= uint64(cb.buf[i-3]) << 8 201 x |= uint64(cb.buf[i-2]) << 16 202 x |= uint64(cb.buf[i-1]) << 24 203 return 204 } 205 206 // DecodeRawBytes reads a count-delimited byte buffer from the Buffer. 207 // This is the format used for the bytes protocol buffer 208 // type and for embedded messages. 209 func (cb *Buffer) DecodeRawBytes(alloc bool) (buf []byte, err error) { 210 n, err := cb.DecodeVarint() 211 if err != nil { 212 return nil, err 213 } 214 215 nb := int(n) 216 if nb < 0 { 217 return nil, fmt.Errorf("proto: bad byte length %d", nb) 218 } 219 end := cb.index + nb 220 if end < cb.index || end > len(cb.buf) { 221 return nil, io.ErrUnexpectedEOF 222 } 223 224 if !alloc { 225 buf = cb.buf[cb.index:end] 226 cb.index = end 227 return 228 } 229 230 buf = make([]byte, nb) 231 copy(buf, cb.buf[cb.index:]) 232 cb.index = end 233 return 234 } 235 236 // ReadGroup reads the input until a "group end" tag is found 237 // and returns the data up to that point. Subsequent reads from 238 // the buffer will read data after the group end tag. If alloc 239 // is true, the data is copied to a new slice before being returned. 240 // Otherwise, the returned slice is a view into the buffer's 241 // underlying byte slice. 242 // 243 // This function correctly handles nested groups: if a "group start" 244 // tag is found, then that group's end tag will be included in the 245 // returned data. 246 func (cb *Buffer) ReadGroup(alloc bool) ([]byte, error) { 247 var groupEnd, dataEnd int 248 groupEnd, dataEnd, err := cb.findGroupEnd() 249 if err != nil { 250 return nil, err 251 } 252 var results []byte 253 if !alloc { 254 results = cb.buf[cb.index:dataEnd] 255 } else { 256 results = make([]byte, dataEnd-cb.index) 257 copy(results, cb.buf[cb.index:]) 258 } 259 cb.index = groupEnd 260 return results, nil 261 } 262 263 // SkipGroup is like ReadGroup, except that it discards the 264 // data and just advances the buffer to point to the input 265 // right *after* the "group end" tag. 266 func (cb *Buffer) SkipGroup() error { 267 groupEnd, _, err := cb.findGroupEnd() 268 if err != nil { 269 return err 270 } 271 cb.index = groupEnd 272 return nil 273 } 274 275 // SkipField attempts to skip the value of a field with the given wire 276 // type. When consuming a protobuf-encoded stream, it can be called immediately 277 // after DecodeTagAndWireType to discard the subsequent data for the field. 278 func (cb *Buffer) SkipField(wireType int8) error { 279 switch wireType { 280 case proto.WireFixed32: 281 if err := cb.Skip(4); err != nil { 282 return err 283 } 284 case proto.WireFixed64: 285 if err := cb.Skip(8); err != nil { 286 return err 287 } 288 case proto.WireVarint: 289 // skip varint by finding last byte (has high bit unset) 290 i := cb.index 291 limit := i + 10 // varint cannot be >10 bytes 292 for { 293 if i >= limit { 294 return ErrOverflow 295 } 296 if i >= len(cb.buf) { 297 return io.ErrUnexpectedEOF 298 } 299 if cb.buf[i]&0x80 == 0 { 300 break 301 } 302 i++ 303 } 304 // TODO: This would only overflow if buffer length was MaxInt and we 305 // read the last byte. This is not a real/feasible concern on 64-bit 306 // systems. Something to worry about for 32-bit systems? Do we care? 307 cb.index = i + 1 308 case proto.WireBytes: 309 l, err := cb.DecodeVarint() 310 if err != nil { 311 return err 312 } 313 if err := cb.Skip(int(l)); err != nil { 314 return err 315 } 316 case proto.WireStartGroup: 317 if err := cb.SkipGroup(); err != nil { 318 return err 319 } 320 default: 321 return ErrBadWireType 322 } 323 return nil 324 } 325 326 func (cb *Buffer) findGroupEnd() (groupEnd int, dataEnd int, err error) { 327 start := cb.index 328 defer func() { 329 cb.index = start 330 }() 331 for { 332 fieldStart := cb.index 333 // read a field tag 334 _, wireType, err := cb.DecodeTagAndWireType() 335 if err != nil { 336 return 0, 0, err 337 } 338 if wireType == proto.WireEndGroup { 339 return cb.index, fieldStart, nil 340 } 341 // skip past the field's data 342 if err := cb.SkipField(wireType); err != nil { 343 return 0, 0, err 344 } 345 } 346 }