vitess.io/vitess@v0.16.2/go/vt/vtgate/planbuilder/operators/operator_test.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package operators
    18  
    19  import (
    20  	"bufio"
    21  	"fmt"
    22  	"io"
    23  	"os"
    24  	"sort"
    25  	"strings"
    26  	"testing"
    27  
    28  	"vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext"
    29  
    30  	"vitess.io/vitess/go/vt/vtgate/engine"
    31  
    32  	"vitess.io/vitess/go/vt/vtgate/vindexes"
    33  
    34  	"github.com/stretchr/testify/assert"
    35  	"github.com/stretchr/testify/require"
    36  
    37  	"vitess.io/vitess/go/vt/sqlparser"
    38  	"vitess.io/vitess/go/vt/vtgate/semantics"
    39  )
    40  
    41  type lineCountingReader struct {
    42  	line int
    43  	r    *bufio.Reader
    44  }
    45  
    46  func (lcr *lineCountingReader) nextLine() (string, error) {
    47  	queryBytes, err := lcr.r.ReadBytes('\n')
    48  	lcr.line++
    49  	return string(queryBytes), err
    50  }
    51  
    52  func readTestCase(lcr *lineCountingReader) (testCase, error) {
    53  	query := ""
    54  	var err error
    55  	for query == "" || query == "\n" || strings.HasPrefix(query, "#") {
    56  		query, err = lcr.nextLine()
    57  		if err != nil {
    58  			return testCase{}, err
    59  		}
    60  	}
    61  
    62  	tc := testCase{query: query, line: lcr.line}
    63  
    64  	for {
    65  		jsonPart, err := lcr.nextLine()
    66  		if err != nil {
    67  			if err == io.EOF {
    68  				return tc, fmt.Errorf("test data is bad. expectation not finished")
    69  			}
    70  			return tc, err
    71  		}
    72  		if jsonPart == "}\n" {
    73  			tc.expected += "}"
    74  			break
    75  		}
    76  		tc.expected += jsonPart
    77  	}
    78  	return tc, nil
    79  }
    80  
    81  type testCase struct {
    82  	line            int
    83  	query, expected string
    84  }
    85  
    86  func TestOperator(t *testing.T) {
    87  	fd, err := os.OpenFile("operator_test_data.txt", os.O_RDONLY, 0)
    88  	require.NoError(t, err)
    89  	r := bufio.NewReader(fd)
    90  
    91  	hash, _ := vindexes.NewHash("user_index", map[string]string{})
    92  	si := &semantics.FakeSI{VindexTables: map[string]vindexes.Vindex{"user_index": hash}}
    93  	lcr := &lineCountingReader{r: r}
    94  	for {
    95  		tc, err := readTestCase(lcr)
    96  		if err == io.EOF {
    97  			break
    98  		}
    99  		t.Run(fmt.Sprintf("%d:%s", tc.line, tc.query), func(t *testing.T) {
   100  			require.NoError(t, err)
   101  			stmt, err := sqlparser.Parse(tc.query)
   102  			require.NoError(t, err)
   103  			semTable, err := semantics.Analyze(stmt, "", si)
   104  			require.NoError(t, err)
   105  			ctx := plancontext.NewPlanningContext(nil, semTable, nil, 0)
   106  			optree, err := createLogicalOperatorFromAST(ctx, stmt)
   107  			require.NoError(t, err)
   108  			optree, err = Compact(ctx, optree)
   109  			require.NoError(t, err)
   110  			output := testString(optree)
   111  			assert.Equal(t, tc.expected, output)
   112  			if t.Failed() {
   113  				fmt.Println(output)
   114  			}
   115  		})
   116  	}
   117  }
   118  
   119  func testString(op interface{}) string { // TODO
   120  	switch op := op.(type) {
   121  	case *QueryGraph:
   122  		return fmt.Sprintf("QueryGraph: %s", op.testString())
   123  	case *Join:
   124  		leftStr := indent(testString(op.LHS))
   125  		rightStr := indent(testString(op.RHS))
   126  		if op.LeftJoin {
   127  			return fmt.Sprintf("OuterJoin: {\n\tInner: %s\n\tOuter: %s\n\tPredicate: %s\n}", leftStr, rightStr, sqlparser.String(op.Predicate))
   128  		}
   129  		return fmt.Sprintf("Join: {\n\tLHS: %s\n\tRHS: %s\n\tPredicate: %s\n}", leftStr, rightStr, sqlparser.String(op.Predicate))
   130  	case *Derived:
   131  		inner := indent(testString(op.Source))
   132  		query := sqlparser.String(op.Query)
   133  		return fmt.Sprintf("Derived %s: {\n\tQuery: %s\n\tInner:%s\n}", op.Alias, query, inner)
   134  	case *SubQuery:
   135  		var inners []string
   136  		for _, sqOp := range op.Inner {
   137  			subquery := fmt.Sprintf("{\n\tType: %s", engine.PulloutOpcode(sqOp.ExtractedSubquery.OpCode).String())
   138  			if sqOp.ExtractedSubquery.GetArgName() != "" {
   139  				subquery += fmt.Sprintf("\n\tArgName: %s", sqOp.ExtractedSubquery.GetArgName())
   140  			}
   141  			subquery += fmt.Sprintf("\n\tQuery: %s\n}", indent(testString(sqOp.Inner)))
   142  			subquery = indent(subquery)
   143  			inners = append(inners, subquery)
   144  		}
   145  		outer := indent(testString(op.Outer))
   146  		join := strings.Join(inners, "\n")
   147  		sprintf := fmt.Sprintf("SubQuery: {\n\tSubQueries: [\n%s]\n\tOuter: %s\n}", join, outer)
   148  		return sprintf
   149  	case *Vindex:
   150  		value := sqlparser.String(op.Value)
   151  		return fmt.Sprintf("Vindex: {\n\tName: %s\n\tValue: %s\n}", op.Vindex.String(), value)
   152  	case *Union:
   153  		var inners []string
   154  		for _, source := range op.Sources {
   155  			inners = append(inners, indent(testString(source)))
   156  		}
   157  		if len(op.Ordering) > 0 {
   158  			inners = append(inners, indent(sqlparser.String(op.Ordering)[1:]))
   159  		}
   160  		dist := ""
   161  		if op.Distinct {
   162  			dist = "(distinct)"
   163  		}
   164  		return fmt.Sprintf("Concatenate%s {\n%s\n}", dist, strings.Join(inners, ",\n"))
   165  	case *Update:
   166  		tbl := "table: " + op.QTable.testString()
   167  		var assignments []string
   168  		// sort to produce stable results, otherwise test is flaky
   169  		keys := make([]string, 0, len(op.Assignments))
   170  		for k := range op.Assignments {
   171  			keys = append(keys, k)
   172  		}
   173  		sort.Strings(keys)
   174  		for _, k := range keys {
   175  			assignments = append(assignments, fmt.Sprintf("\t%s = %s", k, sqlparser.String(op.Assignments[k])))
   176  		}
   177  		return fmt.Sprintf("Update {\n\t%s\nassignments:\n%s\n}", tbl, strings.Join(assignments, "\n"))
   178  	case *Horizon:
   179  		src := indent(testString(op.Source))
   180  		return fmt.Sprintf("Horizon {\n\tQuery: \"%s\"\n\tInner:%s\n}", sqlparser.String(op.Select), src)
   181  	}
   182  	panic(fmt.Sprintf("%T", op))
   183  }
   184  
   185  func indent(s string) string {
   186  	lines := strings.Split(s, "\n")
   187  	for i, line := range lines {
   188  		lines[i] = "\t" + line
   189  	}
   190  	return strings.Join(lines, "\n")
   191  }
   192  
   193  // the following code is only used by tests
   194  
   195  func (qt *QueryTable) testString() string {
   196  	var alias string
   197  	if !qt.Alias.As.IsEmpty() {
   198  		alias = " AS " + sqlparser.String(qt.Alias.As)
   199  	}
   200  	var preds []string
   201  	for _, predicate := range qt.Predicates {
   202  		preds = append(preds, sqlparser.String(predicate))
   203  	}
   204  	var where string
   205  	if len(preds) > 0 {
   206  		where = " where " + strings.Join(preds, " and ")
   207  	}
   208  
   209  	return fmt.Sprintf("\t%v:%s%s%s", qt.ID, sqlparser.String(qt.Table), alias, where)
   210  }
   211  
   212  func (qg *QueryGraph) testString() string {
   213  	return fmt.Sprintf(`{
   214  Tables:
   215  %s%s%s
   216  }`, strings.Join(qg.tableNames(), "\n"), qg.crossPredicateString(), qg.noDepsString())
   217  }
   218  
   219  func (qg *QueryGraph) crossPredicateString() string {
   220  	if len(qg.innerJoins) == 0 {
   221  		return ""
   222  	}
   223  	var joinPreds []string
   224  	for _, join := range qg.innerJoins {
   225  		deps, predicates := join.deps, join.exprs
   226  		var expressions []string
   227  		for _, expr := range predicates {
   228  			expressions = append(expressions, sqlparser.String(expr))
   229  		}
   230  
   231  		exprConcat := strings.Join(expressions, " and ")
   232  		joinPreds = append(joinPreds, fmt.Sprintf("\t%v - %s", deps, exprConcat))
   233  	}
   234  	sort.Strings(joinPreds)
   235  	return fmt.Sprintf("\nJoinPredicates:\n%s", strings.Join(joinPreds, "\n"))
   236  }
   237  
   238  func (qg *QueryGraph) tableNames() []string {
   239  	var tables []string
   240  	for _, t := range qg.Tables {
   241  		tables = append(tables, t.testString())
   242  	}
   243  	return tables
   244  }
   245  
   246  func (qg *QueryGraph) noDepsString() string {
   247  	if qg.NoDeps == nil {
   248  		return ""
   249  	}
   250  	return fmt.Sprintf("\nForAll: %s", sqlparser.String(qg.NoDeps))
   251  }