github.com/twelsh-aw/go/src@v0.0.0-20230516233729-a56fe86a7c81/internal/zstd/zstd.go (about) 1 // Copyright 2023 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Package zstd provides a decompressor for zstd streams, 6 // described in RFC 8878. It does not support dictionaries. 7 package zstd 8 9 import ( 10 "encoding/binary" 11 "errors" 12 "fmt" 13 "io" 14 ) 15 16 // fuzzing is a fuzzer hook set to true when fuzzing. 17 // This is used to reject cases where we don't match zstd. 18 var fuzzing = false 19 20 // Reader implements [io.Reader] to read a zstd compressed stream. 21 type Reader struct { 22 // The underlying Reader. 23 r io.Reader 24 25 // Whether we have read the frame header. 26 // This is of interest when buffer is empty. 27 // If true we expect to see a new block. 28 sawFrameHeader bool 29 30 // Whether the current frame expects a checksum. 31 hasChecksum bool 32 33 // Whether we have read at least one frame. 34 readOneFrame bool 35 36 // True if the frame size is not known. 37 frameSizeUnknown bool 38 39 // The number of uncompressed bytes remaining in the current frame. 40 // If frameSizeUnknown is true, this is not valid. 41 remainingFrameSize uint64 42 43 // The number of bytes read from r up to the start of the current 44 // block, for error reporting. 45 blockOffset int64 46 47 // Buffered decompressed data. 48 buffer []byte 49 // Current read offset in buffer. 50 off int 51 52 // The current repeated offsets. 53 repeatedOffset1 uint32 54 repeatedOffset2 uint32 55 repeatedOffset3 uint32 56 57 // The current Huffman tree used for compressing literals. 58 huffmanTable []uint16 59 huffmanTableBits int 60 61 // The window for back references. 62 windowSize int // maximum required window size 63 window []byte // window data 64 65 // A buffer available to hold a compressed block. 66 compressedBuf []byte 67 68 // A buffer for literals. 69 literals []byte 70 71 // Sequence decode FSE tables. 72 seqTables [3][]fseBaselineEntry 73 seqTableBits [3]uint8 74 75 // Buffers for sequence decode FSE tables. 76 seqTableBuffers [3][]fseBaselineEntry 77 78 // Scratch space used for small reads, to avoid allocation. 79 scratch [16]byte 80 81 // A scratch table for reading an FSE. Only temporarily valid. 82 fseScratch []fseEntry 83 84 // For checksum computation. 85 checksum xxhash64 86 } 87 88 // NewReader creates a new Reader that decompresses data from the given reader. 89 func NewReader(input io.Reader) *Reader { 90 r := new(Reader) 91 r.Reset(input) 92 return r 93 } 94 95 // Reset discards the current state and starts reading a new stream from r. 96 // This permits reusing a Reader rather than allocating a new one. 97 func (r *Reader) Reset(input io.Reader) { 98 r.r = input 99 100 // Several fields are preserved to avoid allocation. 101 // Others are always set before they are used. 102 r.sawFrameHeader = false 103 r.hasChecksum = false 104 r.readOneFrame = false 105 r.frameSizeUnknown = false 106 r.remainingFrameSize = 0 107 r.blockOffset = 0 108 // buffer 109 r.off = 0 110 // repeatedOffset1 111 // repeatedOffset2 112 // repeatedOffset3 113 // huffmanTable 114 // huffmanTableBits 115 // windowSize 116 // window 117 // compressedBuf 118 // literals 119 // seqTables 120 // seqTableBits 121 // seqTableBuffers 122 // scratch 123 // fseScratch 124 } 125 126 // Read implements [io.Reader]. 127 func (r *Reader) Read(p []byte) (int, error) { 128 if err := r.refillIfNeeded(); err != nil { 129 return 0, err 130 } 131 n := copy(p, r.buffer[r.off:]) 132 r.off += n 133 return n, nil 134 } 135 136 // ReadByte implements [io.ByteReader]. 137 func (r *Reader) ReadByte() (byte, error) { 138 if err := r.refillIfNeeded(); err != nil { 139 return 0, err 140 } 141 ret := r.buffer[r.off] 142 r.off++ 143 return ret, nil 144 } 145 146 // refillIfNeeded reads the next block if necessary. 147 func (r *Reader) refillIfNeeded() error { 148 for r.off >= len(r.buffer) { 149 if err := r.refill(); err != nil { 150 return err 151 } 152 r.off = 0 153 } 154 return nil 155 } 156 157 // refill reads and decompresses the next block. 158 func (r *Reader) refill() error { 159 if !r.sawFrameHeader { 160 if err := r.readFrameHeader(); err != nil { 161 return err 162 } 163 } 164 return r.readBlock() 165 } 166 167 // readFrameHeader reads the frame header and prepares to read a block. 168 func (r *Reader) readFrameHeader() error { 169 retry: 170 relativeOffset := 0 171 172 // Read magic number. RFC 3.1.1. 173 if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil { 174 // We require that the stream contain at least one frame. 175 if err == io.EOF && !r.readOneFrame { 176 err = io.ErrUnexpectedEOF 177 } 178 return r.wrapError(relativeOffset, err) 179 } 180 181 if magic := binary.LittleEndian.Uint32(r.scratch[:4]); magic != 0xfd2fb528 { 182 if magic >= 0x184d2a50 && magic <= 0x184d2a5f { 183 // This is a skippable frame. 184 r.blockOffset += int64(relativeOffset) + 4 185 if err := r.skipFrame(); err != nil { 186 return err 187 } 188 goto retry 189 } 190 191 return r.makeError(relativeOffset, "invalid magic number") 192 } 193 194 relativeOffset += 4 195 196 // Read Frame_Header_Descriptor. RFC 3.1.1.1.1. 197 if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil { 198 return r.wrapNonEOFError(relativeOffset, err) 199 } 200 descriptor := r.scratch[0] 201 202 singleSegment := descriptor&(1<<5) != 0 203 204 fcsFieldSize := 1 << (descriptor >> 6) 205 if fcsFieldSize == 1 && !singleSegment { 206 fcsFieldSize = 0 207 } 208 209 var windowDescriptorSize int 210 if singleSegment { 211 windowDescriptorSize = 0 212 } else { 213 windowDescriptorSize = 1 214 } 215 216 if descriptor&(1<<3) != 0 { 217 return r.makeError(relativeOffset, "reserved bit set in frame header descriptor") 218 } 219 220 r.hasChecksum = descriptor&(1<<2) != 0 221 if r.hasChecksum { 222 r.checksum.reset() 223 } 224 225 if descriptor&3 != 0 { 226 return r.makeError(relativeOffset, "dictionaries are not supported") 227 } 228 229 relativeOffset++ 230 231 headerSize := windowDescriptorSize + fcsFieldSize 232 233 if _, err := io.ReadFull(r.r, r.scratch[:headerSize]); err != nil { 234 return r.wrapNonEOFError(relativeOffset, err) 235 } 236 237 // Figure out the maximum amount of data we need to retain 238 // for backreferences. 239 240 if singleSegment { 241 // No window required, as all the data is in a single buffer. 242 r.windowSize = 0 243 } else { 244 // Window descriptor. RFC 3.1.1.1.2. 245 windowDescriptor := r.scratch[0] 246 exponent := uint64(windowDescriptor >> 3) 247 mantissa := uint64(windowDescriptor & 7) 248 windowLog := exponent + 10 249 windowBase := uint64(1) << windowLog 250 windowAdd := (windowBase / 8) * mantissa 251 windowSize := windowBase + windowAdd 252 253 // Default zstd sets limits on the window size. 254 if fuzzing && (windowLog > 31 || windowSize > 1<<27) { 255 return r.makeError(relativeOffset, "windowSize too large") 256 } 257 258 // RFC 8878 permits us to set an 8M max on window size. 259 if windowSize > 8<<20 { 260 windowSize = 8 << 20 261 } 262 263 r.windowSize = int(windowSize) 264 } 265 266 // Frame_Content_Size. RFC 3.1.1.4. 267 r.frameSizeUnknown = false 268 r.remainingFrameSize = 0 269 fb := r.scratch[windowDescriptorSize:] 270 switch fcsFieldSize { 271 case 0: 272 r.frameSizeUnknown = true 273 case 1: 274 r.remainingFrameSize = uint64(fb[0]) 275 case 2: 276 r.remainingFrameSize = 256 + uint64(binary.LittleEndian.Uint16(fb)) 277 case 4: 278 r.remainingFrameSize = uint64(binary.LittleEndian.Uint32(fb)) 279 case 8: 280 r.remainingFrameSize = binary.LittleEndian.Uint64(fb) 281 default: 282 panic("unreachable") 283 } 284 285 relativeOffset += headerSize 286 287 r.sawFrameHeader = true 288 r.readOneFrame = true 289 r.blockOffset += int64(relativeOffset) 290 291 // Prepare to read blocks from the frame. 292 r.repeatedOffset1 = 1 293 r.repeatedOffset2 = 4 294 r.repeatedOffset3 = 8 295 r.huffmanTableBits = 0 296 r.window = r.window[:0] 297 r.seqTables[0] = nil 298 r.seqTables[1] = nil 299 r.seqTables[2] = nil 300 301 return nil 302 } 303 304 // skipFrame skips a skippable frame. RFC 3.1.2. 305 func (r *Reader) skipFrame() error { 306 relativeOffset := 0 307 308 if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil { 309 return r.wrapNonEOFError(relativeOffset, err) 310 } 311 312 relativeOffset += 4 313 314 size := binary.LittleEndian.Uint32(r.scratch[:4]) 315 316 if seeker, ok := r.r.(io.Seeker); ok { 317 if _, err := seeker.Seek(int64(size), io.SeekCurrent); err != nil { 318 return err 319 } 320 r.blockOffset += int64(relativeOffset) + int64(size) 321 return nil 322 } 323 324 var skip []byte 325 const chunk = 1 << 20 // 1M 326 for size >= chunk { 327 if len(skip) == 0 { 328 skip = make([]byte, chunk) 329 } 330 if _, err := io.ReadFull(r.r, skip); err != nil { 331 return r.wrapNonEOFError(relativeOffset, err) 332 } 333 relativeOffset += chunk 334 size -= chunk 335 } 336 if size > 0 { 337 if len(skip) == 0 { 338 skip = make([]byte, size) 339 } 340 if _, err := io.ReadFull(r.r, skip); err != nil { 341 return r.wrapNonEOFError(relativeOffset, err) 342 } 343 relativeOffset += int(size) 344 } 345 346 r.blockOffset += int64(relativeOffset) 347 348 return nil 349 } 350 351 // readBlock reads the next block from a frame. 352 func (r *Reader) readBlock() error { 353 relativeOffset := 0 354 355 // Read Block_Header. RFC 3.1.1.2. 356 if _, err := io.ReadFull(r.r, r.scratch[:3]); err != nil { 357 return r.wrapNonEOFError(relativeOffset, err) 358 } 359 360 relativeOffset += 3 361 362 header := uint32(r.scratch[0]) | (uint32(r.scratch[1]) << 8) | (uint32(r.scratch[2]) << 16) 363 364 lastBlock := header&1 != 0 365 blockType := (header >> 1) & 3 366 blockSize := int(header >> 3) 367 368 // Maximum block size is smaller of window size and 128K. 369 // We don't record the window size for a single segment frame, 370 // so just use 128K. RFC 3.1.1.2.3, 3.1.1.2.4. 371 if blockSize > 128<<10 || (r.windowSize > 0 && blockSize > r.windowSize) { 372 return r.makeError(relativeOffset, "block size too large") 373 } 374 375 // Handle different block types. RFC 3.1.1.2.2. 376 switch blockType { 377 case 0: 378 r.setBufferSize(blockSize) 379 if _, err := io.ReadFull(r.r, r.buffer); err != nil { 380 return r.wrapNonEOFError(relativeOffset, err) 381 } 382 relativeOffset += blockSize 383 r.blockOffset += int64(relativeOffset) 384 case 1: 385 r.setBufferSize(blockSize) 386 if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil { 387 return r.wrapNonEOFError(relativeOffset, err) 388 } 389 relativeOffset++ 390 v := r.scratch[0] 391 for i := range r.buffer { 392 r.buffer[i] = v 393 } 394 r.blockOffset += int64(relativeOffset) 395 case 2: 396 r.blockOffset += int64(relativeOffset) 397 if err := r.compressedBlock(blockSize); err != nil { 398 return err 399 } 400 r.blockOffset += int64(blockSize) 401 case 3: 402 return r.makeError(relativeOffset, "invalid block type") 403 } 404 405 if !r.frameSizeUnknown { 406 if uint64(len(r.buffer)) > r.remainingFrameSize { 407 return r.makeError(relativeOffset, "too many uncompressed bytes in frame") 408 } 409 r.remainingFrameSize -= uint64(len(r.buffer)) 410 } 411 412 if r.hasChecksum { 413 r.checksum.update(r.buffer) 414 } 415 416 if !lastBlock { 417 r.saveWindow(r.buffer) 418 } else { 419 if !r.frameSizeUnknown && r.remainingFrameSize != 0 { 420 return r.makeError(relativeOffset, "not enough uncompressed bytes for frame") 421 } 422 // Check for checksum at end of frame. RFC 3.1.1. 423 if r.hasChecksum { 424 if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil { 425 return r.wrapNonEOFError(0, err) 426 } 427 428 inputChecksum := binary.LittleEndian.Uint32(r.scratch[:4]) 429 dataChecksum := uint32(r.checksum.digest()) 430 if inputChecksum != dataChecksum { 431 return r.wrapError(0, fmt.Errorf("invalid checksum: got %#x want %#x", dataChecksum, inputChecksum)) 432 } 433 434 r.blockOffset += 4 435 } 436 r.sawFrameHeader = false 437 } 438 439 return nil 440 } 441 442 // setBufferSize sets the decompressed buffer size. 443 // When this is called the buffer is empty. 444 func (r *Reader) setBufferSize(size int) { 445 if cap(r.buffer) < size { 446 need := size - cap(r.buffer) 447 r.buffer = append(r.buffer[:cap(r.buffer)], make([]byte, need)...) 448 } 449 r.buffer = r.buffer[:size] 450 } 451 452 // saveWindow saves bytes in the backreference window. 453 // TODO: use a circular buffer for less data movement. 454 func (r *Reader) saveWindow(buf []byte) { 455 if r.windowSize == 0 { 456 return 457 } 458 459 if len(buf) >= r.windowSize { 460 from := len(buf) - r.windowSize 461 r.window = append(r.window[:0], buf[from:]...) 462 return 463 } 464 465 keep := r.windowSize - len(buf) // must be positive 466 if keep < len(r.window) { 467 remove := len(r.window) - keep 468 copy(r.window[:], r.window[remove:]) 469 } 470 471 r.window = append(r.window, buf...) 472 } 473 474 // zstdError is an error while decompressing. 475 type zstdError struct { 476 offset int64 477 err error 478 } 479 480 func (ze *zstdError) Error() string { 481 return fmt.Sprintf("zstd decompression error at %d: %v", ze.offset, ze.err) 482 } 483 484 func (ze *zstdError) Unwrap() error { 485 return ze.err 486 } 487 488 func (r *Reader) makeEOFError(off int) error { 489 return r.wrapError(off, io.ErrUnexpectedEOF) 490 } 491 492 func (r *Reader) wrapNonEOFError(off int, err error) error { 493 if err == io.EOF { 494 err = io.ErrUnexpectedEOF 495 } 496 return r.wrapError(off, err) 497 } 498 499 func (r *Reader) makeError(off int, msg string) error { 500 return r.wrapError(off, errors.New(msg)) 501 } 502 503 func (r *Reader) wrapError(off int, err error) error { 504 if err == io.EOF { 505 return err 506 } 507 return &zstdError{r.blockOffset + int64(off), err} 508 }