github.com/dolthub/go-mysql-server@v0.18.0/enginetest/sqllogictest/check/check.go (about)

     1  package main
     2  
     3  import (
     4  	"bufio"
     5  	"fmt"
     6  	"os"
     7  	"strings"
     8  
     9  	"github.com/dolthub/go-mysql-server/enginetest/sqllogictest/utils"
    10  
    11  	_ "github.com/go-sql-driver/mysql"
    12  	"github.com/gocraft/dbr/v2"
    13  )
    14  
    15  func main() {
    16  	args := os.Args[1:]
    17  	if len(args) != 1 {
    18  		panic("expected 1 arg")
    19  	}
    20  	file, err := os.Open(args[0])
    21  	if err != nil {
    22  		panic(err)
    23  	}
    24  	defer file.Close()
    25  
    26  	conn, err := dbr.Open("mysql", fmt.Sprintf("root:root@tcp(localhost:3306)/"), nil)
    27  	//conn, err := dbr.Open("mysql", fmt.Sprintf("dolt@tcp(localhost:3307)/"), nil)
    28  	if err != nil {
    29  		panic(err)
    30  	}
    31  	_, err = conn.Exec("drop database if exists tmp")
    32  	if err != nil {
    33  		panic(err)
    34  	}
    35  	_, err = conn.Exec("create database tmp")
    36  	if err != nil {
    37  		panic(err)
    38  	}
    39  	_, err = conn.Exec("use tmp")
    40  	if err != nil {
    41  		panic(err)
    42  	}
    43  	defer conn.Close()
    44  
    45  	scanner := bufio.NewScanner(file)
    46  	for {
    47  		if !scanner.Scan() {
    48  			break
    49  		}
    50  		line := scanner.Text()
    51  		if len(line) == 0 {
    52  			continue
    53  		}
    54  		if strings.HasPrefix(line, "#") {
    55  			continue
    56  		}
    57  		if line == "statement ok" {
    58  			stmt := utils.ReadStmt(scanner)
    59  			if _, err := conn.Exec(stmt); err != nil {
    60  				panic(fmt.Sprintf("%s \nerr: %v", stmt, err))
    61  			}
    62  			continue
    63  		}
    64  		if strings.HasPrefix(line, "statement error") {
    65  			stmt := utils.ReadStmt(scanner)
    66  			if _, err := conn.Query(stmt); err == nil {
    67  				panic(fmt.Sprintf("%s \nexpected error, but got none", stmt))
    68  			}
    69  			continue
    70  		}
    71  		if strings.HasPrefix(line, "query") {
    72  			if err := handleQuery(scanner, conn); err != nil {
    73  				panic(err)
    74  			}
    75  			continue
    76  		}
    77  	}
    78  
    79  	fmt.Println("All tests passed")
    80  }
    81  
    82  func handleQuery(scanner *bufio.Scanner, conn *dbr.Connection) error {
    83  	query := utils.ReadQuery(scanner)
    84  	_, rows, err := executeQuery(conn, query)
    85  	if err != nil {
    86  		panic(fmt.Sprintf("%s \nerr: %v ", query, err))
    87  	}
    88  	expectedRows := utils.ReadResults(scanner)
    89  	err = compareRows(rows, expectedRows)
    90  	if err != nil {
    91  		return fmt.Errorf("%s \nerr: %v ", query, err)
    92  	}
    93  	return nil
    94  }
    95  
    96  // executeQuery executes the given query and returns the columns and rows (flattened)
    97  func executeQuery(conn *dbr.Connection, query string) ([]string, []string, error) {
    98  	res, err := conn.Query(query)
    99  	if err != nil {
   100  		return nil, nil, err
   101  	}
   102  
   103  	cols, _ := res.Columns()
   104  	numCols := len(cols)
   105  
   106  	rows := make([]string, 0)
   107  	rowBuf := make([]interface{}, numCols)
   108  	for j := 0; j < numCols; j++ {
   109  		rowBuf[j] = new(interface{})
   110  	}
   111  	for res.Next() {
   112  		if err := res.Scan(rowBuf...); err != nil {
   113  			return nil, nil, err
   114  		}
   115  		for i := 0; i < numCols; i++ {
   116  			rawVal := *rowBuf[i].(*interface{})
   117  			if rawVal == nil {
   118  				rows = append(rows, "NULL")
   119  			} else {
   120  				rows = append(rows, fmt.Sprintf("%s", rawVal))
   121  			}
   122  		}
   123  	}
   124  
   125  	return cols, rows, nil
   126  }
   127  
   128  func compareCols(cols, expectedCols []string) {
   129  	if len(cols) != len(expectedCols) {
   130  		panic(fmt.Sprintf("column lengths not equal: actual: %v, expected %v", len(cols), len(expectedCols)))
   131  	}
   132  	for i := 0; i < len(cols); i++ {
   133  		if expectedCols[i] != "" && cols[i] != expectedCols[i] {
   134  			panic(fmt.Sprintf("column %d not equal: actual: %v, expected %v", i, cols[i], expectedCols[i]))
   135  		}
   136  	}
   137  }
   138  
   139  func compareRows(rows, expectedRows []string) error {
   140  	if len(rows) != len(expectedRows) {
   141  		return fmt.Errorf("row lengths not equal: actual: %v, expected %v", len(rows), len(expectedRows))
   142  	}
   143  	for i := 0; i < len(rows); i++ {
   144  		if rows[i] != expectedRows[i] {
   145  			return fmt.Errorf("row %d not equal: actual: %v, expected %v", i, rows, expectedRows)
   146  		}
   147  	}
   148  	return nil
   149  }
   150  
   151  func printRows(rows [][]string) {
   152  	for _, row := range rows {
   153  		for _, col := range row {
   154  			fmt.Printf("%s\t", col)
   155  		}
   156  		fmt.Println()
   157  	}
   158  }