github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/proc_iters.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package rowexec
    16  
    17  import (
    18  	"errors"
    19  	"fmt"
    20  	"io"
    21  	"strings"
    22  
    23  	"github.com/dolthub/vitess/go/mysql"
    24  
    25  	"github.com/dolthub/go-mysql-server/sql"
    26  	"github.com/dolthub/go-mysql-server/sql/expression"
    27  	"github.com/dolthub/go-mysql-server/sql/plan"
    28  )
    29  
    30  // ifElseIter is the row iterator for *IfElseBlock.
    31  type ifElseIter struct {
    32  	branchIter sql.RowIter
    33  	sch        sql.Schema
    34  	branchNode sql.Node
    35  }
    36  
    37  var _ plan.BlockRowIter = (*ifElseIter)(nil)
    38  
    39  // Next implements the sql.RowIter interface.
    40  func (i *ifElseIter) Next(ctx *sql.Context) (sql.Row, error) {
    41  	if err := startTransaction(ctx); err != nil {
    42  		return nil, err
    43  	}
    44  
    45  	return i.branchIter.Next(ctx)
    46  }
    47  
    48  // Close implements the sql.RowIter interface.
    49  func (i *ifElseIter) Close(ctx *sql.Context) error {
    50  	return i.branchIter.Close(ctx)
    51  }
    52  
    53  // RepresentingNode implements the sql.BlockRowIter interface.
    54  func (i *ifElseIter) RepresentingNode() sql.Node {
    55  	return i.branchNode
    56  }
    57  
    58  // Schema implements the sql.BlockRowIter interface.
    59  func (i *ifElseIter) Schema() sql.Schema {
    60  	return i.sch
    61  }
    62  
    63  // beginEndIter is the sql.RowIter of *BeginEndBlock.
    64  type beginEndIter struct {
    65  	*plan.BeginEndBlock
    66  	rowIter sql.RowIter
    67  }
    68  
    69  var _ sql.RowIter = (*beginEndIter)(nil)
    70  
    71  // Next implements the interface sql.RowIter.
    72  func (b *beginEndIter) Next(ctx *sql.Context) (sql.Row, error) {
    73  	if err := startTransaction(ctx); err != nil {
    74  		return nil, err
    75  	}
    76  
    77  	row, err := b.rowIter.Next(ctx)
    78  	if err != nil {
    79  		if controlFlow, ok := err.(loopError); ok && strings.ToLower(controlFlow.Label) == strings.ToLower(b.Label) {
    80  			if controlFlow.IsExit {
    81  				err = nil
    82  			} else {
    83  				err = fmt.Errorf("encountered ITERATE on BEGIN...END, which should should have been caught by the analyzer")
    84  			}
    85  		}
    86  		if nErr := b.Pref.PopScope(ctx); nErr != nil && err == io.EOF {
    87  			err = nErr
    88  		}
    89  		if errors.Is(err, expression.FetchEOF) {
    90  			err = io.EOF
    91  		}
    92  		return nil, err
    93  	}
    94  	return row, nil
    95  }
    96  
    97  // Close implements the interface sql.RowIter.
    98  func (b *beginEndIter) Close(ctx *sql.Context) error {
    99  	return b.rowIter.Close(ctx)
   100  }
   101  
   102  // callIter is the row iterator for *Call.
   103  type callIter struct {
   104  	call      *plan.Call
   105  	innerIter sql.RowIter
   106  }
   107  
   108  // Next implements the sql.RowIter interface.
   109  func (iter *callIter) Next(ctx *sql.Context) (sql.Row, error) {
   110  	return iter.innerIter.Next(ctx)
   111  }
   112  
   113  // Close implements the sql.RowIter interface.
   114  func (iter *callIter) Close(ctx *sql.Context) error {
   115  	err := iter.innerIter.Close(ctx)
   116  	if err != nil {
   117  		return err
   118  	}
   119  	err = iter.call.Pref.CloseAllCursors(ctx)
   120  	if err != nil {
   121  		return err
   122  	}
   123  
   124  	// Set all user and system variables from INOUT and OUT params
   125  	for i, param := range iter.call.Procedure.Params {
   126  		if param.Direction == plan.ProcedureParamDirection_Inout ||
   127  			(param.Direction == plan.ProcedureParamDirection_Out && iter.call.Pref.VariableHasBeenSet(param.Name)) {
   128  			val, err := iter.call.Pref.GetVariableValue(param.Name)
   129  			if err != nil {
   130  				return err
   131  			}
   132  
   133  			typ := iter.call.Pref.GetVariableType(param.Name)
   134  
   135  			switch callParam := iter.call.Params[i].(type) {
   136  			case *expression.UserVar:
   137  				err = ctx.SetUserVariable(ctx, callParam.Name, val, typ)
   138  				if err != nil {
   139  					return err
   140  				}
   141  			case *expression.SystemVar:
   142  				// This should have been caught by the analyzer, so a major bug exists somewhere
   143  				return fmt.Errorf("unable to set `%s` as it is a system variable", callParam.Name)
   144  			case *expression.ProcedureParam:
   145  				err = callParam.Set(val, param.Type)
   146  				if err != nil {
   147  					return err
   148  				}
   149  			}
   150  		} else if param.Direction == plan.ProcedureParamDirection_Out { // VariableHasBeenSet was false
   151  			// For OUT only, if a var was not set within the procedure body, then we set the vars to nil.
   152  			// If the var had a value before the call then it is basically removed.
   153  			switch callParam := iter.call.Params[i].(type) {
   154  			case *expression.UserVar:
   155  				err = ctx.SetUserVariable(ctx, callParam.Name, nil, iter.call.Pref.GetVariableType(param.Name))
   156  				if err != nil {
   157  					return err
   158  				}
   159  			case *expression.SystemVar:
   160  				// This should have been caught by the analyzer, so a major bug exists somewhere
   161  				return fmt.Errorf("unable to set `%s` as it is a system variable", callParam.Name)
   162  			case *expression.ProcedureParam:
   163  				err := callParam.Set(nil, param.Type)
   164  				if err != nil {
   165  					return err
   166  				}
   167  			}
   168  		}
   169  	}
   170  	return nil
   171  }
   172  
   173  type elseCaseErrorIter struct{}
   174  
   175  var _ sql.RowIter = elseCaseErrorIter{}
   176  
   177  // Next implements the interface sql.RowIter.
   178  func (e elseCaseErrorIter) Next(ctx *sql.Context) (sql.Row, error) {
   179  	return nil, mysql.NewSQLError(1339, "20000", "Case not found for CASE statement")
   180  }
   181  
   182  // Close implements the interface sql.RowIter.
   183  func (e elseCaseErrorIter) Close(context *sql.Context) error {
   184  	return nil
   185  }
   186  
   187  // openIter is the sql.RowIter of *Open.
   188  type openIter struct {
   189  	pRef *expression.ProcedureReference
   190  	name string
   191  	row  sql.Row
   192  	b    *BaseBuilder
   193  }
   194  
   195  var _ sql.RowIter = (*openIter)(nil)
   196  
   197  // Next implements the interface sql.RowIter.
   198  func (o *openIter) Next(ctx *sql.Context) (sql.Row, error) {
   199  	if err := o.openCursor(ctx, o.pRef, o.name, o.row); err != nil {
   200  		return nil, err
   201  	}
   202  	return nil, io.EOF
   203  }
   204  
   205  func (o *openIter) openCursor(ctx *sql.Context, ref *expression.ProcedureReference, name string, row sql.Row) error {
   206  	lowerName := strings.ToLower(name)
   207  	scope := ref.InnermostScope
   208  	for scope != nil {
   209  		if cursorRefVal, ok := scope.Cursors[lowerName]; ok {
   210  			if cursorRefVal.RowIter != nil {
   211  				return sql.ErrCursorAlreadyOpen.New(name)
   212  			}
   213  			var err error
   214  			cursorRefVal.RowIter, err = o.b.buildNodeExec(ctx, cursorRefVal.SelectStmt, row)
   215  			return err
   216  		}
   217  		scope = scope.Parent
   218  	}
   219  	return fmt.Errorf("cannot find cursor `%s`", name)
   220  }
   221  
   222  // Close implements the interface sql.RowIter.
   223  func (o *openIter) Close(ctx *sql.Context) error {
   224  	return nil
   225  }
   226  
   227  // closeIter is the sql.RowIter of *Close.
   228  type closeIter struct {
   229  	pRef *expression.ProcedureReference
   230  	name string
   231  }
   232  
   233  var _ sql.RowIter = (*closeIter)(nil)
   234  
   235  // Next implements the interface sql.RowIter.
   236  func (c *closeIter) Next(ctx *sql.Context) (sql.Row, error) {
   237  	if err := c.pRef.CloseCursor(ctx, c.name); err != nil {
   238  		return nil, err
   239  	}
   240  	return nil, io.EOF
   241  }
   242  
   243  // Close implements the interface sql.RowIter.
   244  func (c *closeIter) Close(ctx *sql.Context) error {
   245  	return nil
   246  }
   247  
   248  // loopError is an error used to control a loop's flow.
   249  type loopError struct {
   250  	Label  string
   251  	IsExit bool
   252  }
   253  
   254  var _ error = loopError{}
   255  
   256  // Error implements the interface error. As long as the analysis step is implemented correctly, this should never be seen.
   257  func (l loopError) Error() string {
   258  	option := "exited"
   259  	if !l.IsExit {
   260  		option = "continued"
   261  	}
   262  	return fmt.Sprintf("should have %s the loop `%s` but it was somehow not found in the call stack", option, l.Label)
   263  }
   264  
   265  // loopAcquireRowIter is a helper function for LOOP that conditionally acquires a new sql.RowIter. If a loop exit is
   266  // encountered, `exitIter` determines whether to return an empty iterator or an io.EOF error.
   267  func (b *BaseBuilder) loopAcquireRowIter(ctx *sql.Context, row sql.Row, label string, block *plan.Block, exitIter bool) (sql.RowIter, error) {
   268  	blockIter, err := b.buildBlock(ctx, block, row)
   269  	if controlFlow, ok := err.(loopError); ok && strings.ToLower(controlFlow.Label) == strings.ToLower(label) {
   270  		if controlFlow.IsExit {
   271  			if exitIter {
   272  				return sql.RowsToRowIter(), nil
   273  			} else {
   274  				return nil, io.EOF
   275  			}
   276  		} else {
   277  			err = io.EOF
   278  		}
   279  	}
   280  	if err == io.EOF {
   281  		blockIter = sql.RowsToRowIter()
   282  		err = nil
   283  	}
   284  	return blockIter, err
   285  }
   286  
   287  // leaveIter is the sql.RowIter of *Leave.
   288  type leaveIter struct {
   289  	Label string
   290  }
   291  
   292  var _ sql.RowIter = (*leaveIter)(nil)
   293  
   294  // Next implements the interface sql.RowIter.
   295  func (l *leaveIter) Next(ctx *sql.Context) (sql.Row, error) {
   296  	return nil, loopError{
   297  		Label:  l.Label,
   298  		IsExit: true,
   299  	}
   300  }
   301  
   302  // Close implements the interface sql.RowIter.
   303  func (l *leaveIter) Close(ctx *sql.Context) error {
   304  	return nil
   305  }
   306  
   307  // iterateIter is the sql.RowIter of *Iterate.
   308  type iterateIter struct {
   309  	Label string
   310  }
   311  
   312  var _ sql.RowIter = (*iterateIter)(nil)
   313  
   314  // Next implements the interface sql.RowIter.
   315  func (i *iterateIter) Next(ctx *sql.Context) (sql.Row, error) {
   316  	return nil, loopError{
   317  		Label:  i.Label,
   318  		IsExit: false,
   319  	}
   320  }
   321  
   322  // Close implements the interface sql.RowIter.
   323  func (i *iterateIter) Close(ctx *sql.Context) error {
   324  	return nil
   325  }
   326  
   327  // startTransaction begins a new transaction if necessary, e.g. if a statement in a stored procedure committed the
   328  // current one
   329  func startTransaction(ctx *sql.Context) error {
   330  	if ctx.GetTransaction() == nil {
   331  		ts, ok := ctx.Session.(sql.TransactionSession)
   332  		if ok {
   333  			tx, err := ts.StartTransaction(ctx, sql.ReadWrite)
   334  			if err != nil {
   335  				return err
   336  			}
   337  
   338  			ctx.SetTransaction(tx)
   339  		}
   340  	}
   341  
   342  	return nil
   343  }