github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/ccl/importccl/read_import_mysql_test.go (about)

     1  // Copyright 2018 The Cockroach Authors.
     2  //
     3  // Licensed as a CockroachDB Enterprise file under the Cockroach Community
     4  // License (the "License"); you may not use this file except in compliance with
     5  // the License. You may obtain a copy of the License at
     6  //
     7  //     https://github.com/cockroachdb/cockroach/blob/master/licenses/CCL.txt
     8  
     9  package importccl
    10  
    11  import (
    12  	"bytes"
    13  	"context"
    14  	"encoding/hex"
    15  	"fmt"
    16  	"io/ioutil"
    17  	"os"
    18  	"path/filepath"
    19  	"reflect"
    20  	"strings"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/cockroachdb/cockroach/pkg/sql/execinfrapb"
    25  	"github.com/cockroachdb/cockroach/pkg/sql/row"
    26  	"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
    27  	"github.com/cockroachdb/cockroach/pkg/sql/sqlbase"
    28  	"github.com/cockroachdb/cockroach/pkg/sql/types"
    29  	"github.com/cockroachdb/cockroach/pkg/util/leaktest"
    30  	"github.com/cockroachdb/cockroach/pkg/util/protoutil"
    31  	"github.com/kr/pretty"
    32  	mysql "vitess.io/vitess/go/vt/sqlparser"
    33  )
    34  
    35  func TestMysqldumpDataReader(t *testing.T) {
    36  	defer leaktest.AfterTest(t)()
    37  
    38  	files := getMysqldumpTestdata(t)
    39  
    40  	ctx := context.Background()
    41  	table := descForTable(t, `CREATE TABLE simple (i INT PRIMARY KEY, s text, b bytea)`, 10, 20, NoFKs)
    42  	tables := map[string]*execinfrapb.ReadImportDataSpec_ImportTable{"simple": {Desc: table}}
    43  
    44  	kvCh := make(chan row.KVBatch, 10)
    45  	converter, err := newMysqldumpReader(ctx, kvCh, tables, testEvalCtx)
    46  
    47  	if err != nil {
    48  		t.Fatal(err)
    49  	}
    50  
    51  	var res []tree.Datums
    52  	converter.debugRow = func(row tree.Datums) {
    53  		res = append(res, append(tree.Datums{}, row...))
    54  	}
    55  
    56  	in, err := os.Open(files.simple)
    57  	if err != nil {
    58  		t.Fatal(err)
    59  	}
    60  	defer in.Close()
    61  	wrapped := &fileReader{Reader: in, counter: byteCounter{r: in}}
    62  
    63  	if err := converter.readFile(ctx, wrapped, 1, 0, nil); err != nil {
    64  		t.Fatal(err)
    65  	}
    66  	close(kvCh)
    67  
    68  	if expected, actual := len(simpleTestRows), len(res); expected != actual {
    69  		t.Fatalf("expected %d rows, got %d: %v", expected, actual, res)
    70  	}
    71  	for i, expected := range simpleTestRows {
    72  		row := res[i]
    73  		if actual := *row[0].(*tree.DInt); expected.i != int(actual) {
    74  			t.Fatalf("row %d: expected i = %d, got %d", i, expected.i, actual)
    75  		}
    76  		if expected.s != injectNull {
    77  			if actual := *row[1].(*tree.DString); expected.s != string(actual) {
    78  				t.Fatalf("row %d: expected s = %q, got %q", i, expected.i, actual)
    79  			}
    80  		} else if row[1] != tree.DNull {
    81  			t.Fatalf("row %d: expected b = NULL, got %T: %v", i, row[1], row[1])
    82  		}
    83  		if expected.b != nil {
    84  			if actual := []byte(*row[2].(*tree.DBytes)); !bytes.Equal(expected.b, actual) {
    85  				t.Fatalf("row %d: expected b = %v, got %v", i, hex.EncodeToString(expected.b), hex.EncodeToString(actual))
    86  			}
    87  		} else if row[2] != tree.DNull {
    88  			t.Fatalf("row %d: expected b = NULL, got %T: %v", i, row[2], row[2])
    89  		}
    90  	}
    91  }
    92  
    93  const expectedParent = 52
    94  
    95  func readFile(t *testing.T, name string) string {
    96  	body, err := ioutil.ReadFile(filepath.Join("testdata", "mysqldump", name))
    97  	if err != nil {
    98  		t.Fatal(err)
    99  	}
   100  	return string(body)
   101  }
   102  
   103  func readMysqlCreateFrom(
   104  	t *testing.T, path, name string, id sqlbase.ID, fks fkHandler,
   105  ) *sqlbase.TableDescriptor {
   106  	t.Helper()
   107  	f, err := os.Open(path)
   108  	if err != nil {
   109  		t.Fatal(err)
   110  	}
   111  	defer f.Close()
   112  
   113  	tbl, err := readMysqlCreateTable(context.Background(), f, testEvalCtx, nil, id, expectedParent, name, fks, map[sqlbase.ID]int64{})
   114  	if err != nil {
   115  		t.Fatal(err)
   116  	}
   117  	return tbl[len(tbl)-1]
   118  }
   119  
   120  func TestMysqldumpSchemaReader(t *testing.T) {
   121  	defer leaktest.AfterTest(t)()
   122  
   123  	files := getMysqldumpTestdata(t)
   124  
   125  	simpleTable := descForTable(t, readFile(t, `simple.cockroach-schema.sql`), expectedParent, 52, NoFKs)
   126  	referencedSimple := descForTable(t, readFile(t, `simple.cockroach-schema.sql`), expectedParent, 52, NoFKs)
   127  	fks := fkHandler{
   128  		allowed:  true,
   129  		resolver: fkResolver(map[string]*sqlbase.MutableTableDescriptor{referencedSimple.Name: sqlbase.NewMutableCreatedTableDescriptor(*referencedSimple)}),
   130  	}
   131  
   132  	t.Run("simple", func(t *testing.T) {
   133  		expected := simpleTable
   134  		got := readMysqlCreateFrom(t, files.simple, "", 51, NoFKs)
   135  		compareTables(t, expected, got)
   136  	})
   137  
   138  	t.Run("second", func(t *testing.T) {
   139  		secondTable := descForTable(t, readFile(t, `second.cockroach-schema.sql`), expectedParent, 53, fks)
   140  		expected := secondTable
   141  		got := readMysqlCreateFrom(t, files.second, "", 53, fks)
   142  		compareTables(t, expected, got)
   143  	})
   144  
   145  	t.Run("everything", func(t *testing.T) {
   146  		expected := descForTable(t, readFile(t, `everything.cockroach-schema.sql`), expectedParent, 53, NoFKs)
   147  		got := readMysqlCreateFrom(t, files.everything, "", 53, NoFKs)
   148  		compareTables(t, expected, got)
   149  	})
   150  
   151  	t.Run("simple-in-multi", func(t *testing.T) {
   152  		expected := simpleTable
   153  		got := readMysqlCreateFrom(t, files.wholeDB, "simple", 51, NoFKs)
   154  		compareTables(t, expected, got)
   155  	})
   156  
   157  	t.Run("third-in-multi", func(t *testing.T) {
   158  		skip := fkHandler{allowed: true, skip: true, resolver: make(fkResolver)}
   159  		expected := descForTable(t, readFile(t, `third.cockroach-schema.sql`), expectedParent, 52, skip)
   160  		got := readMysqlCreateFrom(t, files.wholeDB, "third", 51, skip)
   161  		compareTables(t, expected, got)
   162  	})
   163  }
   164  
   165  func compareTables(t *testing.T, expected, got *sqlbase.TableDescriptor) {
   166  	colNames := func(cols []sqlbase.ColumnDescriptor) string {
   167  		names := make([]string, len(cols))
   168  		for i := range cols {
   169  			names[i] = cols[i].Name
   170  		}
   171  		return strings.Join(names, ", ")
   172  	}
   173  	idxNames := func(indexes []sqlbase.IndexDescriptor) string {
   174  		names := make([]string, len(indexes))
   175  		for i := range indexes {
   176  			names[i] = indexes[i].Name
   177  		}
   178  		return strings.Join(names, ", ")
   179  	}
   180  
   181  	// Attempt to verify the pieces individually, and return more helpful errors
   182  	// if an individual column or index does not match. If the pieces look right
   183  	// when compared individually, move on to compare the whole table desc as
   184  	// rendered to a string via `%+v`, as a more comprehensive check.
   185  
   186  	if expectedCols, gotCols := expected.Columns, got.Columns; len(gotCols) != len(expectedCols) {
   187  		t.Fatalf("expected columns (%d):\n%v\ngot columns (%d):\n%v\n",
   188  			len(expectedCols), colNames(expectedCols), len(gotCols), colNames(gotCols),
   189  		)
   190  	}
   191  	for i := range expected.Columns {
   192  		e, g := expected.Columns[i].SQLString(), got.Columns[i].SQLString()
   193  		if e != g {
   194  			t.Fatalf("column %d (%q): expected\n%s\ngot\n%s\n", i, expected.Columns[i].Name, e, g)
   195  		}
   196  	}
   197  
   198  	if expectedIdx, gotIdx := expected.Indexes, got.Indexes; len(expectedIdx) != len(gotIdx) {
   199  		t.Fatalf("expected indexes (%d):\n%v\ngot indexes (%d):\n%v\n",
   200  			len(expectedIdx), idxNames(expectedIdx), len(gotIdx), idxNames(gotIdx),
   201  		)
   202  	}
   203  	for i := range expected.Indexes {
   204  		tableName := &sqlbase.AnonymousTable
   205  		e, g := expected.Indexes[i].SQLString(tableName), got.Indexes[i].SQLString(tableName)
   206  		if e != g {
   207  			t.Fatalf("index %d: expected\n%s\ngot\n%s\n", i, e, g)
   208  		}
   209  	}
   210  
   211  	// Our attempts to check parts individually (and return readable errors if
   212  	// they didn't match) found nothing.
   213  	expectedBytes, err := protoutil.Marshal(expected)
   214  	if err != nil {
   215  		t.Fatal(err)
   216  	}
   217  
   218  	gotBytes, err := protoutil.Marshal(got)
   219  	if err != nil {
   220  		t.Fatal(err)
   221  	}
   222  	if !bytes.Equal(expectedBytes, gotBytes) {
   223  		t.Fatalf("expected\n%+v\n, got\n%+v\ndiff: %v", expected, got, pretty.Diff(expected, got))
   224  	}
   225  }
   226  
   227  func TestMysqlValueToDatum(t *testing.T) {
   228  	defer leaktest.AfterTest(t)()
   229  
   230  	date := func(s string) tree.Datum {
   231  		d, err := tree.ParseDDate(nil, s)
   232  		if err != nil {
   233  			t.Fatal(err)
   234  		}
   235  		return d
   236  	}
   237  	ts := func(s string) tree.Datum {
   238  		d, err := tree.ParseDTimestamp(nil, s, time.Microsecond)
   239  		if err != nil {
   240  			t.Fatal(err)
   241  		}
   242  		return d
   243  	}
   244  	tests := []struct {
   245  		raw  mysql.Expr
   246  		typ  *types.T
   247  		want tree.Datum
   248  	}{
   249  		{raw: mysql.NewStrVal([]byte("0000-00-00")), typ: types.Date, want: tree.DNull},
   250  		{raw: mysql.NewStrVal([]byte("2010-01-01")), typ: types.Date, want: date("2010-01-01")},
   251  		{raw: mysql.NewStrVal([]byte("0000-00-00 00:00:00")), typ: types.Timestamp, want: tree.DNull},
   252  		{raw: mysql.NewStrVal([]byte("2010-01-01 00:00:00")), typ: types.Timestamp, want: ts("2010-01-01 00:00:00")},
   253  	}
   254  	evalContext := tree.NewTestingEvalContext(nil)
   255  	for _, tc := range tests {
   256  		t.Run(fmt.Sprintf("%v", tc.raw), func(t *testing.T) {
   257  			got, err := mysqlValueToDatum(tc.raw, tc.typ, evalContext)
   258  			if err != nil {
   259  				t.Fatal(err)
   260  			}
   261  			if !reflect.DeepEqual(got, tc.want) {
   262  				t.Errorf("got %v, want %v", got, tc.want)
   263  			}
   264  		})
   265  	}
   266  }