github.com/dolthub/go-mysql-server@v0.18.0/enginetest/scriptgen/setup/main.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package setup 16 17 import ( 18 "bufio" 19 "bytes" 20 "fmt" 21 "io" 22 "os" 23 "strings" 24 25 ast "github.com/dolthub/vitess/go/vt/sqlparser" 26 ) 27 28 type setupSource interface { 29 Next() (bool, error) 30 Close() error 31 Data() Testdata 32 } 33 34 type Testdata struct { 35 pos string // file and line number 36 cmd string // exec, query, ... 37 Sql string 38 stmt ast.Statement 39 expected string 40 } 41 42 type SetupScript []string 43 44 type fileSetup struct { 45 path string 46 file *os.File 47 Scanner *lineScanner 48 data Testdata 49 rewrite *bytes.Buffer 50 } 51 52 func NewFileSetup(path string) (*fileSetup, error) { 53 file, err := os.Open(path) 54 if err != nil { 55 return nil, err 56 } 57 return &fileSetup{ 58 path: path, 59 file: file, 60 Scanner: newLineScanner(file), 61 rewrite: &bytes.Buffer{}, 62 }, nil 63 } 64 65 var _ setupSource = (*fileSetup)(nil) 66 67 func (f *fileSetup) Data() Testdata { 68 return f.data 69 } 70 71 func (f *fileSetup) Next() (bool, error) { 72 f.data = Testdata{} 73 for f.Scanner.Scan() { 74 line := f.Scanner.Text() 75 f.emit(line) 76 77 fields := strings.Fields(line) 78 if len(fields) == 0 { 79 continue 80 } 81 cmd := fields[0] 82 if strings.HasPrefix(cmd, "#") { 83 // Skip comment lines. 84 continue 85 } 86 f.data.pos = fmt.Sprintf("%s:%d", f.path, f.Scanner.line) 87 f.data.cmd = cmd 88 89 var buf bytes.Buffer 90 var separator bool 91 for f.Scanner.Scan() { 92 line := f.Scanner.Text() 93 if strings.TrimSpace(line) == "" { 94 break 95 } 96 97 f.emit(line) 98 if line == "----" { 99 separator = true 100 break 101 } 102 buf.WriteString(line + "\n") 103 } 104 if f.Scanner.Err() != nil { 105 return false, f.Scanner.Err() 106 } 107 108 f.data.Sql = strings.TrimSpace(buf.String()) 109 stmt, err := ast.Parse(f.data.Sql) 110 if err != nil { 111 fmt.Printf("errored at %s: \n%s", f.data.pos, f.data.Sql) 112 return false, err 113 } 114 f.data.stmt = stmt 115 116 if separator { 117 buf.Reset() 118 for f.Scanner.Scan() { 119 line := f.Scanner.Text() 120 if strings.TrimSpace(line) == "" { 121 break 122 } 123 fmt.Fprintln(&buf, line) 124 } 125 f.data.expected = buf.String() 126 } 127 return true, nil 128 } 129 return false, io.EOF 130 } 131 132 func (f *fileSetup) emit(s string) { 133 if f.rewrite != nil { 134 f.rewrite.WriteString(s) 135 f.rewrite.WriteString("\n") 136 } 137 } 138 139 func (f *fileSetup) Close() error { 140 return f.file.Close() 141 } 142 143 type lineScanner struct { 144 *bufio.Scanner 145 line int 146 } 147 148 func newLineScanner(r io.Reader) *lineScanner { 149 buf := make([]byte, 0, 64*1024) 150 s := bufio.NewScanner(r) 151 s.Buffer(buf, 1024*1024) 152 153 return &lineScanner{ 154 Scanner: s, 155 line: 0, 156 } 157 } 158 159 func (l *lineScanner) Scan() bool { 160 ok := l.Scanner.Scan() 161 if ok { 162 l.line++ 163 } 164 return ok 165 } 166 167 type stringSetup struct { 168 setup []string 169 pos int 170 data Testdata 171 } 172 173 var _ setupSource = (*stringSetup)(nil) 174 175 func NewStringSetup(s ...string) []setupSource { 176 return []setupSource{ 177 stringSetup{ 178 setup: s, 179 pos: 0, 180 data: Testdata{}, 181 }, 182 } 183 } 184 185 func (s stringSetup) Next() (bool, error) { 186 if s.pos > len(s.setup) { 187 return false, io.EOF 188 } 189 190 stmt, err := ast.Parse(s.setup[s.pos]) 191 if err != nil { 192 return false, err 193 } 194 195 d := Testdata{ 196 pos: fmt.Sprintf("line %d, query: '%s'", s.pos, s.setup[s.pos]), 197 cmd: "exec", 198 Sql: s.setup[s.pos], 199 stmt: stmt, 200 } 201 s.data = d 202 s.pos++ 203 return true, nil 204 } 205 206 func (s stringSetup) Close() error { 207 s.setup = nil 208 return nil 209 } 210 211 func (s stringSetup) Data() Testdata { 212 return s.data 213 }