github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/recordio/recordio_test.go (about)

     1  // Copyright 2017 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package recordio_test
     6  
     7  import (
     8  	"flag"
     9  	"os"
    10  	"testing"
    11  
    12  	"github.com/Schaudge/grailbase/grail"
    13  	"github.com/Schaudge/grailbase/log"
    14  	"github.com/Schaudge/grailbase/recordio"
    15  	"github.com/Schaudge/grailbase/recordio/deprecated"
    16  	"github.com/Schaudge/grailbase/recordio/recordioflate"
    17  	"github.com/Schaudge/grailbase/recordio/recordioutil"
    18  	"github.com/Schaudge/grailbase/recordio/recordiozstd"
    19  	"github.com/grailbio/testutil/assert"
    20  )
    21  
    22  var (
    23  	pathFlag            = flag.String("path", "/tmp/test.recordio", "Recordio file to use during benchmarking")
    24  	numRecordsFlag      = flag.Int("num-records", 3, "Number of records to write")
    25  	recordSizeFlag      = flag.Int("record-size", 64, "Byte size of each record")
    26  	recordsPerBlockFlag = flag.Int("records-per-block", 1024, "Number of records per block")
    27  	fileVersionFlag     = flag.Int("file", 2, "recordio version")
    28  	trailerFlag         = flag.Bool("trailer", true, "Add trailer")
    29  	compressFlag        = flag.String("compress", "zstd", "Compress blocks using the given transformer")
    30  	packFlag            = flag.Bool("pack", true, "Pack (legcy) items")
    31  )
    32  
    33  var recordTemplate []byte
    34  
    35  func generateRecord(length, seed int) []byte {
    36  	if len(recordTemplate) < length*8 {
    37  		recordTemplate = make([]byte, length*8)
    38  		for i := 0; i < len(recordTemplate); i++ {
    39  			recordTemplate[i] = byte('0' + (i % 64))
    40  		}
    41  	}
    42  	startIndex := seed % (len(recordTemplate) - length + 1)
    43  	return recordTemplate[startIndex : startIndex+length]
    44  }
    45  
    46  func init() {
    47  	recordiozstd.Init()
    48  	recordioflate.Init()
    49  }
    50  
    51  func BenchmarkRead(b *testing.B) {
    52  	if *pathFlag == "" {
    53  		b.Skip("--path is empty")
    54  		return
    55  	}
    56  	var nRecords, nBytes int64
    57  	for i := 0; i < b.N; i++ {
    58  		nRecords, nBytes = 0, 0
    59  		in, err := os.Open(*pathFlag)
    60  		assert.NoError(b, err)
    61  		r := recordio.NewScanner(in, recordio.ScannerOpts{})
    62  		for r.Scan() {
    63  			nBytes += int64(len(r.Get().([]byte)))
    64  			nRecords++
    65  		}
    66  		assert.NoError(b, r.Err())
    67  		assert.NoError(b, in.Close())
    68  	}
    69  	b.Logf("Read %d records, %d bytes (%f bytes/record)",
    70  		nRecords, nBytes, float64(nBytes)/float64(nRecords))
    71  }
    72  
    73  func BenchmarkWrite(b *testing.B) {
    74  	if *pathFlag == "" {
    75  		b.Skip("--path is empty")
    76  		return
    77  	}
    78  	for i := 0; i < b.N; i++ {
    79  		out, err := os.Create(*pathFlag)
    80  		assert.NoError(b, err)
    81  
    82  		switch {
    83  		case *fileVersionFlag == 2:
    84  			opts := recordio.WriterOpts{}
    85  			if i == 0 {
    86  				opts.Index = func(loc recordio.ItemLocation, v interface{}) error {
    87  					log.Debug.Printf("Index: item %v, loc %v", string(v.([]uint8)), loc)
    88  					return nil
    89  				}
    90  			}
    91  			if *compressFlag != "" {
    92  				opts.Transformers = []string{*compressFlag}
    93  			}
    94  			rw := recordio.NewWriter(out, opts)
    95  			rw.AddHeader("intflag", 12345)
    96  			rw.AddHeader("uintflag", uint64(12345))
    97  			rw.AddHeader("strflag", "Hello")
    98  			rw.AddHeader("boolflag", true)
    99  			if *trailerFlag {
   100  				rw.AddHeader(recordio.KeyTrailer, true)
   101  			}
   102  			for j := 0; j < *numRecordsFlag; j++ {
   103  				rw.Append(generateRecord(*recordSizeFlag, j))
   104  				if j%*recordsPerBlockFlag == *recordsPerBlockFlag-1 {
   105  					rw.Flush()
   106  				}
   107  			}
   108  			if *trailerFlag {
   109  				rw.SetTrailer([]byte("Trailer"))
   110  			}
   111  			assert.NoError(b, rw.Finish())
   112  		case *fileVersionFlag == 1 && !*packFlag:
   113  			opts := deprecated.LegacyWriterOpts{}
   114  			if *compressFlag != "" {
   115  				panic("Legacy unpacked format does not support --compress flag")
   116  			}
   117  			rw := deprecated.NewLegacyWriter(out, opts)
   118  			for j := 0; j < *numRecordsFlag; j++ {
   119  				n, err := rw.Write(generateRecord(*recordSizeFlag, j))
   120  				assert.NoError(b, err)
   121  				assert.EQ(b, n, *recordSizeFlag)
   122  			}
   123  		case *fileVersionFlag == 1 && *packFlag:
   124  			opts := deprecated.LegacyPackedWriterOpts{}
   125  			switch *compressFlag {
   126  			case "":
   127  			case "flate":
   128  				opts.Transform = recordioutil.NewFlateTransform(-1).CompressTransform
   129  			default:
   130  				panic(*compressFlag)
   131  			}
   132  			rw := deprecated.NewLegacyPackedWriter(out, opts)
   133  			for j := 0; j < *numRecordsFlag; j++ {
   134  				n, err := rw.Write(generateRecord(*recordSizeFlag, j))
   135  				assert.NoError(b, err)
   136  				assert.EQ(b, n, *recordSizeFlag)
   137  				if j%*recordsPerBlockFlag == *recordsPerBlockFlag-1 {
   138  					log.Printf("FLUSH! %d", *recordsPerBlockFlag)
   139  					assert.NoError(b, rw.Flush())
   140  				}
   141  			}
   142  			assert.NoError(b, rw.Flush())
   143  		default:
   144  			panic(*fileVersionFlag)
   145  		}
   146  		assert.NoError(b, out.Close())
   147  	}
   148  }
   149  
   150  func TestMain(m *testing.M) {
   151  	shutdown := grail.Init()
   152  	status := m.Run()
   153  	shutdown()
   154  	os.Exit(status)
   155  }