github.com/grailbio/base@v0.0.11/recordio/recordioutil/flags_test.go (about)

     1  // Copyright 2017 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package recordioutil_test
     6  
     7  import (
     8  	"flag"
     9  	"io/ioutil"
    10  	"os"
    11  	"path/filepath"
    12  	"testing"
    13  
    14  	"github.com/grailbio/base/recordio/recordioutil"
    15  	"github.com/grailbio/testutil"
    16  	"github.com/grailbio/testutil/expect"
    17  	"github.com/klauspost/compress/flate"
    18  )
    19  
    20  func testCL(flags *recordioutil.WriterFlags, args ...string) error {
    21  	fs := flag.NewFlagSet("test", flag.ContinueOnError)
    22  	recordioutil.RegisterWriterFlags(fs, flags)
    23  	return fs.Parse(args)
    24  }
    25  
    26  func TestFlags(t *testing.T) {
    27  	comp := func(level string, expected int) {
    28  		flags := &recordioutil.WriterFlags{}
    29  		expect.NoError(t, testCL(flags, "--recordio-compression-level="+level))
    30  		expect.EQ(t, flags.CompressionFlag.Level, expected)
    31  		expect.EQ(t, flags.CompressionFlag.String(), level)
    32  		expect.EQ(t, flags.CompressionFlag.Specified, true)
    33  	}
    34  	for _, c := range []struct {
    35  		cl string
    36  		v  int
    37  	}{
    38  		{"none", flate.NoCompression},
    39  		{"fastest", flate.BestSpeed},
    40  		{"best", flate.BestCompression},
    41  		{"default", flate.DefaultCompression},
    42  		{"huffman-only", flate.HuffmanOnly},
    43  	} {
    44  		comp(c.cl, c.v)
    45  	}
    46  
    47  	kd := `{"registry":"reg", "keyid":"ff"}`
    48  	tmpdir, cleanup := testutil.TempDir(t, "", "recorodioutil")
    49  	defer cleanup()
    50  	kf := filepath.Join(tmpdir, "kd")
    51  	if err := ioutil.WriteFile(kf, []byte(kd), os.FileMode(0777)); err != nil {
    52  		t.Fatal(err)
    53  	}
    54  
    55  	flags := &recordioutil.WriterFlags{}
    56  	flags.ItemsPerRecord = 10
    57  	if got, want := flags.ItemsPerRecord, uint(10); got != want {
    58  		t.Errorf("got %v, want %v", got, want)
    59  	}
    60  	expect.NoError(t, testCL(flags, "--recordio-compression-level=best", "--recordio-MiB-per-record=33", "--recordio-items-per-record=66"))
    61  	expect.EQ(t, flags.MegaBytesPerRecord, uint(33))
    62  	expect.EQ(t, flags.ItemsPerRecord, uint(66))
    63  	opts := recordioutil.WriterOptsFromFlags(flags)
    64  	expected := recordioutil.WriterOpts{
    65  		MaxItems:   66,
    66  		MaxBytes:   33 * 1024 * 1024,
    67  		FlateLevel: flate.BestCompression,
    68  	}
    69  	expect.EQ(t, opts, expected)
    70  }
    71  
    72  func TestFlagsErrors(t *testing.T) {
    73  	flags := &recordioutil.WriterFlags{}
    74  	err := testCL(flags, "--recordio-compression-level=x")
    75  	expect.HasSubstr(t, err, "unrecognised compression option")
    76  
    77  	defer func() {
    78  		if r := recover(); r != nil {
    79  			t.Logf("Recovered %v", r)
    80  		} else {
    81  			t.Fatal("failed to panic")
    82  		}
    83  	}()
    84  	flags.CompressionFlag.Level = 33
    85  	_ = flags.CompressionFlag.String()
    86  }