github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/performance/import_benchmarker/testdef.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package import_benchmarker 16 17 import ( 18 "bufio" 19 "bytes" 20 "context" 21 "database/sql" 22 "fmt" 23 "math/rand" 24 "os" 25 "strconv" 26 "strings" 27 "testing" 28 "time" 29 30 "github.com/cespare/xxhash" 31 "github.com/creasty/defaults" 32 sql2 "github.com/dolthub/go-mysql-server/sql" 33 gmstypes "github.com/dolthub/go-mysql-server/sql/types" 34 "github.com/dolthub/vitess/go/sqltypes" 35 ast "github.com/dolthub/vitess/go/vt/sqlparser" 36 "github.com/stretchr/testify/require" 37 yaml "gopkg.in/yaml.v3" 38 39 driver "github.com/dolthub/dolt/go/libraries/doltcore/dtestutils/sql_server_driver" 40 ) 41 42 const defaultBatchSize = 500 43 44 // TestDef is the top-level definition of tests to run. 45 type TestDef struct { 46 Tests []ImportTest `yaml:"tests"` 47 Opts *Opts `yaml:"opts"` 48 } 49 50 type Opts struct { 51 Seed int `yaml:"seed"` 52 } 53 54 // ImportTest is a single test to run. The Repos and MultiRepos will be created, and 55 // any Servers defined within them will be started. The interactions and 56 // assertions defined in Conns will be run. 57 type ImportTest struct { 58 Name string `yaml:"name"` 59 Repos []driver.TestRepo `yaml:"repos"` 60 Tables []Table `yaml:"tables"` 61 62 // Skip the entire test with this reason. 63 Skip string `yaml:"skip"` 64 65 Results *ImportResults 66 files map[uint64]*os.File 67 tmpdir string 68 } 69 70 type Table struct { 71 Name string `yaml:"name"` 72 Schema string `yaml:"schema"` 73 Rows int `default:"200000" yaml:"rows"` 74 Fmt string `default:"csv" yaml:"fmt"` 75 Shuffle bool `default:"false" yaml:"shuffle"` 76 Batch bool `default:"false" yaml:"batch"` 77 TargetTable string 78 } 79 80 func (s *Table) UnmarshalYAML(unmarshal func(interface{}) error) error { 81 defaults.Set(s) 82 83 type plain Table 84 if err := unmarshal((*plain)(s)); err != nil { 85 return err 86 } 87 88 return nil 89 } 90 91 func ParseTestsFile(path string) (TestDef, error) { 92 contents, err := os.ReadFile(path) 93 if err != nil { 94 return TestDef{}, err 95 } 96 dec := yaml.NewDecoder(bytes.NewReader(contents)) 97 dec.KnownFields(true) 98 var res TestDef 99 err = dec.Decode(&res) 100 return res, err 101 } 102 103 func MakeRepo(rs driver.RepoStore, r driver.TestRepo) (driver.Repo, error) { 104 repo, err := rs.MakeRepo(r.Name) 105 if err != nil { 106 return driver.Repo{}, err 107 } 108 return repo, nil 109 } 110 111 func MakeServer(dc driver.DoltCmdable, s *driver.Server) (*driver.SqlServer, error) { 112 if s == nil { 113 return nil, nil 114 } 115 opts := []driver.SqlServerOpt{driver.WithArgs(s.Args...)} 116 if s.Port != 0 { 117 opts = append(opts, driver.WithPort(s.Port)) 118 } 119 server, err := driver.StartSqlServer(dc, opts...) 120 if err != nil { 121 return nil, err 122 } 123 124 return server, nil 125 } 126 127 type ImportResult struct { 128 detail string 129 server string 130 test string 131 time float64 132 rows int 133 fmt string 134 sorted bool 135 batch bool 136 } 137 138 func (r ImportResult) String() string { 139 return fmt.Sprintf("- %s/%s/%s: %.2fs\n", r.test, r.server, r.detail, r.time) 140 } 141 142 type ImportResults struct { 143 res []ImportResult 144 } 145 146 func (r *ImportResults) append(ir ImportResult) { 147 r.res = append(r.res, ir) 148 } 149 150 func (r *ImportResults) String() string { 151 b := strings.Builder{} 152 b.WriteString("Results:\n") 153 for _, x := range r.res { 154 b.WriteString(x.String()) 155 } 156 return b.String() 157 } 158 159 func (r *ImportResults) SqlDump() string { 160 b := strings.Builder{} 161 b.WriteString(`CREATE TABLE IF NOT EXISTS import_perf_results ( 162 test_name varchar(64), 163 server varchar(64), 164 detail varchar(64), 165 row_cnt int, 166 time double, 167 file_format varchar(8), 168 sorted bool, 169 batch bool, 170 primary key (test_name, detail, server) 171 ); 172 `) 173 174 b.WriteString("insert into import_perf_results values\n") 175 for i, r := range r.res { 176 if i > 0 { 177 b.WriteString(",\n ") 178 } 179 var sorted int 180 if r.sorted { 181 sorted = 1 182 } 183 var batch int 184 if r.batch { 185 batch = 1 186 } 187 b.WriteString(fmt.Sprintf( 188 "('%s', '%s', '%s', %d, %.2f, '%s', %b, %b)", 189 r.test, r.server, r.detail, r.rows, r.time, r.fmt, sorted, batch)) 190 } 191 b.WriteString(";\n") 192 193 return b.String() 194 } 195 196 func (test *ImportTest) InitWithTmpDir(s string) { 197 test.tmpdir = s 198 test.files = make(map[uint64]*os.File) 199 } 200 201 // Run executes an import configuration. Test parallelism makes 202 // runtimes resulting from this method unsuitable for reporting. 203 func (test *ImportTest) Run(t *testing.T) { 204 if test.Skip != "" { 205 t.Skip(test.Skip) 206 } 207 var err error 208 if test.Results == nil { 209 test.Results = new(ImportResults) 210 tmp, err := os.MkdirTemp("", "repo-store-") 211 if err != nil { 212 require.NoError(t, err) 213 } 214 test.InitWithTmpDir(tmp) 215 } 216 217 u, err := driver.NewDoltUser() 218 for _, r := range test.Repos { 219 if r.ExternalServer != nil { 220 err := test.RunExternalServerTests(r.Name, r.ExternalServer) 221 require.NoError(t, err) 222 } else if r.Server != nil { 223 err = test.RunSqlServerTests(r, u) 224 require.NoError(t, err) 225 } else { 226 err = test.RunCliTests(r, u) 227 require.NoError(t, err) 228 } 229 } 230 fmt.Println(test.Results.String()) 231 } 232 233 // RunExternalServerTests connects to a single externally provided server to run every test 234 func (test *ImportTest) RunExternalServerTests(repoName string, s *driver.ExternalServer) error { 235 return test.IterImportTables(test.Tables, func(tab Table, f *os.File) error { 236 db, err := driver.ConnectDB(s.User, s.Password, s.Name, s.Host, s.Port, nil) 237 if err != nil { 238 return err 239 } 240 defer db.Close() 241 switch tab.Fmt { 242 case "csv": 243 return test.benchLoadData(repoName, db, tab, f) 244 case "sql": 245 return test.benchSql(repoName, db, tab, f) 246 default: 247 return fmt.Errorf("unexpected table import format: %s", tab.Fmt) 248 } 249 }) 250 } 251 252 // RunSqlServerTests creates a new repo and server for every import test. 253 func (test *ImportTest) RunSqlServerTests(repo driver.TestRepo, user driver.DoltUser) error { 254 return test.IterImportTables(test.Tables, func(tab Table, f *os.File) error { 255 //make a new server for every test 256 server, err := newServer(user, repo) 257 if err != nil { 258 return err 259 } 260 defer server.GracefulStop() 261 262 db, err := server.DB(driver.Connection{User: "root", Pass: ""}) 263 if err != nil { 264 return err 265 } 266 err = modifyServerForImport(db) 267 if err != nil { 268 return err 269 } 270 271 switch tab.Fmt { 272 case "csv": 273 return test.benchLoadData(repo.Name, db, tab, f) 274 case "sql": 275 return test.benchSql(repo.Name, db, tab, f) 276 default: 277 return fmt.Errorf("unexpected table import format: %s", tab.Fmt) 278 } 279 }) 280 } 281 282 func newServer(u driver.DoltUser, r driver.TestRepo) (*driver.SqlServer, error) { 283 rs, err := u.MakeRepoStore() 284 if err != nil { 285 return nil, err 286 } 287 // start dolt server 288 repo, err := MakeRepo(rs, r) 289 if err != nil { 290 return nil, err 291 } 292 server, err := MakeServer(repo, r.Server) 293 if err != nil { 294 return nil, err 295 } 296 if server != nil { 297 server.DBName = r.Name 298 } 299 return server, nil 300 } 301 302 func modifyServerForImport(db *sql.DB) error { 303 _, err := db.Exec("SET GLOBAL local_infile=1 ") 304 if err != nil { 305 return err 306 } 307 return nil 308 } 309 310 func (test *ImportTest) benchLoadData(repoName string, db *sql.DB, tab Table, f *os.File) error { 311 ctx := context.Background() 312 conn, err := db.Conn(ctx) 313 if err != nil { 314 return err 315 } 316 defer conn.Close() 317 318 rows, err := conn.QueryContext(ctx, tab.Schema) 319 if err == nil { 320 rows.Close() 321 } else { 322 return err 323 } 324 325 start := time.Now() 326 327 q := fmt.Sprintf(` 328 LOAD DATA LOCAL INFILE '%s' INTO TABLE xy 329 FIELDS TERMINATED BY ',' ENCLOSED BY '' 330 LINES TERMINATED BY '\n' 331 IGNORE 1 LINES;`, f.Name()) 332 333 rows, err = conn.QueryContext(ctx, q) 334 if err == nil { 335 rows.Close() 336 } else { 337 return err 338 } 339 340 runtime := time.Since(start) 341 342 test.Results.append(ImportResult{ 343 test: test.Name, 344 server: repoName, 345 detail: tab.Name, 346 time: runtime.Seconds(), 347 rows: tab.Rows, 348 fmt: tab.Fmt, 349 sorted: !tab.Shuffle, 350 batch: tab.Batch, 351 }) 352 353 rows, err = conn.QueryContext( 354 ctx, 355 fmt.Sprintf("drop table %s;", tab.TargetTable), 356 ) 357 if err == nil { 358 rows.Close() 359 } else { 360 return err 361 } 362 363 return nil 364 } 365 366 func (test *ImportTest) benchSql(repoName string, db *sql.DB, tab Table, f *os.File) error { 367 ctx := context.Background() 368 conn, err := db.Conn(ctx) 369 if err != nil { 370 return err 371 } 372 defer conn.Close() 373 374 rows, err := conn.QueryContext(ctx, tab.Schema) 375 if err == nil { 376 rows.Close() 377 } else { 378 return err 379 } 380 381 defer conn.ExecContext( 382 ctx, 383 fmt.Sprintf("drop table %s;", tab.TargetTable), 384 ) 385 386 f.Seek(0, 0) 387 s := bufio.NewScanner(f) 388 s.Split(ScanQueries) 389 start := time.Now() 390 391 for lineno := 1; s.Scan(); lineno++ { 392 line := s.Text() 393 var br bool 394 switch { 395 case line == "": 396 return fmt.Errorf("unexpected blank line, line number: %d", lineno) 397 case line == "\n": 398 br = true 399 default: 400 } 401 if br { 402 break 403 } 404 405 if err := s.Err(); err != nil { 406 return fmt.Errorf("%s:%d: %v", f.Name(), lineno, err) 407 } 408 409 _, err := conn.ExecContext(ctx, line) 410 if err != nil { 411 return err 412 } 413 414 } 415 416 runtime := time.Since(start) 417 418 test.Results.append(ImportResult{ 419 test: test.Name, 420 server: repoName, 421 detail: tab.Name, 422 time: runtime.Seconds(), 423 rows: tab.Rows, 424 fmt: tab.Fmt, 425 sorted: !tab.Shuffle, 426 batch: tab.Batch, 427 }) 428 429 if err == nil { 430 rows.Close() 431 } else { 432 return err 433 } 434 435 return nil 436 } 437 438 func ScanQueries(data []byte, atEOF bool) (advance int, token []byte, err error) { 439 if atEOF && len(data) == 0 { 440 return 0, nil, nil 441 } 442 if i := bytes.IndexByte(data, ';'); i >= 0 { 443 // We have a full newline-terminated line. 444 return i + 1, dropCR(data[0:i]), nil 445 } 446 // If we're at EOF, we have a final, non-terminated line. Return it. 447 if atEOF { 448 return len(data), dropCR(data), nil 449 } 450 // Request more data. 451 return 0, nil, nil 452 } 453 454 func dropCR(data []byte) []byte { 455 if len(data) > 0 && data[len(data)-1] == '\r' { 456 return data[0 : len(data)-1] 457 } 458 return data 459 } 460 461 // RunCliTests runs each import test on a new dolt repo to avoid accumulated 462 // startup costs over time between tests. 463 func (test *ImportTest) RunCliTests(r driver.TestRepo, user driver.DoltUser) error { 464 return test.IterImportTables(test.Tables, func(tab Table, f *os.File) error { 465 var err error 466 467 rs, err := user.MakeRepoStore() 468 if err != nil { 469 return err 470 } 471 472 repo, err := MakeRepo(rs, r) 473 if err != nil { 474 return err 475 } 476 477 err = repo.DoltExec("sql", "-q", tab.Schema) 478 if err != nil { 479 return err 480 } 481 482 // start timer 483 start := time.Now() 484 485 cmd := repo.DoltCmd("table", "import", "-r", "--file-type", tab.Fmt, tab.TargetTable, f.Name()) 486 _, err = cmd.StdoutPipe() 487 if err != nil { 488 return err 489 } 490 cmd.Stderr = cmd.Stdout 491 err = cmd.Run() 492 if err != nil { 493 return fmt.Errorf("%w: %s", err, cmd.Stderr) 494 } 495 496 // end timer, append result 497 runtime := time.Since(start) 498 499 test.Results.append(ImportResult{ 500 test: test.Name, 501 server: r.Name, 502 detail: tab.Name, 503 time: runtime.Seconds(), 504 rows: tab.Rows, 505 fmt: tab.Fmt, 506 sorted: !tab.Shuffle, 507 batch: tab.Batch, 508 }) 509 510 // reset repo at end 511 return repo.DoltExec("sql", "-q", fmt.Sprintf("drop table %s", tab.TargetTable)) 512 }) 513 } 514 515 func (test *ImportTest) IterImportTables(tables []Table, cb func(t Table, f *os.File) error) error { 516 for _, t := range tables { 517 key, err := tableKey(t) 518 if err != nil { 519 return err 520 } 521 table, names, types := parseTableAndSchema(t.Schema) 522 t.TargetTable = table 523 524 if f, ok := test.files[key]; ok { 525 // short circuit if we've already made file for schema/row count 526 err = cb(t, f) 527 if err != nil { 528 return err 529 } 530 continue 531 } 532 533 rows := make([]string, 0, t.Rows) 534 genRows(types, t.Rows, t.Fmt, func(r []string) { 535 switch t.Fmt { 536 case "csv": 537 rows = append(rows, strings.Join(r, ",")) 538 case "sql": 539 rows = append(rows, fmt.Sprintf("(%s)", strings.Join(r, ", "))) 540 default: 541 panic(fmt.Sprintf("unknown format: %s", t.Fmt)) 542 } 543 }) 544 545 if t.Shuffle { 546 rand.Shuffle(len(rows), func(i, j int) { rows[i], rows[j] = rows[j], rows[i] }) 547 } 548 549 f, err := os.CreateTemp(test.tmpdir, "import-data-") 550 if err != nil { 551 return err 552 } 553 554 switch t.Fmt { 555 case "csv": 556 fmt.Fprintf(f, "%s\n", strings.Join(names, ",")) 557 for _, r := range rows { 558 fmt.Fprintf(f, "%s\n", r) 559 } 560 case "sql": 561 if t.Batch { 562 batchSize := defaultBatchSize 563 var i int 564 for i+batchSize < len(rows) { 565 fmt.Fprintf(f, newBatch(t.TargetTable, rows[i:i+batchSize])) 566 i += batchSize 567 } 568 if i < len(rows) { 569 fmt.Fprintf(f, newBatch(t.TargetTable, rows[i:])) 570 } 571 } else { 572 for _, r := range rows { 573 fmt.Fprintf(f, fmt.Sprintf("INSERT INTO %s VALUES %s;\n", t.TargetTable, r)) 574 } 575 } 576 default: 577 panic(fmt.Sprintf("unknown format: %s", t.Fmt)) 578 } 579 580 // cache file for schema and row count 581 test.files[key] = f 582 583 err = cb(t, f) 584 if err != nil { 585 return err 586 } 587 } 588 return nil 589 } 590 591 func newBatch(name string, rows []string) string { 592 b := strings.Builder{} 593 b.WriteString(fmt.Sprintf("INSERT INTO %s VALUES\n", name)) 594 for _, r := range rows[:len(rows)-1] { 595 b.WriteString(" ") 596 b.WriteString(r) 597 b.WriteString(",\n") 598 } 599 b.WriteString(" ") 600 b.WriteString(rows[len(rows)-1]) 601 b.WriteString(";\n") 602 603 return b.String() 604 } 605 606 func tableKey(t Table) (uint64, error) { 607 hash := xxhash.New() 608 _, err := hash.Write([]byte(t.Schema)) 609 if err != nil { 610 return 0, err 611 } 612 if _, err := hash.Write([]byte(fmt.Sprintf("%#v,", t.Rows))); err != nil { 613 return 0, err 614 } 615 if err != nil { 616 return 0, err 617 } 618 _, err = hash.Write([]byte(t.Fmt)) 619 if err != nil { 620 return 0, err 621 } 622 return hash.Sum64(), nil 623 } 624 625 func parseTableAndSchema(q string) (string, []string, []sql2.Type) { 626 stmt, _, err := ast.ParseOne(q) 627 if err != nil { 628 panic(fmt.Sprintf("invalid query: %s; %s", q, err)) 629 } 630 var types []sql2.Type 631 var names []string 632 var table string 633 switch n := stmt.(type) { 634 case *ast.DDL: 635 table = n.Table.String() 636 for _, col := range n.TableSpec.Columns { 637 names = append(names, col.Name.String()) 638 typ, err := gmstypes.ColumnTypeToType(&col.Type) 639 if err != nil { 640 panic(fmt.Sprintf("unexpected error reading type: %s", err)) 641 } 642 types = append(types, typ) 643 } 644 default: 645 panic(fmt.Sprintf("expected CREATE TABLE, found: %s", q)) 646 } 647 return table, names, types 648 } 649 650 func genRows(types []sql2.Type, n int, fmt string, cb func(r []string)) { 651 // generate |n| rows with column types 652 for i := 0; i < n; i++ { 653 row := make([]string, len(types)) 654 for j, t := range types { 655 switch fmt { 656 case "sql": 657 switch t.Type() { 658 case sqltypes.Blob, sqltypes.VarChar, sqltypes.Timestamp, sqltypes.Date: 659 row[j] = "'" + genValue(i, t) + "'" 660 default: 661 row[j] = genValue(i, t) 662 } 663 default: 664 row[j] = genValue(i, t) 665 } 666 } 667 cb(row) 668 } 669 } 670 671 func genValue(i int, typ sql2.Type) string { 672 switch typ.Type() { 673 case sqltypes.Blob: 674 return fmt.Sprintf("blob %d", i) 675 case sqltypes.VarChar: 676 return fmt.Sprintf("varchar %d", i) 677 case sqltypes.Int8, sqltypes.Int16, sqltypes.Int32, sqltypes.Int64: 678 return strconv.Itoa(i) 679 case sqltypes.Float32, sqltypes.Float64: 680 return strconv.FormatFloat(float64(i), 'E', -1, 32) 681 case sqltypes.Bit: 682 return strconv.Itoa(i) 683 case sqltypes.Geometry: 684 return `{"type": "Point", "coordinates": [1,2]}` 685 case sqltypes.Timestamp: 686 return "2019-12-31T12:00:00Z" 687 case sqltypes.Date: 688 return "2019-12-31T00:00:00Z" 689 default: 690 panic(fmt.Sprintf("expected type, found: %s", typ)) 691 } 692 } 693 694 func RunTestsFile(t *testing.T, path string) { 695 def, err := ParseTestsFile(path) 696 require.NoError(t, err) 697 for _, test := range def.Tests { 698 t.Run(test.Name, test.Run) 699 } 700 }