github.com/dolthub/go-mysql-server@v0.18.0/enginetest/sqllogictest/convert/convert.go (about)

     1  package main
     2  
     3  import (
     4  	"bufio"
     5  	"fmt"
     6  	"os"
     7  	"strings"
     8  
     9  	"github.com/dolthub/go-mysql-server/enginetest/sqllogictest/utils"
    10  
    11  	_ "github.com/go-sql-driver/mysql"
    12  )
    13  
    14  func main() {
    15  	args := os.Args[1:]
    16  
    17  	if len(args) != 2 {
    18  		panic("expected 2 args")
    19  	}
    20  
    21  	infile, err := os.Open(args[0])
    22  	if err != nil {
    23  		panic(err)
    24  	}
    25  	defer infile.Close()
    26  
    27  	outfile, err := os.Create(args[1])
    28  	if err != nil {
    29  		panic(err)
    30  	}
    31  	defer outfile.Close()
    32  
    33  	scanner := bufio.NewScanner(infile)
    34  	for {
    35  		if !scanner.Scan() {
    36  			break
    37  		}
    38  		line := scanner.Text()
    39  		if len(line) == 0 {
    40  			outfile.WriteString("\n")
    41  			continue
    42  		}
    43  		if strings.HasPrefix(line, "#") {
    44  			outfile.WriteString(line + "\n")
    45  			continue
    46  		}
    47  		if strings.HasPrefix(line, "statement ok") {
    48  			rewriteStmt(scanner, outfile)
    49  			continue
    50  		}
    51  		if strings.HasPrefix(line, "query error") || strings.HasPrefix(line, "statement error") {
    52  			rewriteError(scanner, outfile)
    53  			continue
    54  		}
    55  		if strings.HasPrefix(line, "query") {
    56  			rewriteQuery(scanner, line, outfile)
    57  			continue
    58  		}
    59  	}
    60  }
    61  
    62  func rewriteStmt(scanner *bufio.Scanner, outfile *os.File) {
    63  	var stmt string
    64  	once := true
    65  	for {
    66  		if !scanner.Scan() {
    67  			panic("expected statement")
    68  		}
    69  		part := scanner.Text()
    70  		if len(part) == 0 {
    71  			break
    72  		}
    73  
    74  		parts := strings.Split(part, "; ")
    75  		if len(parts) > 1 {
    76  			// multiple statements in one line for some reason
    77  			for _, p := range parts {
    78  				writeStmt(p, outfile)
    79  				once = true
    80  			}
    81  			continue
    82  		}
    83  
    84  		if once {
    85  			once = false
    86  			stmt += part
    87  		} else {
    88  			stmt += "\n" + part
    89  		}
    90  		if strings.HasSuffix(part, ";") {
    91  			writeStmt(stmt, outfile)
    92  			once = true
    93  			stmt = ""
    94  		}
    95  	}
    96  
    97  	if len(stmt) != 0 {
    98  		writeStmt(stmt, outfile)
    99  	}
   100  }
   101  
   102  func writeStmt(stmt string, outfile *os.File) {
   103  	outfile.WriteString("statement ok\n")
   104  	outfile.WriteString(stmt + "\n\n")
   105  }
   106  
   107  func rewriteError(scanner *bufio.Scanner, outfile *os.File) {
   108  	outfile.WriteString("statement error\n")
   109  
   110  	stmt := utils.ReadStmt(scanner)
   111  	outfile.WriteString(stmt + "\n\n")
   112  }
   113  
   114  func rewriteQuery(scanner *bufio.Scanner, line string, outfile *os.File) {
   115  	schema := strings.Split(line, " ")[1]
   116  	hasColNames := strings.Contains(line, "colnames")
   117  	hasRowSort := strings.Contains(line, "rowsort")
   118  
   119  	if hasRowSort {
   120  		// TODO: throw warning about putting order by in query
   121  	}
   122  
   123  	// expect query
   124  	query := utils.ReadQuery(scanner)
   125  
   126  	// expect colnames; drop them
   127  	if hasColNames && !scanner.Scan() {
   128  		panic("expected colnames")
   129  	}
   130  
   131  	// expect results
   132  	rows := utils.ReadResults(scanner)
   133  
   134  	// ignore queries with full outer join or full join
   135  	if strings.Contains(strings.ToLower(query), "full outer join") || strings.Contains(strings.ToLower(query), "full join") {
   136  		return
   137  	}
   138  
   139  	outfile.WriteString(fmt.Sprintf("query %s nosort\n", schema))
   140  	outfile.WriteString(query + "\n")
   141  	outfile.WriteString(utils.SEP + "\n")
   142  	for _, row := range rows {
   143  		outfile.WriteString(row + "\n")
   144  	}
   145  	outfile.WriteString("\n")
   146  }