github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/sliceio/codec.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package sliceio
     6  
     7  import (
     8  	"bufio"
     9  	"context"
    10  	"encoding/gob"
    11  	"fmt"
    12  	"hash"
    13  	"hash/crc32"
    14  	"io"
    15  	"reflect"
    16  	"strings"
    17  	"unsafe"
    18  
    19  	"github.com/grailbio/base/errors"
    20  	"github.com/grailbio/bigslice/frame"
    21  )
    22  
    23  type session map[frame.Key]reflect.Value
    24  
    25  func (s session) State(key frame.Key, state interface{}) (fresh bool) {
    26  	v, ok := s[key]
    27  	if !ok {
    28  		typ := reflect.TypeOf(state).Elem()
    29  		if typ.Kind() == reflect.Ptr {
    30  			v = reflect.New(typ.Elem())
    31  		} else {
    32  			v = reflect.Zero(typ)
    33  		}
    34  		s[key] = v
    35  	}
    36  	reflect.Indirect(reflect.ValueOf(state)).Set(v)
    37  	return !ok
    38  }
    39  
    40  type gobEncoder struct {
    41  	*gob.Encoder
    42  	session
    43  }
    44  
    45  func newGobEncoder(w io.Writer) *gobEncoder {
    46  	return &gobEncoder{
    47  		Encoder: gob.NewEncoder(w),
    48  		session: make(session),
    49  	}
    50  }
    51  
    52  type gobDecoder struct {
    53  	*gob.Decoder
    54  	session
    55  }
    56  
    57  func newGobDecoder(r io.Reader) *gobDecoder {
    58  	return &gobDecoder{
    59  		Decoder: gob.NewDecoder(r),
    60  		session: make(session),
    61  	}
    62  }
    63  
    64  // An Encoder manages transmission of slices through an underlying
    65  // io.Writer. The stream of slice values represented by batches of
    66  // rows stored in column-major order. Streams can be read by a
    67  // Decoder.
    68  type Encoder struct {
    69  	enc *gobEncoder
    70  	crc hash.Hash32
    71  }
    72  
    73  // NewEncodingWriter returns a Writer that streams slices into the provided
    74  // writer.
    75  func NewEncodingWriter(w io.Writer) *Encoder {
    76  	crc := crc32.NewIEEE()
    77  	return &Encoder{
    78  		enc: newGobEncoder(io.MultiWriter(w, crc)),
    79  		crc: crc,
    80  	}
    81  }
    82  
    83  // Encode encodes a batch of rows and writes the encoded output into
    84  // the encoder's writer.
    85  func (e *Encoder) Write(_ context.Context, f frame.Frame) error {
    86  	e.crc.Reset()
    87  	if err := e.enc.Encode(f.Len()); err != nil {
    88  		return err
    89  	}
    90  	for col := 0; col < f.NumOut(); col++ {
    91  		codec := f.HasCodec(col)
    92  		if err := e.enc.Encode(codec); err != nil {
    93  			return err
    94  		}
    95  		var err error
    96  		if codec {
    97  			err = f.Encode(col, e.enc)
    98  		} else {
    99  			err = e.enc.EncodeValue(f.Value(col))
   100  		}
   101  		if err != nil {
   102  			// Here we're encoding a user-defined type. We pessimistically
   103  			// attribute any errors that appear to come from gob as being
   104  			// related to the inability to encode this user-defined type.
   105  			if strings.HasPrefix(err.Error(), "gob: ") {
   106  				err = errors.E(errors.Fatal, err)
   107  			}
   108  			return err
   109  		}
   110  	}
   111  	return e.enc.Encode(e.crc.Sum32())
   112  }
   113  
   114  // DecodingReader provides a Reader on top of a gob stream
   115  // encoded with batches of rows stored in column-major order.
   116  type decodingReader struct {
   117  	dec     *gobDecoder
   118  	crc     hash.Hash32
   119  	scratch frame.Frame
   120  	buf     frame.Frame
   121  	err     error
   122  }
   123  
   124  // NewDecodingReader returns a new Reader that decodes values from
   125  // the provided stream. Since values are streamed in vectors, decoding
   126  // reader must buffer values until they are read by the consumer.
   127  func NewDecodingReader(r io.Reader) Reader {
   128  	// We need to compute checksums by inspecting the underlying
   129  	// bytestream, however, gob uses whether the reader implements
   130  	// io.ByteReader as a proxy for whether the passed reader is
   131  	// buffered. io.TeeReader does not implement io.ByteReader, and thus
   132  	// gob.Decoder will insert a buffered reader leaving us without
   133  	// means of synchronizing stream positions, required for
   134  	// checksumming. Instead we fake an implementation of io.ByteReader,
   135  	// and take over the responsibility of ensuring that IO is buffered.
   136  	crc := crc32.NewIEEE()
   137  	if _, ok := r.(io.ByteReader); !ok {
   138  		r = bufio.NewReader(r)
   139  	}
   140  	r = io.TeeReader(r, crc)
   141  	return &decodingReader{dec: newGobDecoder(readerByteReader{Reader: r}), crc: crc}
   142  }
   143  
   144  func (d *decodingReader) Read(ctx context.Context, f frame.Frame) (n int, err error) {
   145  	if d.err != nil {
   146  		return 0, d.err
   147  	}
   148  	for d.buf.Len() == 0 {
   149  		d.crc.Reset()
   150  		if d.err = d.dec.Decode(&n); d.err != nil {
   151  			if d.err == io.EOF {
   152  				d.err = EOF
   153  			}
   154  			return 0, d.err
   155  		}
   156  		// In most cases, we should be able to decode directly into the
   157  		// provided frame without any buffering.
   158  		if n <= f.Len() {
   159  			if d.err = d.decode(f.Slice(0, n)); d.err != nil {
   160  				return 0, d.err
   161  			}
   162  			return n, nil
   163  		}
   164  		// Otherwise we have to buffer the decoded frame.
   165  		if d.scratch.IsZero() {
   166  			d.scratch = frame.Make(f, n, n)
   167  		} else {
   168  			d.scratch = d.scratch.Ensure(n)
   169  		}
   170  		d.buf = d.scratch
   171  		if d.err = d.decode(d.buf); d.err != nil {
   172  			return 0, d.err
   173  		}
   174  	}
   175  	n = frame.Copy(f, d.buf)
   176  	d.buf = d.buf.Slice(n, d.buf.Len())
   177  	return n, nil
   178  }
   179  
   180  // Decode a batch of column vectors into the provided frame.
   181  // The frame is preallocated and is guaranteed to have enough
   182  // space to decode all of the values.
   183  func (d *decodingReader) decode(f frame.Frame) error {
   184  	// Always zero memory before decoding with Gob, as it will reuse
   185  	// existing memory. This can be dangerous; especially when
   186  	// that involves user code.
   187  	f.Zero()
   188  	for col := 0; col < f.NumOut(); col++ {
   189  		var codec bool
   190  		if err := d.dec.Decode(&codec); err != nil {
   191  			return err
   192  		}
   193  		if codec && !f.HasCodec(col) {
   194  			return errors.New("column encoded with custom codec but no codec available on receipt")
   195  		}
   196  		if codec {
   197  			if err := f.Decode(col, d.dec); err != nil {
   198  				return err
   199  			}
   200  			continue
   201  		}
   202  		// Arrange for gob to decode directly into the frame's underlying
   203  		// slice. We have to do some gymnastics to produce a pointer to
   204  		// this value (which we'll anyway discard) so that gob can do its
   205  		// job.
   206  		sh := f.SliceHeader(col)
   207  		var p []unsafe.Pointer
   208  		pHdr := (*reflect.SliceHeader)(unsafe.Pointer(&p))
   209  		pHdr.Data = sh.Data
   210  		pHdr.Len = sh.Len
   211  		pHdr.Cap = sh.Cap
   212  		v := reflect.NewAt(reflect.SliceOf(f.Out(col)), unsafe.Pointer(pHdr))
   213  		err := d.dec.DecodeValue(v)
   214  		if err != nil {
   215  			if err == io.EOF {
   216  				return EOF
   217  			}
   218  			return err
   219  		}
   220  		// This is guaranteed by gob, but it seems worthy of some defensive programming here.
   221  		// It's also an extra check against the correctness of the codec.
   222  		if pHdr.Data != sh.Data {
   223  			panic("gob reallocated a slice")
   224  		}
   225  	}
   226  	sum := d.crc.Sum32()
   227  	var decoded uint32
   228  	if err := d.dec.Decode(&decoded); err != nil {
   229  		return err
   230  	}
   231  	if sum != decoded {
   232  		return errors.E(errors.Integrity, fmt.Errorf("computed checksum %x but expected checksum %x", sum, decoded))
   233  	}
   234  	return nil
   235  }
   236  
   237  // readerByteReader is used to provide an (invalid) implementation of
   238  // io.ByteReader to gob.Encoder. See comment in NewDecodingReader
   239  // for details.
   240  type readerByteReader struct {
   241  	io.Reader
   242  	io.ByteReader
   243  }