github.com/grailbio/base@v0.0.11/recordio/deprecated/packer_test.go (about)

     1  // Copyright 2017 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package deprecated_test
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"strings"
    11  	"testing"
    12  
    13  	"github.com/grailbio/base/recordio/deprecated"
    14  	"github.com/grailbio/testutil"
    15  	"github.com/grailbio/testutil/expect"
    16  )
    17  
    18  func expectStats(t *testing.T, depth int, wr *deprecated.Packer, eni, enb int) {
    19  	ni, nb := wr.Stored()
    20  	if got, want := ni, eni; got != want {
    21  		t.Errorf("%v: got %v, want %v", testutil.Caller(depth), got, want)
    22  	}
    23  	if got, want := nb, enb; got != want {
    24  		t.Errorf("%v: got %v, want %v", testutil.Caller(depth), got, want)
    25  	}
    26  }
    27  
    28  func countBytes(b [][]byte) int {
    29  	s := 0
    30  	for _, l := range b {
    31  		s += len(l)
    32  	}
    33  	return s
    34  }
    35  
    36  func cmpComplete(t *testing.T, wr *deprecated.Packer, rd *deprecated.Unpacker, wBufs [][]byte) {
    37  	wDS := countBytes(wBufs)
    38  	expectStats(t, 2, wr, len(wBufs), wDS)
    39  	hdr, gDS, gBufs, err := wr.Pack()
    40  	if err != nil {
    41  		t.Fatalf("%v: %v", testutil.Caller(1), err)
    42  	}
    43  	if got, want := gDS, wDS; got != want {
    44  		t.Errorf("%v: got %v, want %v", testutil.Caller(1), got, want)
    45  	}
    46  	if got, want := len(gBufs), len(wBufs); got != want {
    47  		t.Errorf("%v: got %v, want %v", testutil.Caller(1), got, want)
    48  	}
    49  
    50  	for i, l := range gBufs {
    51  		if got, want := l, wBufs[i]; !bytes.Equal(got, want) {
    52  			t.Errorf("%v: got %s, want %s", testutil.Caller(1), got, want)
    53  		}
    54  	}
    55  
    56  	rbuf := bytes.Join(append([][]byte{hdr}, gBufs...), nil)
    57  	rBufs, err := rd.Unpack(rbuf)
    58  	if err != nil {
    59  		t.Fatal(err)
    60  	}
    61  
    62  	if got, want := len(rBufs), len(wBufs); got != want {
    63  		t.Errorf("%v: got %v, want %v", testutil.Caller(1), got, want)
    64  	}
    65  
    66  	for i, l := range rBufs {
    67  		if got, want := l, wBufs[i]; !bytes.Equal(got, want) {
    68  			t.Errorf("%v: got %s, want %s", testutil.Caller(1), got, want)
    69  		}
    70  	}
    71  }
    72  
    73  func TestPacker(t *testing.T) {
    74  	wr := deprecated.NewPacker(deprecated.PackerOpts{})
    75  	rd := deprecated.NewUnpacker(deprecated.UnpackerOpts{})
    76  	// Pack on empty has no effect.
    77  	_, _, _, err := wr.Pack()
    78  	if err != nil {
    79  		t.Fatal(err)
    80  	}
    81  	expectStats(t, 1, wr, 0, 0)
    82  	msg := []string{"hello", "world"}
    83  	bufs := [][]byte{}
    84  	for _, d := range msg {
    85  		wr.Write([]byte(d))
    86  		bufs = append(bufs, []byte(d))
    87  	}
    88  
    89  	expectStats(t, 1, wr, 2, 10)
    90  	cmpComplete(t, wr, rd, bufs)
    91  	expectStats(t, 1, wr, 0, 0)
    92  	// Pack is not idempotent.
    93  	hdr, _, _, err := wr.Pack()
    94  	if err != nil {
    95  		t.Fatal(err)
    96  	}
    97  	if got, want := len(hdr), 0; got != want {
    98  		t.Errorf("got %v, want %v", got, want)
    99  	}
   100  	expectStats(t, 1, wr, 0, 0)
   101  
   102  	msg = []string{"and", "again", "hello", "there"}
   103  	bufs = [][]byte{}
   104  	for _, d := range msg {
   105  		wr.Write([]byte(d))
   106  		bufs = append(bufs, []byte(d))
   107  	}
   108  	expectStats(t, 1, wr, 4, 18)
   109  	cmpComplete(t, wr, rd, bufs)
   110  }
   111  
   112  func TestPackerReuse(t *testing.T) {
   113  	buffers := make([][]byte, 1, 3)
   114  	wr := deprecated.NewPacker(deprecated.PackerOpts{
   115  		Buffers: buffers[1:],
   116  	})
   117  	msg := []string{"hello", "world"}
   118  	for _, d := range msg {
   119  		wr.Write([]byte(d))
   120  	}
   121  	hdr, size, bufs, _ := wr.Pack()
   122  
   123  	record := bytes.Join(append([][]byte{hdr}, bufs...), nil)
   124  
   125  	buffers[0] = hdr
   126  	buffers = buffers[:len(bufs)+1]
   127  	if got, want := hdr, buffers[0]; !bytes.Equal(got, want) {
   128  		t.Errorf("got %v, want %v", got, want)
   129  	}
   130  	nsize := 0
   131  	for i, b := range bufs {
   132  		if got, want := b, buffers[i+1]; !bytes.Equal(got, want) {
   133  			t.Errorf("got %v, want %v", got, want)
   134  		}
   135  		nsize += len(b)
   136  	}
   137  	if got, want := size, nsize; got != want {
   138  		t.Errorf("got %v, want %v", got, want)
   139  	}
   140  	if got, want := cap(bufs)+1, cap(buffers); got != want {
   141  		t.Errorf("got %v, want %v", got, want)
   142  	}
   143  
   144  	rdbuffers := make([][]byte, 0, 2)
   145  	rd := deprecated.NewUnpacker(deprecated.UnpackerOpts{
   146  		Buffers: rdbuffers,
   147  	})
   148  	bufs, _ = rd.Unpack(record)
   149  
   150  	if got, want := cap(bufs), cap(rdbuffers); got != want {
   151  		t.Errorf("got %v, want %v", got, want)
   152  	}
   153  
   154  	rdbuffers = rdbuffers[:len(bufs)]
   155  	for i, b := range bufs {
   156  		if got, want := b, rdbuffers[i]; !bytes.Equal(got, want) {
   157  			t.Errorf("got %v, want %v", got, want)
   158  		}
   159  		nsize += len(b)
   160  	}
   161  
   162  	// If the number of buffers written exceeds the capacity of
   163  	// the originally supplied Buffers slice, a new one will be
   164  	// created and used by append.
   165  	msg = []string{"hello", "world", "oh", "the", "buffer", "grows"}
   166  	for _, d := range msg {
   167  		wr.Write([]byte(d))
   168  	}
   169  	hdr, _, bufs, _ = wr.Pack()
   170  	record = bytes.Join(append([][]byte{hdr}, bufs...), nil)
   171  
   172  	if got, want := cap(bufs), cap(buffers); got <= want {
   173  		t.Errorf("got %v, want > %v", got, want)
   174  	}
   175  
   176  	// unpack will create a new slice too.
   177  	bufs, _ = rd.Unpack(record)
   178  	if got, want := cap(bufs), cap(rdbuffers); got <= want {
   179  		t.Errorf("got %v, want > %v", got, want)
   180  	}
   181  
   182  }
   183  
   184  func TestPackerTransform(t *testing.T) {
   185  	// Prepend __ and append ++ to every record.
   186  	wropts := deprecated.PackerOpts{
   187  		Transform: func(bufs [][]byte) ([]byte, error) {
   188  			r := []byte("__")
   189  			for _, b := range bufs {
   190  				r = append(r, b...)
   191  			}
   192  			return append(r, []byte("++")...), nil
   193  		},
   194  	}
   195  	wr := deprecated.NewPacker(wropts)
   196  	data := []string{"Hello", "World", "How", "Are", "You?"}
   197  	for _, d := range data {
   198  		wr.Write([]byte(d))
   199  	}
   200  	hdr, _, bufs, _ := wr.Pack()
   201  	record := bytes.Join(append([][]byte{hdr}, bufs...), nil)
   202  
   203  	// Flattening out the buffers and prepending __ and appending ++
   204  	if got, want := string(bytes.Join(bufs, nil)), "__"+strings.Join(data, "")+"++"; got != want {
   205  		t.Errorf("got %v, want %v", got, want)
   206  	}
   207  
   208  	// Scan with the __ and ++ in place. Each item in the each
   209  	// record is the same original size, but the contents are
   210  	// 'shifted' by the leading__
   211  	expected := []string{"__Hel", "loWor", "ldH", "owA", "reYo"}
   212  
   213  	rdopts := deprecated.UnpackerOpts{}
   214  	rd := deprecated.NewUnpacker(rdopts)
   215  	read, _ := rd.Unpack(record)
   216  
   217  	if got, want := len(expected), len(expected); got != want {
   218  		t.Errorf("got %v, want %v", got, want)
   219  	}
   220  
   221  	for i, r := range read {
   222  		if got, want := string(r), expected[i]; got != want {
   223  			t.Errorf("%d: got %v, want %v", i, got, want)
   224  		}
   225  	}
   226  
   227  	rdopts = deprecated.UnpackerOpts{
   228  		// Strip the leading ++ and trailing ++ while scanning
   229  		Transform: func(scratch, buf []byte) ([]byte, error) {
   230  			return buf[2 : len(buf)-2], nil
   231  		},
   232  	}
   233  	rd = deprecated.NewUnpacker(rdopts)
   234  	read, _ = rd.Unpack(record)
   235  	if got, want := len(read), len(data); got != want {
   236  		t.Errorf("got %v, want %v", got, want)
   237  	}
   238  
   239  	for i, r := range read {
   240  		if got, want := string(r), data[i]; got != want {
   241  			t.Errorf("%d: got %v, want %v", i, got, want)
   242  		}
   243  	}
   244  }
   245  
   246  func TestPackerTransformErrors(t *testing.T) {
   247  	wropts := deprecated.PackerOpts{
   248  		Transform: func(bufs [][]byte) ([]byte, error) {
   249  			return nil, fmt.Errorf("transform oops")
   250  		},
   251  	}
   252  	wr := deprecated.NewPacker(wropts)
   253  	wr.Write([]byte("oh"))
   254  	_, _, _, err := wr.Pack()
   255  	expect.HasSubstr(t, err, "transform oops")
   256  
   257  	wropts.Transform = nil
   258  	wr = deprecated.NewPacker(wropts)
   259  	wr.Write([]byte("oh"))
   260  	wr.Write([]byte("ah"))
   261  	hdr, _, bufs, _ := wr.Pack()
   262  
   263  	record := bytes.Join(append([][]byte{hdr}, bufs...), nil)
   264  
   265  	rdopts := deprecated.UnpackerOpts{}
   266  	rdopts.Transform = func(scratch, buf []byte) ([]byte, error) {
   267  		return nil, fmt.Errorf("transform oops")
   268  	}
   269  
   270  	rd := deprecated.NewUnpacker(rdopts)
   271  	_, err = rd.Unpack(record)
   272  	expect.HasSubstr(t, err, "transform oops")
   273  
   274  	rdopts.Transform = func(scratch, buf []byte) ([]byte, error) {
   275  		return nil, nil
   276  	}
   277  
   278  	rd = deprecated.NewUnpacker(rdopts)
   279  	_, err = rd.Unpack(record)
   280  	expect.HasSubstr(t, err, "offset greater than buf size")
   281  }
   282  
   283  func TestPackerErrors(t *testing.T) {
   284  	wr := deprecated.NewPacker(deprecated.PackerOpts{})
   285  	msg := []string{"hello", "world"}
   286  	for _, d := range msg {
   287  		wr.Write([]byte(d))
   288  	}
   289  	hdr, _, bufs, _ := wr.Pack()
   290  	record := bytes.Join(append([][]byte{hdr}, bufs...), nil)
   291  
   292  	shortReadError := func(offset int, msg string) {
   293  		rd := deprecated.NewUnpacker(deprecated.UnpackerOpts{})
   294  		_, err := rd.Unpack(record[:offset])
   295  		expect.HasSubstr(t, err, msg)
   296  	}
   297  	shortReadError(1, "failed to read crc32")
   298  	shortReadError(4, "failed to read number of packed items")
   299  	shortReadError(5, "likely corrupt data, failed to read size of packed item")
   300  	shortReadError(10, "offset greater than buf size")
   301  
   302  	corruptionError := func(offset int, msg string, ow ...byte) {
   303  		tmp := make([]byte, len(record))
   304  		copy(tmp, record)
   305  		for i, v := range ow {
   306  			tmp[offset+i] = v
   307  		}
   308  		rd := deprecated.NewUnpacker(
   309  			deprecated.UnpackerOpts{})
   310  		_, err := rd.Unpack(tmp)
   311  		expect.HasSubstr(t, err, msg)
   312  	}
   313  	tmp := record[2]
   314  	corruptionError(2, "crc check failed - corrupt packed record header", tmp+1)
   315  	corruptionError(4, "likely corrupt data, number of packed items exceeds", 0x7f)
   316  	corruptionError(4, "likely corrupt data, failed to read size of packed item", 0x0f)
   317  	corruptionError(5, "crc check failed - corrupt packed record header", 0x7f)
   318  }
   319  
   320  func TestObjectPacker(t *testing.T) {
   321  	objects := make([]interface{}, 1000)
   322  	op := deprecated.NewObjectPacker(objects, recordioMarshal, deprecated.ObjectPackerOpts{})
   323  	op.Marshal(&TestPB{"hello"})
   324  	op.Marshal(&TestPB{"world"})
   325  	objs, _ := op.Contents()
   326  	if got, want := len(objs), 2; got != want {
   327  		t.Errorf("got %v, want %v", got, want)
   328  	}
   329  	if got, want := objs[1].(*TestPB).Message, "world"; got != want {
   330  		t.Errorf("got %v, want %v", got, want)
   331  	}
   332  	if got, want := objects[0].(*TestPB).Message, "hello"; got != want {
   333  		t.Errorf("got %v, want %v", got, want)
   334  	}
   335  }
   336  
   337  func TestObjectPackerErrors(t *testing.T) {
   338  	// ObjectPacker is tested in concurrent_test
   339  	objects := make([]interface{}, 1000)
   340  	op := deprecated.NewObjectPacker(objects, func(scratch []byte, v interface{}) ([]byte, error) {
   341  		return nil, fmt.Errorf("marshal oops")
   342  	}, deprecated.ObjectPackerOpts{})
   343  	err := op.Marshal(&TestPB{"hello"})
   344  	expect.HasSubstr(t, err, "marshal oops")
   345  }