github.com/dolthub/go-mysql-server@v0.18.0/sql/planbuilder/orderby.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package planbuilder
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  
    21  	"github.com/dolthub/vitess/go/sqltypes"
    22  	ast "github.com/dolthub/vitess/go/vt/sqlparser"
    23  
    24  	"github.com/dolthub/go-mysql-server/sql"
    25  	"github.com/dolthub/go-mysql-server/sql/expression"
    26  	"github.com/dolthub/go-mysql-server/sql/plan"
    27  	"github.com/dolthub/go-mysql-server/sql/transform"
    28  	"github.com/dolthub/go-mysql-server/sql/types"
    29  )
    30  
    31  func (b *Builder) analyzeOrderBy(fromScope, projScope *scope, order ast.OrderBy) (outScope *scope) {
    32  	// Order by resolves to
    33  	// 1) alias in projScope
    34  	// 2) column name in fromScope
    35  	// 3) index into projection scope
    36  
    37  	// if regular col, make sure in aggOut or add to extra cols
    38  
    39  	outScope = fromScope.replace()
    40  	for _, o := range order {
    41  		var descending bool
    42  		switch strings.ToLower(o.Direction) {
    43  		default:
    44  			err := errInvalidSortOrder.New(o.Direction)
    45  			b.handleErr(err)
    46  		case ast.AscScr:
    47  			descending = false
    48  		case ast.DescScr:
    49  			descending = true
    50  		}
    51  
    52  		switch e := o.Expr.(type) {
    53  		case *ast.ColName:
    54  			// check for projection alias first
    55  			dbName := strings.ToLower(e.Qualifier.Qualifier.String())
    56  			tblName := strings.ToLower(e.Qualifier.Name.String())
    57  			colName := strings.ToLower(e.Name.String())
    58  			c, ok := projScope.resolveColumn(dbName, tblName, colName, false, false)
    59  			if ok {
    60  				c.descending = descending
    61  				outScope.addColumn(c)
    62  				continue
    63  			}
    64  
    65  			// fromScope col
    66  			c, ok = fromScope.resolveColumn(dbName, tblName, colName, true, false)
    67  			if !ok {
    68  				err := sql.ErrColumnNotFound.New(e.Name)
    69  				b.handleErr(err)
    70  			}
    71  			c.descending = descending
    72  			c.scalar = c.scalarGf()
    73  			outScope.addColumn(c)
    74  			fromScope.addExtraColumn(c)
    75  		case *ast.SQLVal:
    76  			// integer literal into projScope
    77  			// else throw away
    78  			expr := b.normalizeValArg(e)
    79  			if val, ok := expr.(*ast.SQLVal); ok && val.Type == ast.IntVal {
    80  				lit := b.convertInt(string(val.Val), 10)
    81  				idx, _, err := types.Int64.Convert(lit.Value())
    82  				if err != nil {
    83  					b.handleErr(err)
    84  				}
    85  				intIdx, ok := idx.(int64)
    86  				if !ok {
    87  					b.handleErr(fmt.Errorf("expected integer order by literal"))
    88  				}
    89  				// negative intIdx is allowed in MySQL, and is treated as a no-op
    90  				if intIdx < 0 {
    91  					continue
    92  				}
    93  				if projScope == nil || len(projScope.cols) == 0 {
    94  					err := fmt.Errorf("invalid order by ordinal context")
    95  					b.handleErr(err)
    96  				}
    97  				// MySQL throws a column not found for intIdx = 0 and intIdx > len(cols)
    98  				if intIdx > int64(len(projScope.cols)) || intIdx == 0 {
    99  					err := sql.ErrColumnNotFound.New(fmt.Sprintf("%d", intIdx))
   100  					b.handleErr(err)
   101  				}
   102  				target := projScope.cols[intIdx-1]
   103  				scalar := target.scalar
   104  				if scalar == nil {
   105  					scalar = target.scalarGf()
   106  				}
   107  				if a, ok := target.scalar.(*expression.Alias); ok && a.Unreferencable() && fromScope.groupBy != nil {
   108  					for _, c := range fromScope.groupBy.outScope.cols {
   109  						if target.id == c.id {
   110  							target = c
   111  						}
   112  					}
   113  				}
   114  				outScope.addColumn(scopeColumn{
   115  					tableId:    target.tableId,
   116  					col:        target.col,
   117  					scalar:     scalar,
   118  					typ:        target.typ,
   119  					nullable:   target.nullable,
   120  					descending: descending,
   121  					id:         target.id,
   122  				})
   123  			}
   124  		default:
   125  			// track order by col
   126  			// replace aggregations with refs
   127  			// pick up auxiliary cols
   128  			expr := b.buildScalar(fromScope, e)
   129  			_, ok := outScope.getExpr(expr.String(), true)
   130  			if ok {
   131  				continue
   132  			}
   133  			// aggregate ref -> expr.String() in
   134  			// or compound expression
   135  			expr, _, _ = transform.Expr(expr, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
   136  				//  get fields outside of aggs need to be in extra cols
   137  				switch e := e.(type) {
   138  				case *expression.GetField:
   139  					c, ok := fromScope.resolveColumn("", strings.ToLower(e.Table()), strings.ToLower(e.Name()), true, false)
   140  					if !ok {
   141  						err := sql.ErrColumnNotFound.New(e.Name)
   142  						b.handleErr(err)
   143  					}
   144  					fromScope.addExtraColumn(c)
   145  				case sql.WindowAdaptableExpression:
   146  					// has to have been ref'd already
   147  					id, ok := fromScope.getExpr(e.String(), true)
   148  					if !ok {
   149  						err := fmt.Errorf("faild to ref aggregate expression: %s", e.String())
   150  						b.handleErr(err)
   151  					}
   152  					return expression.NewGetField(int(id), e.Type(), e.String(), e.IsNullable()), transform.NewTree, nil
   153  				default:
   154  				}
   155  				return e, transform.SameTree, nil
   156  			})
   157  			col := scopeColumn{
   158  				col:        expr.String(),
   159  				scalar:     expr,
   160  				typ:        expr.Type(),
   161  				nullable:   expr.IsNullable(),
   162  				descending: descending,
   163  			}
   164  			outScope.newColumn(col)
   165  		}
   166  	}
   167  	return
   168  }
   169  
   170  func (b *Builder) normalizeValArg(e *ast.SQLVal) ast.Expr {
   171  	if e.Type != ast.ValArg || b.bindCtx == nil {
   172  		return e
   173  	}
   174  	name := strings.TrimPrefix(string(e.Val), ":")
   175  	if b.bindCtx.Bindings == nil {
   176  		err := fmt.Errorf("bind variable not provided: '%s'", name)
   177  		b.handleErr(err)
   178  	}
   179  	bv, ok := b.bindCtx.GetSubstitute(name)
   180  	if !ok {
   181  		err := fmt.Errorf("bind variable not provided: '%s'", name)
   182  		b.handleErr(err)
   183  	}
   184  
   185  	val, err := sqltypes.BindVariableToValue(bv)
   186  	if err != nil {
   187  		b.handleErr(err)
   188  	}
   189  	expr, err := ast.ExprFromValue(val)
   190  	switch e := expr.(type) {
   191  	case *ast.SQLVal:
   192  		return e
   193  	case *ast.NullVal:
   194  		return e
   195  	default:
   196  		err := fmt.Errorf("unknown ast.Expr: %T", e)
   197  		b.handleErr(err)
   198  	}
   199  	return nil
   200  }
   201  
   202  func (b *Builder) buildOrderBy(inScope, orderByScope *scope) {
   203  	if len(orderByScope.cols) == 0 {
   204  		return
   205  	}
   206  	var sortFields sql.SortFields
   207  	for _, c := range orderByScope.cols {
   208  		so := sql.Ascending
   209  		if c.descending {
   210  			so = sql.Descending
   211  		}
   212  		scalar := c.scalar
   213  		if scalar == nil {
   214  			scalar = c.scalarGf()
   215  		}
   216  		sf := sql.SortField{
   217  			Column: scalar,
   218  			Order:  so,
   219  		}
   220  		sortFields = append(sortFields, sf)
   221  	}
   222  	sort := plan.NewSort(sortFields, inScope.node)
   223  	inScope.node = sort
   224  	return
   225  }