github.com/grailbio/base@v0.0.11/recordio/deprecated/index_test.go (about)

     1  // Copyright 2017 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package deprecated_test
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"math/rand"
    11  	"reflect"
    12  	"testing"
    13  
    14  	"github.com/golang/protobuf/proto"
    15  	"github.com/grailbio/base/recordio/deprecated"
    16  	"github.com/grailbio/testutil/expect"
    17  )
    18  
    19  func TestRecordioIndex(t *testing.T) {
    20  	type indexEntry struct {
    21  		offset, extent uint64
    22  		p              []byte
    23  		v              interface{}
    24  		first, flushed bool
    25  	}
    26  	nItems := 13
    27  
    28  	seenOffset := map[uint64]bool{}
    29  	index := []*indexEntry{}
    30  	buf := &bytes.Buffer{}
    31  	wropts := deprecated.LegacyWriterOpts{
    32  		Marshal: recordioMarshal,
    33  		Index: func(offset, extent uint64, v interface{}, p []byte) error {
    34  			if seenOffset[offset] {
    35  				t.Errorf("duplicate index entry")
    36  			}
    37  			seenOffset[offset] = true
    38  			index = append(index, &indexEntry{offset, extent, p, v, true, false})
    39  			return nil
    40  		},
    41  	}
    42  	wr := deprecated.NewLegacyWriter(buf, wropts)
    43  
    44  	data := []string{}
    45  	for i := 0; i < nItems; i++ {
    46  		msg := fmt.Sprintf("hello: %v", rand.Int())
    47  		data = append(data, msg)
    48  		wr.Marshal(&TestPB{msg})
    49  	}
    50  	if got, want := len(index), nItems; got != want {
    51  		t.Fatalf("got %v, want %v", got, want)
    52  	}
    53  
    54  	underlying := bytes.NewReader(buf.Bytes())
    55  	scopts := deprecated.LegacyScannerOpts{
    56  		Unmarshal: recordioUnmarshal,
    57  	}
    58  	for i, entry := range index {
    59  		br, _ := deprecated.NewRangeReader(underlying, int64(entry.offset), int64(entry.extent))
    60  		sc := deprecated.NewLegacyScanner(br, scopts)
    61  		msg := &TestPB{}
    62  		if !sc.Scan() {
    63  			t.Fatalf("%v: %v", i, sc.Err())
    64  		}
    65  		sc.Unmarshal(msg)
    66  		if got, want := msg, entry.v; !reflect.DeepEqual(got, want) {
    67  			t.Errorf("%v: got %v, want %v", entry.offset, got, want)
    68  		}
    69  	}
    70  
    71  	wropts.Marshal = nil
    72  	buf.Reset()
    73  	seenOffset = map[uint64]bool{}
    74  	index = []*indexEntry{}
    75  	wr = deprecated.NewLegacyWriter(buf, wropts)
    76  
    77  	raw := [][]byte{}
    78  	for i := 0; i < nItems; i++ {
    79  		msg := fmt.Sprintf("hello: %v", rand.Int())
    80  		p, err := proto.Marshal(&TestPB{msg})
    81  		if err != nil {
    82  			t.Fatal(err)
    83  		}
    84  		raw = append(raw, p)
    85  	}
    86  
    87  	saved := [][]byte{}
    88  	wr.Write(raw[0:1][0])
    89  	saved = append(saved, raw[0:1][0])
    90  	wr.Write(raw[1:2][0])
    91  	saved = append(saved, raw[1:2][0])
    92  	wr.WriteSlices(raw[2], raw[3:10]...)
    93  	saved = append(saved, bytes.Join(raw[2:10], nil))
    94  	wr.WriteSlices(nil, raw[11:nItems]...)
    95  	saved = append(saved, bytes.Join(raw[11:nItems], nil))
    96  
    97  	if got, want := len(index), 4; got != want {
    98  		t.Fatalf("got %v, want %v", got, want)
    99  	}
   100  
   101  	underlying = bytes.NewReader(buf.Bytes())
   102  	for i, entry := range index {
   103  		br, _ := deprecated.NewRangeReader(underlying, int64(entry.offset), int64(entry.extent))
   104  		sc := deprecated.NewLegacyScanner(br, deprecated.LegacyScannerOpts{})
   105  		if got, want := entry.v, interface{}(nil); got != want {
   106  			t.Errorf("%d: got %v, want %v", i, got, want)
   107  		}
   108  		if got := entry.p; got != nil {
   109  			t.Errorf("%d: got %v not nil", i, got)
   110  		}
   111  		if !sc.Scan() {
   112  			t.Fatalf("%v: %v", i, sc.Err())
   113  		}
   114  		if got, want := sc.Bytes(), saved[i]; !bytes.Equal(got, want) {
   115  			t.Errorf("%d: got %v, want %v", i, got, want)
   116  		}
   117  	}
   118  
   119  	// Packed recordio indices.
   120  	buf.Reset()
   121  	seenOffset = map[uint64]bool{}
   122  	index = []*indexEntry{}
   123  	pwropts := deprecated.LegacyPackedWriterOpts{
   124  		MaxItems: 2,
   125  	}
   126  	pwropts.Marshal = recordioMarshal
   127  	pwropts.Index = func(record, recordSize, nitems uint64) (deprecated.ItemIndexFunc, error) {
   128  		if seenOffset[record] {
   129  			t.Errorf("duplicate index entry")
   130  		}
   131  		if nitems > 2 {
   132  			t.Errorf("too many items")
   133  		}
   134  		seenOffset[record] = true
   135  		index = append(index, &indexEntry{offset: record, extent: recordSize, p: nil, v: nil, first: true, flushed: false})
   136  		return func(offset, extent uint64, v interface{}, p []byte) error {
   137  			index = append(index, &indexEntry{offset: offset, extent: extent, p: p, v: v, first: false, flushed: false})
   138  			return nil
   139  		}, nil
   140  	}
   141  
   142  	pwropts.Flushed = func() error {
   143  		index = append(index, &indexEntry{flushed: true})
   144  		return nil
   145  	}
   146  
   147  	pwr := deprecated.NewLegacyPackedWriter(buf, pwropts)
   148  	for i := 0; i < nItems; i++ {
   149  		pwr.Marshal(&TestPB{data[i]})
   150  	}
   151  	pwr.Flush()
   152  	nrecords := nItems / int(pwropts.MaxItems)
   153  	if (nItems % int(pwropts.MaxItems)) > 0 {
   154  		nrecords++
   155  	}
   156  
   157  	// number of items, number of records, including flushes.
   158  	indexSize := nItems + 2*nrecords
   159  	if got, want := len(index), indexSize; got != want {
   160  		t.Errorf("got %v, want %v", got, want)
   161  	}
   162  
   163  	underlying = bytes.NewReader(buf.Bytes())
   164  	nfirst := 0
   165  	nflushed := 0
   166  	var sc deprecated.LegacyPackedScanner
   167  	pscopts := deprecated.LegacyPackedScannerOpts{}
   168  	pscopts.Unmarshal = recordioUnmarshal
   169  
   170  	// Random access to a record, sequential scan within it.
   171  	for i, entry := range index {
   172  		if entry.first {
   173  			nfirst++
   174  			br, _ := deprecated.NewRangeReader(underlying, int64(entry.offset), int64(entry.extent))
   175  			sc = deprecated.NewLegacyPackedScanner(br, pscopts)
   176  			continue
   177  		}
   178  		if entry.flushed {
   179  			nflushed++
   180  			continue
   181  		}
   182  		if !sc.Scan() {
   183  			t.Fatalf("%v: %v", i, sc.Err())
   184  		}
   185  		if got, want := sc.Bytes(), entry.p; !bytes.Equal(got, want) {
   186  			t.Errorf("%v: got %x, want %x", i, got, want)
   187  		}
   188  		msg := &TestPB{}
   189  		sc.Unmarshal(msg)
   190  		if got, want := msg, entry.v; !reflect.DeepEqual(got, want) {
   191  			t.Errorf("%v: got %v, want %v", i, got, want)
   192  		}
   193  	}
   194  
   195  	if got, want := nfirst, nrecords; got != want {
   196  		t.Errorf("got %v, want %v", got, want)
   197  	}
   198  	if got, want := nflushed, nrecords; got != want {
   199  		t.Errorf("got %v, want %v", got, want)
   200  	}
   201  
   202  	underlying = bytes.NewReader(buf.Bytes())
   203  	record := []byte{}
   204  	// Random access to a record, random access to items within it.
   205  	for i, entry := range index {
   206  		if entry.first {
   207  			br, _ := deprecated.NewRangeReader(underlying, int64(entry.offset), int64(entry.extent))
   208  			sc = deprecated.NewLegacyPackedScanner(br, pscopts)
   209  			if !sc.Scan() {
   210  				t.Fatalf("%v: %v", i, sc.Err())
   211  			}
   212  			record = sc.Bytes()
   213  			continue
   214  		}
   215  		if entry.flushed {
   216  			continue
   217  		}
   218  		item := record[entry.offset : entry.offset+entry.extent]
   219  		if got, want := item, entry.p; !bytes.Equal(got, want) {
   220  			t.Errorf("%v:%v: got %x, want %x", i, entry.offset, got, want)
   221  		}
   222  		msg := &TestPB{}
   223  		recordioUnmarshal(item, msg)
   224  		if got, want := msg, entry.v; !reflect.DeepEqual(got, want) {
   225  			t.Errorf("%v: %v: got %v, want %v", i, entry.offset, got, want)
   226  		}
   227  	}
   228  
   229  }
   230  
   231  func TestIndexErrors(t *testing.T) {
   232  	buf := &bytes.Buffer{}
   233  
   234  	wropts := deprecated.LegacyWriterOpts{
   235  		Index: func(offset, extent uint64, v interface{}, p []byte) error {
   236  			return fmt.Errorf("index oops")
   237  		},
   238  	}
   239  	wr := deprecated.NewLegacyWriter(buf, wropts)
   240  	_, err := wr.Write([]byte("hello"))
   241  	expect.HasSubstr(t, err, "index oops")
   242  
   243  	_, err = wr.WriteSlices([]byte("hello"), []byte("world"))
   244  	expect.HasSubstr(t, err, "index oops")
   245  
   246  	wropts = deprecated.LegacyWriterOpts{
   247  		Marshal: recordioMarshal,
   248  		Index: func(offset, extent uint64, v interface{}, p []byte) error {
   249  			return fmt.Errorf("index oops")
   250  		},
   251  	}
   252  	wr = deprecated.NewLegacyWriter(buf, wropts)
   253  	_, err = wr.Marshal(&TestPB{"x"})
   254  	expect.HasSubstr(t, err, "index oops")
   255  
   256  	wr = deprecated.NewLegacyWriter(buf, wropts)
   257  	_, err = wr.Write([]byte("hello"))
   258  	expect.HasSubstr(t, err, "index oops")
   259  
   260  	pwropts := deprecated.LegacyPackedWriterOpts{
   261  		Marshal: recordioMarshal,
   262  		Index: func(record, recordSize, items uint64) (deprecated.ItemIndexFunc, error) {
   263  			if items != 1 {
   264  				t.Errorf("got %v, want 1", items)
   265  			}
   266  			return nil, fmt.Errorf("packed record index oops")
   267  		},
   268  		MaxItems: 1,
   269  	}
   270  	pwr := deprecated.NewLegacyPackedWriter(buf, pwropts)
   271  	_, err = pwr.Marshal(&TestPB{"x"})
   272  	expect.NoError(t, err, "Marshall packed writer")
   273  	err = pwr.Flush()
   274  	expect.HasSubstr(t, err, "packed record index oops")
   275  
   276  	pwropts.Index = func(record, recordSize, items uint64) (deprecated.ItemIndexFunc, error) {
   277  		return func(offset, extent uint64, v interface{}, p []byte) error {
   278  			return fmt.Errorf("packed item index oops")
   279  		}, nil
   280  	}
   281  	pwr = deprecated.NewLegacyPackedWriter(buf, pwropts)
   282  	_, err = pwr.Marshal(&TestPB{"x"})
   283  	expect.NoError(t, err, "Marshall packed writer first try")
   284  	_, err = pwr.Marshal(&TestPB{"y"})
   285  	expect.HasSubstr(t, err, "packed item index oops")
   286  
   287  	pwropts.Index = nil
   288  	pwropts.Flushed = func() error {
   289  		return fmt.Errorf("packed flush oops")
   290  	}
   291  	pwr = deprecated.NewLegacyPackedWriter(buf, pwropts)
   292  	_, err = pwr.Marshal(&TestPB{"z"})
   293  	expect.NoError(t, err, "Marshall packed writer")
   294  	err = pwr.Flush()
   295  	expect.HasSubstr(t, err, "packed flush oops")
   296  }