github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/recordio/deprecated/recordio.go (about)

     1  // Copyright 2017 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package deprecated
     6  
     7  import (
     8  	"encoding/binary"
     9  	"fmt"
    10  	"hash/crc32"
    11  	"io"
    12  
    13  	"github.com/Schaudge/grailbase/errors"
    14  	"github.com/Schaudge/grailbase/recordio/internal"
    15  )
    16  
    17  // LegacyWriterOpts represents the options accepted by NewLegacyWriter.
    18  type LegacyWriterOpts struct {
    19  	// Marshal is called to marshal an object to a byte slice.
    20  	Marshal MarshalFunc
    21  
    22  	// Index is called to enable generating an index for a recordio file; it is
    23  	// called whenever a new record or item is written as per the package
    24  	// level comments.
    25  	Index func(offset, length uint64, v interface{}, p []byte) error
    26  }
    27  
    28  // LegacyScannerOpts represents the options accepted by NewScanner.
    29  type LegacyScannerOpts struct {
    30  	// Unmarshal is called to unmarshal an object from the supplied byte slice.
    31  	Unmarshal UnmarshalFunc
    32  }
    33  
    34  // LegacyScanner is the interface for reading recordio files as streams of typed
    35  // records. Each record is available as both raw bytes and as type via
    36  // Unmarshal.
    37  type LegacyScanner interface {
    38  	// Reset is equivalent to creating a new scanner, but it retains underlying
    39  	// storage. So it is more efficient than NewScanner. Err is reset to nil. Scan
    40  	// and Bytes will read from rd.
    41  	Reset(rd io.Reader)
    42  
    43  	// Scan returns true if a new record was read, false otherwise. It will return
    44  	// false on encoutering an error; the error may be retrieved using the Err
    45  	// method. Note, that Scan will reuse storage from one invocation to the next.
    46  	Scan() bool
    47  
    48  	// Bytes returns the current record as read by a prior call to Scan. It may
    49  	// always be called.
    50  	Bytes() []byte
    51  
    52  	// Err returns the first error encountered.
    53  	Err() error
    54  
    55  	// Unmarshal unmarshals the raw bytes using a preconfigured UnmarshalFunc.
    56  	// It will return an error if there is no preconfigured UnmarshalFunc.
    57  	// Calls to Bytes and Unmarshal may be interspersed.
    58  	Unmarshal(data interface{}) error
    59  }
    60  
    61  // Writer is the interface for writing recordio files as streams of typed
    62  // records.
    63  type LegacyWriter interface {
    64  	// Write writes a []byte record to the supplied writer. Each call to write
    65  	// results in a new record being written.
    66  	// Calls to Write and Record may be interspersed.
    67  	io.Writer
    68  
    69  	// WriteSlices writes out the supplied slices as a single record, it
    70  	// is intended to avoid having to copy slices into a single slice purely
    71  	// to write them out as a single record.
    72  	WriteSlices(hdr []byte, bufs ...[]byte) (n int, err error)
    73  
    74  	// Marshal writes a record using a preconfigured MarshalFunc to the supplied
    75  	// writer. Each call to Record results in a new record being written.
    76  	// Calls to Write and Record may be interspersed.
    77  	Marshal(v interface{}) (n int, err error)
    78  }
    79  
    80  const (
    81  	sizeOffset = internal.NumMagicBytes
    82  	crcOffset  = internal.NumMagicBytes + 8
    83  	dataOffset = internal.NumMagicBytes + 8 + crc32.Size
    84  	// teaderSize is the size in bytes of the recordio header.
    85  	headerSize = dataOffset
    86  )
    87  
    88  type byteWriter struct {
    89  	wr     io.Writer
    90  	magic  internal.MagicBytes
    91  	hdr    [headerSize]byte
    92  	offset uint64
    93  	opts   LegacyWriterOpts
    94  }
    95  
    96  // NewLegacyWriter is DEPRECATED. Use NewWriterV2 instead.
    97  func NewLegacyWriter(wr io.Writer, opts LegacyWriterOpts) LegacyWriter {
    98  	return &byteWriter{
    99  		magic: internal.MagicLegacyUnpacked,
   100  		wr:    wr,
   101  		opts:  opts,
   102  	}
   103  }
   104  
   105  func (w *byteWriter) writeHeader(l uint64) (n uint64, err error) {
   106  	marshalHeader(w.hdr[:], w.magic[:], l)
   107  	hdrSize, err := w.wr.Write(w.hdr[:])
   108  	if err != nil {
   109  		return 0, fmt.Errorf("recordio: failed to write header: %v", err)
   110  	}
   111  	return uint64(hdrSize), nil
   112  }
   113  
   114  func (w *byteWriter) index(offset, length uint64, v interface{}, p []byte) error {
   115  	if ifn := w.opts.Index; ifn != nil {
   116  		if err := ifn(offset, length, v, p); err != nil {
   117  			return fmt.Errorf("recordio: index callback failed: %v", err)
   118  		}
   119  	}
   120  	return nil
   121  }
   122  
   123  func (w *byteWriter) writeBody(p []byte) (n int, err error) {
   124  	n, err = w.wr.Write(p)
   125  	if err != nil {
   126  		return 0, fmt.Errorf("recordio: failed to write record %d bytes: %v", len(p), err)
   127  	}
   128  	w.offset += uint64(n)
   129  	return
   130  }
   131  
   132  func (w *byteWriter) Write(p []byte) (n int, err error) {
   133  	hdrSize, err := w.writeHeader(uint64(len(p)))
   134  	if err != nil {
   135  		return 0, err
   136  	}
   137  	if err := w.index(w.offset, uint64(len(p))+hdrSize, nil, nil); err != nil {
   138  		return 0, err
   139  	}
   140  	w.offset += hdrSize
   141  	return w.writeBody(p)
   142  }
   143  
   144  // WriteSlices writes the supplied slices as a single record. The arguments
   145  // are specified as a 'first' slice and an arbitrary number of subsequent ones
   146  // to allow for writing a 'header' and 'payload' without forcing the caller
   147  // reallocate and copy their data to match this API. Either first or bufs
   148  // may be nil.
   149  func (w *byteWriter) WriteSlices(first []byte, bufs ...[]byte) (n int, err error) {
   150  	_, _, n, err = w.writeSlices(first, bufs...)
   151  	return
   152  }
   153  
   154  func (w *byteWriter) writeSlices(first []byte, bufs ...[]byte) (headerSize, offset uint64, n int, err error) {
   155  	size := uint64(len(first))
   156  	for _, p := range bufs {
   157  		size += uint64(len(p))
   158  	}
   159  	hdrSize, err := w.writeHeader(size)
   160  	if err != nil {
   161  		return 0, 0, 0, err
   162  	}
   163  	offset = w.offset
   164  	if err := w.index(offset, size+hdrSize, nil, nil); err != nil {
   165  		return 0, 0, 0, err
   166  	}
   167  	written := 0
   168  	if len(first) > 0 {
   169  		var err error
   170  		written, err = w.wr.Write(first)
   171  		if err != nil {
   172  			return 0, 0, 0, fmt.Errorf("recordio: failed to write record %d bytes: %v", len(first), err)
   173  		}
   174  	}
   175  	for _, p := range bufs {
   176  		w, err := w.wr.Write(p)
   177  		written += w
   178  		if err != nil {
   179  			return 0, 0, 0, fmt.Errorf("recordio: failed to write record %d bytes: %v", len(p), err)
   180  		}
   181  	}
   182  	w.offset += uint64(written) + hdrSize
   183  	return hdrSize, offset, written, nil
   184  }
   185  
   186  // Marshal implements Writer.Marshal.
   187  func (w *byteWriter) Marshal(v interface{}) (n int, err error) {
   188  	mfn := w.opts.Marshal
   189  	if mfn == nil {
   190  		return 0, fmt.Errorf("Marshal function not configured for recordio.Writer")
   191  	}
   192  	p, err := mfn(nil, v)
   193  	if err != nil {
   194  		return 0, err
   195  	}
   196  	hdrSize, err := w.writeHeader(uint64(len(p)))
   197  	if err != nil {
   198  		return 0, err
   199  	}
   200  	if err := w.index(w.offset, uint64(len(p))+hdrSize, v, p); err != nil {
   201  		return 0, err
   202  	}
   203  	w.offset += hdrSize
   204  	return w.writeBody(p)
   205  }
   206  
   207  func isErr(err error) bool {
   208  	return err != nil && err != io.EOF
   209  }
   210  
   211  // scanner represents scanner for the recordio format.
   212  type LegacyScannerImpl struct {
   213  	rd     io.Reader
   214  	record []byte
   215  	err    errors.Once
   216  	opts   LegacyScannerOpts
   217  	hdr    [headerSize]byte
   218  }
   219  
   220  // NewLegacyScanner is DEPRECATED. Use NewScannerV2 instead.
   221  func NewLegacyScanner(rd io.Reader, opts LegacyScannerOpts) LegacyScanner {
   222  	return &LegacyScannerImpl{
   223  		rd:   rd,
   224  		opts: opts,
   225  	}
   226  }
   227  
   228  // Reset implements Scanner.Reset.
   229  func (s *LegacyScannerImpl) Reset(rd io.Reader) {
   230  	s.rd = rd
   231  	s.err = errors.Once{}
   232  }
   233  
   234  // Unmarshal implements Scanner.Unmarshal.
   235  func (s *LegacyScannerImpl) Unmarshal(v interface{}) error {
   236  	if ufn := s.opts.Unmarshal; ufn != nil {
   237  		return ufn(s.record, v)
   238  	}
   239  	err := fmt.Errorf("Unmarshal function not configured for recordio.Scanner")
   240  	s.err.Set(err)
   241  	return err
   242  }
   243  
   244  // Scan implements Scanner.Scan.
   245  func (s *LegacyScannerImpl) Scan() bool {
   246  	magic, ok := s.InternalScan()
   247  	if !ok {
   248  		return false
   249  	}
   250  	if magic != internal.MagicLegacyUnpacked {
   251  		s.err.Set(fmt.Errorf("recordio: invalid magic number: %v, expect %v", magic,
   252  			internal.MagicLegacyUnpacked))
   253  		return false
   254  	}
   255  	return true
   256  }
   257  
   258  func (s *LegacyScannerImpl) InternalScan() (internal.MagicBytes, bool) {
   259  	if s.err.Err() != nil {
   260  		return internal.MagicInvalid, false
   261  	}
   262  	n, err := io.ReadFull(s.rd, s.hdr[:])
   263  	if n == 0 && err == io.EOF {
   264  		s.err.Set(io.EOF)
   265  		return internal.MagicInvalid, false
   266  	}
   267  	if isErr(err) {
   268  		s.err.Set(fmt.Errorf("recordio: failed to read header: %v", err))
   269  		return internal.MagicInvalid, false
   270  	}
   271  	magic, size, err := unmarshalHeader(s.hdr[:])
   272  	if err != nil {
   273  		s.err.Set(err)
   274  		return magic, false
   275  	}
   276  	if size == 0 {
   277  		s.record = s.record[:0]
   278  		return magic, true
   279  	}
   280  	if size > internal.MaxReadRecordSize {
   281  		s.record = s.record[:0]
   282  		s.err.Set(fmt.Errorf("recordio: unreasonably large read record encountered: %d > %d bytes", size, internal.MaxReadRecordSize))
   283  		return magic, false
   284  	}
   285  	if uint64(cap(s.record)) < size {
   286  		s.record = make([]byte, size)
   287  	} else {
   288  		s.record = s.record[:size]
   289  	}
   290  	n, err = io.ReadFull(s.rd, s.record)
   291  	if isErr(err) {
   292  		s.err.Set(fmt.Errorf("recordio: failed to read record: %v", err))
   293  		return magic, false
   294  	}
   295  	if uint64(n) != size {
   296  		s.err.Set(fmt.Errorf("recordio: short/long record: %d < %d", n, size))
   297  		return magic, false
   298  	}
   299  	return magic, true
   300  }
   301  
   302  // Bytes implements Scanner.Bytes.
   303  func (s *LegacyScannerImpl) Bytes() []byte {
   304  	return s.record
   305  }
   306  
   307  // Err implements Scanner.Err.
   308  func (s *LegacyScannerImpl) Err() error {
   309  	err := s.err.Err()
   310  	if err == io.EOF {
   311  		return nil
   312  	}
   313  	return err
   314  }
   315  
   316  func marshalHeader(out []byte, magic []byte, size uint64) {
   317  	pos := copy(out, magic)
   318  	binary.LittleEndian.PutUint64(out[pos:], size)
   319  	crc := crc32.Checksum(out[pos:pos+8], internal.IEEECRC)
   320  	pos += 8
   321  	binary.LittleEndian.PutUint32(out[pos:], crc)
   322  }
   323  
   324  func unmarshalHeader(buf []byte) (internal.MagicBytes, uint64, error) {
   325  	var magic internal.MagicBytes
   326  	copy(magic[:], buf[0:sizeOffset])
   327  	size := binary.LittleEndian.Uint64(buf[sizeOffset:])
   328  	crc := binary.LittleEndian.Uint32(buf[crcOffset:])
   329  	ncrc := crc32.Checksum(buf[sizeOffset:crcOffset], internal.IEEECRC)
   330  	if ncrc != crc {
   331  		return magic, 0, fmt.Errorf("recordio: crc check failed - corrupt record header (%v != %v)?", ncrc, crc)
   332  	}
   333  	return magic, size, nil
   334  }