github.com/grailbio/base@v0.0.11/recordio/scannerv2.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package recordio
     6  
     7  import (
     8  	"encoding/binary"
     9  	"fmt"
    10  	"io"
    11  	"sync"
    12  
    13  	"github.com/grailbio/base/errors"
    14  	"github.com/grailbio/base/recordio/internal"
    15  )
    16  
    17  var scannerFreePool = sync.Pool{
    18  	New: func() interface{} {
    19  		return &scannerv2{}
    20  	},
    21  }
    22  
    23  // rawItemList is the result of uncompressing & parsing one recordio block.
    24  type rawItemList struct {
    25  	bytes    []byte // raw bytes, post transformation.
    26  	firstOff int    // bytes[firstOff:] contain the application payload
    27  	cumSize  []int  // cumSize[x] is the cumulative bytesize of items [0,x].
    28  }
    29  
    30  func (ri *rawItemList) clear() {
    31  	ri.bytes = ri.bytes[:0]
    32  	ri.cumSize = ri.cumSize[:0]
    33  	ri.firstOff = 0
    34  }
    35  
    36  // len returns the number of items in the block.
    37  func (ri *rawItemList) len() int { return len(ri.cumSize) }
    38  
    39  // item returns the i'th (base 0) item.
    40  //
    41  // REQUIRES: 0 <= i < ri.len().
    42  func (ri *rawItemList) item(i int) []byte {
    43  	startOff := ri.firstOff
    44  	if i > 0 {
    45  		startOff += ri.cumSize[i-1]
    46  	}
    47  	limitOff := ri.firstOff + ri.cumSize[i]
    48  	return ri.bytes[startOff:limitOff]
    49  }
    50  
    51  // Given block contents, apply transformation if any, and parse it into a list
    52  // of items. If transform is nil, it defaults to identity.
    53  func parseChunksToItems(rawItems *rawItemList, chunks [][]byte, transform TransformFunc) error {
    54  	if transform == nil {
    55  		// TODO(saito) Allow TransformFunc to return an iov, and refactor the rest
    56  		// of the codebase to consume it.
    57  		transform = idTransform
    58  	}
    59  	var err error
    60  	if rawItems.bytes != nil {
    61  		// zstd doesn't like an empty slice (zstd.go:100)
    62  		//
    63  		// TODO(saito) fix upstream.
    64  		rawItems.bytes = rawItems.bytes[:cap(rawItems.bytes)]
    65  	}
    66  	if rawItems.bytes, err = transform(rawItems.bytes, chunks); err != nil {
    67  		return err
    68  	}
    69  	block := rawItems.bytes
    70  	unItems, n := binary.Uvarint(block)
    71  	if n <= 0 {
    72  		return fmt.Errorf("recordio: failed to read number of packed items: %v", n)
    73  	}
    74  	nItems := int(unItems)
    75  	pos := n
    76  
    77  	if cap(rawItems.cumSize) < nItems {
    78  		rawItems.cumSize = make([]int, nItems)
    79  	} else {
    80  		rawItems.cumSize = rawItems.cumSize[:nItems]
    81  	}
    82  	total := 0
    83  	for i := 0; i < nItems; i++ {
    84  		size, n := binary.Uvarint(block[pos:])
    85  		if n <= 0 {
    86  			return fmt.Errorf("recordio: likely corrupt data, failed to read size of packed item %v: %v", i, n)
    87  		}
    88  		total += int(size)
    89  		rawItems.cumSize[i] = total
    90  		pos += n
    91  	}
    92  	rawItems.firstOff = pos
    93  	if total+pos != len(block) {
    94  		return fmt.Errorf("recordio: corrupt block header, got block size %d, expected %d", len(block), total+pos)
    95  	}
    96  	return nil
    97  }
    98  
    99  // ScannerOpts defines options used when creating a new scanner.
   100  type ScannerOpts struct {
   101  	// LegacyTransform is used only to read the legacy recordio files. For the V2
   102  	// recordio files, this field is ignored, and transformers are constructed
   103  	// from the header metadata.
   104  	LegacyTransform TransformFunc
   105  
   106  	// Unmarshal transforms a byte slice into an application object. It is called
   107  	// for every item read from storage. If nil, a function that returns []byte
   108  	// unchanged is used. The return value from Unmarshal can be retrieved using
   109  	// the Scanner.Get method.
   110  	Unmarshal func(in []byte) (out interface{}, err error)
   111  }
   112  
   113  // Scanner defines an interface for recordio scanner.
   114  //
   115  // A Scanner implementation must be thread safe.  Legal path expression is
   116  // defined below. Err, Header, and Trailer can be called at any time.
   117  //
   118  //   ((Scan Get*) | Seek)* Finish
   119  //
   120  type Scanner interface {
   121  	// Header returns the contents of the header block.
   122  	Header() ParsedHeader
   123  
   124  	// Scan returns true if a new record was read, false otherwise. It will return
   125  	// false on encountering an error; the error may be retrieved using the Err
   126  	// method. Note, that Scan will reuse storage from one invocation to the next.
   127  	Scan() bool
   128  
   129  	// Get returns the current item as read by a prior call to Scan.
   130  	//
   131  	// REQUIRES: Preceding Scan calls have returned true. There is no Seek
   132  	// call between the last Scan call and the Get call.
   133  	Get() interface{}
   134  
   135  	// Err returns any error encountered by the writer. Once Err() becomes
   136  	// non-nil, it stays so.
   137  	Err() error
   138  
   139  	// Set up so that the next Scan() call causes the pointer to move to the given
   140  	// location.  On any error, Err() will be set.
   141  	//
   142  	// REQUIRES: loc must be one of the values passed to the Index callback
   143  	// during writes.
   144  	Seek(loc ItemLocation)
   145  
   146  	// Trailer returns the trailer block contents.  If the trailer does not exist,
   147  	// or is corrupt, it returns nil.  The caller should examine Err() if Trailer
   148  	// returns nil.
   149  	Trailer() []byte
   150  
   151  	// Return the file format version. Not for general use.
   152  	Version() FormatVersion
   153  
   154  	// Finish should be called exactly once, after the application has finished
   155  	// using the scanner. It returns the value of Err().
   156  	//
   157  	// The Finish method recycles the internal scanner resources for use by other
   158  	// scanners, thereby reducing GC overhead. THe application must not touch the
   159  	// scanner object after Finish.
   160  	Finish() error
   161  }
   162  
   163  type scannerv2 struct {
   164  	err         errors.Once
   165  	sc          *internal.ChunkScanner
   166  	opts        ScannerOpts
   167  	untransform TransformFunc
   168  	header      ParsedHeader
   169  
   170  	rawItems rawItemList
   171  	item     interface{}
   172  	nextItem int
   173  }
   174  
   175  func idUnmarshal(data []byte) (interface{}, error) {
   176  	return data, nil
   177  }
   178  
   179  type errorScanner struct {
   180  	err error
   181  }
   182  
   183  func (s *errorScanner) Header() (p ParsedHeader)   { return }
   184  func (s *errorScanner) Trailer() (b []byte)        { return }
   185  func (s *errorScanner) Version() (v FormatVersion) { return }
   186  func (s *errorScanner) Get() interface{}           { panic(fmt.Sprintf("errorscannerv2.Get: %v", s.err)) }
   187  func (s *errorScanner) Scan() bool                 { return false }
   188  func (s *errorScanner) Seek(ItemLocation)          {}
   189  func (s *errorScanner) Finish() error              { return s.Err() }
   190  func (s *errorScanner) Err() error {
   191  	if s.err == io.EOF {
   192  		return nil
   193  	}
   194  	return s.err
   195  }
   196  
   197  // NewScanner creates a new recordio scanner. The reader can read both legacy
   198  // recordio files (packed or unpacked) or the new-format files. Any error is
   199  // reported through the Scanner.Err method.
   200  func NewScanner(in io.ReadSeeker, opts ScannerOpts) Scanner {
   201  	return NewShardScanner(in, opts, 0, 1, 1)
   202  }
   203  
   204  // NewShardScanner creates a new sharded recordio scanner. The returned scanner
   205  // reads shard [start,limit) (of [0,nshard)) of the recordio file at the
   206  // ReadSeeker in.  Sharding is only supported for v2 recordio files; an error
   207  // scanner is returned if NewShardScanner is called for a legacy recordio file.
   208  //
   209  // NewShardScanner with shard and nshard set to 0 and 1 respectively (i.e.,
   210  // a single shard) behaves as NewScanner.
   211  func NewShardScanner(in io.ReadSeeker, opts ScannerOpts, start, limit, nshard int) Scanner {
   212  	if opts.Unmarshal == nil {
   213  		opts.Unmarshal = idUnmarshal
   214  	}
   215  	if err := internal.Seek(in, 0); err != nil {
   216  		return &errorScanner{err}
   217  	}
   218  	var magic internal.MagicBytes
   219  	if _, err := io.ReadFull(in, magic[:]); err != nil {
   220  		return &errorScanner{err}
   221  	}
   222  	if err := internal.Seek(in, 0); err != nil {
   223  		return &errorScanner{err}
   224  	}
   225  	if start >= limit || limit > nshard || start < 0 || nshard <= 0 {
   226  		return &errorScanner{fmt.Errorf("invalid sharding [%d,%d) of %d", start, limit, nshard)}
   227  	}
   228  	if magic != internal.MagicHeader {
   229  		if start != 0 || limit != 1 || nshard != 1 {
   230  			return &errorScanner{errors.New("legacy record IOs do not support sharding")}
   231  		}
   232  		return newLegacyScannerAdapter(in, opts)
   233  	}
   234  	return newScanner(in, start, limit, nshard, opts)
   235  }
   236  
   237  func newScanner(in io.ReadSeeker, start, limit, nshard int, opts ScannerOpts) Scanner {
   238  	s := scannerFreePool.Get().(*scannerv2)
   239  	if s == nil {
   240  		panic("newScannerV2")
   241  	}
   242  	s.err = errors.Once{Ignored: []error{io.EOF}}
   243  	s.opts = opts
   244  	s.untransform = nil
   245  	s.header = nil
   246  	s.nextItem = 0
   247  	s.item = nil
   248  	s.sc = internal.NewChunkScanner(in, &s.err)
   249  	s.rawItems.clear()
   250  	s.readHeader()
   251  	if s.Err() != nil {
   252  		return s
   253  	}
   254  	// Technically, we shouldn't be reading the trailer again, but
   255  	// the block scanner just ignores it anyway.
   256  	s.sc.LimitShard(start, limit, nshard)
   257  	return s
   258  }
   259  
   260  func (s *scannerv2) readSpecialBlock(expectedMagic internal.MagicBytes, tr TransformFunc) []byte {
   261  	if !s.sc.Scan() {
   262  		s.err.Set(fmt.Errorf("Failed to read block %v", expectedMagic))
   263  		return nil
   264  	}
   265  	magic, chunks := s.sc.Block()
   266  	if magic != expectedMagic {
   267  		s.err.Set(fmt.Errorf("Failed to read block, expect %v, got %v", expectedMagic, magic))
   268  		return nil
   269  	}
   270  	rawItems := rawItemList{}
   271  	err := parseChunksToItems(&rawItems, chunks, tr)
   272  	if err != nil {
   273  		s.err.Set(err)
   274  		return nil
   275  	}
   276  	if rawItems.len() != 1 {
   277  		s.err.Set(fmt.Errorf("Wrong # of items in header block, %d", rawItems.len()))
   278  		return nil
   279  	}
   280  	return rawItems.item(0)
   281  }
   282  
   283  func (s *scannerv2) readHeader() {
   284  	payload := s.readSpecialBlock(internal.MagicHeader, idTransform)
   285  	if s.err.Err() != nil {
   286  		return
   287  	}
   288  	if err := s.header.unmarshal(payload); err != nil {
   289  		s.err.Set(err)
   290  		return
   291  	}
   292  	transformers := []string{}
   293  	for _, h := range s.header {
   294  		if h.Key == KeyTransformer {
   295  			str, ok := h.Value.(string)
   296  			if !ok {
   297  				s.err.Set(fmt.Errorf("Expect string value for key %v, but found %v", h.Key, h.Value))
   298  				return
   299  			}
   300  			transformers = append(transformers, str)
   301  		}
   302  	}
   303  	var err error
   304  	s.untransform, err = registry.GetUntransformer(transformers)
   305  	s.err.Set(err)
   306  }
   307  
   308  func (s *scannerv2) Version() FormatVersion {
   309  	return V2
   310  }
   311  
   312  func (s *scannerv2) Header() ParsedHeader {
   313  	return s.header
   314  }
   315  
   316  func (s *scannerv2) Trailer() []byte {
   317  	if !s.header.HasTrailer() {
   318  		return nil
   319  	}
   320  	curOff := s.sc.Tell()
   321  	defer s.sc.Seek(curOff)
   322  
   323  	magic, chunks := s.sc.ReadLastBlock()
   324  	if s.err.Err() != nil {
   325  		return nil
   326  	}
   327  	if magic != internal.MagicTrailer {
   328  		s.err.Set(fmt.Errorf("Did not found the trailer, instead found magic %v", magic))
   329  		return nil
   330  	}
   331  	rawItems := rawItemList{}
   332  	err := parseChunksToItems(&rawItems, chunks, s.untransform)
   333  	if err != nil {
   334  		s.err.Set(err)
   335  		return nil
   336  	}
   337  	if rawItems.len() != 1 {
   338  		s.err.Set(fmt.Errorf("Expect exactly one trailer item, but found %d", rawItems.len()))
   339  		return nil
   340  	}
   341  	return rawItems.item(0)
   342  }
   343  
   344  func (s *scannerv2) Get() interface{} {
   345  	return s.item
   346  }
   347  
   348  func (s *scannerv2) Seek(loc ItemLocation) {
   349  	// TODO(saito) Avoid seeking the file if loc.Block points to the current block.
   350  	if s.err.Err() == io.EOF {
   351  		s.err = errors.Once{}
   352  	}
   353  	s.sc.Seek(int64(loc.Block))
   354  	if !s.scanNextBlock() {
   355  		return
   356  	}
   357  	if loc.Item >= s.rawItems.len() {
   358  		s.err.Set(fmt.Errorf("Invalid location %+v, block has only %d items", loc, s.rawItems.len()))
   359  	}
   360  	s.nextItem = loc.Item
   361  }
   362  
   363  func (s *scannerv2) scanNextBlock() bool {
   364  	s.rawItems.clear()
   365  	s.nextItem = 0
   366  	if s.Err() != nil {
   367  		return false
   368  	}
   369  	// Need to read the next record.
   370  	if !s.sc.Scan() {
   371  		return false
   372  	}
   373  	magic, chunks := s.sc.Block()
   374  	if magic == internal.MagicPacked {
   375  		if err := parseChunksToItems(&s.rawItems, chunks, s.untransform); err != nil {
   376  			s.err.Set(err)
   377  			return false
   378  		}
   379  		s.nextItem = 0
   380  		return true
   381  	}
   382  	if magic == internal.MagicTrailer {
   383  		// EOF
   384  		return false
   385  	}
   386  	s.err.Set(fmt.Errorf("recordio: invalid magic number: %v", magic))
   387  	return false
   388  }
   389  
   390  func (s *scannerv2) Scan() bool {
   391  	for s.nextItem >= s.rawItems.len() {
   392  		if !s.scanNextBlock() {
   393  			return false
   394  		}
   395  	}
   396  	item, err := s.opts.Unmarshal(s.rawItems.item(s.nextItem))
   397  	if err != nil {
   398  		s.err.Set(err)
   399  		return false
   400  	}
   401  	s.item = item
   402  	s.nextItem++
   403  	return true
   404  }
   405  
   406  func (s *scannerv2) Err() error {
   407  	err := s.err.Err()
   408  	if err == io.EOF {
   409  		err = nil
   410  	}
   411  	return err
   412  }
   413  
   414  func (s *scannerv2) Finish() error {
   415  	err := s.Err()
   416  	s.err = errors.Once{}
   417  	s.opts = ScannerOpts{}
   418  	s.sc = nil
   419  	s.untransform = nil
   420  	s.header = nil
   421  	s.nextItem = 0
   422  	s.item = nil
   423  	scannerFreePool.Put(s)
   424  	return err
   425  }