
     1  // Copyright 2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    15  package expression
    17  import (
    18  	"errors"
    19  	"fmt"
    20  	"strings"
    22  	""
    23  	""
    24  )
    26  // ProcedureReference contains the state for a single CALL statement of a stored procedure.
    27  type ProcedureReference struct {
    28  	InnermostScope *procedureScope
    29  	height         int
    30  }
    31  type procedureScope struct {
    32  	Parent    *procedureScope
    33  	variables map[string]*procedureVariableReferenceValue
    34  	Cursors   map[string]*procedureCursorReferenceValue
    35  	Handlers  []*procedureHandlerReferenceValue
    36  }
    37  type procedureVariableReferenceValue struct {
    38  	Name       string
    39  	Value      interface{}
    40  	SqlType    sql.Type
    41  	HasBeenSet bool
    42  }
    43  type procedureCursorReferenceValue struct {
    44  	Name       string
    45  	SelectStmt sql.Node
    46  	RowIter    sql.RowIter
    47  }
    48  type procedureHandlerReferenceValue struct {
    49  	Stmt        sql.Node
    50  	IsExit      bool
    51  	Action      DeclareHandlerAction
    52  	Cond        HandlerCondition
    53  	ScopeHeight int
    54  }
    56  // ProcedureReferencable indicates that a sql.Node takes a *ProcedureReference returns a new copy with the reference set.
    57  type ProcedureReferencable interface {
    58  	WithParamReference(pRef *ProcedureReference) sql.Node
    59  }
    61  // InitializeVariable sets the initial value for the variable.
    62  func (ppr *ProcedureReference) InitializeVariable(name string, sqlType sql.Type, val interface{}) error {
    63  	convertedVal, _, err := sqlType.Convert(val)
    64  	if err != nil {
    65  		return err
    66  	}
    67  	lowerName := strings.ToLower(name)
    68  	ppr.InnermostScope.variables[lowerName] = &procedureVariableReferenceValue{
    69  		Name:       lowerName,
    70  		Value:      convertedVal,
    71  		SqlType:    sqlType,
    72  		HasBeenSet: false,
    73  	}
    74  	return nil
    75  }
    77  // InitializeCursor sets the initial state for the cursor.
    78  func (ppr *ProcedureReference) InitializeCursor(name string, selectStmt sql.Node) {
    79  	lowerName := strings.ToLower(name)
    80  	ppr.InnermostScope.Cursors[lowerName] = &procedureCursorReferenceValue{
    81  		Name:       lowerName,
    82  		SelectStmt: selectStmt,
    83  		RowIter:    nil,
    84  	}
    85  }
    87  // InitializeHandler sets the given handler's statement.
    88  func (ppr *ProcedureReference) InitializeHandler(stmt sql.Node, action DeclareHandlerAction, cond HandlerCondition) {
    89  	ppr.InnermostScope.Handlers = append(ppr.InnermostScope.Handlers, &procedureHandlerReferenceValue{
    90  		Stmt:        stmt,
    91  		Cond:        cond,
    92  		Action:      action,
    93  		ScopeHeight: ppr.height,
    94  	})
    95  }
    97  // GetVariableValue returns the value of the given parameter.
    98  func (ppr *ProcedureReference) GetVariableValue(name string) (interface{}, error) {
    99  	lowerName := strings.ToLower(name)
   100  	scope := ppr.InnermostScope
   101  	for scope != nil {
   102  		if varRefVal, ok := scope.variables[lowerName]; ok {
   103  			return varRefVal.Value, nil
   104  		}
   105  		scope = scope.Parent
   106  	}
   107  	return nil, fmt.Errorf("cannot find value for parameter `%s`", name)
   108  }
   110  // GetVariableType returns the type of the given parameter. Returns the NULL type if the type cannot be found.
   111  func (ppr *ProcedureReference) GetVariableType(name string) sql.Type {
   112  	if ppr == nil {
   113  		return types.Null
   114  	}
   115  	lowerName := strings.ToLower(name)
   116  	scope := ppr.InnermostScope
   117  	for scope != nil {
   118  		if varRefVal, ok := scope.variables[lowerName]; ok {
   119  			return varRefVal.SqlType
   120  		}
   121  		scope = scope.Parent
   122  	}
   123  	return types.Null
   124  }
   126  // SetVariable updates the value of the given parameter.
   127  func (ppr *ProcedureReference) SetVariable(name string, val interface{}, valType sql.Type) error {
   128  	lowerName := strings.ToLower(name)
   129  	scope := ppr.InnermostScope
   130  	for scope != nil {
   131  		if varRefVal, ok := scope.variables[lowerName]; ok {
   132  			//TODO: do some actual type checking using the given value's type
   133  			val, _, err := varRefVal.SqlType.Convert(val)
   134  			if err != nil {
   135  				return err
   136  			}
   137  			varRefVal.Value = val
   138  			varRefVal.HasBeenSet = true
   139  			return nil
   140  		}
   141  		scope = scope.Parent
   142  	}
   143  	return fmt.Errorf("cannot find value for parameter `%s`", name)
   144  }
   146  // VariableHasBeenSet returns whether the parameter has had its value altered from the initial value.
   147  func (ppr *ProcedureReference) VariableHasBeenSet(name string) bool {
   148  	lowerName := strings.ToLower(name)
   149  	scope := ppr.InnermostScope
   150  	for scope != nil {
   151  		if varRefVal, ok := scope.variables[lowerName]; ok {
   152  			return varRefVal.HasBeenSet
   153  		}
   154  		scope = scope.Parent
   155  	}
   156  	return false
   157  }
   159  // CloseCursor closes the designated cursor.
   160  func (ppr *ProcedureReference) CloseCursor(ctx *sql.Context, name string) error {
   161  	lowerName := strings.ToLower(name)
   162  	scope := ppr.InnermostScope
   163  	for scope != nil {
   164  		if cursorRefVal, ok := scope.Cursors[lowerName]; ok {
   165  			if cursorRefVal.RowIter == nil {
   166  				return sql.ErrCursorNotOpen.New(name)
   167  			}
   168  			err := cursorRefVal.RowIter.Close(ctx)
   169  			cursorRefVal.RowIter = nil
   170  			return err
   171  		}
   172  		scope = scope.Parent
   173  	}
   174  	return fmt.Errorf("cannot find cursor `%s`", name)
   175  }
   177  // FetchCursor returns the next row from the designated cursor.
   178  func (ppr *ProcedureReference) FetchCursor(ctx *sql.Context, name string) (sql.Row, sql.Schema, error) {
   179  	lowerName := strings.ToLower(name)
   180  	scope := ppr.InnermostScope
   181  	for scope != nil {
   182  		if cursorRefVal, ok := scope.Cursors[lowerName]; ok {
   183  			if cursorRefVal.RowIter == nil {
   184  				return nil, nil, sql.ErrCursorNotOpen.New(name)
   185  			}
   186  			row, err := cursorRefVal.RowIter.Next(ctx)
   187  			return row, cursorRefVal.SelectStmt.Schema(), err
   188  		}
   189  		scope = scope.Parent
   190  	}
   191  	return nil, nil, fmt.Errorf("cannot find cursor `%s`", name)
   192  }
   194  // PushScope creates a new scope inside the current one.
   195  func (ppr *ProcedureReference) PushScope() {
   196  	ppr.InnermostScope = &procedureScope{
   197  		Parent:    ppr.InnermostScope,
   198  		variables: make(map[string]*procedureVariableReferenceValue),
   199  		Cursors:   make(map[string]*procedureCursorReferenceValue),
   200  		Handlers:  nil,
   201  	}
   202  	ppr.height++
   203  }
   205  // PopScope removes the innermost scope, returning to its parent. Also closes all open cursors.
   206  func (ppr *ProcedureReference) PopScope(ctx *sql.Context) error {
   207  	var err error
   208  	if ppr.InnermostScope == nil {
   209  		return fmt.Errorf("attempted to pop an empty scope")
   210  	}
   211  	for _, cursorRefVal := range ppr.InnermostScope.Cursors {
   212  		if cursorRefVal.RowIter != nil {
   213  			nErr := cursorRefVal.RowIter.Close(ctx)
   214  			cursorRefVal.RowIter = nil
   215  			if err == nil {
   216  				err = nErr
   217  			}
   218  		}
   219  	}
   220  	ppr.InnermostScope = ppr.InnermostScope.Parent
   221  	ppr.height--
   222  	return nil
   223  }
   225  // CloseAllCursors closes all cursors that are still open.
   226  func (ppr *ProcedureReference) CloseAllCursors(ctx *sql.Context) error {
   227  	var err error
   228  	scope := ppr.InnermostScope
   229  	for scope != nil {
   230  		for _, cursorRefVal := range scope.Cursors {
   231  			if cursorRefVal.RowIter != nil {
   232  				nErr := cursorRefVal.RowIter.Close(ctx)
   233  				cursorRefVal.RowIter = nil
   234  				if err == nil {
   235  					err = nErr
   236  				}
   237  			}
   238  		}
   239  		scope = scope.Parent
   240  	}
   241  	return err
   242  }
   244  // CurrentHeight returns the current height of the scope stack.
   245  func (ppr *ProcedureReference) CurrentHeight() int {
   246  	return ppr.height
   247  }
   249  func NewProcedureReference() *ProcedureReference {
   250  	return &ProcedureReference{
   251  		InnermostScope: &procedureScope{
   252  			Parent:    nil,
   253  			variables: make(map[string]*procedureVariableReferenceValue),
   254  			Cursors:   make(map[string]*procedureCursorReferenceValue),
   255  			Handlers:  nil,
   256  		},
   257  		height: 0,
   258  	}
   259  }
   261  // ProcedureParam represents the parameter of a stored procedure or stored function.
   262  type ProcedureParam struct {
   263  	name       string
   264  	pRef       *ProcedureReference
   265  	typ        sql.Type
   266  	hasBeenSet bool
   267  }
   269  var _ sql.Expression = (*ProcedureParam)(nil)
   270  var _ sql.CollationCoercible = (*ProcedureParam)(nil)
   272  // NewProcedureParam creates a new ProcedureParam expression.
   273  func NewProcedureParam(name string, typ sql.Type) *ProcedureParam {
   274  	return &ProcedureParam{name: strings.ToLower(name), typ: typ}
   275  }
   277  // Children implements the sql.Expression interface.
   278  func (*ProcedureParam) Children() []sql.Expression {
   279  	return nil
   280  }
   282  // Resolved implements the sql.Expression interface.
   283  func (*ProcedureParam) Resolved() bool {
   284  	return true
   285  }
   287  // IsNullable implements the sql.Expression interface.
   288  func (*ProcedureParam) IsNullable() bool {
   289  	return false
   290  }
   292  // Type implements the sql.Expression interface.
   293  func (pp *ProcedureParam) Type() sql.Type {
   294  	return pp.typ
   295  }
   297  // CollationCoercibility implements the sql.CollationCoercible interface.
   298  func (pp *ProcedureParam) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   299  	collation, _ = pp.pRef.GetVariableType(
   300  	return collation, 2
   301  }
   303  // Name implements the Nameable interface.
   304  func (pp *ProcedureParam) Name() string {
   305  	return
   306  }
   308  // String implements the sql.Expression interface.
   309  func (pp *ProcedureParam) String() string {
   310  	return
   311  }
   313  // Eval implements the sql.Expression interface.
   314  func (pp *ProcedureParam) Eval(ctx *sql.Context, r sql.Row) (interface{}, error) {
   315  	return pp.pRef.GetVariableValue(
   316  }
   318  // WithChildren implements the sql.Expression interface.
   319  func (pp *ProcedureParam) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   320  	if len(children) != 0 {
   321  		return nil, sql.ErrInvalidChildrenNumber.New(pp, len(children), 0)
   322  	}
   323  	return pp, nil
   324  }
   326  // WithParamReference returns a new *ProcedureParam containing the given *ProcedureReference.
   327  func (pp *ProcedureParam) WithParamReference(pRef *ProcedureReference) *ProcedureParam {
   328  	npp := *pp
   329  	npp.pRef = pRef
   330  	return &npp
   331  }
   333  // Set sets the value of this procedure parameter to the given value.
   334  func (pp *ProcedureParam) Set(val interface{}, valType sql.Type) error {
   335  	return pp.pRef.SetVariable(, val, valType)
   336  }
   338  // UnresolvedProcedureParam represents an unresolved parameter of a stored procedure or stored function.
   339  type UnresolvedProcedureParam struct {
   340  	name string
   341  }
   343  var _ sql.Expression = (*UnresolvedProcedureParam)(nil)
   344  var _ sql.CollationCoercible = (*UnresolvedProcedureParam)(nil)
   346  // NewUnresolvedProcedureParam creates a new UnresolvedProcedureParam expression.
   347  func NewUnresolvedProcedureParam(name string) *UnresolvedProcedureParam {
   348  	return &UnresolvedProcedureParam{name: strings.ToLower(name)}
   349  }
   351  // Children implements the sql.Expression interface.
   352  func (*UnresolvedProcedureParam) Children() []sql.Expression {
   353  	return nil
   354  }
   356  // Resolved implements the sql.Expression interface.
   357  func (*UnresolvedProcedureParam) Resolved() bool {
   358  	return false
   359  }
   361  // IsNullable implements the sql.Expression interface.
   362  func (*UnresolvedProcedureParam) IsNullable() bool {
   363  	return false
   364  }
   366  // Type implements the sql.Expression interface.
   367  func (*UnresolvedProcedureParam) Type() sql.Type {
   368  	return types.Null
   369  }
   371  // CollationCoercibility implements the interface sql.CollationCoercible.
   372  func (*UnresolvedProcedureParam) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   373  	return sql.Collation_binary, 7
   374  }
   376  // Name implements the Nameable interface.
   377  func (upp *UnresolvedProcedureParam) Name() string {
   378  	return
   379  }
   381  // String implements the sql.Expression interface.
   382  func (upp *UnresolvedProcedureParam) String() string {
   383  	return
   384  }
   386  // Eval implements the sql.Expression interface.
   387  func (upp *UnresolvedProcedureParam) Eval(ctx *sql.Context, r sql.Row) (interface{}, error) {
   388  	return nil, fmt.Errorf("attempted to use unresolved procedure param '%s'",
   389  }
   391  // WithChildren implements the sql.Expression interface.
   392  func (upp *UnresolvedProcedureParam) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   393  	if len(children) != 0 {
   394  		return nil, sql.ErrInvalidChildrenNumber.New(upp, len(children), 0)
   395  	}
   396  	return upp, nil
   397  }
   399  // FetchEOF is a special EOF error that lets the loop implementation
   400  // differentiate between this io.EOF
   401  var FetchEOF = errors.New("exhausted fetch iterator")
   403  type HandlerConditionType uint8
   405  const (
   406  	HandlerConditionUnknown HandlerConditionType = iota
   407  	HandlerConditionNotFound
   408  	HandlerConditionSqlException
   409  )
   411  type HandlerCondition struct {
   412  	SqlStatePrefix string
   413  	Type           HandlerConditionType
   414  }
   416  type DeclareHandlerAction byte
   418  const (
   419  	DeclareHandlerAction_Continue DeclareHandlerAction = iota
   420  	DeclareHandlerAction_Exit
   421  	DeclareHandlerAction_Undo
   422  )
   424  func (c *HandlerCondition) Matches(err error) bool {
   425  	if errors.Is(err, FetchEOF) {
   426  		return c.Type == HandlerConditionNotFound
   427  	} else {
   428  		return c.Type == HandlerConditionSqlException
   429  	}
   430  }