github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/ccl/importccl/read_import_avro_test.go (about)

     1  // Copyright 2019 The Cockroach Authors.
     2  //
     3  // Licensed as a CockroachDB Enterprise file under the Cockroach Community
     4  // License (the "License"); you may not use this file except in compliance with
     5  // the License. You may obtain a copy of the License at
     6  //
     7  //     https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt
     8  
     9  package importccl
    10  
    11  import (
    12  	"bytes"
    13  	"context"
    14  	"encoding/json"
    15  	"fmt"
    16  	"io/ioutil"
    17  	"math/rand"
    18  	"os"
    19  	"testing"
    20  
    21  	"github.com/cockroachdb/cockroach/pkg/roachpb"
    22  	"github.com/cockroachdb/cockroach/pkg/settings/cluster"
    23  	"github.com/cockroachdb/cockroach/pkg/sql/parser"
    24  	"github.com/cockroachdb/cockroach/pkg/sql/row"
    25  	"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
    26  	"github.com/cockroachdb/cockroach/pkg/sql/sqlbase"
    27  	"github.com/cockroachdb/cockroach/pkg/util/leaktest"
    28  	"github.com/linkedin/goavro"
    29  	"github.com/stretchr/testify/require"
    30  )
    31  
    32  // avroGen interface is an interface for generating avro test field.
    33  type avroGen interface {
    34  	Name() string
    35  	Gen() interface{}
    36  	AvroT() interface{} // nil if avro records should omit this field
    37  	SQLT() interface{}  // nil if this column should not be created
    38  }
    39  
    40  // Base type for avro data generators.
    41  type namedField struct {
    42  	name        string
    43  	excludeAvro bool
    44  	excludeSQL  bool
    45  }
    46  
    47  func (g *namedField) Name() string {
    48  	return g.name
    49  }
    50  
    51  // Generates nil or a string.
    52  type nilOrStrGen struct {
    53  	namedField
    54  }
    55  
    56  func (g *nilOrStrGen) Gen() interface{} {
    57  	id := rand.Int()
    58  	if id%2 == 0 {
    59  		return nil
    60  	}
    61  	return map[string]interface{}{"string": fmt.Sprintf("%s %d", g.name, id)}
    62  }
    63  
    64  func (g *nilOrStrGen) AvroT() interface{} {
    65  	if g.excludeAvro {
    66  		return nil
    67  	}
    68  	return []string{"null", "string"}
    69  }
    70  
    71  func (g *nilOrStrGen) SQLT() interface{} {
    72  	if g.excludeSQL {
    73  		return nil
    74  	}
    75  	return "string"
    76  }
    77  
    78  // Generates a sequence number
    79  type seqGen struct {
    80  	namedField
    81  	seq int
    82  }
    83  
    84  func (g *seqGen) Gen() interface{} {
    85  	g.seq++
    86  	return g.seq
    87  }
    88  
    89  func (g *seqGen) AvroT() interface{} {
    90  	if g.excludeAvro {
    91  		return nil
    92  	}
    93  	return "int"
    94  }
    95  
    96  func (g *seqGen) SQLT() interface{} {
    97  	if g.excludeSQL {
    98  		return nil
    99  	}
   100  	return "int"
   101  }
   102  
   103  // Generates array of integers (or nils)
   104  type intArrGen struct {
   105  	namedField
   106  }
   107  
   108  func (g *intArrGen) AvroT() interface{} {
   109  	return []interface{}{
   110  		// Each element is either a null or an array.
   111  		"null",
   112  		// And each array element is either a long or a null.
   113  		map[string]interface{}{"type": "array", "items": []string{"null", "long"}}}
   114  }
   115  
   116  func (g *intArrGen) SQLT() interface{} {
   117  	return "int[]"
   118  }
   119  
   120  func (g *intArrGen) Gen() interface{} {
   121  	id := rand.Int()
   122  	if id%2 == 0 {
   123  		return nil
   124  	}
   125  	var arr []interface{}
   126  	var val interface{}
   127  	// Generate few integers, with some nils thrown in for good measure.
   128  	for i := 0; i < 1+id%10; i++ {
   129  		if i%3 == 0 {
   130  			val = nil
   131  		} else {
   132  			val = map[string]interface{}{"long": i}
   133  		}
   134  		arr = append(arr, val)
   135  	}
   136  	return map[string]interface{}{"array": arr}
   137  }
   138  
   139  // A testHelper to generate avro data.
   140  type testHelper struct {
   141  	schemaJSON  string
   142  	schemaTable *sqlbase.TableDescriptor
   143  	codec       *goavro.Codec
   144  	gens        []avroGen
   145  	settings    *cluster.Settings
   146  	evalCtx     tree.EvalContext
   147  }
   148  
   149  var defaultGens = []avroGen{
   150  	&seqGen{namedField: namedField{name: "uid"}},
   151  	&nilOrStrGen{namedField{name: "uname"}},
   152  	&nilOrStrGen{namedField{name: "notes"}},
   153  }
   154  
   155  func newTestHelper(t *testing.T, gens ...avroGen) *testHelper {
   156  	if len(gens) == 0 {
   157  		gens = defaultGens
   158  	}
   159  
   160  	// Generate avro schema specification as well as CREATE TABLE statement
   161  	// based on the specified generators.
   162  	schema := map[string]interface{}{
   163  		"type": "record",
   164  		"name": "users",
   165  	}
   166  	var avroFields []map[string]interface{}
   167  	createStmt := "CREATE TABLE users ("
   168  
   169  	for i, gen := range gens {
   170  		avroT := gen.AvroT()
   171  		sqlT := gen.SQLT()
   172  		if avroT != nil {
   173  			avroFields = append(avroFields, map[string]interface{}{
   174  				"name": gen.Name(),
   175  				"type": avroT,
   176  			})
   177  		}
   178  
   179  		if sqlT != nil {
   180  			createStmt += fmt.Sprintf("%s %s", gen.Name(), sqlT)
   181  			if i < len(gens)-1 {
   182  				createStmt += ","
   183  			}
   184  		}
   185  	}
   186  
   187  	createStmt += ")"
   188  	schema["fields"] = avroFields
   189  	schemaJSON, err := json.Marshal(schema)
   190  	require.NoError(t, err)
   191  
   192  	codec, err := goavro.NewCodec(string(schemaJSON))
   193  	require.NoError(t, err)
   194  	st := cluster.MakeTestingClusterSettings()
   195  	evalCtx := tree.MakeTestingEvalContext(st)
   196  
   197  	return &testHelper{
   198  		schemaJSON:  string(schemaJSON),
   199  		schemaTable: descForTable(t, createStmt, 10, 20, NoFKs),
   200  		codec:       codec,
   201  		gens:        gens,
   202  		settings:    st,
   203  		evalCtx:     evalCtx,
   204  	}
   205  }
   206  
   207  type testRecordStream struct {
   208  	producer importRowProducer
   209  	consumer importRowConsumer
   210  	rowNum   int64
   211  	conv     *row.DatumRowConverter
   212  }
   213  
   214  // Combine Row() with FillDatums for error checking.
   215  func (t *testRecordStream) Row() error {
   216  	r, err := t.producer.Row()
   217  	if err == nil {
   218  		t.rowNum++
   219  		err = t.consumer.FillDatums(r, t.rowNum, t.conv)
   220  	}
   221  	return err
   222  }
   223  
   224  // Generates test data with the specified format and returns avroRowStream object.
   225  func (th *testHelper) newRecordStream(
   226  	t *testing.T, format roachpb.AvroOptions_Format, strict bool, numRecords int,
   227  ) *testRecordStream {
   228  	// Ensure datum converter doesn't flush (since
   229  	// we're using nil kv channel for this test).
   230  	defer row.TestingSetDatumRowConverterBatchSize(numRecords + 1)()
   231  
   232  	opts := roachpb.AvroOptions{
   233  		Format:     format,
   234  		StrictMode: strict,
   235  	}
   236  
   237  	records := bytes.NewBufferString("")
   238  	if format == roachpb.AvroOptions_OCF {
   239  		th.genOcfData(t, numRecords, records)
   240  	} else {
   241  		opts.RecordSeparator = '\n'
   242  		opts.SchemaJSON = th.schemaJSON
   243  		th.genRecordsData(t, format, numRecords, opts.RecordSeparator, records)
   244  	}
   245  
   246  	avro, err := newAvroInputReader(nil, th.schemaTable, opts, 0, 1, &th.evalCtx)
   247  	require.NoError(t, err)
   248  	producer, consumer, err := newImportAvroPipeline(avro, &fileReader{Reader: records})
   249  	require.NoError(t, err)
   250  
   251  	conv, err := row.NewDatumRowConverter(
   252  		context.Background(), th.schemaTable, nil, th.evalCtx.Copy(), nil)
   253  	require.NoError(t, err)
   254  	return &testRecordStream{
   255  		producer: producer,
   256  		consumer: consumer,
   257  		conv:     conv,
   258  	}
   259  }
   260  
   261  func (th *testHelper) genAvroRecord() interface{} {
   262  	rec := make(map[string]interface{})
   263  	for _, gen := range th.gens {
   264  		if gen.AvroT() != nil {
   265  			rec[gen.Name()] = gen.Gen()
   266  		}
   267  	}
   268  	return rec
   269  }
   270  
   271  // Generates OCF test data.
   272  func (th *testHelper) genOcfData(t *testing.T, numRecords int, records *bytes.Buffer) {
   273  	ocf, err := goavro.NewOCFWriter(goavro.OCFConfig{
   274  		W:      records,
   275  		Codec:  th.codec,
   276  		Schema: th.schemaJSON,
   277  	})
   278  
   279  	for i := 0; err == nil && i < numRecords; i++ {
   280  		err = ocf.Append([]interface{}{th.genAvroRecord()})
   281  	}
   282  	require.NoError(t, err)
   283  }
   284  
   285  // Generates test data with the specified format and returns avroRowStream object.
   286  func (th *testHelper) genRecordsData(
   287  	t *testing.T,
   288  	format roachpb.AvroOptions_Format,
   289  	numRecords int,
   290  	recSeparator rune,
   291  	records *bytes.Buffer,
   292  ) {
   293  	var data []byte
   294  	var err error
   295  
   296  	for i := 0; i < numRecords; i++ {
   297  		rec := th.genAvroRecord()
   298  
   299  		if format == roachpb.AvroOptions_JSON_RECORDS {
   300  			data, err = th.codec.TextualFromNative(nil, rec)
   301  		} else if format == roachpb.AvroOptions_BIN_RECORDS {
   302  			data, err = th.codec.BinaryFromNative(nil, rec)
   303  		} else {
   304  			t.Fatal("unexpected avro format")
   305  		}
   306  
   307  		require.NoError(t, err)
   308  
   309  		records.Write(data)
   310  		if recSeparator != 0 {
   311  			records.WriteRune(recSeparator)
   312  		}
   313  	}
   314  }
   315  
   316  func TestReadsAvroRecords(t *testing.T) {
   317  	defer leaktest.AfterTest(t)()
   318  	th := newTestHelper(t)
   319  
   320  	formats := []roachpb.AvroOptions_Format{
   321  		roachpb.AvroOptions_BIN_RECORDS,
   322  		roachpb.AvroOptions_JSON_RECORDS,
   323  	}
   324  
   325  	for _, format := range formats {
   326  		for _, readSize := range []int{1, 16, 33, 64, 1024} {
   327  			for _, skip := range []bool{false, true} {
   328  				t.Run(fmt.Sprintf("%v-%v-skip=%v", format, readSize, skip), func(t *testing.T) {
   329  					stream := th.newRecordStream(t, format, false, 10)
   330  					stream.producer.(*avroRecordStream).readSize = readSize
   331  
   332  					var rowIdx int64
   333  					for stream.producer.Scan() {
   334  						var err error
   335  						if skip {
   336  							err = stream.producer.Skip()
   337  						} else {
   338  							err = stream.Row()
   339  						}
   340  						require.NoError(t, err)
   341  						rowIdx++
   342  					}
   343  
   344  					require.NoError(t, stream.producer.Err())
   345  					require.EqualValues(t, 10, rowIdx)
   346  				})
   347  			}
   348  		}
   349  	}
   350  }
   351  
   352  func TestReadsAvroOcf(t *testing.T) {
   353  	defer leaktest.AfterTest(t)()
   354  	th := newTestHelper(t)
   355  
   356  	for _, skip := range []bool{false, true} {
   357  		t.Run(fmt.Sprintf("skip=%v", skip), func(t *testing.T) {
   358  			stream := th.newRecordStream(t, roachpb.AvroOptions_OCF, false, 10)
   359  			var rowIdx int64
   360  			for stream.producer.Scan() {
   361  				var err error
   362  				if skip {
   363  					err = stream.producer.Skip()
   364  				} else {
   365  					err = stream.Row()
   366  				}
   367  				require.NoError(t, err)
   368  				rowIdx++
   369  			}
   370  
   371  			require.NoError(t, stream.producer.Err())
   372  			require.EqualValues(t, 10, rowIdx)
   373  		})
   374  	}
   375  }
   376  
   377  func TestRelaxedAndStrictImport(t *testing.T) {
   378  	defer leaktest.AfterTest(t)()
   379  
   380  	tests := []struct {
   381  		name         string
   382  		strict       bool
   383  		excludeAvro  bool
   384  		excludeTable bool
   385  	}{
   386  		{"relaxed-tolerates-missing-fields", false, true, false},
   387  		{"relaxed-tolerates-extra-fields", false, false, true},
   388  		{"relaxed-tolerates-missing-or-extra-fields", false, true, true},
   389  		{"strict-returns-error-missing-fields", true, true, false},
   390  		{"strict-returns-error-extra-fields", true, false, true},
   391  		{"strict-returns-error-missing-or-extra-fields", true, true, true},
   392  	}
   393  
   394  	for f := range roachpb.AvroOptions_Format_name {
   395  		for _, test := range tests {
   396  			format := roachpb.AvroOptions_Format(f)
   397  			t.Run(fmt.Sprintf("%s-%s", format, test.name), func(t *testing.T) {
   398  				f1 := &seqGen{namedField: namedField{name: "f1"}}
   399  				f2 := &seqGen{namedField: namedField{name: "f2"}}
   400  				f1.excludeSQL = test.excludeTable
   401  				f2.excludeAvro = test.excludeAvro
   402  
   403  				th := newTestHelper(t, f1, f2)
   404  				stream := th.newRecordStream(t, format, test.strict, 1)
   405  
   406  				if !stream.producer.Scan() {
   407  					t.Fatal("expected a record, found none")
   408  				}
   409  				err := stream.Row()
   410  				if test.strict && err == nil {
   411  					t.Fatal("expected to fail, but alas")
   412  				}
   413  				if !test.strict && err != nil {
   414  					t.Fatal("expected to succeed, but alas;", err)
   415  				}
   416  			})
   417  		}
   418  	}
   419  }
   420  
   421  func TestHandlesArrayData(t *testing.T) {
   422  	defer leaktest.AfterTest(t)()
   423  	th := newTestHelper(t, &intArrGen{namedField{
   424  		name: "arr_of_ints",
   425  	}})
   426  
   427  	stream := th.newRecordStream(t, roachpb.AvroOptions_OCF, false, 10)
   428  	var rowIdx int64
   429  	for stream.producer.Scan() {
   430  		if err := stream.Row(); err != nil {
   431  			t.Fatal(err)
   432  		}
   433  		rowIdx++
   434  	}
   435  
   436  	require.NoError(t, stream.producer.Err())
   437  	require.EqualValues(t, 10, rowIdx)
   438  }
   439  
   440  type limitAvroStream struct {
   441  	avro       *avroInputReader
   442  	limit      int
   443  	readStream importRowProducer
   444  	input      *os.File
   445  	err        error
   446  }
   447  
   448  func (l *limitAvroStream) Skip() error {
   449  	return nil
   450  }
   451  
   452  func (l *limitAvroStream) Progress() float32 {
   453  	return 0
   454  }
   455  
   456  func (l *limitAvroStream) reopenStream() {
   457  	_, l.err = l.input.Seek(0, 0)
   458  	if l.err == nil {
   459  		producer, _, err := newImportAvroPipeline(l.avro, &fileReader{Reader: l.input})
   460  		l.err = err
   461  		l.readStream = producer
   462  	}
   463  }
   464  
   465  func (l *limitAvroStream) Scan() bool {
   466  	l.limit--
   467  	for l.limit >= 0 && l.err == nil {
   468  		if l.readStream == nil {
   469  			l.reopenStream()
   470  			if l.err != nil {
   471  				return false
   472  			}
   473  		}
   474  
   475  		if l.readStream.Scan() {
   476  			return true
   477  		}
   478  
   479  		// Force reopen the stream until we read enough data.
   480  		l.err = l.readStream.Err()
   481  		l.readStream = nil
   482  	}
   483  	return false
   484  }
   485  
   486  func (l *limitAvroStream) Err() error {
   487  	return l.err
   488  }
   489  
   490  func (l *limitAvroStream) Row() (interface{}, error) {
   491  	return l.readStream.Row()
   492  }
   493  
   494  var _ importRowProducer = &limitAvroStream{}
   495  
   496  // goos: darwin
   497  // goarch: amd64
   498  // pkg: github.com/cockroachdb/cockroach/pkg/ccl/importccl
   499  // BenchmarkOCFImport-16    	  500000	      2612 ns/op	  45.93 MB/s
   500  // BenchmarkOCFImport-16    	  500000	      2607 ns/op	  46.03 MB/s
   501  // BenchmarkOCFImport-16    	  500000	      2719 ns/op	  44.13 MB/s
   502  // BenchmarkOCFImport-16    	  500000	      2825 ns/op	  42.47 MB/s
   503  // BenchmarkOCFImport-16    	  500000	      2924 ns/op	  41.03 MB/s
   504  // BenchmarkOCFImport-16    	  500000	      2917 ns/op	  41.14 MB/s
   505  // BenchmarkOCFImport-16    	  500000	      2926 ns/op	  41.01 MB/s
   506  // BenchmarkOCFImport-16    	  500000	      2954 ns/op	  40.61 MB/s
   507  // BenchmarkOCFImport-16    	  500000	      2942 ns/op	  40.78 MB/s
   508  // BenchmarkOCFImport-16    	  500000	      2987 ns/op	  40.17 MB/s
   509  func BenchmarkOCFImport(b *testing.B) {
   510  	benchmarkAvroImport(b, roachpb.AvroOptions{
   511  		Format: roachpb.AvroOptions_OCF,
   512  	}, "testdata/avro/stock-10000.ocf")
   513  }
   514  
   515  // goos: darwin
   516  // goarch: amd64
   517  // pkg: github.com/cockroachdb/cockroach/pkg/ccl/importccl
   518  // BenchmarkBinaryJSONImport-16    	  500000	      3021 ns/op	  39.71 MB/s
   519  // BenchmarkBinaryJSONImport-16    	  500000	      2991 ns/op	  40.11 MB/s
   520  // BenchmarkBinaryJSONImport-16    	  500000	      3056 ns/op	  39.26 MB/s
   521  // BenchmarkBinaryJSONImport-16    	  500000	      3075 ns/op	  39.02 MB/s
   522  // BenchmarkBinaryJSONImport-16    	  500000	      3052 ns/op	  39.31 MB/s
   523  // BenchmarkBinaryJSONImport-16    	  500000	      3101 ns/op	  38.69 MB/s
   524  // BenchmarkBinaryJSONImport-16    	  500000	      3119 ns/op	  38.47 MB/s
   525  // BenchmarkBinaryJSONImport-16    	  500000	      3237 ns/op	  37.06 MB/s
   526  // BenchmarkBinaryJSONImport-16    	  500000	      3215 ns/op	  37.32 MB/s
   527  // BenchmarkBinaryJSONImport-16    	  500000	      3235 ns/op	  37.09 MB/s
   528  func BenchmarkBinaryJSONImport(b *testing.B) {
   529  	schemaBytes, err := ioutil.ReadFile("testdata/avro/stock-schema.json")
   530  	require.NoError(b, err)
   531  
   532  	benchmarkAvroImport(b, roachpb.AvroOptions{
   533  		Format:     roachpb.AvroOptions_BIN_RECORDS,
   534  		SchemaJSON: string(schemaBytes),
   535  	}, "testdata/avro/stock-10000.bjson")
   536  }
   537  
   538  func benchmarkAvroImport(b *testing.B, avroOpts roachpb.AvroOptions, testData string) {
   539  	ctx := context.Background()
   540  
   541  	b.SetBytes(120) // Raw input size. With 8 indexes, expect more on output side.
   542  
   543  	stmt, err := parser.ParseOne(`CREATE TABLE stock (
   544      s_i_id       integer       not null,
   545      s_w_id       integer       not null,
   546      s_quantity   integer,
   547      s_dist_01    char(24),
   548      s_dist_02    char(24),
   549      s_dist_03    char(24),
   550      s_dist_04    char(24),
   551      s_dist_05    char(24),
   552      s_dist_06    char(24),
   553      s_dist_07    char(24),
   554      s_dist_08    char(24),
   555      s_dist_09    char(24),
   556      s_dist_10    char(24),
   557      s_ytd        integer,
   558      s_order_cnt  integer,
   559      s_remote_cnt integer,
   560      s_data       varchar(50),
   561  		primary key (s_w_id, s_i_id),
   562      index stock_item_fk_idx (s_i_id))
   563    `)
   564  
   565  	require.NoError(b, err)
   566  
   567  	create := stmt.AST.(*tree.CreateTable)
   568  	st := cluster.MakeTestingClusterSettings()
   569  	evalCtx := tree.MakeTestingEvalContext(st)
   570  
   571  	tableDesc, err := MakeSimpleTableDescriptor(ctx, st, create, sqlbase.ID(100), sqlbase.ID(100), NoFKs, 1)
   572  	require.NoError(b, err)
   573  
   574  	kvCh := make(chan row.KVBatch)
   575  	// no-op drain kvs channel.
   576  	go func() {
   577  		for range kvCh {
   578  		}
   579  	}()
   580  
   581  	input, err := os.Open(testData)
   582  	require.NoError(b, err)
   583  
   584  	avro, err := newAvroInputReader(kvCh, tableDesc.TableDesc(), avroOpts, 0, 0, &evalCtx)
   585  	require.NoError(b, err)
   586  
   587  	limitStream := &limitAvroStream{
   588  		avro:  avro,
   589  		limit: b.N,
   590  		input: input,
   591  	}
   592  	_, consumer, err := newImportAvroPipeline(avro, &fileReader{Reader: input})
   593  	require.NoError(b, err)
   594  	b.ResetTimer()
   595  	require.NoError(
   596  		b, runParallelImport(ctx, avro.importContext, &importFileContext{}, limitStream, consumer))
   597  	close(kvCh)
   598  }