github.com/grailbio/base@v0.0.11/recordio/recordio_test.go (about) 1 // Copyright 2017 GRAIL, Inc. All rights reserved. 2 // Use of this source code is governed by the Apache-2.0 3 // license that can be found in the LICENSE file. 4 5 package recordio_test 6 7 import ( 8 "flag" 9 "os" 10 "testing" 11 12 "github.com/grailbio/base/grail" 13 "github.com/grailbio/base/log" 14 "github.com/grailbio/base/recordio" 15 "github.com/grailbio/base/recordio/deprecated" 16 "github.com/grailbio/base/recordio/recordioflate" 17 "github.com/grailbio/base/recordio/recordioutil" 18 "github.com/grailbio/base/recordio/recordiozstd" 19 "github.com/grailbio/testutil/assert" 20 ) 21 22 var ( 23 pathFlag = flag.String("path", "/tmp/test.recordio", "Recordio file to use during benchmarking") 24 numRecordsFlag = flag.Int("num-records", 3, "Number of records to write") 25 recordSizeFlag = flag.Int("record-size", 64, "Byte size of each record") 26 recordsPerBlockFlag = flag.Int("records-per-block", 1024, "Number of records per block") 27 fileVersionFlag = flag.Int("file", 2, "recordio version") 28 trailerFlag = flag.Bool("trailer", true, "Add trailer") 29 compressFlag = flag.String("compress", "zstd", "Compress blocks using the given transformer") 30 packFlag = flag.Bool("pack", true, "Pack (legcy) items") 31 ) 32 33 var recordTemplate []byte 34 35 func generateRecord(length, seed int) []byte { 36 if len(recordTemplate) < length*8 { 37 recordTemplate = make([]byte, length*8) 38 for i := 0; i < len(recordTemplate); i++ { 39 recordTemplate[i] = byte('0' + (i % 64)) 40 } 41 } 42 startIndex := seed % (len(recordTemplate) - length + 1) 43 return recordTemplate[startIndex : startIndex+length] 44 } 45 46 func init() { 47 recordiozstd.Init() 48 recordioflate.Init() 49 } 50 51 func BenchmarkRead(b *testing.B) { 52 if *pathFlag == "" { 53 b.Skip("--path is empty") 54 return 55 } 56 var nRecords, nBytes int64 57 for i := 0; i < b.N; i++ { 58 nRecords, nBytes = 0, 0 59 in, err := os.Open(*pathFlag) 60 assert.NoError(b, err) 61 r := recordio.NewScanner(in, recordio.ScannerOpts{}) 62 for r.Scan() { 63 nBytes += int64(len(r.Get().([]byte))) 64 nRecords++ 65 } 66 assert.NoError(b, r.Err()) 67 assert.NoError(b, in.Close()) 68 } 69 b.Logf("Read %d records, %d bytes (%f bytes/record)", 70 nRecords, nBytes, float64(nBytes)/float64(nRecords)) 71 } 72 73 func BenchmarkWrite(b *testing.B) { 74 if *pathFlag == "" { 75 b.Skip("--path is empty") 76 return 77 } 78 for i := 0; i < b.N; i++ { 79 out, err := os.Create(*pathFlag) 80 assert.NoError(b, err) 81 82 switch { 83 case *fileVersionFlag == 2: 84 opts := recordio.WriterOpts{} 85 if i == 0 { 86 opts.Index = func(loc recordio.ItemLocation, v interface{}) error { 87 log.Debug.Printf("Index: item %v, loc %v", string(v.([]uint8)), loc) 88 return nil 89 } 90 } 91 if *compressFlag != "" { 92 opts.Transformers = []string{*compressFlag} 93 } 94 rw := recordio.NewWriter(out, opts) 95 rw.AddHeader("intflag", 12345) 96 rw.AddHeader("uintflag", uint64(12345)) 97 rw.AddHeader("strflag", "Hello") 98 rw.AddHeader("boolflag", true) 99 if *trailerFlag { 100 rw.AddHeader(recordio.KeyTrailer, true) 101 } 102 for j := 0; j < *numRecordsFlag; j++ { 103 rw.Append(generateRecord(*recordSizeFlag, j)) 104 if j%*recordsPerBlockFlag == *recordsPerBlockFlag-1 { 105 rw.Flush() 106 } 107 } 108 if *trailerFlag { 109 rw.SetTrailer([]byte("Trailer")) 110 } 111 assert.NoError(b, rw.Finish()) 112 case *fileVersionFlag == 1 && !*packFlag: 113 opts := deprecated.LegacyWriterOpts{} 114 if *compressFlag != "" { 115 panic("Legacy unpacked format does not support --compress flag") 116 } 117 rw := deprecated.NewLegacyWriter(out, opts) 118 for j := 0; j < *numRecordsFlag; j++ { 119 n, err := rw.Write(generateRecord(*recordSizeFlag, j)) 120 assert.NoError(b, err) 121 assert.EQ(b, n, *recordSizeFlag) 122 } 123 case *fileVersionFlag == 1 && *packFlag: 124 opts := deprecated.LegacyPackedWriterOpts{} 125 switch *compressFlag { 126 case "": 127 case "flate": 128 opts.Transform = recordioutil.NewFlateTransform(-1).CompressTransform 129 default: 130 panic(*compressFlag) 131 } 132 rw := deprecated.NewLegacyPackedWriter(out, opts) 133 for j := 0; j < *numRecordsFlag; j++ { 134 n, err := rw.Write(generateRecord(*recordSizeFlag, j)) 135 assert.NoError(b, err) 136 assert.EQ(b, n, *recordSizeFlag) 137 if j%*recordsPerBlockFlag == *recordsPerBlockFlag-1 { 138 log.Printf("FLUSH! %d", *recordsPerBlockFlag) 139 assert.NoError(b, rw.Flush()) 140 } 141 } 142 assert.NoError(b, rw.Flush()) 143 default: 144 panic(*fileVersionFlag) 145 } 146 assert.NoError(b, out.Close()) 147 } 148 } 149 150 func TestMain(m *testing.M) { 151 shutdown := grail.Init() 152 status := m.Run() 153 shutdown() 154 os.Exit(status) 155 }