github.com/grailbio/base@v0.0.11/recordio/deprecated/packed_test.go (about)

     1  // Copyright 2017 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package deprecated_test
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/sha256"
    10  	"encoding/hex"
    11  	"encoding/json"
    12  	"fmt"
    13  	"math/rand"
    14  	"reflect"
    15  	"sync"
    16  	"testing"
    17  	"time"
    18  
    19  	"github.com/grailbio/base/recordio/deprecated"
    20  	"github.com/grailbio/base/recordio/internal"
    21  	"github.com/grailbio/testutil"
    22  	"github.com/grailbio/testutil/assert"
    23  	"github.com/grailbio/testutil/expect"
    24  )
    25  
    26  func TestPackedWriteRead(t *testing.T) {
    27  	sl := func(s ...string) [][]byte {
    28  		var bs [][]byte
    29  		for _, t := range s {
    30  			bs = append(bs, []byte(t))
    31  		}
    32  		return bs
    33  	}
    34  	read := func(sc deprecated.LegacyScanner) string {
    35  		if !sc.Scan() {
    36  			assert.NoError(t, sc.Err())
    37  			return "eof"
    38  		}
    39  		return string(sc.Bytes())
    40  	}
    41  
    42  	for _, tc := range []struct {
    43  		in                 [][]byte
    44  		maxItems, maxBytes uint32
    45  		nreads             int
    46  	}{
    47  		{sl(""), 1, 10, 1},
    48  		{sl("a", "b"), 2, 10, 1},
    49  		{sl("", "", ""), 2, 10, 2},
    50  		{sl("hello", "world", "line", "2", "line3"), 2, 100, 3},
    51  		{sl("a", "b", "c", "d", "e", "f", "g"), 100, 2, 4},
    52  	} {
    53  		out := &bytes.Buffer{}
    54  		nflushs := 0
    55  		wropts := deprecated.LegacyPackedWriterOpts{
    56  			MaxItems: tc.maxItems,
    57  			MaxBytes: tc.maxBytes,
    58  			Flushed:  func() error { nflushs++; return nil },
    59  		}
    60  		wr := deprecated.NewLegacyPackedWriter(out, wropts)
    61  
    62  		for _, p := range tc.in {
    63  			n, err := wr.Write(p)
    64  			assert.NoError(t, err)
    65  			assert.True(t, n == len(p))
    66  		}
    67  		assert.NoError(t, wr.Flush())
    68  		// Make sure Flushing with nothing to flush has no effect.
    69  		assert.NoError(t, wr.Flush())
    70  		assert.EQ(t, bytes.Count(out.Bytes(), internal.MagicPacked[:]), tc.nreads)
    71  		assert.EQ(t, nflushs, tc.nreads)
    72  		assert.EQ(t, bytes.Count(out.Bytes(), internal.MagicLegacyUnpacked[:]), 0)
    73  
    74  		sc := deprecated.NewLegacyPackedScanner(bytes.NewReader(out.Bytes()), deprecated.LegacyPackedScannerOpts{})
    75  		for _, expected := range tc.in {
    76  			expect.EQ(t, read(sc), string(expected))
    77  		}
    78  		expect.EQ(t, read(sc), "eof")
    79  		sc.Reset(bytes.NewReader(out.Bytes()))
    80  		for _, expected := range tc.in[:1] {
    81  			expect.EQ(t, read(sc), string(expected))
    82  		}
    83  		sc.Reset(bytes.NewReader(out.Bytes()))
    84  		for _, expected := range tc.in {
    85  			expect.EQ(t, read(sc), string(expected))
    86  		}
    87  		expect.EQ(t, read(sc), "eof")
    88  	}
    89  }
    90  
    91  func TestPackedMarshal(t *testing.T) {
    92  	type js struct {
    93  		I  int
    94  		IS string
    95  	}
    96  	wropts := deprecated.LegacyPackedWriterOpts{
    97  		MaxItems: 2,
    98  		MaxBytes: 0,
    99  	}
   100  	wropts.Marshal = func(scratch []byte, v interface{}) ([]byte, error) {
   101  		return json.Marshal(v)
   102  	}
   103  	scopts := deprecated.LegacyPackedScannerOpts{}
   104  	scopts.Unmarshal = json.Unmarshal
   105  	for i, tc := range []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 100, 233} {
   106  		out := &bytes.Buffer{}
   107  		wr := deprecated.NewLegacyPackedWriter(out, wropts)
   108  		for i := 0; i < tc; i++ {
   109  			if _, err := wr.Marshal(&js{i, fmt.Sprintf("%d", i)}); err != nil {
   110  				t.Fatalf("%v: %v", i, err)
   111  			}
   112  		}
   113  		if err := wr.Flush(); err != nil {
   114  			t.Fatalf("%v: %v", i, err)
   115  		}
   116  		sc := deprecated.NewLegacyPackedScanner(out, scopts)
   117  		data := make([]js, tc)
   118  		next := 0
   119  		for sc.Scan() {
   120  			if err := sc.Unmarshal(&data[next]); err != nil {
   121  				t.Fatalf("%v: %v", next, err)
   122  			}
   123  			next++
   124  		}
   125  		if err := sc.Err(); err != nil {
   126  			t.Fatalf("%v: %v", i, err)
   127  		}
   128  		for i, d := range data {
   129  			w := &js{i, fmt.Sprintf("%v", i)}
   130  			if got, want := &d, w; !reflect.DeepEqual(got, want) {
   131  				t.Errorf("%v: got %v, want %v", i, got, want)
   132  			}
   133  		}
   134  		if got, want := len(data), tc; got != want {
   135  			t.Errorf("%v: got %v, want %v", i, got, want)
   136  		}
   137  	}
   138  }
   139  func TestPackedMixed(t *testing.T) {
   140  	type js struct {
   141  		I  int
   142  		IS string
   143  	}
   144  	indexedObjs := 0
   145  	indexedBufs := 0
   146  	wropts := deprecated.LegacyPackedWriterOpts{
   147  		MaxItems: 2,
   148  		MaxBytes: 0,
   149  	}
   150  	wropts.Marshal = func(scratch []byte, v interface{}) ([]byte, error) {
   151  		return json.Marshal(v)
   152  	}
   153  	wropts.Index = func(record, recordSize, items uint64) (deprecated.ItemIndexFunc, error) {
   154  		return func(offset, extent uint64, v interface{}, p []byte) error {
   155  			if p != nil {
   156  				indexedBufs++
   157  			}
   158  			if v != nil {
   159  				indexedObjs++
   160  			}
   161  			return nil
   162  		}, nil
   163  	}
   164  	scopts := deprecated.LegacyPackedScannerOpts{}
   165  	scopts.Unmarshal = json.Unmarshal
   166  	wantIndexedObjs, wantIndexedBufs := 0, 0
   167  	for i, tc := range []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 100, 233} {
   168  		out := &bytes.Buffer{}
   169  		wr := deprecated.NewLegacyPackedWriter(out, wropts)
   170  		for i := 0; i < tc; i++ {
   171  			wantIndexedBufs++
   172  			if i%2 == 0 {
   173  				buf, err := json.Marshal(&js{i, fmt.Sprintf("%d", i)})
   174  				if err != nil {
   175  					t.Fatalf("%v: %v", i, err)
   176  				}
   177  				wr.Write(buf)
   178  			} else {
   179  				if _, err := wr.Marshal(&js{i, fmt.Sprintf("%d", i)}); err != nil {
   180  					t.Fatalf("%v: %v", i, err)
   181  				}
   182  				wantIndexedObjs++
   183  			}
   184  		}
   185  		if err := wr.Flush(); err != nil {
   186  			t.Fatalf("%v: %v", i, err)
   187  		}
   188  		sc := deprecated.NewLegacyPackedScanner(out, scopts)
   189  		data := make([]js, tc)
   190  		next := 0
   191  		for sc.Scan() {
   192  			if err := sc.Unmarshal(&data[next]); err != nil {
   193  				t.Fatalf("%v: %v", next, err)
   194  			}
   195  			next++
   196  		}
   197  		if err := sc.Err(); err != nil {
   198  			t.Fatalf("%v: %v", i, err)
   199  		}
   200  		for i, d := range data {
   201  			w := &js{i, fmt.Sprintf("%v", i)}
   202  			if got, want := &d, w; !reflect.DeepEqual(got, want) {
   203  				t.Errorf("%v: got %v, want %v", i, got, want)
   204  			}
   205  		}
   206  		if got, want := len(data), tc; got != want {
   207  			t.Errorf("%v: got %v, want %v", i, got, want)
   208  		}
   209  	}
   210  
   211  	if got, want := indexedBufs, wantIndexedBufs; got != want {
   212  		t.Errorf("got %v, want %v", got, want)
   213  	}
   214  	if got, want := indexedObjs, wantIndexedObjs; got != want {
   215  		t.Errorf("got %v, want %v", got, want)
   216  	}
   217  }
   218  
   219  func TestPackedMax(t *testing.T) {
   220  	max := deprecated.MaxPackedItems
   221  	deprecated.MaxPackedItems = 2
   222  	defer func() {
   223  		deprecated.MaxPackedItems = max
   224  	}()
   225  	out := &bytes.Buffer{}
   226  	wropts := deprecated.LegacyPackedWriterOpts{
   227  		MaxItems: 200,
   228  		MaxBytes: 0,
   229  	}
   230  	wr := deprecated.NewLegacyPackedWriter(out, wropts)
   231  	wr.Write([]byte("hello 1"))
   232  	wr.Write([]byte("hello 2"))
   233  	wr.Write([]byte("hello 3"))
   234  	wr.Write([]byte("hello 4"))
   235  	wr.Write([]byte("hello 5"))
   236  	if err := wr.Flush(); err != nil {
   237  		t.Fatal(err)
   238  	}
   239  	if got, want := bytes.Count(out.Bytes(), internal.MagicPacked[:]), 3; got != want {
   240  		t.Errorf("got %v, want %v", got, want)
   241  	}
   242  }
   243  
   244  func TestPackedErrors(t *testing.T) {
   245  	wropts := deprecated.LegacyPackedWriterOpts{
   246  		MaxItems: 2,
   247  		MaxBytes: 10,
   248  	}
   249  	buf := &bytes.Buffer{}
   250  	wr := deprecated.NewLegacyPackedWriter(buf, wropts)
   251  
   252  	bigbuf := [50]byte{}
   253  	_, err := wr.Write(bigbuf[:])
   254  	expect.HasSubstr(t, err, "buffer is too large 50 > 10")
   255  
   256  	wropts.Marshal = func(scratch []byte, v interface{}) ([]byte, error) {
   257  		return json.Marshal(v)
   258  	}
   259  	wr = deprecated.NewLegacyPackedWriter(buf, wropts)
   260  	_, err = wr.Marshal(bigbuf[:])
   261  	expect.HasSubstr(t, err, "buffer is too large 70 > 10")
   262  
   263  	writeError := func(offset int, short bool, msg string) {
   264  		ew := &fakeWriter{errAt: offset, short: short}
   265  		wropts := deprecated.LegacyPackedWriterOpts{MaxItems: 1, MaxBytes: 10}
   266  		wr := deprecated.NewLegacyPackedWriter(ew, wropts)
   267  		_, err := wr.Write([]byte("hello"))
   268  		expect.NoError(t, err, "first write succeeds")
   269  		_, err = wr.Write([]byte("hello"))
   270  		expect.HasSubstr(t, err, msg)
   271  	}
   272  	writeError(0, false, "recordio: failed to write header")
   273  	writeError(22, false, "recordio: failed to write record")
   274  	writeError(25, true, "recordio: buffered write too short")
   275  
   276  	wropts.Transform = func(bufs [][]byte) ([]byte, error) {
   277  		return nil, fmt.Errorf("transform oops")
   278  	}
   279  	wr = deprecated.NewLegacyPackedWriter(buf, wropts)
   280  	wr.Write([]byte("oh"))
   281  	err = wr.Flush()
   282  	expect.HasSubstr(t, err, "transform oops")
   283  
   284  	buf.Reset()
   285  	wr = deprecated.NewLegacyPackedWriter(buf, deprecated.LegacyPackedWriterOpts{})
   286  	wr.Write([]byte(""))
   287  	if err := wr.Flush(); err != nil {
   288  		t.Fatal(err)
   289  	}
   290  
   291  	shortReadError := func(offset int, msg string) {
   292  		f := buf.Bytes()
   293  		tmp := make([]byte, len(f))
   294  		copy(tmp, f)
   295  		s := deprecated.NewLegacyPackedScanner(bytes.NewBuffer(tmp[:offset]),
   296  			deprecated.LegacyPackedScannerOpts{})
   297  		if s.Scan() {
   298  			t.Errorf("expected false")
   299  		}
   300  		expect.HasSubstr(t, s.Err(), msg)
   301  	}
   302  	shortReadError(1, "unexpected EOF")
   303  	shortReadError(19, "unexpected EOF")
   304  	shortReadError(20, "short/long record")
   305  
   306  	scopts := deprecated.LegacyPackedScannerOpts{}
   307  	scopts.Transform = func(scratch, buf []byte) ([]byte, error) {
   308  		return nil, fmt.Errorf("transform oops")
   309  	}
   310  	sc := deprecated.NewLegacyPackedScanner(bytes.NewBuffer(buf.Bytes()), scopts)
   311  	if sc.Scan() {
   312  		t.Fatal("expected false")
   313  	}
   314  	if sc.Scan() {
   315  		t.Fatal("expect false")
   316  	}
   317  	expect.HasSubstr(t, sc.Err(), "transform oops")
   318  }
   319  
   320  func readAll(t *testing.T, buf *bytes.Buffer, opts deprecated.LegacyPackedScannerOpts) []string {
   321  	sc := deprecated.NewLegacyPackedScanner(buf, opts)
   322  	var read []string
   323  	for sc.Scan() {
   324  		read = append(read, string(sc.Bytes()))
   325  	}
   326  	assert.NoError(t, sc.Err())
   327  	return read
   328  }
   329  
   330  func TestPackedTransform(t *testing.T) {
   331  	// Prepend __ and append ++ to every record.
   332  	wropts := deprecated.LegacyPackedWriterOpts{
   333  		MaxItems: 2,
   334  		Transform: func(bufs [][]byte) ([]byte, error) {
   335  			r := []byte("__")
   336  			for _, b := range bufs {
   337  				r = append(r, b...)
   338  			}
   339  			return append(r, []byte("++")...), nil
   340  		},
   341  	}
   342  	buf := &bytes.Buffer{}
   343  	wr := deprecated.NewLegacyPackedWriter(buf, wropts)
   344  	data := []string{"Hello", "World", "How", "Are", "You?"}
   345  	for _, d := range data {
   346  		wr.Write([]byte(d))
   347  	}
   348  	wr.Flush()
   349  
   350  	// Scan with the __ and ++ in place. Each item in the each
   351  	// record is the same original size, but the contents are
   352  	// 'shifted' by the leading__
   353  	expected := []string{"__Hel", "loWor", "__H", "owA", "__Yo"}
   354  	saved := make([]byte, len(buf.Bytes()))
   355  	copy(saved, buf.Bytes())
   356  
   357  	read := readAll(t, buf, deprecated.LegacyPackedScannerOpts{})
   358  	if got, want := len(expected), len(expected); got != want {
   359  		t.Errorf("got %v, want %v", got, want)
   360  	}
   361  
   362  	for i, r := range read {
   363  		if got, want := r, expected[i]; got != want {
   364  			t.Errorf("%d: got %v, want %v", i, got, want)
   365  		}
   366  	}
   367  
   368  	scopts := deprecated.LegacyPackedScannerOpts{}
   369  	// Strip the leading ++ and trailing ++ while scanning
   370  	scopts.Transform = func(scratch, buf []byte) ([]byte, error) {
   371  		return buf[2 : len(buf)-2], nil
   372  	}
   373  	buf = bytes.NewBuffer(saved)
   374  	read = readAll(t, buf, scopts)
   375  	if got, want := len(read), len(data); got != want {
   376  		t.Errorf("got %v, want %v", got, want)
   377  	}
   378  
   379  	for i, r := range read {
   380  		if got, want := r, data[i]; got != want {
   381  			t.Errorf("%d: got %v, want %v", i, got, want)
   382  		}
   383  	}
   384  }
   385  
   386  func createBufs(nBufs, maxSize int) (map[string][]byte, error) {
   387  	rand.Seed(time.Now().UnixNano())
   388  	out := map[string][]byte{}
   389  	for i := 0; i < nBufs; i++ {
   390  		size := rand.Intn(maxSize)
   391  		if size == 0 {
   392  			size = maxSize / 2
   393  		}
   394  		buf := make([]byte, size)
   395  		n, err := rand.Read(buf)
   396  		if err != nil || n != cap(buf) {
   397  			return nil, fmt.Errorf("failed to generate %d bytes of random data: %d != %d: %v", cap(buf), n, cap(buf), err)
   398  		}
   399  		s := sha256.Sum256(buf)
   400  		k := hex.EncodeToString(s[:])
   401  		if _, present := out[k]; present {
   402  			// avoid dups.
   403  			i--
   404  			continue
   405  		}
   406  		out[k] = buf
   407  	}
   408  	return out, nil
   409  }
   410  
   411  func TestPackedConcurrentWrites(t *testing.T) {
   412  	nbufs := 200
   413  	maxBufSize := 1 * 1024 * 1024
   414  	data, err := createBufs(nbufs, maxBufSize)
   415  	assert.NoError(t, err)
   416  
   417  	tmpdir, cleanup := testutil.TempDir(t, "", "encrypted-test")
   418  	defer testutil.NoCleanupOnError(t, cleanup, "tmpdir: ", tmpdir)
   419  	buf := &bytes.Buffer{}
   420  	wr := deprecated.NewLegacyPackedWriter(buf, deprecated.LegacyPackedWriterOpts{MaxItems: 10})
   421  
   422  	var wg sync.WaitGroup
   423  	wg.Add(len(data))
   424  	ch := make(chan error, nbufs)
   425  	for _, v := range data {
   426  		go func(b []byte) {
   427  			_, err := wr.Write(b)
   428  			ch <- err
   429  			wg.Done()
   430  		}(v)
   431  	}
   432  
   433  	wg.Wait()
   434  	wr.Flush()
   435  	close(ch)
   436  	for err := range ch {
   437  		assert.NoError(t, err)
   438  	}
   439  
   440  	scanner := deprecated.NewLegacyPackedScanner(buf, deprecated.LegacyPackedScannerOpts{})
   441  	for scanner.Scan() {
   442  		buf := scanner.Bytes()
   443  		sum := sha256.Sum256(buf)
   444  		key := hex.EncodeToString(sum[:])
   445  		if _, present := data[key]; !present {
   446  			t.Errorf("corrupt/wrong data %v is not a sha256 of one of the test bufs", key)
   447  			continue
   448  		}
   449  		data[key] = nil
   450  	}
   451  	assert.NoError(t, scanner.Err())
   452  
   453  	for k, v := range data {
   454  		if v != nil {
   455  			t.Errorf("failed to read buffer with sha256 of %v", k)
   456  		}
   457  	}
   458  }