github.com/go-asm/go@v1.21.1-0.20240213172139-40c5ead50c48/zstd/zstd.go (about) 1 // Copyright 2023 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Package zstd provides a decompressor for zstd streams, 6 // described in RFC 8878. It does not support dictionaries. 7 package zstd 8 9 import ( 10 "encoding/binary" 11 "errors" 12 "fmt" 13 "io" 14 ) 15 16 // fuzzing is a fuzzer hook set to true when fuzzing. 17 // This is used to reject cases where we don't match zstd. 18 var fuzzing = false 19 20 // Reader implements [io.Reader] to read a zstd compressed stream. 21 type Reader struct { 22 // The underlying Reader. 23 r io.Reader 24 25 // Whether we have read the frame header. 26 // This is of interest when buffer is empty. 27 // If true we expect to see a new block. 28 sawFrameHeader bool 29 30 // Whether the current frame expects a checksum. 31 hasChecksum bool 32 33 // Whether we have read at least one frame. 34 readOneFrame bool 35 36 // True if the frame size is not known. 37 frameSizeUnknown bool 38 39 // The number of uncompressed bytes remaining in the current frame. 40 // If frameSizeUnknown is true, this is not valid. 41 remainingFrameSize uint64 42 43 // The number of bytes read from r up to the start of the current 44 // block, for error reporting. 45 blockOffset int64 46 47 // Buffered decompressed data. 48 buffer []byte 49 // Current read offset in buffer. 50 off int 51 52 // The current repeated offsets. 53 repeatedOffset1 uint32 54 repeatedOffset2 uint32 55 repeatedOffset3 uint32 56 57 // The current Huffman tree used for compressing literals. 58 huffmanTable []uint16 59 huffmanTableBits int 60 61 // The window for back references. 62 window window 63 64 // A buffer available to hold a compressed block. 65 compressedBuf []byte 66 67 // A buffer for literals. 68 literals []byte 69 70 // Sequence decode FSE tables. 71 seqTables [3][]fseBaselineEntry 72 seqTableBits [3]uint8 73 74 // Buffers for sequence decode FSE tables. 75 seqTableBuffers [3][]fseBaselineEntry 76 77 // Scratch space used for small reads, to avoid allocation. 78 scratch [16]byte 79 80 // A scratch table for reading an FSE. Only temporarily valid. 81 fseScratch []fseEntry 82 83 // For checksum computation. 84 checksum xxhash64 85 } 86 87 // NewReader creates a new Reader that decompresses data from the given reader. 88 func NewReader(input io.Reader) *Reader { 89 r := new(Reader) 90 r.Reset(input) 91 return r 92 } 93 94 // Reset discards the current state and starts reading a new stream from r. 95 // This permits reusing a Reader rather than allocating a new one. 96 func (r *Reader) Reset(input io.Reader) { 97 r.r = input 98 99 // Several fields are preserved to avoid allocation. 100 // Others are always set before they are used. 101 r.sawFrameHeader = false 102 r.hasChecksum = false 103 r.readOneFrame = false 104 r.frameSizeUnknown = false 105 r.remainingFrameSize = 0 106 r.blockOffset = 0 107 r.buffer = r.buffer[:0] 108 r.off = 0 109 // repeatedOffset1 110 // repeatedOffset2 111 // repeatedOffset3 112 // huffmanTable 113 // huffmanTableBits 114 // window 115 // compressedBuf 116 // literals 117 // seqTables 118 // seqTableBits 119 // seqTableBuffers 120 // scratch 121 // fseScratch 122 } 123 124 // Read implements [io.Reader]. 125 func (r *Reader) Read(p []byte) (int, error) { 126 if err := r.refillIfNeeded(); err != nil { 127 return 0, err 128 } 129 n := copy(p, r.buffer[r.off:]) 130 r.off += n 131 return n, nil 132 } 133 134 // ReadByte implements [io.ByteReader]. 135 func (r *Reader) ReadByte() (byte, error) { 136 if err := r.refillIfNeeded(); err != nil { 137 return 0, err 138 } 139 ret := r.buffer[r.off] 140 r.off++ 141 return ret, nil 142 } 143 144 // refillIfNeeded reads the next block if necessary. 145 func (r *Reader) refillIfNeeded() error { 146 for r.off >= len(r.buffer) { 147 if err := r.refill(); err != nil { 148 return err 149 } 150 r.off = 0 151 } 152 return nil 153 } 154 155 // refill reads and decompresses the next block. 156 func (r *Reader) refill() error { 157 if !r.sawFrameHeader { 158 if err := r.readFrameHeader(); err != nil { 159 return err 160 } 161 } 162 return r.readBlock() 163 } 164 165 // readFrameHeader reads the frame header and prepares to read a block. 166 func (r *Reader) readFrameHeader() error { 167 retry: 168 relativeOffset := 0 169 170 // Read magic number. RFC 3.1.1. 171 if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil { 172 // We require that the stream contains at least one frame. 173 if err == io.EOF && !r.readOneFrame { 174 err = io.ErrUnexpectedEOF 175 } 176 return r.wrapError(relativeOffset, err) 177 } 178 179 if magic := binary.LittleEndian.Uint32(r.scratch[:4]); magic != 0xfd2fb528 { 180 if magic >= 0x184d2a50 && magic <= 0x184d2a5f { 181 // This is a skippable frame. 182 r.blockOffset += int64(relativeOffset) + 4 183 if err := r.skipFrame(); err != nil { 184 return err 185 } 186 r.readOneFrame = true 187 goto retry 188 } 189 190 return r.makeError(relativeOffset, "invalid magic number") 191 } 192 193 relativeOffset += 4 194 195 // Read Frame_Header_Descriptor. RFC 3.1.1.1.1. 196 if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil { 197 return r.wrapNonEOFError(relativeOffset, err) 198 } 199 descriptor := r.scratch[0] 200 201 singleSegment := descriptor&(1<<5) != 0 202 203 fcsFieldSize := 1 << (descriptor >> 6) 204 if fcsFieldSize == 1 && !singleSegment { 205 fcsFieldSize = 0 206 } 207 208 var windowDescriptorSize int 209 if singleSegment { 210 windowDescriptorSize = 0 211 } else { 212 windowDescriptorSize = 1 213 } 214 215 if descriptor&(1<<3) != 0 { 216 return r.makeError(relativeOffset, "reserved bit set in frame header descriptor") 217 } 218 219 r.hasChecksum = descriptor&(1<<2) != 0 220 if r.hasChecksum { 221 r.checksum.reset() 222 } 223 224 // Dictionary_ID_Flag. RFC 3.1.1.1.1.6. 225 dictionaryIdSize := 0 226 if dictIdFlag := descriptor & 3; dictIdFlag != 0 { 227 dictionaryIdSize = 1 << (dictIdFlag - 1) 228 } 229 230 relativeOffset++ 231 232 headerSize := windowDescriptorSize + dictionaryIdSize + fcsFieldSize 233 234 if _, err := io.ReadFull(r.r, r.scratch[:headerSize]); err != nil { 235 return r.wrapNonEOFError(relativeOffset, err) 236 } 237 238 // Figure out the maximum amount of data we need to retain 239 // for backreferences. 240 var windowSize int 241 if !singleSegment { 242 // Window descriptor. RFC 3.1.1.1.2. 243 windowDescriptor := r.scratch[0] 244 exponent := uint64(windowDescriptor >> 3) 245 mantissa := uint64(windowDescriptor & 7) 246 windowLog := exponent + 10 247 windowBase := uint64(1) << windowLog 248 windowAdd := (windowBase / 8) * mantissa 249 windowSize = int(windowBase + windowAdd) 250 251 // Default zstd sets limits on the window size. 252 if fuzzing && (windowLog > 31 || windowSize > 1<<27) { 253 return r.makeError(relativeOffset, "windowSize too large") 254 } 255 } 256 257 // Dictionary_ID. RFC 3.1.1.1.3. 258 if dictionaryIdSize != 0 { 259 dictionaryId := r.scratch[windowDescriptorSize : windowDescriptorSize+dictionaryIdSize] 260 // Allow only zero Dictionary ID. 261 for _, b := range dictionaryId { 262 if b != 0 { 263 return r.makeError(relativeOffset, "dictionaries are not supported") 264 } 265 } 266 } 267 268 // Frame_Content_Size. RFC 3.1.1.1.4. 269 r.frameSizeUnknown = false 270 r.remainingFrameSize = 0 271 fb := r.scratch[windowDescriptorSize+dictionaryIdSize:] 272 switch fcsFieldSize { 273 case 0: 274 r.frameSizeUnknown = true 275 case 1: 276 r.remainingFrameSize = uint64(fb[0]) 277 case 2: 278 r.remainingFrameSize = 256 + uint64(binary.LittleEndian.Uint16(fb)) 279 case 4: 280 r.remainingFrameSize = uint64(binary.LittleEndian.Uint32(fb)) 281 case 8: 282 r.remainingFrameSize = binary.LittleEndian.Uint64(fb) 283 default: 284 panic("unreachable") 285 } 286 287 // RFC 3.1.1.1.2. 288 // When Single_Segment_Flag is set, Window_Descriptor is not present. 289 // In this case, Window_Size is Frame_Content_Size. 290 if singleSegment { 291 windowSize = int(r.remainingFrameSize) 292 } 293 294 // RFC 8878 3.1.1.1.1.2. permits us to set an 8M max on window size. 295 if windowSize > 8<<20 { 296 windowSize = 8 << 20 297 } 298 299 relativeOffset += headerSize 300 301 r.sawFrameHeader = true 302 r.readOneFrame = true 303 r.blockOffset += int64(relativeOffset) 304 305 // Prepare to read blocks from the frame. 306 r.repeatedOffset1 = 1 307 r.repeatedOffset2 = 4 308 r.repeatedOffset3 = 8 309 r.huffmanTableBits = 0 310 r.window.reset(windowSize) 311 r.seqTables[0] = nil 312 r.seqTables[1] = nil 313 r.seqTables[2] = nil 314 315 return nil 316 } 317 318 // skipFrame skips a skippable frame. RFC 3.1.2. 319 func (r *Reader) skipFrame() error { 320 relativeOffset := 0 321 322 if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil { 323 return r.wrapNonEOFError(relativeOffset, err) 324 } 325 326 relativeOffset += 4 327 328 size := binary.LittleEndian.Uint32(r.scratch[:4]) 329 if size == 0 { 330 r.blockOffset += int64(relativeOffset) 331 return nil 332 } 333 334 if seeker, ok := r.r.(io.Seeker); ok { 335 r.blockOffset += int64(relativeOffset) 336 // Implementations of Seeker do not always detect invalid offsets, 337 // so check that the new offset is valid by comparing to the end. 338 prev, err := seeker.Seek(0, io.SeekCurrent) 339 if err != nil { 340 return r.wrapError(0, err) 341 } 342 end, err := seeker.Seek(0, io.SeekEnd) 343 if err != nil { 344 return r.wrapError(0, err) 345 } 346 if prev > end-int64(size) { 347 r.blockOffset += end - prev 348 return r.makeEOFError(0) 349 } 350 351 // The new offset is valid, so seek to it. 352 _, err = seeker.Seek(prev+int64(size), io.SeekStart) 353 if err != nil { 354 return r.wrapError(0, err) 355 } 356 r.blockOffset += int64(size) 357 return nil 358 } 359 360 var skip []byte 361 const chunk = 1 << 20 // 1M 362 for size >= chunk { 363 if len(skip) == 0 { 364 skip = make([]byte, chunk) 365 } 366 if _, err := io.ReadFull(r.r, skip); err != nil { 367 return r.wrapNonEOFError(relativeOffset, err) 368 } 369 relativeOffset += chunk 370 size -= chunk 371 } 372 if size > 0 { 373 if len(skip) == 0 { 374 skip = make([]byte, size) 375 } 376 if _, err := io.ReadFull(r.r, skip); err != nil { 377 return r.wrapNonEOFError(relativeOffset, err) 378 } 379 relativeOffset += int(size) 380 } 381 382 r.blockOffset += int64(relativeOffset) 383 384 return nil 385 } 386 387 // readBlock reads the next block from a frame. 388 func (r *Reader) readBlock() error { 389 relativeOffset := 0 390 391 // Read Block_Header. RFC 3.1.1.2. 392 if _, err := io.ReadFull(r.r, r.scratch[:3]); err != nil { 393 return r.wrapNonEOFError(relativeOffset, err) 394 } 395 396 relativeOffset += 3 397 398 header := uint32(r.scratch[0]) | (uint32(r.scratch[1]) << 8) | (uint32(r.scratch[2]) << 16) 399 400 lastBlock := header&1 != 0 401 blockType := (header >> 1) & 3 402 blockSize := int(header >> 3) 403 404 // Maximum block size is smaller of window size and 128K. 405 // We don't record the window size for a single segment frame, 406 // so just use 128K. RFC 3.1.1.2.3, 3.1.1.2.4. 407 if blockSize > 128<<10 || (r.window.size > 0 && blockSize > r.window.size) { 408 return r.makeError(relativeOffset, "block size too large") 409 } 410 411 // Handle different block types. RFC 3.1.1.2.2. 412 switch blockType { 413 case 0: 414 r.setBufferSize(blockSize) 415 if _, err := io.ReadFull(r.r, r.buffer); err != nil { 416 return r.wrapNonEOFError(relativeOffset, err) 417 } 418 relativeOffset += blockSize 419 r.blockOffset += int64(relativeOffset) 420 case 1: 421 r.setBufferSize(blockSize) 422 if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil { 423 return r.wrapNonEOFError(relativeOffset, err) 424 } 425 relativeOffset++ 426 v := r.scratch[0] 427 for i := range r.buffer { 428 r.buffer[i] = v 429 } 430 r.blockOffset += int64(relativeOffset) 431 case 2: 432 r.blockOffset += int64(relativeOffset) 433 if err := r.compressedBlock(blockSize); err != nil { 434 return err 435 } 436 r.blockOffset += int64(blockSize) 437 case 3: 438 return r.makeError(relativeOffset, "invalid block type") 439 } 440 441 if !r.frameSizeUnknown { 442 if uint64(len(r.buffer)) > r.remainingFrameSize { 443 return r.makeError(relativeOffset, "too many uncompressed bytes in frame") 444 } 445 r.remainingFrameSize -= uint64(len(r.buffer)) 446 } 447 448 if r.hasChecksum { 449 r.checksum.update(r.buffer) 450 } 451 452 if !lastBlock { 453 r.window.save(r.buffer) 454 } else { 455 if !r.frameSizeUnknown && r.remainingFrameSize != 0 { 456 return r.makeError(relativeOffset, "not enough uncompressed bytes for frame") 457 } 458 // Check for checksum at end of frame. RFC 3.1.1. 459 if r.hasChecksum { 460 if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil { 461 return r.wrapNonEOFError(0, err) 462 } 463 464 inputChecksum := binary.LittleEndian.Uint32(r.scratch[:4]) 465 dataChecksum := uint32(r.checksum.digest()) 466 if inputChecksum != dataChecksum { 467 return r.wrapError(0, fmt.Errorf("invalid checksum: got %#x want %#x", dataChecksum, inputChecksum)) 468 } 469 470 r.blockOffset += 4 471 } 472 r.sawFrameHeader = false 473 } 474 475 return nil 476 } 477 478 // setBufferSize sets the decompressed buffer size. 479 // When this is called the buffer is empty. 480 func (r *Reader) setBufferSize(size int) { 481 if cap(r.buffer) < size { 482 need := size - cap(r.buffer) 483 r.buffer = append(r.buffer[:cap(r.buffer)], make([]byte, need)...) 484 } 485 r.buffer = r.buffer[:size] 486 } 487 488 // zstdError is an error while decompressing. 489 type zstdError struct { 490 offset int64 491 err error 492 } 493 494 func (ze *zstdError) Error() string { 495 return fmt.Sprintf("zstd decompression error at %d: %v", ze.offset, ze.err) 496 } 497 498 func (ze *zstdError) Unwrap() error { 499 return ze.err 500 } 501 502 func (r *Reader) makeEOFError(off int) error { 503 return r.wrapError(off, io.ErrUnexpectedEOF) 504 } 505 506 func (r *Reader) wrapNonEOFError(off int, err error) error { 507 if err == io.EOF { 508 err = io.ErrUnexpectedEOF 509 } 510 return r.wrapError(off, err) 511 } 512 513 func (r *Reader) makeError(off int, msg string) error { 514 return r.wrapError(off, errors.New(msg)) 515 } 516 517 func (r *Reader) wrapError(off int, err error) error { 518 if err == io.EOF { 519 return err 520 } 521 return &zstdError{r.blockOffset + int64(off), err} 522 }