vitess.io/vitess@v0.16.2/go/vt/vttablet/tabletserver/planbuilder/plan_test.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package planbuilder 18 19 import ( 20 "bufio" 21 "bytes" 22 "encoding/json" 23 "fmt" 24 "io" 25 "log" 26 "os" 27 "path/filepath" 28 "strings" 29 "testing" 30 31 "vitess.io/vitess/go/vt/vtgate/evalengine" 32 33 "github.com/stretchr/testify/require" 34 35 "vitess.io/vitess/go/vt/sqlparser" 36 "vitess.io/vitess/go/vt/tableacl" 37 "vitess.io/vitess/go/vt/vttablet/tabletserver/schema" 38 ) 39 40 // MarshalJSON returns a JSON of the given Plan. 41 // This is only for testing. 42 func (p *Plan) MarshalJSON() ([]byte, error) { 43 mplan := struct { 44 PlanID PlanType 45 TableName sqlparser.IdentifierCS `json:",omitempty"` 46 Permissions []Permission `json:",omitempty"` 47 FieldQuery *sqlparser.ParsedQuery `json:",omitempty"` 48 FullQuery *sqlparser.ParsedQuery `json:",omitempty"` 49 NextCount string `json:",omitempty"` 50 WhereClause *sqlparser.ParsedQuery `json:",omitempty"` 51 NeedsReservedConn bool `json:",omitempty"` 52 }{ 53 PlanID: p.PlanID, 54 TableName: p.TableName(), 55 Permissions: p.Permissions, 56 FullQuery: p.FullQuery, 57 WhereClause: p.WhereClause, 58 } 59 if p.NextCount != nil { 60 mplan.NextCount = evalengine.FormatExpr(p.NextCount) 61 } 62 if p.NeedsReservedConn { 63 mplan.NeedsReservedConn = true 64 } 65 return json.Marshal(&mplan) 66 } 67 68 func TestPlan(t *testing.T) { 69 testPlan(t, "exec_cases.txt") 70 } 71 72 func TestDDLPlan(t *testing.T) { 73 testPlan(t, "ddl_cases.txt") 74 } 75 76 func testPlan(t *testing.T, fileName string) { 77 t.Helper() 78 testSchema := loadSchema("schema_test.json") 79 for tcase := range iterateExecFile(fileName) { 80 t.Run(tcase.input, func(t *testing.T) { 81 if strings.Contains(tcase.options, "PassthroughDMLs") { 82 PassthroughDMLs = true 83 } 84 var plan *Plan 85 var err error 86 statement, err := sqlparser.Parse(tcase.input) 87 if err == nil { 88 plan, err = Build(statement, testSchema, "dbName", false) 89 } 90 PassthroughDMLs = false 91 92 var out string 93 if err != nil { 94 out = err.Error() 95 } else { 96 bout, err := json.Marshal(plan) 97 require.NoError(t, err, "Error marshalling %v: %v", plan, err) 98 out = string(bout) 99 } 100 if out != tcase.output { 101 t.Errorf("Line:%v\ngot = %s\nwant = %s", tcase.lineno, out, tcase.output) 102 if err != nil { 103 out = fmt.Sprintf("\"%s\"", out) 104 } else { 105 bout, _ := json.MarshalIndent(plan, "", " ") 106 out = string(bout) 107 } 108 fmt.Printf("\"in> %s\"\nout>%s\nexpected: %s\n\n", tcase.input, out, tcase.output) 109 } 110 }) 111 } 112 } 113 114 func TestPlanInReservedConn(t *testing.T) { 115 testSchema := loadSchema("schema_test.json") 116 for tcase := range iterateExecFile("exec_cases.txt") { 117 t.Run(tcase.input, func(t *testing.T) { 118 if strings.Contains(tcase.options, "PassthroughDMLs") { 119 PassthroughDMLs = true 120 } 121 var plan *Plan 122 var err error 123 statement, err := sqlparser.Parse(tcase.input) 124 if err == nil { 125 plan, err = Build(statement, testSchema, "dbName", false) 126 } 127 PassthroughDMLs = false 128 129 var out string 130 if err != nil { 131 out = err.Error() 132 } else { 133 bout, err := json.Marshal(plan) 134 if err != nil { 135 t.Fatalf("Error marshalling %v: %v", plan, err) 136 } 137 out = string(bout) 138 } 139 if out != tcase.output { 140 t.Errorf("Line:%v\ngot = %s\nwant = %s", tcase.lineno, out, tcase.output) 141 if err != nil { 142 out = fmt.Sprintf("\"%s\"", out) 143 } else { 144 bout, _ := json.MarshalIndent(plan, "", " ") 145 out = string(bout) 146 } 147 fmt.Printf("\"%s\"\n%s\n\n", tcase.input, out) 148 } 149 }) 150 } 151 } 152 153 func TestCustom(t *testing.T) { 154 testSchemas, _ := filepath.Glob("testdata/*_schema.json") 155 if len(testSchemas) == 0 { 156 t.Log("No schemas to test") 157 return 158 } 159 for _, schemFile := range testSchemas { 160 schem := loadSchema(schemFile) 161 t.Logf("Testing schema %s", schemFile) 162 files, err := filepath.Glob(strings.Replace(schemFile, "schema.json", "*.txt", -1)) 163 if err != nil { 164 log.Fatal(err) 165 } 166 if len(files) == 0 { 167 t.Fatalf("No test files for %s", schemFile) 168 } 169 for _, file := range files { 170 t.Logf("Testing file %s", file) 171 for tcase := range iterateExecFile(file) { 172 statement, err := sqlparser.Parse(tcase.input) 173 if err != nil { 174 t.Fatalf("Got error: %v, parsing sql: %v", err.Error(), tcase.input) 175 } 176 plan, err := Build(statement, schem, "dbName", false) 177 var out string 178 if err != nil { 179 out = err.Error() 180 } else { 181 bout, err := json.Marshal(plan) 182 if err != nil { 183 t.Fatalf("Error marshalling %v: %v", plan, err) 184 } 185 out = string(bout) 186 } 187 if out != tcase.output { 188 t.Errorf("File: %s: Line:%v\ngot = %s\nwant = %s", file, tcase.lineno, out, tcase.output) 189 } 190 } 191 } 192 } 193 } 194 195 func TestStreamPlan(t *testing.T) { 196 testSchema := loadSchema("schema_test.json") 197 for tcase := range iterateExecFile("stream_cases.txt") { 198 plan, err := BuildStreaming(tcase.input, testSchema) 199 var out string 200 if err != nil { 201 out = err.Error() 202 } else { 203 bout, err := json.Marshal(plan) 204 if err != nil { 205 t.Fatalf("Error marshalling %v: %v", plan, err) 206 } 207 out = string(bout) 208 } 209 if out != tcase.output { 210 t.Errorf("Line:%v\ngot = %s\nwant = %s", tcase.lineno, out, tcase.output) 211 } 212 } 213 } 214 215 func TestMessageStreamingPlan(t *testing.T) { 216 testSchema := loadSchema("schema_test.json") 217 plan, err := BuildMessageStreaming("msg", testSchema) 218 require.NoError(t, err) 219 bout, _ := json.Marshal(plan) 220 planJSON := string(bout) 221 222 wantPlan := &Plan{ 223 PlanID: PlanMessageStream, 224 Table: testSchema["msg"], 225 Permissions: []Permission{{ 226 TableName: "msg", 227 Role: tableacl.WRITER, 228 }}, 229 } 230 bout, _ = json.Marshal(wantPlan) 231 wantJSON := string(bout) 232 233 if planJSON != wantJSON { 234 t.Errorf("BuildMessageStreaming: \n%s, want\n%s", planJSON, wantJSON) 235 } 236 237 _, err = BuildMessageStreaming("absent", testSchema) 238 want := "table absent not found in schema" 239 if err == nil || err.Error() != want { 240 t.Errorf("BuildMessageStreaming(absent) error: %v, want %s", err, want) 241 } 242 243 _, err = BuildMessageStreaming("a", testSchema) 244 want = "'a' is not a message table" 245 if err == nil || err.Error() != want { 246 t.Errorf("BuildMessageStreaming(absent) error: %v, want %s", err, want) 247 } 248 } 249 250 func TestLockPlan(t *testing.T) { 251 testSchema := loadSchema("schema_test.json") 252 for tcase := range iterateExecFile("lock_cases.txt") { 253 t.Run(tcase.input, func(t *testing.T) { 254 var plan *Plan 255 var err error 256 statement, err := sqlparser.Parse(tcase.input) 257 if err == nil { 258 plan, err = Build(statement, testSchema, "dbName", false) 259 } 260 261 var out string 262 if err != nil { 263 out = err.Error() 264 } else { 265 bout, err := json.Marshal(plan) 266 if err != nil { 267 t.Fatalf("Error marshalling %v: %v", plan, err) 268 } 269 out = string(bout) 270 } 271 if out != tcase.output { 272 t.Errorf("Line:%v\ngot = %s\nwant = %s", tcase.lineno, out, tcase.output) 273 if err != nil { 274 out = fmt.Sprintf("\"%s\"", out) 275 } else { 276 bout, _ := json.MarshalIndent(plan, "", " ") 277 out = string(bout) 278 } 279 fmt.Printf("\"in> %s\"\nout>%s\nexpected: %s\n\n", tcase.input, out, tcase.output) 280 } 281 }) 282 } 283 } 284 285 func loadSchema(name string) map[string]*schema.Table { 286 b, err := os.ReadFile(locateFile(name)) 287 if err != nil { 288 panic(err) 289 } 290 tables := make([]*schema.Table, 0, 10) 291 err = json.Unmarshal(b, &tables) 292 if err != nil { 293 panic(err) 294 } 295 s := make(map[string]*schema.Table) 296 for _, t := range tables { 297 s[t.Name.String()] = t 298 } 299 return s 300 } 301 302 type testCase struct { 303 file string 304 lineno int 305 options string 306 input string 307 output string 308 } 309 310 func iterateExecFile(name string) (testCaseIterator chan testCase) { 311 name = locateFile(name) 312 fd, err := os.OpenFile(name, os.O_RDONLY, 0) 313 if err != nil { 314 panic(fmt.Sprintf("Could not open file %s", name)) 315 } 316 testCaseIterator = make(chan testCase) 317 go func() { 318 defer close(testCaseIterator) 319 320 r := bufio.NewReader(fd) 321 lineno := 0 322 options := "" 323 for { 324 binput, err := r.ReadBytes('\n') 325 if err != nil { 326 if err != io.EOF { 327 fmt.Printf("Line: %d\n", lineno) 328 panic(fmt.Errorf("Error reading file %s: %s", name, err.Error())) 329 } 330 break 331 } 332 lineno++ 333 input := string(binput) 334 if input == "" || input == "\n" || input[0] == '#' || strings.HasPrefix(input, "Length:") { 335 // fmt.Printf("%s\n", input) 336 continue 337 } 338 339 if strings.HasPrefix(input, "options:") { 340 options = input[8:] 341 continue 342 } 343 err = json.Unmarshal(binput, &input) 344 if err != nil { 345 fmt.Printf("Line: %d, input: %s\n", lineno, binput) 346 panic(err) 347 } 348 input = strings.Trim(input, "\"") 349 var output []byte 350 for { 351 l, err := r.ReadBytes('\n') 352 lineno++ 353 if err != nil { 354 fmt.Printf("Line: %d\n", lineno) 355 panic(fmt.Errorf("Error reading file %s: %s", name, err.Error())) 356 } 357 output = append(output, l...) 358 if l[0] == '}' { 359 output = output[:len(output)-1] 360 b := bytes.NewBuffer(make([]byte, 0, 64)) 361 if err := json.Compact(b, output); err == nil { 362 output = b.Bytes() 363 } 364 break 365 } 366 if l[0] == '"' { 367 output = output[1 : len(output)-2] 368 break 369 } 370 } 371 testCaseIterator <- testCase{name, lineno, options, input, string(output)} 372 options = "" 373 } 374 }() 375 return testCaseIterator 376 } 377 378 func locateFile(name string) string { 379 return "testdata/" + name 380 }