github.com/grailbio/base@v0.0.11/recordio/internal/chunk.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package internal
     6  
     7  import (
     8  	"encoding/binary"
     9  	"fmt"
    10  	"hash"
    11  	"hash/crc32"
    12  	"io"
    13  	"math"
    14  
    15  	"github.com/grailbio/base/errors"
    16  )
    17  
    18  type chunkFlag uint32
    19  
    20  const (
    21  	// ChunkHeaderSize is the fixed header size for a chunk.
    22  	ChunkHeaderSize = 28
    23  
    24  	// ChunkSize is the fixed size of a chunk, including its header.
    25  	ChunkSize = 32 << 10
    26  
    27  	// MaxChunkPayloadSize is the maximum size of payload a chunk can carry.
    28  	MaxChunkPayloadSize = ChunkSize - ChunkHeaderSize
    29  )
    30  
    31  // Chunk layout:
    32  //
    33  //   magic [8B]
    34  //   crc   [4B LE]
    35  //   flag  [4B LE]
    36  //   size  [4B LE]
    37  //   total [4B LE]
    38  //   index [4B LE]
    39  //   data [size]
    40  //   padding [32768 - 28 - size]
    41  //
    42  // magic: one of MagicHeader, MagicPacked, MagicTrailer.
    43  // size: size of the chunk payload (data). size <=  (32<<10) - 28
    44  // padding: garbage data added to make the chunk size exactly 32768B.
    45  //
    46  // total: the total # of chunks in the blocks.
    47  // index: the index of the chunk within the block.  Index is 0 for the first
    48  // block, 1 for the 2nd block, and so on.
    49  // flag: unused now.
    50  //
    51  // crc: IEEE CRC32 of of the succeeding fields: size, index, flag, and data.
    52  //  Note: padding is not included in the CRC.
    53  type chunkHeader [ChunkHeaderSize]byte
    54  
    55  func (h *chunkHeader) TotalChunks() int {
    56  	return int(binary.LittleEndian.Uint32(h[20:]))
    57  }
    58  
    59  func (h *chunkHeader) Index() int {
    60  	return int(binary.LittleEndian.Uint32(h[24:]))
    61  }
    62  
    63  var chunkPadding [MaxChunkPayloadSize]byte
    64  
    65  // Seek to "off". Returns nil iff the seek ptr moves to "off".
    66  func Seek(r io.ReadSeeker, off int64) error {
    67  	n, err := r.Seek(off, io.SeekStart)
    68  	if err != nil {
    69  		return err
    70  	}
    71  	if n != off {
    72  		return fmt.Errorf("seek: got %v, expect %v", n, off)
    73  	}
    74  	return nil
    75  }
    76  
    77  func init() {
    78  	temp := [4]byte{0xde, 0xad, 0xbe, 0xef}
    79  	for i := range chunkPadding {
    80  		chunkPadding[i] = temp[i%len(temp)]
    81  	}
    82  }
    83  
    84  // ChunkWriter implements low-level block-write operations. It takes logical
    85  // block and stores it as a sequence of chunks. Thread compatible.
    86  type ChunkWriter struct {
    87  	nWritten int64
    88  	w        io.Writer
    89  	err      *errors.Once
    90  	crc      hash.Hash32
    91  }
    92  
    93  // Len returns the number of bytes successfully written so far.
    94  // The value is meaningful only when err.Err()==nil.
    95  func (w *ChunkWriter) Len() int64 {
    96  	return w.nWritten
    97  }
    98  
    99  // Write one block. An error is reported through w.err.
   100  func (w *ChunkWriter) Write(magic MagicBytes, payload []byte) {
   101  	var header chunkHeader
   102  	copy(header[:], magic[:])
   103  
   104  	chunkIndex := 0
   105  	totalChunks := (len(payload)-1)/MaxChunkPayloadSize + 1
   106  	for {
   107  		var chunkPayload []byte
   108  		lastChunk := false
   109  		if len(payload) <= MaxChunkPayloadSize {
   110  			lastChunk = true
   111  			chunkPayload = payload
   112  			payload = nil
   113  		} else {
   114  			chunkPayload = payload[:MaxChunkPayloadSize]
   115  			payload = payload[MaxChunkPayloadSize:]
   116  		}
   117  		binary.LittleEndian.PutUint32(header[12:], uint32(0))
   118  		binary.LittleEndian.PutUint32(header[16:], uint32(len(chunkPayload)))
   119  		binary.LittleEndian.PutUint32(header[20:], uint32(totalChunks))
   120  		binary.LittleEndian.PutUint32(header[24:], uint32(chunkIndex))
   121  
   122  		w.crc.Reset()
   123  		w.crc.Write(header[12:])
   124  		w.crc.Write(chunkPayload)
   125  		csum := w.crc.Sum32()
   126  		binary.LittleEndian.PutUint32(header[8:], csum)
   127  		w.doWrite(header[:])
   128  		w.doWrite(chunkPayload)
   129  		chunkIndex++
   130  		if lastChunk {
   131  			paddingSize := MaxChunkPayloadSize - len(chunkPayload)
   132  			if paddingSize > 0 {
   133  				w.doWrite(chunkPadding[:paddingSize])
   134  			}
   135  			break
   136  		}
   137  	}
   138  	if chunkIndex != totalChunks {
   139  		panic(fmt.Sprintf("nchunks %d, total %d", chunkIndex, totalChunks))
   140  	}
   141  }
   142  
   143  func (w *ChunkWriter) doWrite(data []byte) {
   144  	n, err := w.w.Write(data)
   145  	if err != nil {
   146  		w.err.Set(err)
   147  		return
   148  	}
   149  	w.nWritten += int64(len(data))
   150  	if n != len(data) {
   151  		w.err.Set(fmt.Errorf("Failed to write %d bytes (got %d)", len(data), n))
   152  	}
   153  }
   154  
   155  // NewChunkWriter creates a new chunk writer. Any error is reported through
   156  // "err".
   157  func NewChunkWriter(w io.Writer, err *errors.Once) *ChunkWriter {
   158  	return &ChunkWriter{w: w, err: err, crc: crc32.New(IEEECRC)}
   159  }
   160  
   161  // ChunkScanner reads a sequence of chunks and reconstructs a logical
   162  // block. Thread compatible.
   163  type ChunkScanner struct {
   164  	r   io.ReadSeeker
   165  	err *errors.Once
   166  
   167  	fileSize int64
   168  	off      int64
   169  	limit    int64
   170  
   171  	magic  MagicBytes
   172  	chunks [][]byte
   173  
   174  	pool                 [][]byte
   175  	unused               int // the first unused buf in pool.
   176  	approxChunksPerBlock float64
   177  }
   178  
   179  // NewChunkScanner creates a new chunk scanner. Any error is reported through "err".
   180  func NewChunkScanner(r io.ReadSeeker, err *errors.Once) *ChunkScanner {
   181  	rx := &ChunkScanner{r: r, err: err}
   182  	// Compute the file size.
   183  	var e error
   184  	if rx.fileSize, e = r.Seek(0, io.SeekEnd); e != nil {
   185  		rx.err.Set(e)
   186  	}
   187  	rx.err.Set(Seek(r, 0))
   188  	rx.limit = math.MaxInt64
   189  	return rx
   190  }
   191  
   192  // LimitShard limits this scanner to scan the blocks belonging to a shard range
   193  // [start,limit) out of [0, nshard). The shard range begins at the scanner's
   194  // current offset, which must be on a block boundary. The file (beginning at the
   195  // current scanner offset) is divided into n shards. Each shard scans blocks
   196  // until the next segment. If a shard begins in the middle of a block, that
   197  // block belongs to the previous shard.
   198  func (r *ChunkScanner) LimitShard(start, limit, nshard int) {
   199  	// Compute the offset and limit for shard-of-nshard.
   200  	// Invariant: limit is the offset at or after which a new block
   201  	// should not be scanned.
   202  	numChunks := (r.fileSize - r.off) / ChunkSize
   203  	chunksPerShard := float64(numChunks) / float64(nshard)
   204  	startOff := r.off
   205  	r.off = startOff + int64(float64(start)*chunksPerShard)*ChunkSize
   206  	r.limit = startOff + int64(float64(limit)*chunksPerShard)*ChunkSize
   207  	if start == 0 {
   208  		// No more work to do. We assume LimitShard is called on a block boundary.
   209  		return
   210  	}
   211  	r.err.Set(Seek(r.r, r.off))
   212  	if r.err.Err() != nil {
   213  		return
   214  	}
   215  	var header chunkHeader
   216  	if !r.readChunkHeader(&header) {
   217  		return
   218  	}
   219  	if r.err.Err() != nil {
   220  		return
   221  	}
   222  	index := header.Index()
   223  	if index == 0 {
   224  		// No more work to do: we're already on a block boundary.
   225  		return
   226  	}
   227  	// We're in the middle of a block. The current block belongs to the
   228  	// previous shard, so we forward to the next block boundary.
   229  	total := header.TotalChunks()
   230  	if total <= index {
   231  		r.err.Set(errors.New("invalid chunk header"))
   232  		return
   233  	}
   234  	r.off += ChunkSize * int64(total-index)
   235  	r.err.Set(Seek(r.r, r.off))
   236  }
   237  
   238  // Tell returns the file offset of the next block to be read.
   239  // Any error is reported in r.Err()
   240  func (r *ChunkScanner) Tell() int64 {
   241  	return r.off
   242  }
   243  
   244  // Seek moves the read pointer so that next Scan() will move to the block at the
   245  // given file offset. Any error is reported in r.Err()
   246  func (r *ChunkScanner) Seek(off int64) { // "go vet" complaint expected
   247  	r.off = off
   248  	r.err.Set(Seek(r.r, off))
   249  }
   250  
   251  // Scan reads the next block. It returns false on EOF or any error.
   252  // AnB error is reported in r.Err()
   253  func (r *ChunkScanner) Scan() bool {
   254  	r.resetChunks()
   255  	r.magic = MagicInvalid
   256  	if r.err.Err() != nil {
   257  		return false
   258  	}
   259  	if r.off >= r.limit {
   260  		r.err.Set(io.EOF)
   261  		return false
   262  	}
   263  	totalChunks := -1
   264  	for {
   265  		chunkMagic, _, nchunks, index, chunkPayload := r.readChunk()
   266  		if chunkMagic == MagicInvalid || r.err.Err() != nil {
   267  			return false
   268  		}
   269  		if len(r.chunks) == 0 {
   270  			r.magic = chunkMagic
   271  			totalChunks = nchunks
   272  		}
   273  		if chunkMagic != r.magic {
   274  			r.err.Set(fmt.Errorf("Magic number changed in the middle of a chunk sequence, got %v, expect %v",
   275  				r.magic, chunkMagic))
   276  			return false
   277  		}
   278  		if len(r.chunks) != index {
   279  			r.err.Set(fmt.Errorf("Chunk index mismatch, got %v, expect %v for magic %x",
   280  				index, len(r.chunks), r.magic))
   281  			return false
   282  		}
   283  		if nchunks != totalChunks {
   284  			r.err.Set(fmt.Errorf("Chunk nchunk mismatch, got %v, expect %v for magic %x",
   285  				nchunks, totalChunks, r.magic))
   286  			return false
   287  		}
   288  		r.chunks = append(r.chunks, chunkPayload)
   289  		if index == totalChunks-1 {
   290  			break
   291  		}
   292  	}
   293  	return true
   294  }
   295  
   296  // Block returns the current block contents.
   297  //
   298  // REQUIRES: Last Scan() call returned true.
   299  func (r *ChunkScanner) Block() (MagicBytes, [][]byte) {
   300  	return r.magic, r.chunks
   301  }
   302  
   303  func (r *ChunkScanner) readChunkHeader(header *chunkHeader) bool {
   304  	_, err := io.ReadFull(r.r, header[:])
   305  	if err != nil {
   306  		r.err.Set(err)
   307  		return false
   308  	}
   309  	r.off, err = r.r.Seek(-ChunkHeaderSize, io.SeekCurrent)
   310  	r.err.Set(err)
   311  	return true
   312  }
   313  
   314  // Read one chunk. On Error or EOF, returns MagicInvalid. The caller should
   315  // check r.err.Err() to distinguish EOF and a real error.
   316  func (r *ChunkScanner) readChunk() (MagicBytes, chunkFlag, int, int, []byte) {
   317  	chunkBuf := r.allocChunk()
   318  	n, err := io.ReadFull(r.r, chunkBuf)
   319  	r.off += int64(n)
   320  	if err != nil {
   321  		r.err.Set(err)
   322  		return MagicInvalid, chunkFlag(0), 0, 0, nil
   323  	}
   324  	header := chunkBuf[:ChunkHeaderSize]
   325  
   326  	var magic MagicBytes
   327  	copy(magic[:], header[:])
   328  	expectedCsum := binary.LittleEndian.Uint32(header[8:])
   329  	flag := chunkFlag(binary.LittleEndian.Uint32(header[12:]))
   330  	size := binary.LittleEndian.Uint32(header[16:])
   331  	totalChunks := int(binary.LittleEndian.Uint32(header[20:]))
   332  	index := int(binary.LittleEndian.Uint32(header[24:]))
   333  	if size > MaxChunkPayloadSize {
   334  		r.err.Set(fmt.Errorf("Invalid chunk size %d", size))
   335  		return MagicInvalid, chunkFlag(0), 0, 0, nil
   336  	}
   337  
   338  	chunkPayload := chunkBuf[ChunkHeaderSize : ChunkHeaderSize+size]
   339  	actualCsum := crc32.Checksum(chunkBuf[12:ChunkHeaderSize+size], IEEECRC)
   340  	if expectedCsum != actualCsum {
   341  		r.err.Set(fmt.Errorf("Chunk checksum mismatch, expect %d, got %d",
   342  			actualCsum, expectedCsum))
   343  	}
   344  	return magic, flag, totalChunks, index, chunkPayload
   345  }
   346  
   347  func (r *ChunkScanner) resetChunks() {
   348  	// Avoid keeping too much data in the freepool.  If the pool size exceeds 2x
   349  	// the avg size of recent blocks, trim it down.
   350  	nChunks := float64(len(r.chunks))
   351  	if r.approxChunksPerBlock == 0 {
   352  		r.approxChunksPerBlock = nChunks
   353  	} else {
   354  		r.approxChunksPerBlock =
   355  			r.approxChunksPerBlock*0.9 + nChunks*0.1
   356  	}
   357  	max := int(r.approxChunksPerBlock*2) + 1
   358  	if len(r.pool) > max {
   359  		r.pool = r.pool[:max]
   360  	}
   361  	r.unused = 0
   362  	r.chunks = r.chunks[:0]
   363  }
   364  
   365  func (r *ChunkScanner) allocChunk() []byte {
   366  	for len(r.pool) <= r.unused {
   367  		r.pool = append(r.pool, make([]byte, ChunkSize))
   368  	}
   369  	b := r.pool[r.unused]
   370  	r.unused++
   371  	if len(b) != ChunkSize {
   372  		panic(r)
   373  	}
   374  	return b
   375  }
   376  
   377  // ReadLastBlock reads the trailer. Sets err if the trailer does not exist, or
   378  // is corrupt. After the call, the read pointer is at an undefined position so
   379  // the user must call Seek() explicitly.
   380  func (r *ChunkScanner) ReadLastBlock() (MagicBytes, [][]byte) {
   381  	var err error
   382  	r.off, err = r.r.Seek(-ChunkSize, io.SeekEnd)
   383  	if err != nil {
   384  		r.err.Set(err)
   385  		return MagicInvalid, nil
   386  	}
   387  	magic, _, totalChunks, index, payload := r.readChunk()
   388  	if magic != MagicTrailer {
   389  		r.err.Set(fmt.Errorf("Missing magic trailer; found %v", magic))
   390  		return MagicInvalid, nil
   391  	}
   392  	if index == 0 && totalChunks == 1 {
   393  		// Fast path for a single-chunk trailer.
   394  		return magic, [][]byte{payload}
   395  	}
   396  	// Seek to the beginning of the block.
   397  	r.off, err = r.r.Seek(-int64(index+1)*ChunkSize, io.SeekEnd)
   398  	if err != nil {
   399  		r.err.Set(err)
   400  		return MagicInvalid, nil
   401  	}
   402  	if !r.Scan() {
   403  		r.err.Set(fmt.Errorf("Failed to read trailer"))
   404  		return MagicInvalid, nil
   405  	}
   406  	return r.magic, r.chunks
   407  }