github.com/grailbio/base@v0.0.11/stateio/reader.go (about)

     1  // Copyright 2019 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package stateio
     6  
     7  import (
     8  	"errors"
     9  	"io"
    10  	"os"
    11  
    12  	"github.com/grailbio/base/logio"
    13  )
    14  
    15  // ErrCorrupt is returned when a corrupted log is encountered.
    16  var ErrCorrupt = errors.New("corrupt state entry")
    17  
    18  // Restore restores the state from the last epoch in the state log
    19  // read by the provided reader and the given limit. The returned
    20  // state may be nil if no snapshot was defined for the epoch.
    21  func Restore(r io.ReaderAt, limit int64) (state []byte, epoch uint64, updates *Reader, err error) {
    22  	if limit == 0 {
    23  		return nil, 0, nil, nil
    24  	}
    25  	off, err := logio.Rewind(r, limit)
    26  	if err != nil {
    27  		return
    28  	}
    29  	reader := &readerAtReader{r, off}
    30  	log := logio.NewReader(reader, off)
    31  	entry, err := log.Read()
    32  	if err != nil {
    33  		return
    34  	}
    35  	var (
    36  		typ  uint8
    37  		data []byte
    38  		ok   bool
    39  	)
    40  	typ, epoch, data, ok = parse(entry)
    41  	if !ok {
    42  		// TODO(marius): let the user deal with this? perhaps by providing
    43  		// a utility function in package logio to skip corrupted entries.
    44  		err = ErrCorrupt
    45  		return
    46  	}
    47  	if typ == entrySnap {
    48  		// Special case: the first entry is a snapshot, so we need to restore
    49  		// the correct epoch.
    50  		epoch = uint64(off)
    51  	} else {
    52  		reader.off = int64(epoch)
    53  		log.Reset(reader, reader.off)
    54  		entry, err = log.Read()
    55  		if err != nil {
    56  			return
    57  		}
    58  		typ, _, data, ok = parse(entry)
    59  		if !ok {
    60  			err = ErrCorrupt
    61  			return
    62  		}
    63  	}
    64  
    65  	if typ == entrySnap {
    66  		state = append([]byte{}, data...)
    67  	} else {
    68  		reader.off = int64(epoch)
    69  		log.Reset(reader, reader.off)
    70  	}
    71  	updates = &Reader{log, epoch}
    72  	return
    73  }
    74  
    75  // RestoreFile is a convenience function that restores the file from
    76  // the provided os file.
    77  func RestoreFile(file *os.File) (state []byte, epoch uint64, updates *Reader, err error) {
    78  	off, err := file.Seek(0, io.SeekEnd)
    79  	if err != nil {
    80  		return nil, 0, nil, err
    81  	}
    82  	state, epoch, updates, err = Restore(file, off)
    83  	if _, e := file.Seek(off, io.SeekStart); e != nil && err == nil {
    84  		err = e
    85  	}
    86  	return
    87  }
    88  
    89  // Reader reads a single epoch state updates.
    90  type Reader struct {
    91  	log    *logio.Reader
    92  	offset uint64
    93  }
    94  
    95  // Read returns the next state update entry. Read returns ErrCorrupt
    96  // if a corrupted log entry was encountered, or logio.ErrCorrupt is a
    97  // corrupt log file was encountered. In the latter case, the user may
    98  // skip the corrupted entry by issuing another read.
    99  func (r *Reader) Read() ([]byte, error) {
   100  	if r == nil {
   101  		return nil, io.EOF
   102  	}
   103  	entry, err := r.log.Read()
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  	typ, offset, data, ok := parse(entry)
   108  	if !ok {
   109  		return nil, ErrCorrupt
   110  	}
   111  	if typ == entrySnap {
   112  		return nil, io.EOF
   113  	}
   114  	if offset != r.offset {
   115  		// We should always encounter a new snapshot before an offset change.
   116  		return nil, ErrCorrupt
   117  	}
   118  	return data, nil
   119  }
   120  
   121  type readerAtReader struct {
   122  	r   io.ReaderAt
   123  	off int64
   124  }
   125  
   126  func (r *readerAtReader) Read(p []byte) (n int, err error) {
   127  	n, err = r.r.ReadAt(p, r.off)
   128  	if err == io.ErrUnexpectedEOF {
   129  		err = nil
   130  	}
   131  	r.off += int64(n)
   132  	return n, err
   133  }