github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/performance/import_benchmarker/testdef.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package import_benchmarker
    16  
    17  import (
    18  	"bufio"
    19  	"bytes"
    20  	"context"
    21  	"database/sql"
    22  	"fmt"
    23  	"math/rand"
    24  	"os"
    25  	"strconv"
    26  	"strings"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/cespare/xxhash"
    31  	"github.com/creasty/defaults"
    32  	sql2 "github.com/dolthub/go-mysql-server/sql"
    33  	gmstypes "github.com/dolthub/go-mysql-server/sql/types"
    34  	"github.com/dolthub/vitess/go/sqltypes"
    35  	ast "github.com/dolthub/vitess/go/vt/sqlparser"
    36  	"github.com/stretchr/testify/require"
    37  	yaml "gopkg.in/yaml.v3"
    38  
    39  	driver "github.com/dolthub/dolt/go/libraries/doltcore/dtestutils/sql_server_driver"
    40  )
    41  
    42  const defaultBatchSize = 500
    43  
    44  // TestDef is the top-level definition of tests to run.
    45  type TestDef struct {
    46  	Tests []ImportTest `yaml:"tests"`
    47  	Opts  *Opts        `yaml:"opts"`
    48  }
    49  
    50  type Opts struct {
    51  	Seed int `yaml:"seed"`
    52  }
    53  
    54  // ImportTest is a single test to run. The Repos and MultiRepos will be created, and
    55  // any Servers defined within them will be started. The interactions and
    56  // assertions defined in Conns will be run.
    57  type ImportTest struct {
    58  	Name   string            `yaml:"name"`
    59  	Repos  []driver.TestRepo `yaml:"repos"`
    60  	Tables []Table           `yaml:"tables"`
    61  
    62  	// Skip the entire test with this reason.
    63  	Skip string `yaml:"skip"`
    64  
    65  	Results *ImportResults
    66  	files   map[uint64]*os.File
    67  	tmpdir  string
    68  }
    69  
    70  type Table struct {
    71  	Name        string `yaml:"name"`
    72  	Schema      string `yaml:"schema"`
    73  	Rows        int    `default:"200000" yaml:"rows"`
    74  	Fmt         string `default:"csv" yaml:"fmt"`
    75  	Shuffle     bool   `default:"false" yaml:"shuffle"`
    76  	Batch       bool   `default:"false" yaml:"batch"`
    77  	TargetTable string
    78  }
    79  
    80  func (s *Table) UnmarshalYAML(unmarshal func(interface{}) error) error {
    81  	defaults.Set(s)
    82  
    83  	type plain Table
    84  	if err := unmarshal((*plain)(s)); err != nil {
    85  		return err
    86  	}
    87  
    88  	return nil
    89  }
    90  
    91  func ParseTestsFile(path string) (TestDef, error) {
    92  	contents, err := os.ReadFile(path)
    93  	if err != nil {
    94  		return TestDef{}, err
    95  	}
    96  	dec := yaml.NewDecoder(bytes.NewReader(contents))
    97  	dec.KnownFields(true)
    98  	var res TestDef
    99  	err = dec.Decode(&res)
   100  	return res, err
   101  }
   102  
   103  func MakeRepo(rs driver.RepoStore, r driver.TestRepo) (driver.Repo, error) {
   104  	repo, err := rs.MakeRepo(r.Name)
   105  	if err != nil {
   106  		return driver.Repo{}, err
   107  	}
   108  	return repo, nil
   109  }
   110  
   111  func MakeServer(dc driver.DoltCmdable, s *driver.Server) (*driver.SqlServer, error) {
   112  	if s == nil {
   113  		return nil, nil
   114  	}
   115  	opts := []driver.SqlServerOpt{driver.WithArgs(s.Args...)}
   116  	if s.Port != 0 {
   117  		opts = append(opts, driver.WithPort(s.Port))
   118  	}
   119  	server, err := driver.StartSqlServer(dc, opts...)
   120  	if err != nil {
   121  		return nil, err
   122  	}
   123  
   124  	return server, nil
   125  }
   126  
   127  type ImportResult struct {
   128  	detail string
   129  	server string
   130  	test   string
   131  	time   float64
   132  	rows   int
   133  	fmt    string
   134  	sorted bool
   135  	batch  bool
   136  }
   137  
   138  func (r ImportResult) String() string {
   139  	return fmt.Sprintf("- %s/%s/%s: %.2fs\n", r.test, r.server, r.detail, r.time)
   140  }
   141  
   142  type ImportResults struct {
   143  	res []ImportResult
   144  }
   145  
   146  func (r *ImportResults) append(ir ImportResult) {
   147  	r.res = append(r.res, ir)
   148  }
   149  
   150  func (r *ImportResults) String() string {
   151  	b := strings.Builder{}
   152  	b.WriteString("Results:\n")
   153  	for _, x := range r.res {
   154  		b.WriteString(x.String())
   155  	}
   156  	return b.String()
   157  }
   158  
   159  func (r *ImportResults) SqlDump() string {
   160  	b := strings.Builder{}
   161  	b.WriteString(`CREATE TABLE IF NOT EXISTS import_perf_results (
   162    test_name varchar(64),
   163    server varchar(64),
   164    detail varchar(64),
   165    row_cnt int,
   166    time double,
   167    file_format varchar(8),
   168    sorted bool,
   169    batch bool,
   170    primary key (test_name, detail, server)
   171  );
   172  `)
   173  
   174  	b.WriteString("insert into import_perf_results values\n")
   175  	for i, r := range r.res {
   176  		if i > 0 {
   177  			b.WriteString(",\n  ")
   178  		}
   179  		var sorted int
   180  		if r.sorted {
   181  			sorted = 1
   182  		}
   183  		var batch int
   184  		if r.batch {
   185  			batch = 1
   186  		}
   187  		b.WriteString(fmt.Sprintf(
   188  			"('%s', '%s', '%s', %d, %.2f, '%s', %b, %b)",
   189  			r.test, r.server, r.detail, r.rows, r.time, r.fmt, sorted, batch))
   190  	}
   191  	b.WriteString(";\n")
   192  
   193  	return b.String()
   194  }
   195  
   196  func (test *ImportTest) InitWithTmpDir(s string) {
   197  	test.tmpdir = s
   198  	test.files = make(map[uint64]*os.File)
   199  }
   200  
   201  // Run executes an import configuration. Test parallelism makes
   202  // runtimes resulting from this method unsuitable for reporting.
   203  func (test *ImportTest) Run(t *testing.T) {
   204  	if test.Skip != "" {
   205  		t.Skip(test.Skip)
   206  	}
   207  	var err error
   208  	if test.Results == nil {
   209  		test.Results = new(ImportResults)
   210  		tmp, err := os.MkdirTemp("", "repo-store-")
   211  		if err != nil {
   212  			require.NoError(t, err)
   213  		}
   214  		test.InitWithTmpDir(tmp)
   215  	}
   216  
   217  	u, err := driver.NewDoltUser()
   218  	for _, r := range test.Repos {
   219  		if r.ExternalServer != nil {
   220  			err := test.RunExternalServerTests(r.Name, r.ExternalServer)
   221  			require.NoError(t, err)
   222  		} else if r.Server != nil {
   223  			err = test.RunSqlServerTests(r, u)
   224  			require.NoError(t, err)
   225  		} else {
   226  			err = test.RunCliTests(r, u)
   227  			require.NoError(t, err)
   228  		}
   229  	}
   230  	fmt.Println(test.Results.String())
   231  }
   232  
   233  // RunExternalServerTests connects to a single externally provided server to run every test
   234  func (test *ImportTest) RunExternalServerTests(repoName string, s *driver.ExternalServer) error {
   235  	return test.IterImportTables(test.Tables, func(tab Table, f *os.File) error {
   236  		db, err := driver.ConnectDB(s.User, s.Password, s.Name, s.Host, s.Port, nil)
   237  		if err != nil {
   238  			return err
   239  		}
   240  		defer db.Close()
   241  		switch tab.Fmt {
   242  		case "csv":
   243  			return test.benchLoadData(repoName, db, tab, f)
   244  		case "sql":
   245  			return test.benchSql(repoName, db, tab, f)
   246  		default:
   247  			return fmt.Errorf("unexpected table import format: %s", tab.Fmt)
   248  		}
   249  	})
   250  }
   251  
   252  // RunSqlServerTests creates a new repo and server for every import test.
   253  func (test *ImportTest) RunSqlServerTests(repo driver.TestRepo, user driver.DoltUser) error {
   254  	return test.IterImportTables(test.Tables, func(tab Table, f *os.File) error {
   255  		//make a new server for every test
   256  		server, err := newServer(user, repo)
   257  		if err != nil {
   258  			return err
   259  		}
   260  		defer server.GracefulStop()
   261  
   262  		db, err := server.DB(driver.Connection{User: "root", Pass: ""})
   263  		if err != nil {
   264  			return err
   265  		}
   266  		err = modifyServerForImport(db)
   267  		if err != nil {
   268  			return err
   269  		}
   270  
   271  		switch tab.Fmt {
   272  		case "csv":
   273  			return test.benchLoadData(repo.Name, db, tab, f)
   274  		case "sql":
   275  			return test.benchSql(repo.Name, db, tab, f)
   276  		default:
   277  			return fmt.Errorf("unexpected table import format: %s", tab.Fmt)
   278  		}
   279  	})
   280  }
   281  
   282  func newServer(u driver.DoltUser, r driver.TestRepo) (*driver.SqlServer, error) {
   283  	rs, err := u.MakeRepoStore()
   284  	if err != nil {
   285  		return nil, err
   286  	}
   287  	// start dolt server
   288  	repo, err := MakeRepo(rs, r)
   289  	if err != nil {
   290  		return nil, err
   291  	}
   292  	server, err := MakeServer(repo, r.Server)
   293  	if err != nil {
   294  		return nil, err
   295  	}
   296  	if server != nil {
   297  		server.DBName = r.Name
   298  	}
   299  	return server, nil
   300  }
   301  
   302  func modifyServerForImport(db *sql.DB) error {
   303  	_, err := db.Exec("SET GLOBAL local_infile=1 ")
   304  	if err != nil {
   305  		return err
   306  	}
   307  	return nil
   308  }
   309  
   310  func (test *ImportTest) benchLoadData(repoName string, db *sql.DB, tab Table, f *os.File) error {
   311  	ctx := context.Background()
   312  	conn, err := db.Conn(ctx)
   313  	if err != nil {
   314  		return err
   315  	}
   316  	defer conn.Close()
   317  
   318  	rows, err := conn.QueryContext(ctx, tab.Schema)
   319  	if err == nil {
   320  		rows.Close()
   321  	} else {
   322  		return err
   323  	}
   324  
   325  	start := time.Now()
   326  
   327  	q := fmt.Sprintf(`
   328  LOAD DATA LOCAL INFILE '%s' INTO TABLE xy
   329  FIELDS TERMINATED BY ',' ENCLOSED BY ''
   330  LINES TERMINATED BY '\n'
   331  IGNORE 1 LINES;`, f.Name())
   332  
   333  	rows, err = conn.QueryContext(ctx, q)
   334  	if err == nil {
   335  		rows.Close()
   336  	} else {
   337  		return err
   338  	}
   339  
   340  	runtime := time.Since(start)
   341  
   342  	test.Results.append(ImportResult{
   343  		test:   test.Name,
   344  		server: repoName,
   345  		detail: tab.Name,
   346  		time:   runtime.Seconds(),
   347  		rows:   tab.Rows,
   348  		fmt:    tab.Fmt,
   349  		sorted: !tab.Shuffle,
   350  		batch:  tab.Batch,
   351  	})
   352  
   353  	rows, err = conn.QueryContext(
   354  		ctx,
   355  		fmt.Sprintf("drop table %s;", tab.TargetTable),
   356  	)
   357  	if err == nil {
   358  		rows.Close()
   359  	} else {
   360  		return err
   361  	}
   362  
   363  	return nil
   364  }
   365  
   366  func (test *ImportTest) benchSql(repoName string, db *sql.DB, tab Table, f *os.File) error {
   367  	ctx := context.Background()
   368  	conn, err := db.Conn(ctx)
   369  	if err != nil {
   370  		return err
   371  	}
   372  	defer conn.Close()
   373  
   374  	rows, err := conn.QueryContext(ctx, tab.Schema)
   375  	if err == nil {
   376  		rows.Close()
   377  	} else {
   378  		return err
   379  	}
   380  
   381  	defer conn.ExecContext(
   382  		ctx,
   383  		fmt.Sprintf("drop table %s;", tab.TargetTable),
   384  	)
   385  
   386  	f.Seek(0, 0)
   387  	s := bufio.NewScanner(f)
   388  	s.Split(ScanQueries)
   389  	start := time.Now()
   390  
   391  	for lineno := 1; s.Scan(); lineno++ {
   392  		line := s.Text()
   393  		var br bool
   394  		switch {
   395  		case line == "":
   396  			return fmt.Errorf("unexpected blank line, line number: %d", lineno)
   397  		case line == "\n":
   398  			br = true
   399  		default:
   400  		}
   401  		if br {
   402  			break
   403  		}
   404  
   405  		if err := s.Err(); err != nil {
   406  			return fmt.Errorf("%s:%d: %v", f.Name(), lineno, err)
   407  		}
   408  
   409  		_, err := conn.ExecContext(ctx, line)
   410  		if err != nil {
   411  			return err
   412  		}
   413  
   414  	}
   415  
   416  	runtime := time.Since(start)
   417  
   418  	test.Results.append(ImportResult{
   419  		test:   test.Name,
   420  		server: repoName,
   421  		detail: tab.Name,
   422  		time:   runtime.Seconds(),
   423  		rows:   tab.Rows,
   424  		fmt:    tab.Fmt,
   425  		sorted: !tab.Shuffle,
   426  		batch:  tab.Batch,
   427  	})
   428  
   429  	if err == nil {
   430  		rows.Close()
   431  	} else {
   432  		return err
   433  	}
   434  
   435  	return nil
   436  }
   437  
   438  func ScanQueries(data []byte, atEOF bool) (advance int, token []byte, err error) {
   439  	if atEOF && len(data) == 0 {
   440  		return 0, nil, nil
   441  	}
   442  	if i := bytes.IndexByte(data, ';'); i >= 0 {
   443  		// We have a full newline-terminated line.
   444  		return i + 1, dropCR(data[0:i]), nil
   445  	}
   446  	// If we're at EOF, we have a final, non-terminated line. Return it.
   447  	if atEOF {
   448  		return len(data), dropCR(data), nil
   449  	}
   450  	// Request more data.
   451  	return 0, nil, nil
   452  }
   453  
   454  func dropCR(data []byte) []byte {
   455  	if len(data) > 0 && data[len(data)-1] == '\r' {
   456  		return data[0 : len(data)-1]
   457  	}
   458  	return data
   459  }
   460  
   461  // RunCliTests runs each import test on a new dolt repo to avoid accumulated
   462  // startup costs over time between tests.
   463  func (test *ImportTest) RunCliTests(r driver.TestRepo, user driver.DoltUser) error {
   464  	return test.IterImportTables(test.Tables, func(tab Table, f *os.File) error {
   465  		var err error
   466  
   467  		rs, err := user.MakeRepoStore()
   468  		if err != nil {
   469  			return err
   470  		}
   471  
   472  		repo, err := MakeRepo(rs, r)
   473  		if err != nil {
   474  			return err
   475  		}
   476  
   477  		err = repo.DoltExec("sql", "-q", tab.Schema)
   478  		if err != nil {
   479  			return err
   480  		}
   481  
   482  		// start timer
   483  		start := time.Now()
   484  
   485  		cmd := repo.DoltCmd("table", "import", "-r", "--file-type", tab.Fmt, tab.TargetTable, f.Name())
   486  		_, err = cmd.StdoutPipe()
   487  		if err != nil {
   488  			return err
   489  		}
   490  		cmd.Stderr = cmd.Stdout
   491  		err = cmd.Run()
   492  		if err != nil {
   493  			return fmt.Errorf("%w: %s", err, cmd.Stderr)
   494  		}
   495  
   496  		// end timer, append result
   497  		runtime := time.Since(start)
   498  
   499  		test.Results.append(ImportResult{
   500  			test:   test.Name,
   501  			server: r.Name,
   502  			detail: tab.Name,
   503  			time:   runtime.Seconds(),
   504  			rows:   tab.Rows,
   505  			fmt:    tab.Fmt,
   506  			sorted: !tab.Shuffle,
   507  			batch:  tab.Batch,
   508  		})
   509  
   510  		// reset repo at end
   511  		return repo.DoltExec("sql", "-q", fmt.Sprintf("drop table %s", tab.TargetTable))
   512  	})
   513  }
   514  
   515  func (test *ImportTest) IterImportTables(tables []Table, cb func(t Table, f *os.File) error) error {
   516  	for _, t := range tables {
   517  		key, err := tableKey(t)
   518  		if err != nil {
   519  			return err
   520  		}
   521  		table, names, types := parseTableAndSchema(t.Schema)
   522  		t.TargetTable = table
   523  
   524  		if f, ok := test.files[key]; ok {
   525  			// short circuit if we've already made file for schema/row count
   526  			err = cb(t, f)
   527  			if err != nil {
   528  				return err
   529  			}
   530  			continue
   531  		}
   532  
   533  		rows := make([]string, 0, t.Rows)
   534  		genRows(types, t.Rows, t.Fmt, func(r []string) {
   535  			switch t.Fmt {
   536  			case "csv":
   537  				rows = append(rows, strings.Join(r, ","))
   538  			case "sql":
   539  				rows = append(rows, fmt.Sprintf("(%s)", strings.Join(r, ", ")))
   540  			default:
   541  				panic(fmt.Sprintf("unknown format: %s", t.Fmt))
   542  			}
   543  		})
   544  
   545  		if t.Shuffle {
   546  			rand.Shuffle(len(rows), func(i, j int) { rows[i], rows[j] = rows[j], rows[i] })
   547  		}
   548  
   549  		f, err := os.CreateTemp(test.tmpdir, "import-data-")
   550  		if err != nil {
   551  			return err
   552  		}
   553  
   554  		switch t.Fmt {
   555  		case "csv":
   556  			fmt.Fprintf(f, "%s\n", strings.Join(names, ","))
   557  			for _, r := range rows {
   558  				fmt.Fprintf(f, "%s\n", r)
   559  			}
   560  		case "sql":
   561  			if t.Batch {
   562  				batchSize := defaultBatchSize
   563  				var i int
   564  				for i+batchSize < len(rows) {
   565  					fmt.Fprintf(f, newBatch(t.TargetTable, rows[i:i+batchSize]))
   566  					i += batchSize
   567  				}
   568  				if i < len(rows) {
   569  					fmt.Fprintf(f, newBatch(t.TargetTable, rows[i:]))
   570  				}
   571  			} else {
   572  				for _, r := range rows {
   573  					fmt.Fprintf(f, fmt.Sprintf("INSERT INTO %s VALUES %s;\n", t.TargetTable, r))
   574  				}
   575  			}
   576  		default:
   577  			panic(fmt.Sprintf("unknown format: %s", t.Fmt))
   578  		}
   579  
   580  		// cache file for schema and row count
   581  		test.files[key] = f
   582  
   583  		err = cb(t, f)
   584  		if err != nil {
   585  			return err
   586  		}
   587  	}
   588  	return nil
   589  }
   590  
   591  func newBatch(name string, rows []string) string {
   592  	b := strings.Builder{}
   593  	b.WriteString(fmt.Sprintf("INSERT INTO %s VALUES\n", name))
   594  	for _, r := range rows[:len(rows)-1] {
   595  		b.WriteString("  ")
   596  		b.WriteString(r)
   597  		b.WriteString(",\n")
   598  	}
   599  	b.WriteString("  ")
   600  	b.WriteString(rows[len(rows)-1])
   601  	b.WriteString(";\n")
   602  
   603  	return b.String()
   604  }
   605  
   606  func tableKey(t Table) (uint64, error) {
   607  	hash := xxhash.New()
   608  	_, err := hash.Write([]byte(t.Schema))
   609  	if err != nil {
   610  		return 0, err
   611  	}
   612  	if _, err := hash.Write([]byte(fmt.Sprintf("%#v,", t.Rows))); err != nil {
   613  		return 0, err
   614  	}
   615  	if err != nil {
   616  		return 0, err
   617  	}
   618  	_, err = hash.Write([]byte(t.Fmt))
   619  	if err != nil {
   620  		return 0, err
   621  	}
   622  	return hash.Sum64(), nil
   623  }
   624  
   625  func parseTableAndSchema(q string) (string, []string, []sql2.Type) {
   626  	stmt, _, err := ast.ParseOne(q)
   627  	if err != nil {
   628  		panic(fmt.Sprintf("invalid query: %s; %s", q, err))
   629  	}
   630  	var types []sql2.Type
   631  	var names []string
   632  	var table string
   633  	switch n := stmt.(type) {
   634  	case *ast.DDL:
   635  		table = n.Table.String()
   636  		for _, col := range n.TableSpec.Columns {
   637  			names = append(names, col.Name.String())
   638  			typ, err := gmstypes.ColumnTypeToType(&col.Type)
   639  			if err != nil {
   640  				panic(fmt.Sprintf("unexpected error reading type: %s", err))
   641  			}
   642  			types = append(types, typ)
   643  		}
   644  	default:
   645  		panic(fmt.Sprintf("expected CREATE TABLE, found: %s", q))
   646  	}
   647  	return table, names, types
   648  }
   649  
   650  func genRows(types []sql2.Type, n int, fmt string, cb func(r []string)) {
   651  	// generate |n| rows with column types
   652  	for i := 0; i < n; i++ {
   653  		row := make([]string, len(types))
   654  		for j, t := range types {
   655  			switch fmt {
   656  			case "sql":
   657  				switch t.Type() {
   658  				case sqltypes.Blob, sqltypes.VarChar, sqltypes.Timestamp, sqltypes.Date:
   659  					row[j] = "'" + genValue(i, t) + "'"
   660  				default:
   661  					row[j] = genValue(i, t)
   662  				}
   663  			default:
   664  				row[j] = genValue(i, t)
   665  			}
   666  		}
   667  		cb(row)
   668  	}
   669  }
   670  
   671  func genValue(i int, typ sql2.Type) string {
   672  	switch typ.Type() {
   673  	case sqltypes.Blob:
   674  		return fmt.Sprintf("blob %d", i)
   675  	case sqltypes.VarChar:
   676  		return fmt.Sprintf("varchar %d", i)
   677  	case sqltypes.Int8, sqltypes.Int16, sqltypes.Int32, sqltypes.Int64:
   678  		return strconv.Itoa(i)
   679  	case sqltypes.Float32, sqltypes.Float64:
   680  		return strconv.FormatFloat(float64(i), 'E', -1, 32)
   681  	case sqltypes.Bit:
   682  		return strconv.Itoa(i)
   683  	case sqltypes.Geometry:
   684  		return `{"type": "Point", "coordinates": [1,2]}`
   685  	case sqltypes.Timestamp:
   686  		return "2019-12-31T12:00:00Z"
   687  	case sqltypes.Date:
   688  		return "2019-12-31T00:00:00Z"
   689  	default:
   690  		panic(fmt.Sprintf("expected type, found: %s", typ))
   691  	}
   692  }
   693  
   694  func RunTestsFile(t *testing.T, path string) {
   695  	def, err := ParseTestsFile(path)
   696  	require.NoError(t, err)
   697  	for _, test := range def.Tests {
   698  		t.Run(test.Name, test.Run)
   699  	}
   700  }