github.com/grailbio/base@v0.0.11/recordio/deprecated/packed.go (about)

     1  // Copyright 2017 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package deprecated
     6  
     7  import (
     8  	"fmt"
     9  	"io"
    10  	"sync"
    11  
    12  	"github.com/grailbio/base/recordio/internal"
    13  )
    14  
    15  var (
    16  	// MaxPackedItems defines the max items that can be
    17  	// packed into a single record by a PackedWriter.
    18  	MaxPackedItems = uint32(10 * 1024 * 1024)
    19  	// DefaultPackedItems defines the default number of items that can
    20  	// be packed into a single record by a PackedWriter.
    21  	DefaultPackedItems = uint32(16 * 1024)
    22  	// DefaultPackedBytes defines the default number of bytes that can
    23  	// be packed into a single record by a PackedWriter.
    24  	DefaultPackedBytes = uint32(16 * 1024 * 1024)
    25  )
    26  
    27  // RecordIndex is called every time a new record is added to a stream. It is
    28  // called with the offset and size of the record, and the number of items being
    29  // written to the record. It can optionally return a function that will be
    30  // subsequently called for each item that is written to this record. This makes
    31  // it possibly to ensure that all calls to index the items in a single record
    32  // are handled by the same method and object and hence to index records
    33  // concurrently.
    34  type RecordIndex func(recordOffset, recordLength, nitems uint64) (ItemIndexFunc, error)
    35  
    36  // ItemIndexFunc is called every item that an item is added to a record.
    37  type ItemIndexFunc func(itemOffset, itemLength uint64, v interface{}, p []byte) error
    38  
    39  // LegacyPackedWriterOpts represents the options to NewPackedWriter
    40  type LegacyPackedWriterOpts struct {
    41  	// Marshal is called to marshal an object to a byte slice.
    42  	Marshal MarshalFunc
    43  
    44  	// Index is called whenever a new record is written.
    45  	Index RecordIndex
    46  
    47  	// Flushed is called whenever a record is written.
    48  	Flushed func() error
    49  
    50  	// Transform is called when buffered data is about to be written to a record.
    51  	// It is intended for implementing data transformations such as compression
    52  	// and/or encryption. The Transform function specified here must be
    53  	// reversible by the Transform function in the Scanner.
    54  	Transform func(in [][]byte) (buf []byte, err error)
    55  
    56  	// MaxItems is the maximum number of items to pack into a single record.
    57  	// It defaults to DefaultPackedItems if set to 0.
    58  	// If MaxItems exceeds MaxPackedItems it will silently set to MaxPackedItems.
    59  	MaxItems uint32
    60  
    61  	// MaxBytes is the maximum number of bytes to pack into a single record.
    62  	// It defaults to DefaultPackedBytes if set to 0.
    63  	MaxBytes uint32
    64  }
    65  
    66  // LegacyPackedScannerOpts represents the options to NewPackedScanner
    67  type LegacyPackedScannerOpts struct {
    68  	LegacyScannerOpts
    69  
    70  	// Transform is called on the data read from a record to reverse any
    71  	// transformations performed when creating the record. It is intended
    72  	// for decompression, decryption etc.
    73  	Transform func(scratch, in []byte) (out []byte, err error)
    74  }
    75  
    76  // PackedWriter represents an interface that can be used to write multiple
    77  // items to the same recordio record.
    78  type LegacyPackedWriter interface {
    79  	// Write writes a []byte record to the supplied writer. Each call to write
    80  	// results in a new record being written.
    81  	// Calls to Write and Record may be interspersed.
    82  	io.Writer
    83  
    84  	// Marshal marshals an object priort to writing it to the underlying
    85  	// recordio stream.
    86  	Marshal(v interface{}) (n int, err error)
    87  
    88  	// Flush is called to write any currently buffered data to the current
    89  	// record. A subsequent write will result in a new record being
    90  	// written. Flush must be called to ensure that the last record is
    91  	// completely written.
    92  	Flush() error
    93  }
    94  
    95  // PackedScanner represents an interface that can be used to read items
    96  // from a recordio file written using a PackedWriter.
    97  type LegacyPackedScanner interface {
    98  	LegacyScanner
    99  }
   100  
   101  type packedWriter struct {
   102  	sync.Mutex
   103  	wr      *byteWriter
   104  	pbw     *Packer
   105  	opts    LegacyPackedWriterOpts
   106  	objects []interface{}
   107  }
   108  
   109  // NewLegacyPackedWriter is deprecated. Use NewWriterV2 instead.
   110  //
   111  // NewLegacyPackedWriter is writer that will pack up to MaxItems or MaxBytes,
   112  // whichever comes first, into a single write to the underlying recordio
   113  // stream. Callers to Write must guarantee that they will not modify the buffers
   114  // passed as arguments since Write does not make an internal copy until the
   115  // buffered data is written. A caller can count items/bytes or provide a Flushed
   116  // called to determine when it is safe to reuse any storage.  This scheme avoids
   117  // an unnecessary copy for []byte and most implementations of Marshal will
   118  // create a new buffer to store the marshaled data.
   119  func NewLegacyPackedWriter(wr io.Writer, opts LegacyPackedWriterOpts) LegacyPackedWriter {
   120  	if opts.MaxItems == 0 {
   121  		opts.MaxItems = DefaultPackedItems
   122  	}
   123  	if opts.MaxBytes == 0 {
   124  		opts.MaxBytes = DefaultPackedBytes
   125  	}
   126  	if opts.MaxItems > MaxPackedItems {
   127  		opts.MaxItems = MaxPackedItems
   128  	}
   129  	bufStorage := make([][]byte, 0, opts.MaxItems+1)
   130  	objectStorage := make([]interface{}, 0, opts.MaxItems)
   131  	subopts := PackerOpts{
   132  		Transform: opts.Transform,
   133  		Buffers:   bufStorage, // reserve the first buffer for the hdr.
   134  	}
   135  	pw := &packedWriter{
   136  		opts:    opts,
   137  		wr:      NewLegacyWriter(wr, LegacyWriterOpts{Marshal: opts.Marshal}).(*byteWriter),
   138  		pbw:     NewPacker(subopts),
   139  		objects: objectStorage,
   140  	}
   141  	pw.wr.magic = internal.MagicPacked
   142  	return pw
   143  }
   144  
   145  // Implement recordio.LegacyPackedWriter.Write.
   146  func (pw *packedWriter) Write(p []byte) (n int, err error) {
   147  	pw.Lock()
   148  	defer pw.Unlock()
   149  	if err := pw.flushIfNeeded(len(p)); err != nil {
   150  		return 0, err
   151  	}
   152  	pw.objects = append(pw.objects, nil)
   153  	return pw.pbw.Write(p)
   154  }
   155  
   156  func (pw *packedWriter) flushIfNeeded(lp int) error {
   157  	if lp > int(pw.opts.MaxBytes) {
   158  		return fmt.Errorf("buffer is too large %v > %v", lp, pw.opts.MaxBytes)
   159  	}
   160  	// lock already held.
   161  	nItems, nBytes := pw.pbw.Stored()
   162  	if ((nBytes + lp) > int(pw.opts.MaxBytes)) || ((nItems + 1) > int(pw.opts.MaxItems)) {
   163  		return pw.flush()
   164  	}
   165  	return nil
   166  }
   167  
   168  // Implement recordio.Writer.
   169  func (pw *packedWriter) Marshal(v interface{}) (n int, err error) {
   170  	mfn := pw.opts.Marshal
   171  	if mfn == nil {
   172  		return 0, fmt.Errorf("Marshal function not configured for recordio.PackedWriter")
   173  	}
   174  	p, err := mfn(nil, v)
   175  	if err != nil {
   176  		return 0, err
   177  	}
   178  	pw.Lock()
   179  	defer pw.Unlock()
   180  	if err := pw.flushIfNeeded(len(p)); err != nil {
   181  		return 0, err
   182  	}
   183  	pw.objects = append(pw.objects, v)
   184  	return pw.pbw.Write(p)
   185  }
   186  
   187  // Implement recordio.LegacyPackedWriter.Flush.
   188  func (pw *packedWriter) Flush() error {
   189  	pw.Lock()
   190  	defer pw.Unlock()
   191  	return pw.flush()
   192  }
   193  
   194  func (pw *packedWriter) flush() error {
   195  	// lock already held.
   196  	stored, _ := pw.pbw.Stored()
   197  	hdr, dataSize, buffers, err := pw.pbw.Pack()
   198  	if err != nil {
   199  		return err
   200  	}
   201  	if len(buffers) == 0 {
   202  		// It's ok to write out buffers of zero length, hence dataSize == 0
   203  		// can't be used to determine if there's no data to write out.
   204  		return nil
   205  	}
   206  	hdrSize, offset, n, err := pw.wr.writeSlices(hdr, buffers...)
   207  	if err != nil {
   208  		return err
   209  	}
   210  	if got, want := n, dataSize+len(hdr); got != want {
   211  		return fmt.Errorf("recordio: buffered write too short wrote %v instead of %v", got, want)
   212  	}
   213  
   214  	// Call the indexing funcs.
   215  	if rifn := pw.opts.Index; rifn != nil {
   216  		next := uint64(0)
   217  		ifn, err := rifn(offset, uint64(n)+hdrSize, uint64(stored))
   218  		if err != nil {
   219  			return err
   220  		}
   221  		if ifn != nil {
   222  			for i, b := range buffers {
   223  				err := ifn(next, uint64(len(b)), pw.objects[i], b)
   224  				if err != nil {
   225  					return err
   226  				}
   227  				next += uint64(len(b))
   228  			}
   229  		}
   230  	}
   231  	pw.objects = pw.objects[0:0]
   232  	// Reset everything ready for the next record.
   233  	if flfn := pw.opts.Flushed; flfn != nil {
   234  		return flfn()
   235  	}
   236  	return nil
   237  }
   238  
   239  type packedScanner struct {
   240  	err      error
   241  	sc       *LegacyScannerImpl
   242  	buffered [][]byte
   243  	record   []byte
   244  	opts     LegacyPackedScannerOpts
   245  	nextItem int
   246  	pbr      *Unpacker
   247  }
   248  
   249  // NewLegacyPackedScanner is deprecated. Use NewScannerV2 instead.
   250  func NewLegacyPackedScanner(rd io.Reader, opts LegacyPackedScannerOpts) LegacyScanner {
   251  	return &packedScanner{
   252  		sc: NewLegacyScanner(rd, opts.LegacyScannerOpts).(*LegacyScannerImpl),
   253  		pbr: NewUnpacker(UnpackerOpts{
   254  			Transform: opts.Transform,
   255  		}),
   256  		opts: opts,
   257  	}
   258  }
   259  
   260  func (ps *packedScanner) setErr(err error) {
   261  	if ps.err == nil {
   262  		ps.err = err
   263  	}
   264  }
   265  
   266  // Reset implements recordio.Scanner.Reset.
   267  func (ps *packedScanner) Reset(rd io.Reader) {
   268  	ps.sc.Reset(rd)
   269  	ps.nextItem = 0
   270  	ps.buffered = ps.buffered[:0]
   271  	ps.err = nil
   272  }
   273  
   274  // Scan implements recordio.Scanner.Scan.
   275  func (ps *packedScanner) Scan() bool {
   276  	if ps.err != nil {
   277  		return false
   278  	}
   279  	if ps.nextItem < len(ps.buffered) {
   280  		ps.record = ps.buffered[ps.nextItem]
   281  		ps.nextItem++
   282  		return true
   283  	}
   284  	// Need to read the next record.
   285  	magic, ok := ps.sc.InternalScan()
   286  	if !ok {
   287  		return false
   288  	}
   289  	if magic != internal.MagicPacked {
   290  		ps.sc.err.Set(fmt.Errorf("recordio: invalid magic number: %v, expect %v", magic, internal.MagicPacked))
   291  		return false
   292  	}
   293  	tmp, err := ps.pbr.Unpack(ps.sc.Bytes())
   294  	if err != nil {
   295  		ps.setErr(err)
   296  		return false
   297  	}
   298  	ps.buffered = tmp
   299  	ps.record = ps.buffered[0]
   300  	ps.nextItem = 1
   301  	return true
   302  }
   303  
   304  // Scan implements recordio.Scanner.Bytes.
   305  func (ps *packedScanner) Bytes() []byte {
   306  	return ps.record
   307  }
   308  
   309  // Scan implements recordio.Scanner.Err.
   310  func (ps *packedScanner) Err() error {
   311  	if ps.err != nil {
   312  		return ps.err
   313  	}
   314  	return ps.sc.Err()
   315  }
   316  
   317  // Scan implements recordio.Scanner.Unmarshal.
   318  func (ps *packedScanner) Unmarshal(v interface{}) error {
   319  	if ufn := ps.opts.Unmarshal; ufn != nil {
   320  		return ufn(ps.Bytes(), v)
   321  	}
   322  	err := fmt.Errorf("Unmarshal function not configured for recordio.PackedScanner")
   323  	ps.setErr(err)
   324  	return err
   325  }