github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/proc.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package rowexec
    16  
    17  import (
    18  	"errors"
    19  	"fmt"
    20  	"io"
    21  	"strings"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  	"github.com/dolthub/go-mysql-server/sql/expression"
    25  	"github.com/dolthub/go-mysql-server/sql/plan"
    26  )
    27  
    28  func (b *BaseBuilder) buildCaseStatement(ctx *sql.Context, n *plan.CaseStatement, row sql.Row) (sql.RowIter, error) {
    29  	caseValue, err := n.Expr.Eval(ctx, row)
    30  	if err != nil {
    31  		return nil, err
    32  	}
    33  
    34  	for _, ifConditional := range n.IfElse.IfConditionals {
    35  		whenValue, err := ifConditional.Condition.Eval(ctx, row)
    36  		if err != nil {
    37  			return nil, err
    38  		}
    39  		comparison, err := n.Expr.Type().Compare(caseValue, whenValue)
    40  		if err != nil {
    41  			return nil, err
    42  		}
    43  		if comparison != 0 {
    44  			continue
    45  		}
    46  
    47  		return b.buildCaseIter(ctx, row, ifConditional, ifConditional.Body)
    48  	}
    49  
    50  	// All conditions failed so we run the else
    51  	return b.buildCaseIter(ctx, row, n.IfElse.Else, n.IfElse.Else)
    52  }
    53  
    54  func (b *BaseBuilder) buildCaseIter(ctx *sql.Context, row sql.Row, iterNode sql.Node, bodyNode sql.Node) (sql.RowIter, error) {
    55  	// All conditions failed so we run the else
    56  	branchIter, err := b.buildNodeExec(ctx, iterNode, row)
    57  	if err != nil {
    58  		return nil, err
    59  	}
    60  	// If the branchIter is already a block iter, then we don't need to construct our own, as its contained
    61  	// node and schema will be a better representation of the iterated rows.
    62  	if blockRowIter, ok := branchIter.(plan.BlockRowIter); ok {
    63  		return blockRowIter, nil
    64  	}
    65  	return &ifElseIter{
    66  		branchIter: branchIter,
    67  		sch:        bodyNode.Schema(),
    68  		branchNode: bodyNode,
    69  	}, nil
    70  }
    71  
    72  func (b *BaseBuilder) buildIfElseBlock(ctx *sql.Context, n *plan.IfElseBlock, row sql.Row) (sql.RowIter, error) {
    73  	var branchIter sql.RowIter
    74  
    75  	var err error
    76  	for _, ifConditional := range n.IfConditionals {
    77  		condition, err := ifConditional.Condition.Eval(ctx, row)
    78  		if err != nil {
    79  			return nil, err
    80  		}
    81  		var passedCondition bool
    82  		if condition != nil {
    83  			passedCondition, err = sql.ConvertToBool(ctx, condition)
    84  			if err != nil {
    85  				return nil, err
    86  			}
    87  		}
    88  		if !passedCondition {
    89  			continue
    90  		}
    91  
    92  		// TODO: this should happen at iteration time, but this call is where the actual iteration happens
    93  		err = startTransaction(ctx)
    94  		if err != nil {
    95  			return nil, err
    96  		}
    97  
    98  		branchIter, err = b.buildNodeExec(ctx, ifConditional, row)
    99  		if err != nil {
   100  			return nil, err
   101  		}
   102  		// If the branchIter is already a block iter, then we don't need to construct our own, as its contained
   103  		// node and schema will be a better representation of the iterated rows.
   104  		if blockRowIter, ok := branchIter.(plan.BlockRowIter); ok {
   105  			return blockRowIter, nil
   106  		}
   107  		return &ifElseIter{
   108  			branchIter: branchIter,
   109  			sch:        ifConditional.Body.Schema(),
   110  			branchNode: ifConditional.Body,
   111  		}, nil
   112  	}
   113  
   114  	// TODO: this should happen at iteration time, but this call is where the actual iteration happens
   115  	err = startTransaction(ctx)
   116  	if err != nil {
   117  		return nil, err
   118  	}
   119  
   120  	// All conditions failed so we run the else
   121  	branchIter, err = b.buildNodeExec(ctx, n.Else, row)
   122  	if err != nil {
   123  		return nil, err
   124  	}
   125  	// If the branchIter is already a block iter, then we don't need to construct our own, as its contained
   126  	// node and schema will be a better representation of the iterated rows.
   127  	if blockRowIter, ok := branchIter.(plan.BlockRowIter); ok {
   128  		return blockRowIter, nil
   129  	}
   130  	return &ifElseIter{
   131  		branchIter: branchIter,
   132  		sch:        n.Else.Schema(),
   133  		branchNode: n.Else,
   134  	}, nil
   135  }
   136  
   137  func (b *BaseBuilder) buildBeginEndBlock(ctx *sql.Context, n *plan.BeginEndBlock, row sql.Row) (sql.RowIter, error) {
   138  	n.Pref.PushScope()
   139  	rowIter, err := b.buildNodeExec(ctx, n.Block, row)
   140  	if err != nil {
   141  		if controlFlow, ok := err.(loopError); ok && strings.ToLower(controlFlow.Label) == strings.ToLower(n.Label) {
   142  			if controlFlow.IsExit {
   143  				err = nil
   144  			} else {
   145  				err = fmt.Errorf("encountered ITERATE on BEGIN...END, which should should have been caught by the analyzer")
   146  			}
   147  		} else {
   148  			scope := n.Pref.InnermostScope
   149  			for i := len(scope.Handlers) - 1; i >= 0; i-- {
   150  				if !scope.Handlers[i].Cond.Matches(err) {
   151  					continue
   152  				}
   153  				originalScope := n.Pref.InnermostScope
   154  				defer func() {
   155  					n.Pref.InnermostScope = originalScope
   156  				}()
   157  				n.Pref.InnermostScope = scope
   158  				handlerRefVal := scope.Handlers[i]
   159  
   160  				handlerRowIter, err := b.buildNodeExec(ctx, handlerRefVal.Stmt, nil)
   161  				if err != nil {
   162  					return sql.RowsToRowIter(), err
   163  				}
   164  				defer handlerRowIter.Close(ctx)
   165  
   166  				for {
   167  					_, err := handlerRowIter.Next(ctx)
   168  					if err == io.EOF {
   169  						break
   170  					} else if err != nil {
   171  						return sql.RowsToRowIter(), err
   172  					}
   173  				}
   174  				if scope.Handlers[i].Action == expression.DeclareHandlerAction_Exit {
   175  					return sql.RowsToRowIter(), nil
   176  				}
   177  				return sql.RowsToRowIter(), io.EOF
   178  			}
   179  		}
   180  		if errors.Is(err, io.EOF) {
   181  			return sql.RowsToRowIter(), nil
   182  		}
   183  		if nErr := n.Pref.PopScope(ctx); err == nil && nErr != nil {
   184  			err = nErr
   185  		}
   186  		if errors.Is(err, expression.FetchEOF) && n.Pref.CurrentHeight() == 1 {
   187  			// Don't return the fetch error in the first BEGIN block, though MySQL returns:
   188  			// ERROR 1329 (02000): No data - zero rows fetched, selected, or processed
   189  			return sql.RowsToRowIter(), nil
   190  		}
   191  		return sql.RowsToRowIter(), err
   192  	}
   193  	return &beginEndIter{
   194  		BeginEndBlock: n,
   195  		rowIter:       rowIter,
   196  	}, nil
   197  }
   198  
   199  func (b *BaseBuilder) buildIfConditional(ctx *sql.Context, n *plan.IfConditional, row sql.Row) (sql.RowIter, error) {
   200  	return b.buildNodeExec(ctx, n.Body, row)
   201  }
   202  
   203  func (b *BaseBuilder) buildProcedureResolvedTable(ctx *sql.Context, n *plan.ProcedureResolvedTable, row sql.Row) (sql.RowIter, error) {
   204  	rt, err := n.NewestTable(ctx)
   205  	if err != nil {
   206  		return nil, err
   207  	}
   208  	return b.buildResolvedTable(ctx, rt, row)
   209  }
   210  
   211  func (b *BaseBuilder) buildCall(ctx *sql.Context, n *plan.Call, row sql.Row) (sql.RowIter, error) {
   212  	for i, paramExpr := range n.Params {
   213  		val, err := paramExpr.Eval(ctx, row)
   214  		if err != nil {
   215  			return nil, err
   216  		}
   217  		paramName := n.Procedure.Params[i].Name
   218  		paramType := n.Procedure.Params[i].Type
   219  		err = n.Pref.InitializeVariable(paramName, paramType, val)
   220  		if err != nil {
   221  			return nil, err
   222  		}
   223  	}
   224  
   225  	n.Pref.PushScope()
   226  	defer n.Pref.PopScope(ctx)
   227  
   228  	innerIter, err := b.buildNodeExec(ctx, n.Procedure, row)
   229  	if err != nil {
   230  		return nil, err
   231  	}
   232  	return &callIter{
   233  		call:      n,
   234  		innerIter: innerIter,
   235  	}, nil
   236  }
   237  
   238  // buildLoop builds and returns an iterator that can be used to iterate over the result set returned from the
   239  // specified loop, |n|, for the specified row, |row|. Note that because of how we execute stored procedures and cache
   240  // the results in order to only send back the LAST result set (instead of supporting multiple results sets from
   241  // stored procedures, like MySQL does), building the iterator here also implicitly means that we're executing the
   242  // loop logic and caching the result set in memory. This will obviously be an issue for very large result sets.
   243  // Unfortunately, we can't know at analysis time what the last result set returned will be, since conditional logic
   244  // in stored procedures can't be known until execution time, hence why we end up caching result sets when we
   245  // see them and just playing back the last one. Adding support for MySQL's multiple result set behavior and better
   246  // matching MySQL on which statements are allowed to return result sets from a stored procedure seems like it could
   247  // potentially allow us to get rid of that caching.
   248  func (b *BaseBuilder) buildLoop(ctx *sql.Context, n *plan.Loop, row sql.Row) (sql.RowIter, error) {
   249  	// Acquiring the RowIter will actually execute the loop body once (because of how we cache/scan for the right
   250  	// SELECT result set to return), so we grab the iter ONLY if we're supposed to run through the loop body once
   251  	// before evaluating the condition
   252  	var loopBodyIter sql.RowIter
   253  	if n.OnceBeforeEval {
   254  		var err error
   255  		loopBodyIter, err = b.loopAcquireRowIter(ctx, row, n.Label, n.Block, true)
   256  		if err != nil {
   257  			return nil, err
   258  		}
   259  	}
   260  
   261  	var returnRows []sql.Row
   262  	var returnNode sql.Node
   263  	var returnSch sql.Schema
   264  	selectSeen := false
   265  
   266  	// It's technically valid to make an infinite loop, but we don't want to actually allow that
   267  	const maxIterationCount = 10_000_000_000
   268  
   269  	for loopIteration := 0; loopIteration <= maxIterationCount; loopIteration++ {
   270  		if loopIteration >= maxIterationCount {
   271  			return nil, fmt.Errorf("infinite LOOP detected")
   272  		}
   273  
   274  		// If the condition is false, then we stop evaluation
   275  		condition, err := n.Condition.Eval(ctx, nil)
   276  		if err != nil {
   277  			return nil, err
   278  		}
   279  		conditionBool, err := sql.ConvertToBool(ctx, condition)
   280  		if err != nil {
   281  			return nil, err
   282  		}
   283  		if !conditionBool {
   284  			// loopBodyIter should only be set if this is the first time through the loop and the loop has a
   285  			// OnceBeforeEval condition. This ensures we return a result set, without us having to drain the iterator,
   286  			// recache rows, and return a new iterator.
   287  			if loopBodyIter != nil {
   288  				return loopBodyIter, nil
   289  			} else {
   290  				break
   291  			}
   292  		}
   293  
   294  		if loopBodyIter == nil {
   295  			var err error
   296  			loopBodyIter, err = b.loopAcquireRowIter(ctx, nil, strings.ToLower(n.Label), n.Block, false)
   297  			if err == io.EOF {
   298  				break
   299  			} else if err != nil {
   300  				return nil, err
   301  			}
   302  		}
   303  
   304  		includeResultSet := false
   305  
   306  		var subIterNode sql.Node = n.Block
   307  		subIterSch := n.Block.Schema()
   308  		if blockRowIter, ok := loopBodyIter.(plan.BlockRowIter); ok {
   309  			subIterNode = blockRowIter.RepresentingNode()
   310  			subIterSch = blockRowIter.Schema()
   311  
   312  			if plan.NodeRepresentsSelect(subIterNode) {
   313  				selectSeen = true
   314  				includeResultSet = true
   315  				returnNode = subIterNode
   316  				returnSch = subIterSch
   317  			} else if !selectSeen {
   318  				includeResultSet = true
   319  				returnNode = subIterNode
   320  				returnSch = subIterSch
   321  			}
   322  		}
   323  
   324  		// Wrap the caching code in an inline function so that we can use defer to safely dispose of the cache
   325  		err = func() error {
   326  			rowCache, disposeFunc := ctx.Memory.NewRowsCache()
   327  			defer disposeFunc()
   328  
   329  			nextRow, err := loopBodyIter.Next(ctx)
   330  			for ; err == nil; nextRow, err = loopBodyIter.Next(ctx) {
   331  				rowCache.Add(nextRow)
   332  			}
   333  			if err != io.EOF {
   334  				return err
   335  			}
   336  
   337  			err = loopBodyIter.Close(ctx)
   338  			if err != nil {
   339  				return err
   340  			}
   341  			loopBodyIter = nil
   342  
   343  			if includeResultSet {
   344  				returnRows = rowCache.Get()
   345  			}
   346  			return nil
   347  		}()
   348  
   349  		if err != nil {
   350  			if err == io.EOF {
   351  				// no-op for an EOF, just execute the next loop iteration
   352  			} else if controlFlow, ok := err.(loopError); ok && strings.ToLower(controlFlow.Label) == n.Label {
   353  				if controlFlow.IsExit {
   354  					break
   355  				}
   356  			} else {
   357  				// If the error wasn't a control flow error signaling to start the next loop iteration or to
   358  				// exit the loop, then it must be a real error, so just return it.
   359  				return nil, err
   360  			}
   361  		}
   362  	}
   363  
   364  	return &blockIter{
   365  		internalIter: sql.RowsToRowIter(returnRows...),
   366  		repNode:      returnNode,
   367  		sch:          returnSch,
   368  	}, nil
   369  }
   370  
   371  func (b *BaseBuilder) buildElseCaseError(ctx *sql.Context, n plan.ElseCaseError, row sql.Row) (sql.RowIter, error) {
   372  	return elseCaseErrorIter{}, nil
   373  }
   374  
   375  func (b *BaseBuilder) buildOpen(ctx *sql.Context, n *plan.Open, row sql.Row) (sql.RowIter, error) {
   376  	return &openIter{pRef: n.Pref, name: n.Name, row: row}, nil
   377  }
   378  
   379  func (b *BaseBuilder) buildClose(ctx *sql.Context, n *plan.Close, row sql.Row) (sql.RowIter, error) {
   380  	return &closeIter{pRef: n.Pref, name: n.Name}, nil
   381  }
   382  
   383  func (b *BaseBuilder) buildLeave(ctx *sql.Context, n *plan.Leave, row sql.Row) (sql.RowIter, error) {
   384  	return &leaveIter{n.Label}, nil
   385  }
   386  
   387  func (b *BaseBuilder) buildIterate(ctx *sql.Context, n *plan.Iterate, row sql.Row) (sql.RowIter, error) {
   388  	return &iterateIter{n.Label}, nil
   389  }
   390  
   391  func (b *BaseBuilder) buildWhile(ctx *sql.Context, n *plan.While, row sql.Row) (sql.RowIter, error) {
   392  	return b.buildLoop(ctx, n.Loop, row)
   393  }