github.com/grailbio/bigslice@v0.0.0-20230519005545-30c4c12152ad/sliceio/codec_test.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package sliceio
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"encoding/json"
    11  	"math/rand"
    12  	"reflect"
    13  	"strings"
    14  	"testing"
    15  
    16  	fuzz "github.com/google/gofuzz"
    17  	"github.com/grailbio/base/errors"
    18  	"github.com/grailbio/bigslice/frame"
    19  	"github.com/grailbio/bigslice/slicetype"
    20  )
    21  
    22  type testStruct struct{ A, B, C int }
    23  
    24  func init() {
    25  	var key = frame.FreshKey()
    26  
    27  	frame.RegisterOps(func(slice []testStruct) frame.Ops {
    28  		return frame.Ops{
    29  			Encode: func(e frame.Encoder, i, j int) error {
    30  				p, err := json.Marshal(slice[i:j])
    31  				if err != nil {
    32  					return err
    33  				}
    34  				return e.Encode(p)
    35  			},
    36  			Decode: func(d frame.Decoder, i, j int) error {
    37  				var p *[]byte
    38  				if d.State(key, &p) {
    39  					*p = []byte{}
    40  				}
    41  				if err := d.Decode(p); err != nil {
    42  					return err
    43  				}
    44  				x := slice[i:j]
    45  				if err := json.Unmarshal(*p, &x); err != nil {
    46  					return err
    47  				}
    48  				if len(x) != j-i {
    49  					return errors.New("bad json decode")
    50  				}
    51  				return nil
    52  			},
    53  		}
    54  	})
    55  }
    56  
    57  var (
    58  	typeOfString = reflect.TypeOf("")
    59  	typeOfInt    = reflect.TypeOf(0)
    60  )
    61  
    62  func TestCodec(t *testing.T) {
    63  	const N = 1000
    64  	fz := fuzz.New()
    65  	fz.NilChance(0)
    66  	fz.NumElements(N, N)
    67  	var (
    68  		c0 []string
    69  		c1 []testStruct
    70  	)
    71  	fz.Fuzz(&c0)
    72  	fz.Fuzz(&c1)
    73  
    74  	var b bytes.Buffer
    75  	enc := NewEncodingWriter(&b)
    76  
    77  	ctx := context.Background()
    78  	in := frame.Slices(c0, c1)
    79  	if err := enc.Write(ctx, in); err != nil {
    80  		t.Fatal(err)
    81  	}
    82  	if err := enc.Write(ctx, in); err != nil {
    83  		t.Fatal(err)
    84  	}
    85  	data := b.Bytes()
    86  	for _, chunkSize := range []int{1, N / 3, N / 2, N, N * 2} {
    87  		dec := NewDecodingReader(bytes.NewReader(data))
    88  		out := frame.Make(in, N*2, N*2)
    89  		for i := 0; i < N*2; {
    90  			j := i + chunkSize
    91  			if j > N*2 {
    92  				j = N * 2
    93  			}
    94  			n, err := dec.Read(ctx, out.Slice(i, j))
    95  			if err != nil {
    96  				t.Fatal(err)
    97  			}
    98  			i += n
    99  		}
   100  		for i := 0; i < in.NumOut(); i++ {
   101  			if !reflect.DeepEqual(in.Interface(i), out.Slice(0, N).Interface(i)) {
   102  				t.Errorf("column %d mismatch", i)
   103  			}
   104  			if !reflect.DeepEqual(in.Interface(i), out.Slice(N, N*2).Interface(i)) {
   105  				t.Errorf("column %d mismatch", i)
   106  			}
   107  		}
   108  		n, err := dec.Read(ctx, out)
   109  		if got, want := err, EOF; got != want {
   110  			t.Errorf("got %v, want %v", got, want)
   111  		}
   112  		if got, want := n, 0; got != want {
   113  			t.Errorf("got %v, want %v", got, want)
   114  		}
   115  	}
   116  
   117  	/*
   118  		// Make sure we don't reallocate if we're providing slices with enough
   119  		// capacity already.
   120  		outptrs := make([]uintptr, len(out))
   121  		for i := range out {
   122  			outptrs[i] = out[i].Pointer() // points to the slice header's data
   123  		}
   124  		if err := enc.Write(in); err != nil {
   125  			t.Fatal(err)
   126  		}
   127  		if err := dec.Decode(out...); err != nil {
   128  			t.Fatal(err)
   129  		}
   130  		for i := range out {
   131  			if outptrs[i] != out[i].Pointer() {
   132  				t.Errorf("column slice %d reallocated", i)
   133  			}
   134  		}
   135  	*/
   136  }
   137  
   138  func TestDecodingReaderWithZeros(t *testing.T) {
   139  	// Gob, in its infinite cleverness, does not transmit zero values.
   140  	// However, it apparently also does not zero out zero values in
   141  	// structs that are reused. This requires special handling that is
   142  	// tested here.
   143  	type fields struct{ A, B, C int }
   144  	var b bytes.Buffer
   145  	in := []fields{{1, 2, 3}, {1, 0, 3}}
   146  	ctx := context.Background()
   147  	enc := NewEncodingWriter(&b)
   148  	if err := enc.Write(ctx, frame.Slices(in[0:1])); err != nil {
   149  		t.Fatal(err)
   150  	}
   151  	if err := enc.Write(ctx, frame.Slices(in[1:2])); err != nil {
   152  		t.Fatal(err)
   153  	}
   154  
   155  	r := NewDecodingReader(&b)
   156  
   157  	var out []fields
   158  	if err := ReadAll(ctx, r, &out); err != nil {
   159  		t.Fatal(err)
   160  	}
   161  
   162  	if got, want := out, in; !reflect.DeepEqual(got, want) {
   163  		t.Errorf("got %+v, want %+v", got, want)
   164  	}
   165  }
   166  
   167  func TestDecodingReaderCorrupted(t *testing.T) {
   168  	const N = 100
   169  	col := make([]int, 100)
   170  	for i := range col {
   171  		col[i] = i
   172  	}
   173  	var b bytes.Buffer
   174  	enc := NewEncodingWriter(&b)
   175  	ctx := context.Background()
   176  	if err := enc.Write(ctx, frame.Slices(col, col, col)); err != nil {
   177  		t.Fatal(err)
   178  	}
   179  	buf := func() []byte {
   180  		p := b.Bytes()
   181  		if len(p) == 0 {
   182  			t.Fatal(p)
   183  		}
   184  		return append([]byte{}, p...)
   185  	}
   186  	rnd := rand.New(rand.NewSource(1234))
   187  	var nintegrity int
   188  	for i := 0; i < N; i++ {
   189  		// First, check that it reads (valid); then corrupt a single bit
   190  		// and make sure we get an error.
   191  		p := buf()
   192  		r := NewDecodingReader(bytes.NewReader(p))
   193  		ctx := context.Background()
   194  		var c1, c2, c3 []int
   195  		if err := ReadAll(ctx, r, &c1, &c2, &c3); err != nil {
   196  			t.Error(err)
   197  			continue
   198  		}
   199  		i := rnd.Intn(len(p))
   200  		p[i] ^= byte(1 << uint(rnd.Intn(8)))
   201  		r = NewDecodingReader(bytes.NewReader(p))
   202  		err := ReadAll(ctx, r, &c1, &c2, &c3)
   203  		if err == nil {
   204  			t.Error("got nil err")
   205  			continue
   206  		}
   207  		// Depending on which bit gets flipped, we might end up with
   208  		// a different decoding error.
   209  		switch errors.Recover(err).Kind {
   210  		default:
   211  			t.Errorf("invalid error %v", err)
   212  		case errors.Integrity:
   213  			nintegrity++
   214  		case errors.Other:
   215  			switch {
   216  			default:
   217  				t.Errorf("invalid error %v", err)
   218  			case strings.HasPrefix(err.Error(), "gob:"):
   219  			case err.Error() == "extra data in buffer":
   220  			case err.Error() == "unexpected EOF":
   221  			}
   222  		}
   223  	}
   224  	if nintegrity == 0 {
   225  		t.Error("encountered no integrity errors")
   226  	}
   227  }
   228  
   229  func TestDecodingSlices(t *testing.T) {
   230  	// Gob will reuse slices during decoding if we're not careful.
   231  	var b bytes.Buffer
   232  	ctx := context.Background()
   233  	in := [][]string{{"a", "b"}, {"c", "d"}}
   234  	enc := NewEncodingWriter(&b)
   235  	if err := enc.Write(ctx, frame.Slices(in[0:1])); err != nil {
   236  		t.Fatal(err)
   237  	}
   238  	if err := enc.Write(ctx, frame.Slices(in[1:2])); err != nil {
   239  		t.Fatal(err)
   240  	}
   241  
   242  	r := NewDecodingReader(&b)
   243  	var out [][]string
   244  	if err := ReadAll(ctx, r, &out); err != nil {
   245  		t.Fatal(err)
   246  	}
   247  	if got, want := out, in; !reflect.DeepEqual(got, want) {
   248  		t.Errorf("got %+v, want %+v", got, want)
   249  	}
   250  }
   251  
   252  func TestEmptyDecodingReader(t *testing.T) {
   253  	r := NewDecodingReader(bytes.NewReader(nil))
   254  	f := frame.Make(slicetype.New(typeOfString, typeOfInt), 100, 100)
   255  	n, err := r.Read(context.Background(), f)
   256  	if got, want := n, 0; got != want {
   257  		t.Errorf("got %v, want %v", got, want)
   258  	}
   259  	if got, want := err, EOF; got != want {
   260  		t.Errorf("got %v, want %v", got, want)
   261  	}
   262  	n, err = r.Read(context.Background(), f)
   263  	if got, want := n, 0; got != want {
   264  		t.Errorf("got %v, want %v", got, want)
   265  	}
   266  	if got, want := err, EOF; got != want {
   267  		t.Errorf("got %v, want %v", got, want)
   268  	}
   269  }
   270  
   271  // TestScratchBufferGrowth verifies that decodingReader can buffer encoded
   272  // frames of increasing length. Note that this is a very
   273  // implementation-dependent test that verifies that scratch buffer resizing
   274  // works correctly.
   275  func TestScratchBufferGrowth(t *testing.T) {
   276  	var (
   277  		b   bytes.Buffer
   278  		in0 = []int{0, 1}
   279  		in1 = []int{2, 3, 4}
   280  		enc = NewEncodingWriter(&b)
   281  		ctx = context.Background()
   282  	)
   283  	// Encode a 2-length frame followed by a 3-length frame. This will cause
   284  	// the decoder to resize its scratch buffer.
   285  	if err := enc.Write(ctx, frame.Slices(in0)); err != nil {
   286  		t.Fatal(err)
   287  	}
   288  	if err := enc.Write(ctx, frame.Slices(in1)); err != nil {
   289  		t.Fatal(err)
   290  	}
   291  
   292  	r := NewDecodingReader(&b)
   293  	// Read one row at a time to force the decodingReader to internally buffer
   294  	// the decoded frames.
   295  	f := frame.Make(slicetype.New(typeOfInt), 1, 1)
   296  	var out []int
   297  	for {
   298  		n, err := r.Read(ctx, f)
   299  		if err == EOF {
   300  			break
   301  		}
   302  		if err != nil {
   303  			t.Fatalf("unexpected error: %v", err)
   304  		}
   305  		out = append(out, f.Slice(0, n).Value(0).Interface().([]int)...)
   306  	}
   307  	if got, want := len(out), 5; got != want {
   308  		t.Errorf("got %v, want %v", got, want)
   309  	}
   310  	if got, want := out, append(in0, in1...); !reflect.DeepEqual(got, want) {
   311  		t.Errorf("got %v, want %v", got, want)
   312  	}
   313  }
   314  
   315  func testRoundTrip(t *testing.T, cols ...interface{}) {
   316  	t.Helper()
   317  	var N = 1000
   318  	if testing.Short() {
   319  		N = 10
   320  	}
   321  	var Stride = N / 5
   322  	fz := fuzz.New()
   323  	fz.NilChance(0)
   324  	fz.NumElements(N, N)
   325  	for i := range cols {
   326  		ptr := reflect.New(reflect.TypeOf(cols[i]))
   327  		fz.Fuzz(ptr.Interface())
   328  		cols[i] = reflect.Indirect(ptr).Interface()
   329  	}
   330  	var b bytes.Buffer
   331  	enc := NewEncodingWriter(&b)
   332  	for i := 0; i < N; i += Stride {
   333  		j := i + Stride
   334  		if j > N {
   335  			j = N
   336  		}
   337  		args := make([]interface{}, len(cols))
   338  		for k := range args {
   339  			args[k] = reflect.ValueOf(cols[k]).Slice(i, j).Interface()
   340  		}
   341  		ctx := context.Background()
   342  		if err := enc.Write(ctx, frame.Slices(args...)); err != nil {
   343  			t.Fatal(err)
   344  		}
   345  	}
   346  	args := make([]interface{}, len(cols))
   347  	for i := range args {
   348  		// Create an empty slice from the end of the parent slice.
   349  		slice := reflect.ValueOf(cols[i]).Slice(N, N)
   350  		ptr := reflect.New(slice.Type())
   351  		reflect.Indirect(ptr).Set(slice)
   352  		args[i] = ptr.Interface()
   353  	}
   354  	if err := ReadAll(context.Background(), NewDecodingReader(&b), args...); err != nil {
   355  		t.Fatal(err)
   356  	}
   357  	for i, want := range cols {
   358  		got := reflect.Indirect(reflect.ValueOf(args[i])).Interface()
   359  		if !reflect.DeepEqual(got, want) {
   360  			t.Errorf("got %v, want %v", got, want)
   361  		}
   362  	}
   363  }
   364  
   365  func TestTypes(t *testing.T) {
   366  	types := [][]interface{}{
   367  		{[]int{}, []string{}},
   368  		{[]struct{ A, B, C int }{}},
   369  		{[][]string{}, []int{}},
   370  		{[]struct {
   371  			A *int
   372  			B string
   373  		}{}, []*int{}},
   374  		{[]rune{}, []byte{}, [][]byte{}, []int16{}, []int8{}, []*[]string{}, []int64{}},
   375  	}
   376  	for _, cols := range types {
   377  		testRoundTrip(t, cols...)
   378  	}
   379  }
   380  
   381  func TestSession(t *testing.T) {
   382  	s := make(session)
   383  	k1, k2 := frame.FreshKey(), frame.FreshKey()
   384  	var x *int
   385  	if !s.State(k1, &x) {
   386  		t.Fatal("k1 not initialized")
   387  	}
   388  	var y *string
   389  	if !s.State(k2, &y) {
   390  		t.Fatal("k2 not initialized")
   391  	}
   392  	*y = "ok"
   393  
   394  	x = nil
   395  	if s.State(k1, &x) {
   396  		t.Fatal("k1 initialized twice")
   397  	}
   398  	if got, want := *x, 0; got != want {
   399  		t.Errorf("got %v, want %v", got, want)
   400  	}
   401  
   402  	y = nil
   403  	if s.State(k2, &y) {
   404  		t.Fatal("k2 initialized twice")
   405  	}
   406  	if got, want := *y, "ok"; got != want {
   407  		t.Errorf("got %v, want %v", got, want)
   408  	}
   409  }