vitess.io/vitess@v0.16.2/go/vt/vtgate/planbuilder/operators/operator_test.go (about) 1 /* 2 Copyright 2021 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package operators 18 19 import ( 20 "bufio" 21 "fmt" 22 "io" 23 "os" 24 "sort" 25 "strings" 26 "testing" 27 28 "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" 29 30 "vitess.io/vitess/go/vt/vtgate/engine" 31 32 "vitess.io/vitess/go/vt/vtgate/vindexes" 33 34 "github.com/stretchr/testify/assert" 35 "github.com/stretchr/testify/require" 36 37 "vitess.io/vitess/go/vt/sqlparser" 38 "vitess.io/vitess/go/vt/vtgate/semantics" 39 ) 40 41 type lineCountingReader struct { 42 line int 43 r *bufio.Reader 44 } 45 46 func (lcr *lineCountingReader) nextLine() (string, error) { 47 queryBytes, err := lcr.r.ReadBytes('\n') 48 lcr.line++ 49 return string(queryBytes), err 50 } 51 52 func readTestCase(lcr *lineCountingReader) (testCase, error) { 53 query := "" 54 var err error 55 for query == "" || query == "\n" || strings.HasPrefix(query, "#") { 56 query, err = lcr.nextLine() 57 if err != nil { 58 return testCase{}, err 59 } 60 } 61 62 tc := testCase{query: query, line: lcr.line} 63 64 for { 65 jsonPart, err := lcr.nextLine() 66 if err != nil { 67 if err == io.EOF { 68 return tc, fmt.Errorf("test data is bad. expectation not finished") 69 } 70 return tc, err 71 } 72 if jsonPart == "}\n" { 73 tc.expected += "}" 74 break 75 } 76 tc.expected += jsonPart 77 } 78 return tc, nil 79 } 80 81 type testCase struct { 82 line int 83 query, expected string 84 } 85 86 func TestOperator(t *testing.T) { 87 fd, err := os.OpenFile("operator_test_data.txt", os.O_RDONLY, 0) 88 require.NoError(t, err) 89 r := bufio.NewReader(fd) 90 91 hash, _ := vindexes.NewHash("user_index", map[string]string{}) 92 si := &semantics.FakeSI{VindexTables: map[string]vindexes.Vindex{"user_index": hash}} 93 lcr := &lineCountingReader{r: r} 94 for { 95 tc, err := readTestCase(lcr) 96 if err == io.EOF { 97 break 98 } 99 t.Run(fmt.Sprintf("%d:%s", tc.line, tc.query), func(t *testing.T) { 100 require.NoError(t, err) 101 stmt, err := sqlparser.Parse(tc.query) 102 require.NoError(t, err) 103 semTable, err := semantics.Analyze(stmt, "", si) 104 require.NoError(t, err) 105 ctx := plancontext.NewPlanningContext(nil, semTable, nil, 0) 106 optree, err := createLogicalOperatorFromAST(ctx, stmt) 107 require.NoError(t, err) 108 optree, err = Compact(ctx, optree) 109 require.NoError(t, err) 110 output := testString(optree) 111 assert.Equal(t, tc.expected, output) 112 if t.Failed() { 113 fmt.Println(output) 114 } 115 }) 116 } 117 } 118 119 func testString(op interface{}) string { // TODO 120 switch op := op.(type) { 121 case *QueryGraph: 122 return fmt.Sprintf("QueryGraph: %s", op.testString()) 123 case *Join: 124 leftStr := indent(testString(op.LHS)) 125 rightStr := indent(testString(op.RHS)) 126 if op.LeftJoin { 127 return fmt.Sprintf("OuterJoin: {\n\tInner: %s\n\tOuter: %s\n\tPredicate: %s\n}", leftStr, rightStr, sqlparser.String(op.Predicate)) 128 } 129 return fmt.Sprintf("Join: {\n\tLHS: %s\n\tRHS: %s\n\tPredicate: %s\n}", leftStr, rightStr, sqlparser.String(op.Predicate)) 130 case *Derived: 131 inner := indent(testString(op.Source)) 132 query := sqlparser.String(op.Query) 133 return fmt.Sprintf("Derived %s: {\n\tQuery: %s\n\tInner:%s\n}", op.Alias, query, inner) 134 case *SubQuery: 135 var inners []string 136 for _, sqOp := range op.Inner { 137 subquery := fmt.Sprintf("{\n\tType: %s", engine.PulloutOpcode(sqOp.ExtractedSubquery.OpCode).String()) 138 if sqOp.ExtractedSubquery.GetArgName() != "" { 139 subquery += fmt.Sprintf("\n\tArgName: %s", sqOp.ExtractedSubquery.GetArgName()) 140 } 141 subquery += fmt.Sprintf("\n\tQuery: %s\n}", indent(testString(sqOp.Inner))) 142 subquery = indent(subquery) 143 inners = append(inners, subquery) 144 } 145 outer := indent(testString(op.Outer)) 146 join := strings.Join(inners, "\n") 147 sprintf := fmt.Sprintf("SubQuery: {\n\tSubQueries: [\n%s]\n\tOuter: %s\n}", join, outer) 148 return sprintf 149 case *Vindex: 150 value := sqlparser.String(op.Value) 151 return fmt.Sprintf("Vindex: {\n\tName: %s\n\tValue: %s\n}", op.Vindex.String(), value) 152 case *Union: 153 var inners []string 154 for _, source := range op.Sources { 155 inners = append(inners, indent(testString(source))) 156 } 157 if len(op.Ordering) > 0 { 158 inners = append(inners, indent(sqlparser.String(op.Ordering)[1:])) 159 } 160 dist := "" 161 if op.Distinct { 162 dist = "(distinct)" 163 } 164 return fmt.Sprintf("Concatenate%s {\n%s\n}", dist, strings.Join(inners, ",\n")) 165 case *Update: 166 tbl := "table: " + op.QTable.testString() 167 var assignments []string 168 // sort to produce stable results, otherwise test is flaky 169 keys := make([]string, 0, len(op.Assignments)) 170 for k := range op.Assignments { 171 keys = append(keys, k) 172 } 173 sort.Strings(keys) 174 for _, k := range keys { 175 assignments = append(assignments, fmt.Sprintf("\t%s = %s", k, sqlparser.String(op.Assignments[k]))) 176 } 177 return fmt.Sprintf("Update {\n\t%s\nassignments:\n%s\n}", tbl, strings.Join(assignments, "\n")) 178 case *Horizon: 179 src := indent(testString(op.Source)) 180 return fmt.Sprintf("Horizon {\n\tQuery: \"%s\"\n\tInner:%s\n}", sqlparser.String(op.Select), src) 181 } 182 panic(fmt.Sprintf("%T", op)) 183 } 184 185 func indent(s string) string { 186 lines := strings.Split(s, "\n") 187 for i, line := range lines { 188 lines[i] = "\t" + line 189 } 190 return strings.Join(lines, "\n") 191 } 192 193 // the following code is only used by tests 194 195 func (qt *QueryTable) testString() string { 196 var alias string 197 if !qt.Alias.As.IsEmpty() { 198 alias = " AS " + sqlparser.String(qt.Alias.As) 199 } 200 var preds []string 201 for _, predicate := range qt.Predicates { 202 preds = append(preds, sqlparser.String(predicate)) 203 } 204 var where string 205 if len(preds) > 0 { 206 where = " where " + strings.Join(preds, " and ") 207 } 208 209 return fmt.Sprintf("\t%v:%s%s%s", qt.ID, sqlparser.String(qt.Table), alias, where) 210 } 211 212 func (qg *QueryGraph) testString() string { 213 return fmt.Sprintf(`{ 214 Tables: 215 %s%s%s 216 }`, strings.Join(qg.tableNames(), "\n"), qg.crossPredicateString(), qg.noDepsString()) 217 } 218 219 func (qg *QueryGraph) crossPredicateString() string { 220 if len(qg.innerJoins) == 0 { 221 return "" 222 } 223 var joinPreds []string 224 for _, join := range qg.innerJoins { 225 deps, predicates := join.deps, join.exprs 226 var expressions []string 227 for _, expr := range predicates { 228 expressions = append(expressions, sqlparser.String(expr)) 229 } 230 231 exprConcat := strings.Join(expressions, " and ") 232 joinPreds = append(joinPreds, fmt.Sprintf("\t%v - %s", deps, exprConcat)) 233 } 234 sort.Strings(joinPreds) 235 return fmt.Sprintf("\nJoinPredicates:\n%s", strings.Join(joinPreds, "\n")) 236 } 237 238 func (qg *QueryGraph) tableNames() []string { 239 var tables []string 240 for _, t := range qg.Tables { 241 tables = append(tables, t.testString()) 242 } 243 return tables 244 } 245 246 func (qg *QueryGraph) noDepsString() string { 247 if qg.NoDeps == nil { 248 return "" 249 } 250 return fmt.Sprintf("\nForAll: %s", sqlparser.String(qg.NoDeps)) 251 }