github.com/grailbio/base@v0.0.11/recordio/writerv2.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package recordio
     6  
     7  import (
     8  	"encoding/binary"
     9  	"fmt"
    10  	"io"
    11  	"sync"
    12  
    13  	"github.com/grailbio/base/errors"
    14  	"github.com/grailbio/base/recordio/internal"
    15  )
    16  
    17  const (
    18  	// DefaultFlushParallelism is the default value for WriterOpts.MaxFlushParallelism.
    19  	DefaultFlushParallelism = uint32(8)
    20  
    21  	// MaxFlushParallelism is the max allowed value for WriterOpts.MaxFlushParallelism.
    22  	MaxFlushParallelism = uint32(128)
    23  
    24  	// MaxPackedItems defines the max items that can be
    25  	// packed into a single record by a PackedWriter.
    26  	MaxPackedItems = uint32(10 * 1024 * 1024)
    27  	// DefaultPackedItems defines the default number of items that can
    28  	// be packed into a single record by a PackedWriter.
    29  	DefaultPackedItems = uint32(16 * 1024)
    30  )
    31  
    32  // ItemLocation identifies the location of an item in a recordio file.
    33  type ItemLocation struct {
    34  	// Location of the first byte of the block within the file. Unit is bytes.
    35  	Block uint64
    36  	// Index of the item within the block. The Nth item in the block (N=1,2,...)
    37  	// has value N-1.
    38  	Item int
    39  }
    40  
    41  // IndexFunc runs after an item is flushed to storage.  Parameter "loc" is the
    42  // location of the item in the file.  It can be later passed to Reader.Seek
    43  // method to seek to the item.
    44  type IndexFunc func(loc ItemLocation, item interface{}) error
    45  
    46  // WriterOpts defines options used when creating a new writer.
    47  type WriterOpts struct {
    48  	// Marshal is called for every item added by Append. It serializes the the
    49  	// record. If Marshal is nil, it defaults to a function that casts the value
    50  	// to []byte and returns it. Marshal may be called concurrently.
    51  	Marshal MarshalFunc
    52  
    53  	// Index is called for every item added, just before it is written to
    54  	// storage. Index callback may be called concurrently and out of order of
    55  	// locations.
    56  	//
    57  	// After Index is called, the Writer guarantees that it never touches
    58  	// the value again. The application may recycle the value in a freepool, if it
    59  	// desires. Index may be nil.
    60  	Index IndexFunc
    61  
    62  	// Transformer specifies a list of functions to compress, encrypt, or modify
    63  	// data in any other way, just before a block is written to storage.
    64  	//
    65  	// Each entry in Transformer must be of form "name" or "name config.."  The
    66  	// "name" is matched against the registry (see RegisterTransformer).  The
    67  	// "config" part is passed to the transformer factory function.  If "name" is
    68  	// not registered, the writer will fail immediately.
    69  	//
    70  	// If Transformers contains multiple strings, Transformers[0] is invoked
    71  	// first, then its results are passed to Transformers[1], so on.
    72  	//
    73  	// If len(Transformers)==0, then an identity transformer is used. It will
    74  	// return the block as is.
    75  	//
    76  	// Recordio package includes the following standard transformers:
    77  	//
    78  	//  "zstd N" (N is -1 or an integer from 0 to 22): zstd compression level N.
    79  	//  If " N" part is omitted or N=-1, the default compression level is used.
    80  	//  To use zstd, import the 'recordiozstd' package and call
    81  	//  'recordiozstd.Init()' in an init() function.
    82  	//
    83  	//  "flate N" (N is -1 or an integer from 0 to 9): flate compression level N.
    84  	//  If " N" part is omitted or N=-1, the default compression level is used.
    85  	//  To use flate, import the 'recordioflate' package and call
    86  	//  'recordioflate.Init()' in an init() function.
    87  	Transformers []string
    88  
    89  	// MaxItems is the maximum number of items to pack into a single record.
    90  	// It defaults to DefaultPackedItems if set to 0.
    91  	// If MaxItems exceeds MaxPackedItems it will silently set to MaxPackedItems.
    92  	MaxItems uint32
    93  
    94  	// MaxFlushParallelism limits the maximum number of block flush operations in
    95  	// flight before blocking the application. It defaults to
    96  	// DefaultMaxFlushParallelism.
    97  	MaxFlushParallelism uint32
    98  
    99  	// REQUIRES: AddHeader(KeyTrailer, true) has been called or the KeyTrailer
   100  	// option set to true.
   101  	KeyTrailer bool
   102  
   103  	// SkipHeader skips writing out the header and starts in the
   104  	// `wStateWritingBody` state.
   105  	SkipHeader bool
   106  
   107  	// TODO(saito) Consider providing a flag to allow out-of-order writes, like
   108  	// ConcurrentPackedWriter.
   109  }
   110  
   111  // Writer defines an interface for recordio writer. An implementation must be
   112  // thread safe.
   113  //
   114  // Legal path expression is defined below. Err can be called at any time, so it
   115  // is not included in the expression. ? means 0 or 1 call, * means 0 or more
   116  // calls.
   117  //
   118  //   AddHeader*
   119  //   (Append|Flush)*
   120  //   SetTrailer?
   121  //   Finish
   122  type Writer interface {
   123  	// Add an arbitrary metadata to the file. This method must be called
   124  	// before any other Append* or Set* functions. If the key had been already added
   125  	// to the header, this method will overwrite it with the value.
   126  	//
   127  	// REQUIRES: Append, SetTrailer, Finish have not been called.
   128  	AddHeader(key string, value interface{})
   129  
   130  	// Write one item. The marshaler will be eventually called to
   131  	// serialize the item.  The type of v must match the input type for
   132  	// the Marshal function passed when the writer is created. Note that
   133  	// since marhsalling is performed asynchronously, the object passed
   134  	// to append should be considered owned by the writer, and must not
   135  	// be reused by the caller.
   136  	//
   137  	// The writer flushes items to the storage in the order of addition.
   138  	//
   139  	// REQUIRES: Finish and SetTrailer have not been called.
   140  	Append(v interface{})
   141  
   142  	// Schedule to flush the current block. The next item will be written in a new
   143  	// block. This method just schedules for flush, and returns before the block
   144  	// is actually written to storage. Call Wait to wait for Flush to finish.
   145  	Flush()
   146  
   147  	// Block the caller until all the prior Flush calls finish.
   148  	Wait()
   149  
   150  	// Add an arbitrary data at the end of the file. After this function, no
   151  	// {Add*,Append*,Set*} functions may be called.
   152  	//
   153  	// REQUIRES: AddHeader(KeyTrailer, true) has been called.
   154  	SetTrailer([]byte)
   155  
   156  	// Err returns any error encountered by the writer. Once Err() becomes
   157  	// non-nil, it stays so.
   158  	Err() error
   159  
   160  	// Finish must be called at the end of writing. Finish will internally call
   161  	// Flush, then returns the value of Err. No method, other than Err, shall be
   162  	// called in a future.
   163  	Finish() error
   164  }
   165  
   166  type blockType int
   167  
   168  const (
   169  	bTypeInvalid blockType = iota
   170  	bTypeHeader
   171  	bTypeBody
   172  	bTypeTrailer
   173  )
   174  
   175  var magicv2Bytes = []internal.MagicBytes{
   176  	internal.MagicInvalid,
   177  	internal.MagicHeader,
   178  	internal.MagicPacked,
   179  	internal.MagicTrailer,
   180  }
   181  
   182  // Contents of one recordio block.
   183  type writerv2Block struct {
   184  	bType blockType
   185  
   186  	// Objects added by Append.
   187  	objects []interface{}
   188  	rawData []byte
   189  
   190  	// Result of serializing objects.  {bufs,objects} are used iff btype = body
   191  	serialized []byte
   192  
   193  	// Block write order.  The domain is (0,1,2,...)
   194  	flushSeq int
   195  
   196  	// Tmp used during data serialization
   197  	tmpBuf [][]byte
   198  }
   199  
   200  func (b *writerv2Block) reset() {
   201  	b.serialized = b.serialized[:0]
   202  	b.objects = b.objects[:0]
   203  	b.bType = bTypeInvalid
   204  }
   205  
   206  // State of the writerv2. The state transitions in one direction only.
   207  type writerState int
   208  
   209  const (
   210  	// No writes started. AddHeader() can be done in this state only.
   211  	wStateInitial writerState = iota
   212  	// The main state. Append and Flush can be called.
   213  	wStateWritingBody
   214  	// State after a SetTrailer call.
   215  	wStateWritingTrailer
   216  	// State after Finish call.
   217  	wStateFinished
   218  )
   219  
   220  // Implementation of Writer
   221  type writerv2 struct {
   222  	// List of empty writerv2Blocks. Capacity is fixed at
   223  	// opts.MaxFlushParallelism.
   224  	freeBlocks chan *writerv2Block
   225  	opts       WriterOpts
   226  	err        errors.Once
   227  	fq         flushQueue
   228  
   229  	mu           sync.Mutex
   230  	state        writerState
   231  	header       ParsedHeader
   232  	curBodyBlock *writerv2Block
   233  }
   234  
   235  // For serializing block writes. Thread safe.
   236  type flushQueue struct {
   237  	freeBlocks chan *writerv2Block   // Copy of writerv2.freeBlocks.
   238  	opts       WriterOpts            // Copy of writerv2.opts.
   239  	err        *errors.Once          // Copy of writerv2.err.
   240  	wr         *internal.ChunkWriter // Raw chunk writer.
   241  
   242  	transform TransformFunc
   243  
   244  	mu sync.Mutex
   245  	// flushing is true iff. flushBlocks() is scheduled.
   246  	flushing bool
   247  	// block sequence numbers are dense integer sequence (0, 1, 2, ...)  assigned
   248  	// to blocks. Blocks are written to the storage in the sequence order.
   249  	nextSeq int                    // Seq# to be assigned to the next block.
   250  	lastSeq int                    // Seq# of last block flushed to storage.
   251  	queue   map[int]*writerv2Block // Blocks ready to be flushed. Keys are seq#s.
   252  }
   253  
   254  // Assign a new block-flush sequence number.
   255  func (fq *flushQueue) newSeq() int {
   256  	fq.mu.Lock()
   257  	seq := fq.nextSeq
   258  	fq.nextSeq++
   259  	fq.mu.Unlock()
   260  	return seq
   261  }
   262  
   263  func idMarshal(scratch []byte, v interface{}) ([]byte, error) {
   264  	return v.([]byte), nil
   265  }
   266  
   267  // NewWriter creates a new writer.  New users should use this class instead of
   268  // Writer, PackedWriter, or ConcurrentPackedWriter.
   269  //
   270  // Caution: files created by this writer cannot be read by a legacy
   271  // recordio.Scanner.
   272  func NewWriter(wr io.Writer, opts WriterOpts) Writer {
   273  	if opts.Marshal == nil {
   274  		opts.Marshal = idMarshal
   275  	}
   276  	if opts.MaxItems == 0 {
   277  		opts.MaxItems = DefaultPackedItems
   278  	}
   279  	if opts.MaxItems > MaxPackedItems {
   280  		opts.MaxItems = MaxPackedItems
   281  	}
   282  	if opts.MaxFlushParallelism == 0 {
   283  		opts.MaxFlushParallelism = DefaultFlushParallelism
   284  	}
   285  	if opts.MaxFlushParallelism > MaxFlushParallelism {
   286  		opts.MaxFlushParallelism = MaxFlushParallelism
   287  	}
   288  
   289  	w := &writerv2{
   290  		opts:       opts,
   291  		freeBlocks: make(chan *writerv2Block, opts.MaxFlushParallelism),
   292  	}
   293  
   294  	if opts.SkipHeader {
   295  		w.state = wStateWritingBody
   296  	} else {
   297  		for _, val := range opts.Transformers {
   298  			w.header = append(w.header, KeyValue{KeyTransformer, val})
   299  		}
   300  	}
   301  	if opts.KeyTrailer {
   302  		w.header = append(w.header, KeyValue{KeyTrailer, true})
   303  	}
   304  
   305  	w.fq = flushQueue{
   306  		wr:         internal.NewChunkWriter(wr, &w.err),
   307  		opts:       opts,
   308  		freeBlocks: w.freeBlocks,
   309  		err:        &w.err,
   310  		lastSeq:    -1,
   311  		queue:      make(map[int]*writerv2Block),
   312  	}
   313  	for i := uint32(0); i < opts.MaxFlushParallelism; i++ {
   314  		w.freeBlocks <- &writerv2Block{
   315  			objects: make([]interface{}, 0, opts.MaxItems+1),
   316  		}
   317  	}
   318  	var err error
   319  	if w.fq.transform, err = registry.getTransformer(opts.Transformers); err != nil {
   320  		w.err.Set(err)
   321  	}
   322  	return w
   323  }
   324  
   325  func (w *writerv2) AddHeader(key string, value interface{}) {
   326  	w.mu.Lock()
   327  	if w.state != wStateInitial {
   328  		panic(fmt.Sprintf("AddHeader: wrong state: %v", w.state))
   329  	}
   330  	w.header = append(w.header, KeyValue{key, value})
   331  	w.mu.Unlock()
   332  }
   333  
   334  func (w *writerv2) startFlushHeader() {
   335  	data, err := w.header.marshal()
   336  	if err != nil {
   337  		w.err.Set(err)
   338  		return
   339  	}
   340  	b := <-w.freeBlocks
   341  	b.bType = bTypeHeader
   342  	b.rawData = data
   343  	b.flushSeq = w.fq.newSeq()
   344  	go w.fq.serializeAndEnqueueBlock(b)
   345  }
   346  
   347  func (w *writerv2) startFlushBodyBlock() {
   348  	b := w.curBodyBlock
   349  	w.curBodyBlock = nil
   350  	b.bType = bTypeBody
   351  	b.flushSeq = w.fq.newSeq()
   352  	go w.fq.serializeAndEnqueueBlock(b)
   353  }
   354  
   355  func (w *writerv2) Append(v interface{}) {
   356  	w.mu.Lock()
   357  	if w.state == wStateInitial {
   358  		w.startFlushHeader()
   359  		w.state = wStateWritingBody
   360  	} else if w.state != wStateWritingBody {
   361  		panic(fmt.Sprintf("Append: wrong state: %v", w.state))
   362  	}
   363  	if w.curBodyBlock == nil {
   364  		w.curBodyBlock = <-w.freeBlocks
   365  	}
   366  	w.curBodyBlock.objects = append(w.curBodyBlock.objects, v)
   367  	if len(w.curBodyBlock.objects) >= cap(w.curBodyBlock.objects) {
   368  		w.startFlushBodyBlock()
   369  	}
   370  	w.mu.Unlock()
   371  }
   372  
   373  func (w *writerv2) Flush() {
   374  	w.mu.Lock()
   375  	if w.state == wStateInitial {
   376  		w.mu.Unlock()
   377  		return
   378  	}
   379  	if w.state != wStateWritingBody {
   380  		panic(fmt.Sprintf("Flush: wrong state: %v", w.state))
   381  	}
   382  	if w.curBodyBlock != nil {
   383  		w.startFlushBodyBlock()
   384  	}
   385  	w.mu.Unlock()
   386  }
   387  
   388  func generatePackedHeaderv2(items [][]byte) []byte {
   389  	// 1 varint for # items, n for the size of each of n items.
   390  	hdrSize := (len(items) + 1) * binary.MaxVarintLen32
   391  	hdr := make([]byte, hdrSize)
   392  
   393  	// Write the number of items in this record.
   394  	pos := binary.PutUvarint(hdr, uint64(len(items)))
   395  	// Write the size of each item.
   396  	for _, p := range items {
   397  		pos += binary.PutUvarint(hdr[pos:], uint64(len(p)))
   398  	}
   399  	hdr = hdr[:pos]
   400  	return hdr
   401  }
   402  
   403  // Produce a packed recordio block.
   404  func (fq *flushQueue) serializeBlock(b *writerv2Block) {
   405  	getChunks := func(n int) [][]byte {
   406  		if cap(b.tmpBuf) >= n+1 {
   407  			b.tmpBuf = b.tmpBuf[:n+1]
   408  		} else {
   409  			b.tmpBuf = make([][]byte, n+1)
   410  		}
   411  		return b.tmpBuf
   412  	}
   413  	if fq.err.Err() != nil {
   414  		return
   415  	}
   416  	var tmpBuf [][]byte // tmpBuf[0] is for the packed header.
   417  	if b.bType == bTypeBody {
   418  		tmpBuf = getChunks(len(b.objects))
   419  		// Marshal items into bytes.
   420  		for i, v := range b.objects {
   421  			s, err := fq.opts.Marshal(tmpBuf[i+1], v)
   422  			if err != nil {
   423  				fq.err.Set(err)
   424  			}
   425  			tmpBuf[i+1] = s
   426  		}
   427  	} else {
   428  		tmpBuf = getChunks(1)
   429  		tmpBuf[1] = b.rawData
   430  	}
   431  
   432  	tmpBuf[0] = generatePackedHeaderv2(tmpBuf[1:])
   433  	transform := idTransform
   434  	if b.bType == bTypeBody || b.bType == bTypeTrailer {
   435  		transform = fq.transform
   436  	}
   437  
   438  	var err error
   439  	if b.serialized, err = transform(b.serialized, tmpBuf); err != nil {
   440  		fq.err.Set(err)
   441  	}
   442  }
   443  
   444  // Schedule "b" for writes. Caller must have marshaled and transformed "b"
   445  // before the call.  It's ok to call enqueue concurrently; blocks are written to
   446  // the storage in flushSeq order.
   447  func (fq *flushQueue) enqueueBlock(b *writerv2Block) {
   448  	fq.mu.Lock()
   449  	fq.queue[b.flushSeq] = b
   450  	if !fq.flushing && b.flushSeq == fq.lastSeq+1 {
   451  		fq.flushing = true
   452  		fq.mu.Unlock()
   453  		fq.flushBlocks()
   454  	} else {
   455  		fq.mu.Unlock()
   456  	}
   457  }
   458  
   459  func (fq *flushQueue) serializeAndEnqueueBlock(b *writerv2Block) {
   460  	fq.serializeBlock(b)
   461  	fq.enqueueBlock(b)
   462  }
   463  
   464  func (fq *flushQueue) flushBlocks() {
   465  	fq.mu.Lock()
   466  	if !fq.flushing {
   467  		panic(fq)
   468  	}
   469  
   470  	for {
   471  		b, ok := fq.queue[fq.lastSeq+1]
   472  		if !ok {
   473  			break
   474  		}
   475  		delete(fq.queue, b.flushSeq)
   476  		fq.lastSeq++
   477  		fq.mu.Unlock()
   478  
   479  		fq.flushBlock(b)
   480  		b.reset()
   481  		fq.freeBlocks <- b
   482  		fq.mu.Lock()
   483  	}
   484  	if !fq.flushing {
   485  		panic(fq)
   486  	}
   487  	fq.flushing = false
   488  	fq.mu.Unlock()
   489  }
   490  
   491  func (fq *flushQueue) flushBlock(b *writerv2Block) {
   492  	offset := uint64(fq.wr.Len())
   493  	if fq.err.Err() == nil {
   494  		fq.wr.Write(magicv2Bytes[b.bType], b.serialized)
   495  	}
   496  	if b.bType == bTypeBody && fq.opts.Index != nil {
   497  		// Call the indexing funcs.
   498  		//
   499  		// TODO(saito) Run this code in a separate thread.
   500  		ifn := fq.opts.Index
   501  		for i := range b.objects {
   502  			loc := ItemLocation{Block: offset, Item: i}
   503  			if err := ifn(loc, b.objects[i]); err != nil {
   504  				fq.err.Set(err)
   505  			}
   506  		}
   507  	}
   508  }
   509  
   510  func (w *writerv2) SetTrailer(data []byte) {
   511  	w.mu.Lock()
   512  	if !w.header.HasTrailer() {
   513  		panic(fmt.Sprintf("settrailer: Key '%v' must be set to true", KeyTrailer))
   514  	}
   515  	if w.state == wStateInitial {
   516  		w.startFlushHeader()
   517  	} else if w.state == wStateWritingBody {
   518  		if w.curBodyBlock != nil {
   519  			w.startFlushBodyBlock()
   520  		}
   521  	} else {
   522  		panic(fmt.Sprintf("SetTrailer: wrong state: %v", w.state))
   523  	}
   524  	if w.curBodyBlock != nil {
   525  		panic(w)
   526  	}
   527  	w.state = wStateWritingTrailer
   528  	w.mu.Unlock()
   529  
   530  	b := <-w.freeBlocks
   531  	b.bType = bTypeTrailer
   532  	b.rawData = make([]byte, len(data))
   533  	copy(b.rawData, data)
   534  	b.flushSeq = w.fq.newSeq()
   535  	go w.fq.serializeAndEnqueueBlock(b)
   536  }
   537  
   538  func (w *writerv2) Err() error {
   539  	return w.err.Err()
   540  }
   541  
   542  func (w *writerv2) Wait() {
   543  	w.mu.Lock()
   544  	n := 0
   545  	if w.curBodyBlock != nil {
   546  		n++
   547  	}
   548  
   549  	tmp := make([]*writerv2Block, 0, cap(w.freeBlocks))
   550  	for n < cap(w.freeBlocks) {
   551  		b := <-w.freeBlocks
   552  		tmp = append(tmp, b)
   553  		n++
   554  	}
   555  
   556  	for _, b := range tmp {
   557  		w.freeBlocks <- b
   558  	}
   559  	w.mu.Unlock()
   560  }
   561  
   562  func (w *writerv2) Finish() error {
   563  	if w.state == wStateInitial {
   564  		w.startFlushHeader()
   565  		w.state = wStateWritingBody
   566  	}
   567  	if w.state == wStateWritingBody {
   568  		if w.curBodyBlock != nil {
   569  			w.startFlushBodyBlock()
   570  		}
   571  	} else if w.state != wStateWritingTrailer {
   572  		panic(w)
   573  	}
   574  	if w.curBodyBlock != nil {
   575  		w.startFlushBodyBlock()
   576  	}
   577  	w.state = wStateFinished
   578  	// Drain all ongoing flushes.
   579  	for i := 0; i < cap(w.freeBlocks); i++ {
   580  		<-w.freeBlocks
   581  	}
   582  	close(w.freeBlocks)
   583  	if len(w.fq.queue) > 0 {
   584  		panic(w)
   585  	}
   586  	return w.err.Err()
   587  }