vitess.io/vitess@v0.16.2/go/vt/vtgate/planbuilder/operators/SQL_builder.go (about)

     1  /*
     2  Copyright 2022 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package operators
    18  
    19  import (
    20  	"fmt"
    21  	"sort"
    22  	"strings"
    23  
    24  	"vitess.io/vitess/go/vt/vtgate/planbuilder/operators/ops"
    25  
    26  	"vitess.io/vitess/go/vt/sqlparser"
    27  	"vitess.io/vitess/go/vt/vterrors"
    28  	"vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext"
    29  	"vitess.io/vitess/go/vt/vtgate/semantics"
    30  )
    31  
    32  type (
    33  	queryBuilder struct {
    34  		ctx        *plancontext.PlanningContext
    35  		sel        sqlparser.SelectStatement
    36  		tableNames []string
    37  	}
    38  )
    39  
    40  func ToSQL(ctx *plancontext.PlanningContext, op ops.Operator) (sqlparser.SelectStatement, error) {
    41  	q := &queryBuilder{ctx: ctx}
    42  	err := buildQuery(op, q)
    43  	if err != nil {
    44  		return nil, err
    45  	}
    46  	q.sortTables()
    47  	return q.sel, nil
    48  }
    49  
    50  func (qb *queryBuilder) addTable(db, tableName, alias string, tableID semantics.TableSet, hints sqlparser.IndexHints) {
    51  	tableExpr := sqlparser.TableName{
    52  		Name:      sqlparser.NewIdentifierCS(tableName),
    53  		Qualifier: sqlparser.NewIdentifierCS(db),
    54  	}
    55  	qb.addTableExpr(tableName, alias, tableID, tableExpr, hints, nil)
    56  }
    57  
    58  func (qb *queryBuilder) addTableExpr(
    59  	tableName, alias string,
    60  	tableID semantics.TableSet,
    61  	tblExpr sqlparser.SimpleTableExpr,
    62  	hints sqlparser.IndexHints,
    63  	columnAliases sqlparser.Columns,
    64  ) {
    65  	if qb.sel == nil {
    66  		qb.sel = &sqlparser.Select{}
    67  	}
    68  	sel := qb.sel.(*sqlparser.Select)
    69  	elems := &sqlparser.AliasedTableExpr{
    70  		Expr:       tblExpr,
    71  		Partitions: nil,
    72  		As:         sqlparser.NewIdentifierCS(alias),
    73  		Hints:      hints,
    74  		Columns:    columnAliases,
    75  	}
    76  	qb.ctx.SemTable.ReplaceTableSetFor(tableID, elems)
    77  	sel.From = append(sel.From, elems)
    78  	qb.sel = sel
    79  	qb.tableNames = append(qb.tableNames, tableName)
    80  }
    81  
    82  func (qb *queryBuilder) addPredicate(expr sqlparser.Expr) {
    83  	if _, toBeSkipped := qb.ctx.SkipPredicates[expr]; toBeSkipped {
    84  		// This is a predicate that was added to the RHS of an ApplyJoin.
    85  		// The original predicate will be added, so we don't have to add this here
    86  		return
    87  	}
    88  
    89  	sel := qb.sel.(*sqlparser.Select)
    90  	_, isSubQuery := expr.(*sqlparser.ExtractedSubquery)
    91  	var addPred func(sqlparser.Expr)
    92  
    93  	if sqlparser.ContainsAggregation(expr) && !isSubQuery {
    94  		addPred = sel.AddHaving
    95  	} else {
    96  		addPred = sel.AddWhere
    97  	}
    98  	for _, exp := range sqlparser.SplitAndExpression(nil, expr) {
    99  		addPred(exp)
   100  	}
   101  }
   102  
   103  func (qb *queryBuilder) addProjection(projection *sqlparser.AliasedExpr) {
   104  	sel := qb.sel.(*sqlparser.Select)
   105  	sel.SelectExprs = append(sel.SelectExprs, projection)
   106  }
   107  
   108  func (qb *queryBuilder) joinInnerWith(other *queryBuilder, onCondition sqlparser.Expr) {
   109  	sel := qb.sel.(*sqlparser.Select)
   110  	otherSel := other.sel.(*sqlparser.Select)
   111  	sel.From = append(sel.From, otherSel.From...)
   112  	sel.SelectExprs = append(sel.SelectExprs, otherSel.SelectExprs...)
   113  
   114  	var predicate sqlparser.Expr
   115  	if sel.Where != nil {
   116  		predicate = sel.Where.Expr
   117  	}
   118  	if otherSel.Where != nil {
   119  		predExprs := sqlparser.SplitAndExpression(nil, predicate)
   120  		otherExprs := sqlparser.SplitAndExpression(nil, otherSel.Where.Expr)
   121  		predicate = qb.ctx.SemTable.AndExpressions(append(predExprs, otherExprs...)...)
   122  	}
   123  	if predicate != nil {
   124  		sel.Where = &sqlparser.Where{Type: sqlparser.WhereClause, Expr: predicate}
   125  	}
   126  
   127  	qb.addPredicate(onCondition)
   128  }
   129  
   130  func (qb *queryBuilder) joinOuterWith(other *queryBuilder, onCondition sqlparser.Expr) {
   131  	sel := qb.sel.(*sqlparser.Select)
   132  	otherSel := other.sel.(*sqlparser.Select)
   133  	var lhs sqlparser.TableExpr
   134  	if len(sel.From) == 1 {
   135  		lhs = sel.From[0]
   136  	} else {
   137  		lhs = &sqlparser.ParenTableExpr{Exprs: sel.From}
   138  	}
   139  	var rhs sqlparser.TableExpr
   140  	if len(otherSel.From) == 1 {
   141  		rhs = otherSel.From[0]
   142  	} else {
   143  		rhs = &sqlparser.ParenTableExpr{Exprs: otherSel.From}
   144  	}
   145  	sel.From = []sqlparser.TableExpr{&sqlparser.JoinTableExpr{
   146  		LeftExpr:  lhs,
   147  		RightExpr: rhs,
   148  		Join:      sqlparser.LeftJoinType,
   149  		Condition: &sqlparser.JoinCondition{
   150  			On: onCondition,
   151  		},
   152  	}}
   153  
   154  	sel.SelectExprs = append(sel.SelectExprs, otherSel.SelectExprs...)
   155  	var predicate sqlparser.Expr
   156  	if sel.Where != nil {
   157  		predicate = sel.Where.Expr
   158  	}
   159  	if otherSel.Where != nil {
   160  		predicate = qb.ctx.SemTable.AndExpressions(predicate, otherSel.Where.Expr)
   161  	}
   162  	if predicate != nil {
   163  		sel.Where = &sqlparser.Where{Type: sqlparser.WhereClause, Expr: predicate}
   164  	}
   165  }
   166  
   167  func (qb *queryBuilder) rewriteExprForDerivedTable(expr sqlparser.Expr, dtName string) {
   168  	_ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
   169  		col, ok := node.(*sqlparser.ColName)
   170  		if !ok {
   171  			return true, nil
   172  		}
   173  		hasTable := qb.hasTable(col.Qualifier.Name.String())
   174  		if hasTable {
   175  			col.Qualifier = sqlparser.TableName{
   176  				Name: sqlparser.NewIdentifierCS(dtName),
   177  			}
   178  		}
   179  		return true, nil
   180  	}, expr)
   181  }
   182  
   183  func (qb *queryBuilder) hasTable(tableName string) bool {
   184  	for _, name := range qb.tableNames {
   185  		if strings.EqualFold(tableName, name) {
   186  			return true
   187  		}
   188  	}
   189  	return false
   190  }
   191  
   192  func (qb *queryBuilder) sortTables() {
   193  	_ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
   194  		sel, isSel := node.(*sqlparser.Select)
   195  		if !isSel {
   196  			return true, nil
   197  		}
   198  		ts := &tableSorter{
   199  			sel: sel,
   200  			tbl: qb.ctx.SemTable,
   201  		}
   202  		sort.Sort(ts)
   203  		return true, nil
   204  	}, qb.sel)
   205  
   206  }
   207  
   208  type tableSorter struct {
   209  	sel *sqlparser.Select
   210  	tbl *semantics.SemTable
   211  }
   212  
   213  // Len implements the Sort interface
   214  func (ts *tableSorter) Len() int {
   215  	return len(ts.sel.From)
   216  }
   217  
   218  // Less implements the Sort interface
   219  func (ts *tableSorter) Less(i, j int) bool {
   220  	lhs := ts.sel.From[i]
   221  	rhs := ts.sel.From[j]
   222  	left, ok := lhs.(*sqlparser.AliasedTableExpr)
   223  	if !ok {
   224  		return i < j
   225  	}
   226  	right, ok := rhs.(*sqlparser.AliasedTableExpr)
   227  	if !ok {
   228  		return i < j
   229  	}
   230  
   231  	return ts.tbl.TableSetFor(left).TableOffset() < ts.tbl.TableSetFor(right).TableOffset()
   232  }
   233  
   234  // Swap implements the Sort interface
   235  func (ts *tableSorter) Swap(i, j int) {
   236  	ts.sel.From[i], ts.sel.From[j] = ts.sel.From[j], ts.sel.From[i]
   237  }
   238  
   239  func (h *Horizon) toSQL(qb *queryBuilder) error {
   240  	err := stripDownQuery(h.Select, qb.sel)
   241  	if err != nil {
   242  		return err
   243  	}
   244  	_ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
   245  		if aliasedExpr, ok := node.(sqlparser.SelectExpr); ok {
   246  			removeKeyspaceFromSelectExpr(aliasedExpr)
   247  		}
   248  		return true, nil
   249  	}, qb.sel)
   250  	return nil
   251  }
   252  
   253  func removeKeyspaceFromSelectExpr(expr sqlparser.SelectExpr) {
   254  	switch expr := expr.(type) {
   255  	case *sqlparser.AliasedExpr:
   256  		sqlparser.RemoveKeyspaceFromColName(expr.Expr)
   257  	case *sqlparser.StarExpr:
   258  		expr.TableName.Qualifier = sqlparser.NewIdentifierCS("")
   259  	}
   260  }
   261  
   262  func stripDownQuery(from, to sqlparser.SelectStatement) error {
   263  	var err error
   264  
   265  	switch node := from.(type) {
   266  	case *sqlparser.Select:
   267  		toNode, ok := to.(*sqlparser.Select)
   268  		if !ok {
   269  			return vterrors.VT13001("AST did not match")
   270  		}
   271  		toNode.Distinct = node.Distinct
   272  		toNode.GroupBy = node.GroupBy
   273  		toNode.Having = node.Having
   274  		toNode.OrderBy = node.OrderBy
   275  		toNode.Comments = node.Comments
   276  		toNode.SelectExprs = node.SelectExprs
   277  		for _, expr := range toNode.SelectExprs {
   278  			removeKeyspaceFromSelectExpr(expr)
   279  		}
   280  	case *sqlparser.Union:
   281  		toNode, ok := to.(*sqlparser.Union)
   282  		if !ok {
   283  			return vterrors.VT13001("AST did not match")
   284  		}
   285  		err = stripDownQuery(node.Left, toNode.Left)
   286  		if err != nil {
   287  			return err
   288  		}
   289  		err = stripDownQuery(node.Right, toNode.Right)
   290  		if err != nil {
   291  			return err
   292  		}
   293  		toNode.OrderBy = node.OrderBy
   294  	default:
   295  		return vterrors.VT13001(fmt.Sprintf("this should not happen - we have covered all implementations of SelectStatement %T", from))
   296  	}
   297  	return nil
   298  }
   299  
   300  func buildQuery(op ops.Operator, qb *queryBuilder) error {
   301  	switch op := op.(type) {
   302  	case *Table:
   303  		dbName := ""
   304  
   305  		if op.QTable.IsInfSchema {
   306  			dbName = op.QTable.Table.Qualifier.String()
   307  		}
   308  		qb.addTable(dbName, op.QTable.Table.Name.String(), op.QTable.Alias.As.String(), TableID(op), op.QTable.Alias.Hints)
   309  		for _, pred := range op.QTable.Predicates {
   310  			qb.addPredicate(pred)
   311  		}
   312  		for _, name := range op.Columns {
   313  			qb.addProjection(&sqlparser.AliasedExpr{Expr: name})
   314  		}
   315  	case *ApplyJoin:
   316  		err := buildQuery(op.LHS, qb)
   317  		if err != nil {
   318  			return err
   319  		}
   320  		// If we are going to add the predicate used in join here
   321  		// We should not add the predicate's copy of when it was split into
   322  		// two parts. To avoid this, we use the SkipPredicates map.
   323  		for _, expr := range qb.ctx.JoinPredicates[op.Predicate] {
   324  			qb.ctx.SkipPredicates[expr] = nil
   325  		}
   326  		qbR := &queryBuilder{ctx: qb.ctx}
   327  		err = buildQuery(op.RHS, qbR)
   328  		if err != nil {
   329  			return err
   330  		}
   331  		if op.LeftJoin {
   332  			qb.joinOuterWith(qbR, op.Predicate)
   333  		} else {
   334  			qb.joinInnerWith(qbR, op.Predicate)
   335  		}
   336  	case *Filter:
   337  		err := buildQuery(op.Source, qb)
   338  		if err != nil {
   339  			return err
   340  		}
   341  		for _, pred := range op.Predicates {
   342  			qb.addPredicate(pred)
   343  		}
   344  	case *Derived:
   345  		err := buildQuery(op.Source, qb)
   346  		if err != nil {
   347  			return err
   348  		}
   349  		sel := qb.sel.(*sqlparser.Select) // we can only handle SELECT in derived tables at the moment
   350  		qb.sel = nil
   351  		sqlparser.RemoveKeyspace(op.Query)
   352  		opQuery := op.Query.(*sqlparser.Select)
   353  		sel.Limit = opQuery.Limit
   354  		sel.OrderBy = opQuery.OrderBy
   355  		sel.GroupBy = opQuery.GroupBy
   356  		sel.Having = mergeHaving(sel.Having, opQuery.Having)
   357  		sel.SelectExprs = opQuery.SelectExprs
   358  		qb.addTableExpr(op.Alias, op.Alias, TableID(op), &sqlparser.DerivedTable{
   359  			Select: sel,
   360  		}, nil, op.ColumnAliases)
   361  		for _, col := range op.Columns {
   362  			qb.addProjection(&sqlparser.AliasedExpr{Expr: col})
   363  		}
   364  	case *Horizon:
   365  		err := buildQuery(op.Source, qb)
   366  		if err != nil {
   367  			return err
   368  		}
   369  
   370  		err = stripDownQuery(op.Select, qb.sel)
   371  		if err != nil {
   372  			return err
   373  		}
   374  		_ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
   375  			if aliasedExpr, ok := node.(sqlparser.SelectExpr); ok {
   376  				removeKeyspaceFromSelectExpr(aliasedExpr)
   377  			}
   378  			return true, nil
   379  		}, qb.sel)
   380  		return nil
   381  
   382  	default:
   383  		return vterrors.VT13001(fmt.Sprintf("do not know how to turn %T into SQL", op))
   384  	}
   385  	return nil
   386  }
   387  
   388  func mergeHaving(h1, h2 *sqlparser.Where) *sqlparser.Where {
   389  	switch {
   390  	case h1 == nil && h2 == nil:
   391  		return nil
   392  	case h1 == nil:
   393  		return h2
   394  	case h2 == nil:
   395  		return h1
   396  	default:
   397  		h1.Expr = sqlparser.AndExpressions(h1.Expr, h2.Expr)
   398  		return h1
   399  	}
   400  }