github.com/dolthub/go-mysql-server@v0.18.0/sql/planbuilder/select.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package planbuilder
    16  
    17  import (
    18  	"fmt"
    19  
    20  	ast "github.com/dolthub/vitess/go/vt/sqlparser"
    21  
    22  	"github.com/dolthub/go-mysql-server/sql"
    23  	"github.com/dolthub/go-mysql-server/sql/expression"
    24  	"github.com/dolthub/go-mysql-server/sql/mysql_db"
    25  	"github.com/dolthub/go-mysql-server/sql/plan"
    26  	"github.com/dolthub/go-mysql-server/sql/types"
    27  )
    28  
    29  func (b *Builder) buildSelectStmt(inScope *scope, s ast.SelectStatement) (outScope *scope) {
    30  	switch s := s.(type) {
    31  	case *ast.Select:
    32  		if s.With != nil {
    33  			cteScope := b.buildWith(inScope, s.With)
    34  			return b.buildSelect(cteScope, s)
    35  		}
    36  		return b.buildSelect(inScope, s)
    37  	case *ast.SetOp:
    38  		if s.With != nil {
    39  			cteScope := b.buildWith(inScope, s.With)
    40  			return b.buildSetOp(cteScope, s)
    41  		}
    42  		return b.buildSetOp(inScope, s)
    43  	case *ast.ParenSelect:
    44  		return b.buildSelectStmt(inScope, s.Select)
    45  	default:
    46  		b.handleErr(fmt.Errorf("unknown select statement %T", s))
    47  	}
    48  	return
    49  }
    50  
    51  func (b *Builder) buildSelect(inScope *scope, s *ast.Select) (outScope *scope) {
    52  	// General order of binding:
    53  	// 1) Get definitions in FROM.
    54  	// 2) Build WHERE, which can only reference FROM columns.
    55  	// 3) Bookkeep aggregation/window function usage in higher-scopes
    56  	//    (GROUP BY, WINDOW, HAVING, SELECT, ORDER BY).
    57  	// 4) Construct either i) aggregation, ii) window, or iii) projection over
    58  	//    FROM clause providing expressions used in (2) (including aliases).
    59  	// 5) Build top-level scopes, replacing aggregation and aliases with
    60  	//    projections from (4).
    61  	// 6) Finish with final target projections.
    62  	fromScope := b.buildFrom(inScope, s.From)
    63  	if cn, ok := fromScope.node.(sql.CommentedNode); ok && len(s.Comments) > 0 {
    64  		fromScope.node = cn.WithComment(string(s.Comments[0]))
    65  	}
    66  
    67  	// Resolve and fold named window definitions
    68  	b.buildNamedWindows(fromScope, s.Window)
    69  
    70  	b.buildWhere(fromScope, s.Where)
    71  	// select *, (SELECT t2.i) from t1 left join using t2 on i;
    72  	// select t1.*, t2.*, t2.* from ...
    73  	//
    74  	projScope := fromScope.push()
    75  
    76  	// Aggregates in select list added to fromScope.groupBy.outCols.
    77  	// Args to aggregates added to fromScope.groupBy.inCols.
    78  	b.analyzeProjectionList(fromScope, projScope, s.SelectExprs)
    79  
    80  	// Find aggregations in order by
    81  	orderByScope := b.analyzeOrderBy(fromScope, projScope, s.OrderBy)
    82  
    83  	// Find aggregations in having
    84  	b.analyzeHaving(fromScope, projScope, s.Having)
    85  
    86  	// At this point we've recorded dependencies for higher-level scopes,
    87  	// so we can build the FROM clause
    88  	if b.needsAggregation(fromScope, s) {
    89  		groupingCols := b.buildGroupingCols(fromScope, projScope, s.GroupBy, s.SelectExprs)
    90  		outScope = b.buildAggregation(fromScope, projScope, groupingCols)
    91  	} else if fromScope.windowFuncs != nil {
    92  		outScope = b.buildWindow(fromScope, projScope)
    93  	} else {
    94  		outScope = b.buildInnerProj(fromScope, projScope)
    95  	}
    96  
    97  	// At this point, we've combined table relations, performed aggregations,
    98  	// and projected aliases used in higher-level clauses. Aliases and agg
    99  	// expressions in higher level scopes will be replaced with GetField
   100  	// references.
   101  
   102  	b.buildHaving(fromScope, projScope, outScope, s.Having)
   103  
   104  	b.buildOrderBy(outScope, orderByScope)
   105  
   106  	// Last level projection restricts outputs to target projections.
   107  	b.buildProjection(outScope, projScope)
   108  	outScope = projScope
   109  
   110  	b.buildDistinct(outScope, s.QueryOpts.Distinct)
   111  
   112  	// OFFSET and LIMIT are last
   113  	offset := b.buildOffset(outScope, s.Limit)
   114  	if offset != nil {
   115  		outScope.node = plan.NewOffset(offset, outScope.node)
   116  	}
   117  	limit := b.buildLimit(outScope, s.Limit)
   118  	if limit != nil {
   119  		l := plan.NewLimit(limit, outScope.node)
   120  		l.CalcFoundRows = s.QueryOpts.SQLCalcFoundRows
   121  		outScope.node = l
   122  	}
   123  
   124  	return
   125  }
   126  
   127  func (b *Builder) buildLimit(inScope *scope, limit *ast.Limit) sql.Expression {
   128  	if limit != nil {
   129  		return b.buildLimitVal(inScope, limit.Rowcount)
   130  	}
   131  	return nil
   132  }
   133  
   134  func (b *Builder) buildOffset(inScope *scope, limit *ast.Limit) sql.Expression {
   135  	if limit != nil && limit.Offset != nil {
   136  		e := b.buildLimitVal(inScope, limit.Offset)
   137  		if lit, ok := e.(*expression.Literal); ok {
   138  			// Check if offset starts at 0, if so, we can just remove the offset node.
   139  			// Only cast to int8, as a larger int type just means a non-zero offset.
   140  			if val, err := lit.Eval(b.ctx, nil); err == nil {
   141  				if v, ok := val.(int64); ok && v == 0 {
   142  					return nil
   143  				}
   144  			}
   145  		}
   146  		return e
   147  	}
   148  	return nil
   149  }
   150  
   151  // buildLimitVal resolves a literal numeric type or a numeric
   152  // prodecure parameter
   153  func (b *Builder) buildLimitVal(inScope *scope, e ast.Expr) sql.Expression {
   154  	switch e := e.(type) {
   155  	case *ast.ColName:
   156  		if inScope.procActive() {
   157  			if col, ok := inScope.proc.GetVar(e.String()); ok {
   158  				// proc param is OK
   159  				if pp, ok := col.scalarGf().(*expression.ProcedureParam); ok {
   160  					if !pp.Type().Promote().Equals(types.Int64) && !pp.Type().Promote().Equals(types.Uint64) {
   161  						err := fmt.Errorf("the variable '%s' has a non-integer based type: %s", pp.Name(), pp.Type().String())
   162  						b.handleErr(err)
   163  					}
   164  					return pp
   165  				}
   166  			}
   167  		}
   168  		err := fmt.Errorf("limit expression expected to be numeric or prodecure parameter, found invalid column: %s", e.String())
   169  		b.handleErr(err)
   170  	default:
   171  		l := b.buildScalar(inScope, e)
   172  		return b.typeCoerceLiteral(l)
   173  	}
   174  	return nil
   175  }
   176  
   177  func (b *Builder) typeCoerceLiteral(e sql.Expression) sql.Expression {
   178  	// todo this should be in a module that can generically coerce to a type or type class
   179  	switch e := e.(type) {
   180  	case *expression.Literal:
   181  		val, _, err := types.Int64.Convert(e.Value())
   182  		if err != nil {
   183  			err = fmt.Errorf("%s: %w", err.Error(), sql.ErrInvalidTypeForLimit.New(types.Int64, e.Type()))
   184  		}
   185  		return expression.NewLiteral(val, types.Int64)
   186  	case *expression.BindVar:
   187  		return e
   188  	default:
   189  		err := sql.ErrInvalidTypeForLimit.New(expression.Literal{}, e)
   190  		b.handleErr(err)
   191  	}
   192  	return nil
   193  }
   194  
   195  // buildDistinct creates a new plan.Distinct node if the query has a DISTINCT option.
   196  // If the query has both DISTINCT and ALL, an error is returned.
   197  func (b *Builder) buildDistinct(inScope *scope, distinct bool) {
   198  	if distinct {
   199  		inScope.node = plan.NewDistinct(inScope.node)
   200  	}
   201  }
   202  
   203  func (b *Builder) currentDb() sql.Database {
   204  	if b.currentDatabase == nil {
   205  		if b.ctx.GetCurrentDatabase() == "" {
   206  			err := sql.ErrNoDatabaseSelected.New()
   207  			b.handleErr(err)
   208  		}
   209  		database, err := b.cat.Database(b.ctx, b.ctx.GetCurrentDatabase())
   210  		if err != nil {
   211  			b.handleErr(err)
   212  		}
   213  
   214  		if privilegedDatabase, ok := database.(mysql_db.PrivilegedDatabase); ok {
   215  			database = privilegedDatabase.Unwrap()
   216  		}
   217  		b.currentDatabase = database
   218  	}
   219  	return b.currentDatabase
   220  }
   221  
   222  func (b *Builder) renameSource(scope *scope, table string, cols []string) {
   223  	if table != "" {
   224  		scope.setTableAlias(table)
   225  	}
   226  	if len(cols) > 0 {
   227  		scope.setColAlias(cols)
   228  	}
   229  	for i, c := range scope.cols {
   230  		c.scalar = nil
   231  		scope.cols[i] = c
   232  	}
   233  }