github.com/grailbio/base@v0.0.11/recordio/v2_test.go (about) 1 // Copyright 2018 GRAIL, Inc. All rights reserved. 2 // Use of this source code is governed by the Apache-2.0 3 // license that can be found in the LICENSE file. 4 5 package recordio_test 6 7 import ( 8 "bytes" 9 "fmt" 10 "math/rand" 11 "strconv" 12 "sync/atomic" 13 "testing" 14 15 "github.com/grailbio/base/recordio" 16 "github.com/grailbio/base/recordio/deprecated" 17 "github.com/grailbio/base/recordio/internal" 18 "github.com/grailbio/base/recordio/recordioiov" 19 "github.com/grailbio/base/recordio/recordiozstd" 20 "github.com/grailbio/testutil/assert" 21 "github.com/grailbio/testutil/expect" 22 ) 23 24 func init() { recordiozstd.Init() } 25 26 func marshalString(scratch []byte, v interface{}) ([]byte, error) { 27 return []byte(v.(string)), nil 28 } 29 30 func unmarshalString(data []byte) (interface{}, error) { 31 return string(data), nil 32 } 33 34 func readAllV2(t *testing.T, buf *bytes.Buffer) (recordio.ParsedHeader, []string, string) { 35 sc := recordio.NewScanner(bytes.NewReader(buf.Bytes()), recordio.ScannerOpts{ 36 Unmarshal: unmarshalString, 37 }) 38 header := sc.Header() 39 trailer := string(sc.Trailer()) 40 var body []string 41 for sc.Scan() { 42 body = append(body, sc.Get().(string)) 43 } 44 expect.False(t, sc.Scan()) // Scan() calls after EOF should return false. 45 expect.NoError(t, sc.Err()) 46 return header, body, trailer 47 } 48 49 // Test reading a packed v1 file. 50 func TestReadV1Packed(t *testing.T) { 51 buf := &bytes.Buffer{} 52 w := deprecated.NewLegacyPackedWriter(buf, deprecated.LegacyPackedWriterOpts{}) 53 w.Write([]byte("Foo")) 54 w.Write([]byte("Baz")) 55 w.Flush() 56 57 _, body, trailer := readAllV2(t, buf) 58 expect.EQ(t, "", trailer) 59 expect.EQ(t, []string{"Foo", "Baz"}, body) 60 } 61 62 // Test reading an unpacked v1 file. 63 func TestReadV1Unpacked(t *testing.T) { 64 buf := &bytes.Buffer{} 65 w := deprecated.NewLegacyWriter(buf, deprecated.LegacyWriterOpts{}) 66 w.Write([]byte("Foo")) 67 w.Write([]byte("Baz")) 68 69 _, body, trailer := readAllV2(t, buf) 70 expect.EQ(t, "", trailer) 71 expect.EQ(t, []string{"Foo", "Baz"}, body) 72 } 73 74 func TestEmptyFile(t *testing.T) { 75 _, body, trailer := readAllV2(t, &bytes.Buffer{}) 76 expect.EQ(t, "", trailer) 77 expect.EQ(t, []string(nil), body) 78 } 79 80 func TestEmptyBody(t *testing.T) { 81 buf := &bytes.Buffer{} 82 wr := recordio.NewWriter(buf, recordio.WriterOpts{Marshal: marshalString}) 83 assert.NoError(t, wr.Finish()) 84 assert.EQ(t, len(buf.Bytes()), internal.ChunkSize) // one header chunk 85 header, body, trailer := readAllV2(t, buf) 86 assert.EQ(t, recordio.ParsedHeader(nil), header) 87 assert.EQ(t, []string(nil), body) 88 assert.EQ(t, "", trailer) 89 } 90 91 func TestFlushEmpty(t *testing.T) { 92 buf := &bytes.Buffer{} 93 wr := recordio.NewWriter(buf, recordio.WriterOpts{Marshal: marshalString}) 94 wr.Flush() 95 assert.NoError(t, wr.Finish()) 96 header, body, trailer := readAllV2(t, buf) 97 assert.EQ(t, recordio.ParsedHeader(nil), header) 98 assert.EQ(t, []string(nil), body) 99 assert.EQ(t, "", trailer) 100 } 101 102 func TestV2NonEmptyHeaderEmptyBody(t *testing.T) { 103 buf := &bytes.Buffer{} 104 wr := recordio.NewWriter(buf, recordio.WriterOpts{Marshal: marshalString}) 105 wr.AddHeader("Foo", "Hah") 106 assert.NoError(t, wr.Finish()) 107 assert.EQ(t, len(buf.Bytes()), internal.ChunkSize) // one header chunk 108 header, body, trailer := readAllV2(t, buf) 109 assert.EQ(t, recordio.ParsedHeader{recordio.KeyValue{"Foo", "Hah"}}, header) 110 assert.EQ(t, []string(nil), body) 111 assert.EQ(t, "", trailer) 112 } 113 114 func TestV2EmptyBodyNonEmptyTrailer(t *testing.T) { 115 buf := &bytes.Buffer{} 116 wr := recordio.NewWriter(buf, recordio.WriterOpts{Marshal: marshalString, KeyTrailer: true}) 117 wr.SetTrailer([]byte("TTT")) 118 assert.NoError(t, wr.Finish()) 119 assert.EQ(t, len(buf.Bytes()), 2*internal.ChunkSize) // header+trailer 120 header, body, trailer := readAllV2(t, buf) 121 assert.EQ(t, recordio.ParsedHeader{recordio.KeyValue{recordio.KeyTrailer, true}}, header) 122 assert.EQ(t, []string(nil), body) 123 assert.EQ(t, "TTT", trailer) 124 } 125 126 func TestV2LargeTrailer(t *testing.T) { 127 buf := &bytes.Buffer{} 128 wr := recordio.NewWriter(buf, recordio.WriterOpts{Marshal: marshalString, KeyTrailer: true}) 129 wr.Append("XX") 130 131 rnd := rand.New(rand.NewSource(0)) 132 largeData := randomString(internal.ChunkSize*10+100, rnd) 133 wr.SetTrailer([]byte(largeData)) 134 assert.NoError(t, wr.Finish()) 135 header, body, trailer := readAllV2(t, buf) 136 assert.EQ(t, recordio.ParsedHeader{recordio.KeyValue{recordio.KeyTrailer, true}}, header) 137 assert.EQ(t, []string{"XX"}, body) 138 assert.EQ(t, largeData, trailer) 139 } 140 141 func TestV2WriteRead(t *testing.T) { 142 buf := &bytes.Buffer{} 143 144 index := make(map[string]recordio.ItemLocation) 145 wr := recordio.NewWriter(buf, recordio.WriterOpts{ 146 Marshal: marshalString, 147 Index: func(loc recordio.ItemLocation, v interface{}) error { 148 index[v.(string)] = loc 149 return nil 150 }, 151 KeyTrailer: true, 152 }) 153 wr.AddHeader("hh0", "vv0") 154 wr.AddHeader("hh1", 12345) 155 wr.AddHeader("hh2", uint16(234)) 156 wr.Append("F0") 157 wr.Append("F1") 158 wr.Flush() 159 wr.Append("F2") 160 wr.Flush() 161 wr.Append("F3") 162 wr.SetTrailer([]byte("Trailer2")) 163 assert.NoError(t, wr.Finish()) 164 165 header, body, trailer := readAllV2(t, buf) 166 expect.EQ(t, recordio.ParsedHeader{ 167 recordio.KeyValue{"trailer", true}, 168 recordio.KeyValue{"hh0", "vv0"}, 169 recordio.KeyValue{"hh1", int64(12345)}, 170 recordio.KeyValue{"hh2", uint64(234)}, 171 }, header) 172 expect.EQ(t, trailer, "Trailer2") 173 expect.EQ(t, body, []string{"F0", "F1", "F2", "F3"}) 174 175 // Test seeking 176 expect.EQ(t, len(index), 4) 177 sc := recordio.NewScanner(bytes.NewReader(buf.Bytes()), recordio.ScannerOpts{ 178 Unmarshal: unmarshalString, 179 }) 180 181 for _, value := range body { 182 loc := index[value] 183 sc.Seek(loc) 184 expect.NoError(t, sc.Err()) 185 expect.True(t, sc.Scan()) 186 expect.EQ(t, sc.Get().(string), value) 187 } 188 } 189 190 func TestV2RestartWithSkipHeader(t *testing.T) { 191 ogBuf := &bytes.Buffer{} 192 index := make(map[string]recordio.ItemLocation) 193 194 writerOpts := recordio.WriterOpts{ 195 Marshal: marshalString, 196 Index: func(loc recordio.ItemLocation, v interface{}) error { 197 index[v.(string)] = loc 198 return nil 199 }, 200 KeyTrailer: true, 201 } 202 203 wr := recordio.NewWriter(ogBuf, writerOpts) 204 wr.AddHeader("hh0", "vv0") 205 wr.AddHeader("hh1", 12345) 206 wr.AddHeader("hh2", uint16(234)) 207 wr.Append("F0") 208 wr.Append("F1") 209 wr.Flush() 210 wr.Append("F2") 211 wr.Flush() 212 wr.Wait() 213 214 bytesWrittenSoFar := uint64(32768 * 3) // 3 blocks have been written, 1 for header, 2 for data 215 216 writerOpts.Index = func(loc recordio.ItemLocation, v interface{}) error { 217 loc.Block += bytesWrittenSoFar 218 index[v.(string)] = loc 219 return nil 220 } 221 writerOpts.SkipHeader = true 222 223 // new buffer with the originally written bytes pre-populated 224 restartBuf := bytes.NewBuffer(ogBuf.Bytes()) 225 restartWriter := recordio.NewWriter(restartBuf, writerOpts) 226 227 restartWriter.Append("F3") 228 restartWriter.SetTrailer([]byte("Trailer2")) 229 assert.NoError(t, restartWriter.Finish()) 230 231 header, body, trailer := readAllV2(t, restartBuf) 232 expect.EQ(t, recordio.ParsedHeader{ 233 recordio.KeyValue{"trailer", true}, 234 recordio.KeyValue{"hh0", "vv0"}, 235 recordio.KeyValue{"hh1", int64(12345)}, 236 recordio.KeyValue{"hh2", uint64(234)}, 237 }, header) 238 expect.EQ(t, trailer, "Trailer2") 239 expect.EQ(t, body, []string{"F0", "F1", "F2", "F3"}) 240 241 // Test seeking 242 expect.EQ(t, len(index), 4) 243 sc := recordio.NewScanner(bytes.NewReader(restartBuf.Bytes()), recordio.ScannerOpts{ 244 Unmarshal: unmarshalString, 245 }) 246 247 for _, value := range body { 248 loc := index[value] 249 sc.Seek(loc) 250 expect.NoError(t, sc.Err()) 251 expect.True(t, sc.Scan()) 252 expect.EQ(t, sc.Get().(string), value) 253 } 254 } 255 256 func TestV2NonExistentTransformer(t *testing.T) { 257 buf := &bytes.Buffer{} 258 wr := recordio.NewWriter(buf, recordio.WriterOpts{ 259 Marshal: marshalString, 260 Transformers: []string{"nonexistent"}, 261 }) 262 for i := 0; i < 1000; i++ { 263 wr.Append("data") 264 wr.Flush() 265 } 266 wr.Finish() 267 assert.Regexp(t, wr.Err(), "Transformer .* not found") 268 } 269 270 func TestV2TransformerError(t *testing.T) { 271 // A transformer that adds N to every byte. 272 recordio.RegisterTransformer("error", 273 func(config string) (recordio.TransformFunc, error) { 274 return func(scratch []byte, in [][]byte) ([]byte, error) { 275 return nil, fmt.Errorf("synthetic transformer error") 276 }, nil 277 }, 278 func(config string) (recordio.TransformFunc, error) { 279 t.Fail() 280 return nil, nil 281 }) 282 buf := &bytes.Buffer{} 283 wr := recordio.NewWriter(buf, recordio.WriterOpts{ 284 Marshal: marshalString, 285 Transformers: []string{"error"}, 286 }) 287 wr.Append("data") 288 wr.Finish() 289 assert.Regexp(t, wr.Err(), "synthetic transformer error") 290 } 291 292 func getBytewiseTransformFunc() func(scratch []byte, in [][]byte, tr func(uint8) uint8) ([]byte, error) { 293 return func(scratch []byte, in [][]byte, tr func(uint8) uint8) ([]byte, error) { 294 nBytes := recordioiov.TotalBytes(in) 295 out := recordioiov.Slice(scratch, nBytes) 296 n := 0 297 for _, buf := range in { 298 for i := range buf { 299 out[n] = tr(buf[i]) 300 n++ 301 } 302 } 303 return out, nil 304 } 305 } 306 307 func TestV2Transformer(t *testing.T) { 308 bytewiseTransform := getBytewiseTransformFunc() 309 var nPlus, nMinus, nXor int32 310 311 // A transformer that adds N to every byte. 312 recordio.RegisterTransformer("testplus", 313 func(config string) (recordio.TransformFunc, error) { 314 delta, err := strconv.Atoi(config) 315 if err != nil { 316 return nil, err 317 } 318 return func(scratch []byte, in [][]byte) ([]byte, error) { 319 atomic.AddInt32(&nPlus, 1) 320 return bytewiseTransform(scratch, in, func(b uint8) uint8 { return b + uint8(delta) }) 321 }, nil 322 }, 323 func(config string) (recordio.TransformFunc, error) { 324 delta, err := strconv.Atoi(config) 325 if err != nil { 326 return nil, err 327 } 328 return func(scratch []byte, in [][]byte) ([]byte, error) { 329 atomic.AddInt32(&nMinus, 1) 330 return bytewiseTransform(scratch, in, func(b uint8) uint8 { return b - uint8(delta) }) 331 }, nil 332 }) 333 334 // A transformer that xors every byte. 335 xorTransformerFactory := func(config string) (recordio.TransformFunc, error) { 336 delta, err := strconv.Atoi(config) 337 if err != nil { 338 return nil, err 339 } 340 return func(scratch []byte, in [][]byte) ([]byte, error) { 341 atomic.AddInt32(&nXor, 1) 342 return bytewiseTransform(scratch, in, func(b uint8) uint8 { return b ^ uint8(delta) }) 343 }, nil 344 } 345 recordio.RegisterTransformer("testxor", xorTransformerFactory, xorTransformerFactory) 346 347 buf := &bytes.Buffer{} 348 wr := recordio.NewWriter(buf, recordio.WriterOpts{ 349 Marshal: marshalString, 350 Transformers: []string{"testplus 3", "testxor 111"}, 351 KeyTrailer: true, 352 }) 353 354 wr.Append("F0") 355 wr.Append("F1") 356 wr.Flush() 357 wr.Append("F2") 358 wr.SetTrailer([]byte("Trailer2")) 359 assert.NoError(t, wr.Finish()) 360 assert.EQ(t, nPlus, int32(3)) // two data + one trailer block 361 assert.EQ(t, nXor, int32(3)) 362 363 header, body, _ := readAllV2(t, buf) 364 expect.EQ(t, recordio.ParsedHeader{ 365 recordio.KeyValue{"transformer", "testplus 3"}, 366 recordio.KeyValue{"transformer", "testxor 111"}, 367 recordio.KeyValue{"trailer", true}, 368 }, header) 369 expect.EQ(t, body, []string{"F0", "F1", "F2"}) 370 assert.EQ(t, nPlus, int32(3)) 371 assert.EQ(t, nXor, int32(6)) 372 } 373 374 func TestV2TransformerWithRestart(t *testing.T) { 375 bytewiseTransform := getBytewiseTransformFunc() 376 var nPlus, nMinus, nXor int32 377 378 // A transformer that adds N to every byte. 379 recordio.RegisterTransformer("restart-testplus", 380 func(config string) (recordio.TransformFunc, error) { 381 delta, err := strconv.Atoi(config) 382 if err != nil { 383 return nil, err 384 } 385 return func(scratch []byte, in [][]byte) ([]byte, error) { 386 atomic.AddInt32(&nPlus, 1) 387 return bytewiseTransform(scratch, in, func(b uint8) uint8 { return b + uint8(delta) }) 388 }, nil 389 }, 390 func(config string) (recordio.TransformFunc, error) { 391 delta, err := strconv.Atoi(config) 392 if err != nil { 393 return nil, err 394 } 395 return func(scratch []byte, in [][]byte) ([]byte, error) { 396 atomic.AddInt32(&nMinus, 1) 397 return bytewiseTransform(scratch, in, func(b uint8) uint8 { return b - uint8(delta) }) 398 }, nil 399 }) 400 401 // A transformer that xors every byte. 402 xorTransformerFactory := func(config string) (recordio.TransformFunc, error) { 403 delta, err := strconv.Atoi(config) 404 if err != nil { 405 return nil, err 406 } 407 return func(scratch []byte, in [][]byte) ([]byte, error) { 408 atomic.AddInt32(&nXor, 1) 409 return bytewiseTransform(scratch, in, func(b uint8) uint8 { return b ^ uint8(delta) }) 410 }, nil 411 } 412 recordio.RegisterTransformer("restart-testxor", xorTransformerFactory, xorTransformerFactory) 413 414 ogBuf := &bytes.Buffer{} 415 writerOpts := recordio.WriterOpts{ 416 Marshal: marshalString, 417 Transformers: []string{"restart-testplus 3", "restart-testxor 111"}, 418 KeyTrailer: true, 419 } 420 wr := recordio.NewWriter(ogBuf, writerOpts) 421 422 wr.Append("F0") 423 wr.Append("F1") 424 wr.Flush() 425 wr.Wait() 426 427 restartBuf := bytes.NewBuffer(ogBuf.Bytes()) 428 429 writerOpts.SkipHeader = true 430 wr = recordio.NewWriter(restartBuf, writerOpts) 431 432 wr.Append("F2") 433 wr.SetTrailer([]byte("Trailer2")) 434 assert.NoError(t, wr.Finish()) 435 436 assert.EQ(t, nPlus, int32(3)) // two data + one trailer block 437 assert.EQ(t, nXor, int32(3)) 438 439 header, body, _ := readAllV2(t, restartBuf) 440 expect.EQ(t, recordio.ParsedHeader{ 441 recordio.KeyValue{"transformer", "restart-testplus 3"}, 442 recordio.KeyValue{"transformer", "restart-testxor 111"}, 443 recordio.KeyValue{"trailer", true}, 444 }, header) 445 expect.EQ(t, body, []string{"F0", "F1", "F2"}) 446 assert.EQ(t, nPlus, int32(3)) 447 assert.EQ(t, nXor, int32(6)) 448 } 449 450 func randomString(n int, r *rand.Rand) string { 451 buf := make([]byte, n) 452 for i := 0; i < n; i++ { 453 buf[i] = uint8('A' + r.Intn(64)) 454 } 455 return string(buf) 456 } 457 458 func generateRandomRecordio(t *testing.T, rnd *rand.Rand, flushProbability float64, nRecords, datasize int, wopts recordio.WriterOpts) ([]byte, []string, map[string]recordio.ItemLocation) { 459 buf := &bytes.Buffer{} 460 items := make([]string, nRecords) 461 index := make(map[string]recordio.ItemLocation) 462 wopts.Marshal = marshalString 463 wopts.Index = func(loc recordio.ItemLocation, v interface{}) error { 464 index[v.(string)] = loc 465 return nil 466 } 467 wr := recordio.NewWriter(buf, wopts) 468 wr.AddHeader(recordio.KeyTrailer, true) 469 for i := 0; i < nRecords; i++ { 470 data := randomString(rnd.Intn(datasize)+1, rnd) 471 wr.Append(data) 472 items[i] = data 473 if rnd.Float64() < flushProbability { 474 wr.Flush() 475 } 476 assert.NoError(t, wr.Err()) 477 } 478 wr.SetTrailer([]byte("Trailer")) 479 assert.NoError(t, wr.Finish()) 480 return buf.Bytes(), items, index 481 } 482 483 func doShardedReads(t *testing.T, data []byte, stride, nshard int, items []string) int { 484 expected := items 485 var maxShardSize int 486 for shard := 0; shard < nshard; shard += stride { 487 limit := shard + stride 488 if limit > nshard { 489 limit = nshard 490 } 491 ropts := recordio.ScannerOpts{Unmarshal: unmarshalString} 492 sc := recordio.NewShardScanner(bytes.NewReader(data), ropts, shard, limit, nshard) 493 assert.EQ(t, "Trailer", string(sc.Trailer()), "Error: %v, shard %d/%d", sc.Err(), shard, nshard) 494 shardSize := 0 495 i := 0 496 for sc.Scan() { 497 assert.EQ(t, expected[0], sc.Get().(string), "i=%d, err %v, shard %d/%d", i, sc.Err(), shard, nshard) 498 expected = expected[1:] 499 shardSize++ 500 i++ 501 } 502 assert.NoError(t, sc.Err()) 503 if shardSize > maxShardSize { 504 maxShardSize = shardSize 505 } 506 } 507 assert.EQ(t, 0, len(expected)) 508 return maxShardSize 509 } 510 511 func doRandomTest( 512 t *testing.T, 513 seed int64, 514 flushProbability float64, 515 nshard int, 516 maxrecords int, 517 datasize int, 518 wopts recordio.WriterOpts) { 519 t.Run("r", func(t *testing.T) { 520 t.Parallel() 521 t.Logf("Start test with wopt %+v, nshards %d, maxrecords %d, datasize %d", wopts, nshard, maxrecords, datasize) 522 523 rnd := rand.New(rand.NewSource(seed)) 524 var nRecords int 525 if maxrecords > 0 { 526 nRecords = rnd.Intn(maxrecords) + 1 527 } 528 data, items, index := generateRandomRecordio(t, rnd, flushProbability, nRecords, datasize, wopts) 529 530 doShardedReads(t, data, 1, nshard, items) 531 532 ropts := recordio.ScannerOpts{Unmarshal: unmarshalString} 533 sc := recordio.NewScanner(bytes.NewReader(data), ropts) 534 for _, value := range items { 535 loc := index[value] 536 sc.Seek(loc) 537 expect.NoError(t, sc.Err()) 538 expect.True(t, sc.Scan()) 539 expect.EQ(t, value, sc.Get().(string)) 540 } 541 }) 542 } 543 544 func TestV2Random(t *testing.T) { 545 const ( 546 maxrecords = 2000 547 datasize = 30 548 ) 549 for wo := 0; wo < 2; wo++ { 550 opts := recordio.WriterOpts{} 551 if wo == 1 { 552 opts.Transformers = []string{"zstd"} 553 } 554 doRandomTest(t, 0, 0.001, 2000, maxrecords, 10<<10, opts) 555 doRandomTest(t, 0, 0.1, 1, maxrecords, datasize, opts) 556 doRandomTest(t, 0, 1.0, 1, maxrecords, datasize, opts) 557 doRandomTest(t, 0, 0.0, 1, maxrecords, datasize, opts) 558 559 opts.MaxFlushParallelism = 1 560 doRandomTest(t, 0, 0.1, 1, maxrecords, datasize, opts) 561 opts.MaxFlushParallelism = 0 562 doRandomTest(t, 0, 0.1, 1000, maxrecords, datasize, opts) 563 doRandomTest(t, 0, 1.0, 3, maxrecords, 30, opts) 564 doRandomTest(t, 0, 0.0, 2, maxrecords, 30, opts) 565 // Make sure we generate blocks big enough so that 566 // shards have to straddle block boundaries. 567 // Make sure that lots of shards with a single record reads correctly. 568 doRandomTest(t, 0, 0.001, 2000, 1, datasize, opts) 569 // Same with an empty recordio file. 570 doRandomTest(t, 0, 0.001, 2000, 0, datasize, opts) 571 } 572 } 573 574 func TestRandomLargeWrites(t *testing.T) { 575 rnd := rand.New(rand.NewSource(0)) 576 577 nRecords := 100000 578 data, items, _ := generateRandomRecordio(t, rnd, 0.01, nRecords, 1024, recordio.WriterOpts{}) 579 580 nShards := 10 581 maxShardSize := doShardedReads(t, data, 1, nShards, items) 582 assert.GT(t, maxShardSize, 8000, "max %d, nshard %d nRecords %d", maxShardSize, nShards, nRecords) 583 assert.LT(t, maxShardSize, 12000, "max %d, nshard %d nRecords %d", maxShardSize, nShards, nRecords) 584 585 // Use the same sharding, but use a large absolute shard value to detect possible rounding errors. 586 nShards = 1000000000 587 stride := nShards / 10 588 maxShardSize = doShardedReads(t, data, stride, nShards, items) 589 assert.GT(t, maxShardSize, 8000, "max %d, nshard %d nRecords %d", maxShardSize, nShards, nRecords) 590 assert.LT(t, maxShardSize, 12000, "max %d, nshard %d nRecords %d", maxShardSize, nShards, nRecords) 591 } 592 593 // TODO: test seeking to bogus location. 594 595 // TODO: test flushing with no data. 596 597 // TODO: benchmark