github.com/dolthub/go-mysql-server@v0.18.0/sql/plan/external_procedure.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"reflect"
    19  	"strconv"
    20  
    21  	"github.com/dolthub/go-mysql-server/sql"
    22  	"github.com/dolthub/go-mysql-server/sql/expression"
    23  )
    24  
    25  var (
    26  	boolType = reflect.TypeOf(bool(false))
    27  	intType  = reflect.TypeOf(int(0))
    28  	uintType = reflect.TypeOf(uint(0))
    29  )
    30  
    31  // ExternalProcedure is the sql.Node container for sql.ExternalStoredProcedureDetails.
    32  type ExternalProcedure struct {
    33  	sql.ExternalStoredProcedureDetails
    34  	ParamDefinitions []ProcedureParam
    35  	Params           []*expression.ProcedureParam
    36  }
    37  
    38  var _ sql.Node = (*ExternalProcedure)(nil)
    39  var _ sql.Expressioner = (*ExternalProcedure)(nil)
    40  var _ sql.CollationCoercible = (*ExternalProcedure)(nil)
    41  
    42  // Resolved implements the interface sql.Node.
    43  func (n *ExternalProcedure) Resolved() bool {
    44  	return true
    45  }
    46  
    47  func (n *ExternalProcedure) IsReadOnly() bool {
    48  	return n.ExternalStoredProcedureDetails.ReadOnly
    49  }
    50  
    51  // String implements the interface sql.Node.
    52  func (n *ExternalProcedure) String() string {
    53  	return n.ExternalStoredProcedureDetails.Name
    54  }
    55  
    56  // Schema implements the interface sql.Node.
    57  func (n *ExternalProcedure) Schema() sql.Schema {
    58  	return n.ExternalStoredProcedureDetails.Schema
    59  }
    60  
    61  // Children implements the interface sql.Node.
    62  func (n *ExternalProcedure) Children() []sql.Node {
    63  	return nil
    64  }
    65  
    66  // WithChildren implements the interface sql.Node.
    67  func (n *ExternalProcedure) WithChildren(children ...sql.Node) (sql.Node, error) {
    68  	if len(children) != 0 {
    69  		return nil, sql.ErrInvalidChildrenNumber.New(n, len(children), 0)
    70  	}
    71  	return n, nil
    72  }
    73  
    74  // Expressions implements the interface sql.Expressioner.
    75  func (n *ExternalProcedure) Expressions() []sql.Expression {
    76  	exprs := make([]sql.Expression, len(n.Params))
    77  	for i, param := range n.Params {
    78  		exprs[i] = param
    79  	}
    80  	return exprs
    81  }
    82  
    83  // WithExpressions implements the interface sql.Expressioner.
    84  func (n *ExternalProcedure) WithExpressions(expressions ...sql.Expression) (sql.Node, error) {
    85  	if len(expressions) != len(n.Params) {
    86  		return nil, sql.ErrInvalidExpressionNumber.New(n, len(expressions), len(n.Params))
    87  	}
    88  	newParams := make([]*expression.ProcedureParam, len(expressions))
    89  	for i, expr := range expressions {
    90  		newParams[i] = expr.(*expression.ProcedureParam)
    91  	}
    92  	nn := *n
    93  	nn.Params = newParams
    94  	return &nn, nil
    95  }
    96  
    97  // CheckPrivileges implements the interface sql.Node.
    98  func (n *ExternalProcedure) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
    99  	//TODO: when DEFINER is implemented for stored procedures then this should be added
   100  	return true
   101  }
   102  
   103  // CollationCoercibility implements the interface sql.CollationCoercible.
   104  func (*ExternalProcedure) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   105  	return sql.Collation_binary, 7
   106  }
   107  
   108  // RowIter implements the interface sql.Node.
   109  func (n *ExternalProcedure) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) {
   110  	// The function's structure has been verified by the analyzer, so no need to double-check any of it here
   111  	funcVal := reflect.ValueOf(n.Function)
   112  	funcType := funcVal.Type()
   113  	// The first parameter is always the context, but it doesn't exist as far as the stored procedures are concerned, so
   114  	// we prepend it here
   115  	funcParams := make([]reflect.Value, len(n.Params)+1)
   116  	funcParams[0] = reflect.ValueOf(ctx)
   117  
   118  	for i := range n.Params {
   119  		paramDefinition := n.ParamDefinitions[i]
   120  		var funcParamType reflect.Type
   121  		if paramDefinition.Variadic {
   122  			funcParamType = funcType.In(funcType.NumIn() - 1).Elem()
   123  		} else {
   124  			funcParamType = funcType.In(i + 1)
   125  		}
   126  		// Grab the passed-in variable and convert it to the type we expect
   127  		exprParamVal, err := n.Params[i].Eval(ctx, nil)
   128  		if err != nil {
   129  			return nil, err
   130  		}
   131  		exprParamVal, _, err = paramDefinition.Type.Convert(exprParamVal)
   132  		if err != nil {
   133  			return nil, err
   134  		}
   135  
   136  		funcParams[i+1], err = n.ProcessParam(ctx, funcParamType, exprParamVal)
   137  		if err != nil {
   138  			return nil, err
   139  		}
   140  	}
   141  	out := funcVal.Call(funcParams)
   142  
   143  	// Again, these types are enforced in the analyzer, so it's safe to assume their types here
   144  	if err, ok := out[1].Interface().(error); ok { // Only evaluates to true when error is not nil
   145  		return nil, err
   146  	}
   147  	for i, paramDefinition := range n.ParamDefinitions {
   148  		if paramDefinition.Direction == ProcedureParamDirection_Inout || paramDefinition.Direction == ProcedureParamDirection_Out {
   149  			exprParam := n.Params[i]
   150  			funcParamVal := funcParams[i+1].Elem().Interface()
   151  			err := exprParam.Set(funcParamVal, exprParam.Type())
   152  			if err != nil {
   153  				return nil, err
   154  			}
   155  		}
   156  	}
   157  	// It's not invalid to return a nil RowIter, as having no rows to return is expected of many stored procedures.
   158  	if rowIter, ok := out[0].Interface().(sql.RowIter); ok {
   159  		return rowIter, nil
   160  	}
   161  	return sql.RowsToRowIter(), nil
   162  }
   163  
   164  func (n *ExternalProcedure) ProcessParam(ctx *sql.Context, funcParamType reflect.Type, exprParamVal interface{}) (reflect.Value, error) {
   165  	funcParamCompType := funcParamType
   166  	if funcParamType.Kind() == reflect.Ptr {
   167  		funcParamCompType = funcParamType.Elem()
   168  	}
   169  	// Convert to bool, int, and uint as they differ from their sql.Type value
   170  	if exprParamVal != nil {
   171  		switch funcParamCompType {
   172  		case boolType:
   173  			val := false
   174  			if exprParamVal.(int8) != 0 {
   175  				val = true
   176  			}
   177  			exprParamVal = val
   178  		case intType:
   179  			if strconv.IntSize == 32 {
   180  				exprParamVal = int(exprParamVal.(int32))
   181  			} else {
   182  				exprParamVal = int(exprParamVal.(int64))
   183  			}
   184  		case uintType:
   185  			if strconv.IntSize == 32 {
   186  				exprParamVal = uint(exprParamVal.(uint32))
   187  			} else {
   188  				exprParamVal = uint(exprParamVal.(uint64))
   189  			}
   190  		}
   191  	}
   192  
   193  	if funcParamType.Kind() == reflect.Ptr { // Coincides with INOUT
   194  		funcParamVal := reflect.New(funcParamType.Elem())
   195  		if exprParamVal != nil {
   196  			funcParamVal.Elem().Set(reflect.ValueOf(exprParamVal))
   197  		}
   198  		return funcParamVal, nil
   199  	} else { // Coincides with IN
   200  		funcParamVal := reflect.New(funcParamType)
   201  		if exprParamVal != nil {
   202  			funcParamVal.Elem().Set(reflect.ValueOf(exprParamVal))
   203  		}
   204  		return funcParamVal.Elem(), nil
   205  	}
   206  }