github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/recordio/deprecated/recordio_test.go (about)

     1  package deprecated_test
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	"os"
     9  	"reflect"
    10  	"strings"
    11  	"testing"
    12  
    13  	"github.com/Schaudge/grailbase/recordio/deprecated"
    14  	"github.com/Schaudge/grailbase/recordio/internal"
    15  	"github.com/grailbio/testutil/assert"
    16  	"github.com/grailbio/testutil/expect"
    17  	"github.com/klauspost/compress/gzip"
    18  	"v.io/x/lib/gosh"
    19  )
    20  
    21  func cat(args ...[]byte) []byte {
    22  	r := []byte{}
    23  	for _, a := range args {
    24  		r = append(r, a...)
    25  	}
    26  	return r
    27  }
    28  
    29  func cmp(a, b []byte) bool {
    30  	if len(a) != len(b) {
    31  		return false
    32  	}
    33  	for ia, va := range a {
    34  		if b[ia] != va {
    35  			return false
    36  		}
    37  	}
    38  	return true
    39  }
    40  
    41  func TestRecordioSimpleWriteRead(t *testing.T) {
    42  	c := func(args []byte, m string) []byte {
    43  		return cat(internal.MagicLegacyUnpacked[:], args, []byte(m))
    44  	}
    45  	for i, tc := range []struct {
    46  		in  string
    47  		out []byte
    48  	}{
    49  		{"", c([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0x69, 0xDF, 0x22, 0x65}, "")},
    50  		{"a", c([]byte{1, 0, 0, 0, 0, 0, 0, 0, 0xF7, 0xDF, 0x88, 0xA9}, "a")},
    51  		{"hello\n", c([]byte{6, 0, 0, 0, 0, 0, 0, 0, 0xEE, 0xD6, 0x4D, 0xA3}, "hello\n")},
    52  	} {
    53  		out := &bytes.Buffer{}
    54  		rw := deprecated.NewLegacyWriter(out, deprecated.LegacyWriterOpts{})
    55  		n, err := rw.Write([]byte(tc.in))
    56  		assert.NoError(t, err)
    57  		assert.EQ(t, n, len(tc.in))
    58  		if got, want := out.Bytes(), tc.out; !cmp(got, want) {
    59  			t.Errorf("%d: got %v, want %v", i, got, want)
    60  		}
    61  
    62  		s := deprecated.NewLegacyScanner(out, deprecated.LegacyScannerOpts{})
    63  		assert.True(t, s.Scan())
    64  		b := s.Bytes()
    65  		if got, want := b, []byte(tc.in); !cmp(got, want) {
    66  			t.Errorf("%d: got %v, want %v", i, got, want)
    67  		}
    68  		if got, want := s.Scan(), false; got != want {
    69  			t.Errorf("%d: got %v, want %v", i, got, want)
    70  		}
    71  		if err := s.Err(); err != nil {
    72  			t.Errorf("%d: %v", i, err)
    73  		}
    74  	}
    75  
    76  	out := &bytes.Buffer{}
    77  	rw := deprecated.NewLegacyWriter(out, deprecated.LegacyWriterOpts{})
    78  	n, err := rw.WriteSlices([]byte("hello"), []byte(" "), []byte("world"))
    79  	if err != nil {
    80  		t.Fatal(err)
    81  	}
    82  	if got, want := n, 11; got != want {
    83  		t.Errorf("got %v, want %v", got, want)
    84  	}
    85  	s := deprecated.NewLegacyScanner(out, deprecated.LegacyScannerOpts{})
    86  	s.Scan()
    87  	if got, want := s.Bytes(), []byte("hello world"); !bytes.Equal(got, want) {
    88  		t.Errorf("got %s, want %s", got, want)
    89  	}
    90  
    91  	n, err = rw.WriteSlices(nil, []byte("hello"), []byte(" "), []byte("world"))
    92  	if err != nil {
    93  		t.Fatal(err)
    94  	}
    95  	if got, want := n, 11; got != want {
    96  		t.Errorf("got %v, want %v", got, want)
    97  	}
    98  	s = deprecated.NewLegacyScanner(out, deprecated.LegacyScannerOpts{})
    99  	s.Scan()
   100  	if got, want := s.Bytes(), []byte("hello world"); !bytes.Equal(got, want) {
   101  		t.Errorf("got %s, want %s", got, want)
   102  	}
   103  }
   104  
   105  type fakeWriter struct {
   106  	errAt int
   107  	s     int
   108  	short bool // set for a 'short writes'
   109  }
   110  
   111  func (fw *fakeWriter) Write(p []byte) (n int, err error) {
   112  	fw.s += len(p)
   113  	if fw.s > fw.errAt {
   114  		if fw.short {
   115  			return 0, nil
   116  		}
   117  		return 0, fmt.Errorf("error at %d", fw.errAt)
   118  	}
   119  	return len(p), nil
   120  }
   121  
   122  type fakeReader struct {
   123  	buf   *bytes.Buffer
   124  	errAt int
   125  	s     int
   126  }
   127  
   128  func (fr *fakeReader) Read(p []byte) (n int, err error) {
   129  	b := fr.buf.Next(len(p))
   130  	fr.s += len(p)
   131  	if fr.s > fr.errAt {
   132  		return 0, fmt.Errorf("fail at %d", fr.errAt)
   133  	}
   134  	copy(p, b)
   135  	return len(b), nil
   136  }
   137  
   138  func TestRecordioErrors(t *testing.T) {
   139  	writeError := func(offset int, msg string) {
   140  		ew := &fakeWriter{errAt: offset}
   141  		rw := deprecated.NewLegacyWriter(ew, deprecated.LegacyWriterOpts{})
   142  		_, err := rw.Write([]byte("hello"))
   143  		expect.HasSubstr(t, err, msg)
   144  	}
   145  	writeError(0, "failed to write header")
   146  	writeError(21, "failed to write record")
   147  
   148  	marshalError := func(offset int, msg string) {
   149  		ew := &fakeWriter{errAt: offset}
   150  		rw := deprecated.NewLegacyWriter(ew, deprecated.LegacyWriterOpts{
   151  			Marshal: recordioMarshal,
   152  		})
   153  		_, err := rw.Marshal(&TestPB{"oops"})
   154  		expect.HasSubstr(t, err, msg)
   155  	}
   156  	marshalError(0, "failed to write header")
   157  	marshalError(21, "failed to write record")
   158  
   159  	buf := &bytes.Buffer{}
   160  	rw := deprecated.NewLegacyWriter(buf, deprecated.LegacyWriterOpts{})
   161  	rw.Write([]byte("hello\n"))
   162  
   163  	corruptionError := func(offset int, msg string) {
   164  		f := buf.Bytes()
   165  		tmp := make([]byte, len(f))
   166  		copy(tmp, f)
   167  		tmp[offset] = 0xff
   168  		s := deprecated.NewLegacyScanner(bytes.NewBuffer(tmp), deprecated.LegacyScannerOpts{})
   169  		if s.Scan() {
   170  			t.Errorf("expected false")
   171  		}
   172  		expect.HasSubstr(t, s.Err(), msg)
   173  	}
   174  
   175  	corruptionError(0, "invalid magic number")
   176  	corruptionError(10, "crc check failed")
   177  	corruptionError(17, "crc check failed")
   178  
   179  	shortReadError := func(offset int, msg string) {
   180  		f := buf.Bytes()
   181  		tmp := make([]byte, len(f))
   182  		copy(tmp, f)
   183  		s := deprecated.NewLegacyScanner(bytes.NewBuffer(tmp[:offset]), deprecated.LegacyScannerOpts{})
   184  		if s.Scan() {
   185  			t.Errorf("expected false")
   186  		}
   187  		expect.HasSubstr(t, s.Err(), msg)
   188  	}
   189  	shortReadError(1, "unexpected EOF")
   190  	shortReadError(19, "unexpected EOF")
   191  	shortReadError(20, "short/long record")
   192  
   193  	readError := func(offset int, msg string) {
   194  		f := buf.Bytes()
   195  		tmp := make([]byte, len(f))
   196  		copy(tmp, f)
   197  		rdr := &fakeReader{buf: bytes.NewBuffer(tmp), errAt: offset}
   198  		s := deprecated.NewLegacyScanner(rdr, deprecated.LegacyScannerOpts{})
   199  		if s.Scan() {
   200  			t.Errorf("expected false")
   201  		}
   202  		expect.HasSubstr(t, s.Err(), msg)
   203  	}
   204  	readError(1, "failed to read header")
   205  	readError(9, "failed to read header")
   206  	readError(19, "failed to read header")
   207  	readError(20, "failed to read record")
   208  
   209  	defer func(oldSize uint64) {
   210  		internal.MaxReadRecordSize = oldSize
   211  	}(internal.MaxReadRecordSize)
   212  	internal.MaxReadRecordSize = 100
   213  	buf.Reset()
   214  	rw = deprecated.NewLegacyWriter(buf, deprecated.LegacyWriterOpts{})
   215  	rw.Write([]byte(strings.Repeat("a", 101)))
   216  	s := deprecated.NewLegacyScanner(buf, deprecated.LegacyScannerOpts{})
   217  	if got, want := s.Scan(), false; got != want {
   218  		t.Errorf("got %v, want %v", got, want)
   219  	}
   220  	expect.HasSubstr(t, s.Err(), "unreasonably large read")
   221  
   222  	ew := &fakeWriter{errAt: 25}
   223  	wr := deprecated.NewLegacyWriter(ew, deprecated.LegacyWriterOpts{})
   224  	_, err := wr.WriteSlices([]byte("hello"), []byte(" "), []byte("world"))
   225  	expect.HasSubstr(t, err, "recordio: failed to write record")
   226  }
   227  
   228  func TestRecordioEmpty(t *testing.T) {
   229  	out := &bytes.Buffer{}
   230  	s := deprecated.NewLegacyScanner(out, deprecated.LegacyScannerOpts{})
   231  	if got, want := s.Scan(), false; got != want {
   232  		t.Errorf("got %v, want %v", got, want)
   233  	}
   234  	if s.Bytes() != nil {
   235  		t.Errorf("expected nil slice")
   236  	}
   237  	if got, want := s.Scan(), false; got != want {
   238  		t.Errorf("got %v, want %v", got, want)
   239  	}
   240  	if err := s.Err(); err != nil {
   241  		t.Errorf("%v", err)
   242  	}
   243  }
   244  
   245  func TestRecordioMultiple(t *testing.T) {
   246  	out := &bytes.Buffer{}
   247  	rw := deprecated.NewLegacyWriter(out, deprecated.LegacyWriterOpts{})
   248  	expected := []string{"", "hello", "world", "", "last record"}
   249  	for _, str := range expected {
   250  		n, err := rw.Write([]byte(str))
   251  		assert.NoError(t, err)
   252  		assert.True(t, n == len(str))
   253  	}
   254  
   255  	read := func(s deprecated.LegacyScanner) string {
   256  		if !s.Scan() {
   257  			expect.NoError(t, s.Err())
   258  			return "eof"
   259  		}
   260  		return string(s.Bytes())
   261  	}
   262  	s := deprecated.NewLegacyScanner(bytes.NewReader(out.Bytes()), deprecated.LegacyScannerOpts{})
   263  	for _, str := range expected {
   264  		expect.EQ(t, read(s), str)
   265  	}
   266  	expect.EQ(t, "eof", read(s))
   267  
   268  	s.Reset(bytes.NewReader(out.Bytes()))
   269  	for _, str := range expected[:2] {
   270  		expect.EQ(t, read(s), str)
   271  	}
   272  	s.Reset(bytes.NewReader(out.Bytes()))
   273  	for _, str := range expected {
   274  		expect.EQ(t, read(s), str)
   275  	}
   276  	expect.EQ(t, read(s), "eof")
   277  }
   278  
   279  type lazyReader struct {
   280  	rd io.Reader
   281  }
   282  
   283  func (lz *lazyReader) Read(p []byte) (n int, err error) {
   284  	lazy := len(p)
   285  	if lazy > 10 {
   286  		lazy -= 10
   287  	}
   288  	return io.ReadFull(lz.rd, p[:lazy])
   289  }
   290  
   291  func TestRecordioWriteRead(t *testing.T) {
   292  	gosh := gosh.NewShell(t)
   293  	f := gosh.MakeTempFile()
   294  	rw := deprecated.NewLegacyWriter(f, deprecated.LegacyWriterOpts{})
   295  	contents := []string{"hello", "world", "!"}
   296  	for _, rec := range contents {
   297  		rw.Write([]byte(rec))
   298  	}
   299  	name := f.Name()
   300  	f.Close()
   301  
   302  	o := gosh.Cmd("gzip", "--keep", name).CombinedOutput()
   303  	if got, want := o, ""; got != want {
   304  		t.Errorf("got %v, want %v", got, want)
   305  	}
   306  
   307  	readall := func(rd io.Reader) ([][]byte, error) {
   308  		r := [][]byte{}
   309  		s := deprecated.NewLegacyScanner(rd, deprecated.LegacyScannerOpts{})
   310  		for s.Scan() {
   311  			r = append(r, s.Bytes())
   312  		}
   313  		return r, s.Err()
   314  	}
   315  
   316  	for _, n := range []string{name} {
   317  		rd, err := os.Open(n)
   318  		if err != nil {
   319  			t.Fatalf("%v: %v", n, err)
   320  		}
   321  		gz, err := os.Open(n + ".gz")
   322  		if err != nil {
   323  			t.Fatalf("%v: %v", n, err)
   324  		}
   325  		gzrd, err := gzip.NewReader(gz)
   326  		if err != nil {
   327  			t.Fatalf("%v: %v", n, err)
   328  		}
   329  
   330  		raw, err := readall(rd)
   331  		if err != nil {
   332  			t.Fatalf("%v (raw): read %d records, %v", n, len(raw), err)
   333  		}
   334  
   335  		compressed, err := readall(gzrd)
   336  		if err != nil {
   337  			t.Fatalf("%v (gzip): read %d records (%d raw), %v", n, len(compressed), len(raw), err)
   338  		}
   339  
   340  		buf, err := ioutil.ReadFile(n)
   341  		if err != nil {
   342  			t.Fatal(err)
   343  		}
   344  		lzr := &lazyReader{rd: bytes.NewBuffer(buf)}
   345  		lazy, err := readall(lzr)
   346  		if err != nil {
   347  			t.Fatalf("%v (lazyreader): read %d records (%d raw), %v", n, len(lazy), len(raw), err)
   348  
   349  		}
   350  		if got, want := len(raw), len(compressed); got != want {
   351  			t.Errorf("%v: got %v, want %v", n, got, want)
   352  		}
   353  		if got, want := raw, compressed; !reflect.DeepEqual(got, want) {
   354  			t.Errorf("%v: got %v, want %v", n, got, want)
   355  		}
   356  		if got, want := raw, lazy; !reflect.DeepEqual(got, want) {
   357  			t.Errorf("%v: got %v, want %v", n, got, want)
   358  		}
   359  	}
   360  }