github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/resolve_external_stored_procedures.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package analyzer
    16  
    17  import (
    18  	"reflect"
    19  	"strconv"
    20  	"time"
    21  
    22  	"github.com/shopspring/decimal"
    23  
    24  	"github.com/dolthub/go-mysql-server/sql"
    25  	"github.com/dolthub/go-mysql-server/sql/expression"
    26  	"github.com/dolthub/go-mysql-server/sql/plan"
    27  	"github.com/dolthub/go-mysql-server/sql/types"
    28  )
    29  
    30  var (
    31  	// ctxType is the reflect.Type of a *sql.Context.
    32  	ctxType = reflect.TypeOf((*sql.Context)(nil))
    33  	// ctxType is the reflect.Type of a sql.RowIter.
    34  	rowIterType = reflect.TypeOf((*sql.RowIter)(nil)).Elem()
    35  	// ctxType is the reflect.Type of an error.
    36  	errorType = reflect.TypeOf((*error)(nil)).Elem()
    37  	// externalStoredProcedurePointerTypes maps a non-pointer type to a sql.Type for external stored procedures.
    38  	externalStoredProcedureTypes = map[reflect.Type]sql.Type{
    39  		reflect.TypeOf(int(0)):            types.Int64,
    40  		reflect.TypeOf(int8(0)):           types.Int8,
    41  		reflect.TypeOf(int16(0)):          types.Int16,
    42  		reflect.TypeOf(int32(0)):          types.Int32,
    43  		reflect.TypeOf(int64(0)):          types.Int64,
    44  		reflect.TypeOf(uint(0)):           types.Uint64,
    45  		reflect.TypeOf(uint8(0)):          types.Uint8,
    46  		reflect.TypeOf(uint16(0)):         types.Uint16,
    47  		reflect.TypeOf(uint32(0)):         types.Uint32,
    48  		reflect.TypeOf(uint64(0)):         types.Uint64,
    49  		reflect.TypeOf(float32(0)):        types.Float32,
    50  		reflect.TypeOf(float64(0)):        types.Float64,
    51  		reflect.TypeOf(bool(false)):       types.Int8,
    52  		reflect.TypeOf(string("")):        types.LongText,
    53  		reflect.TypeOf([]byte{}):          types.LongBlob,
    54  		reflect.TypeOf(time.Time{}):       types.DatetimeMaxPrecision,
    55  		reflect.TypeOf(decimal.Decimal{}): types.InternalDecimalType,
    56  	}
    57  	// externalStoredProcedurePointerTypes maps a pointer type to a sql.Type for external stored procedures.
    58  	externalStoredProcedurePointerTypes = map[reflect.Type]sql.Type{
    59  		reflect.TypeOf((*int)(nil)):             types.Int64,
    60  		reflect.TypeOf((*int8)(nil)):            types.Int8,
    61  		reflect.TypeOf((*int16)(nil)):           types.Int16,
    62  		reflect.TypeOf((*int32)(nil)):           types.Int32,
    63  		reflect.TypeOf((*int64)(nil)):           types.Int64,
    64  		reflect.TypeOf((*uint)(nil)):            types.Uint64,
    65  		reflect.TypeOf((*uint8)(nil)):           types.Uint8,
    66  		reflect.TypeOf((*uint16)(nil)):          types.Uint16,
    67  		reflect.TypeOf((*uint32)(nil)):          types.Uint32,
    68  		reflect.TypeOf((*uint64)(nil)):          types.Uint64,
    69  		reflect.TypeOf((*float32)(nil)):         types.Float32,
    70  		reflect.TypeOf((*float64)(nil)):         types.Float64,
    71  		reflect.TypeOf((*bool)(nil)):            types.Int8,
    72  		reflect.TypeOf((*string)(nil)):          types.LongText,
    73  		reflect.TypeOf((*[]byte)(nil)):          types.LongBlob,
    74  		reflect.TypeOf((*time.Time)(nil)):       types.DatetimeMaxPrecision,
    75  		reflect.TypeOf((*decimal.Decimal)(nil)): types.InternalDecimalType,
    76  	}
    77  )
    78  
    79  func init() {
    80  	if strconv.IntSize == 32 {
    81  		externalStoredProcedureTypes[reflect.TypeOf(int(0))] = types.Int32
    82  		externalStoredProcedureTypes[reflect.TypeOf(uint(0))] = types.Uint32
    83  		externalStoredProcedurePointerTypes[reflect.TypeOf((*int)(nil))] = types.Int32
    84  		externalStoredProcedurePointerTypes[reflect.TypeOf((*uint)(nil))] = types.Uint32
    85  	}
    86  }
    87  
    88  // resolveExternalStoredProcedure resolves external stored procedures, converting them to the format expected of
    89  // normal stored procedures.
    90  func resolveExternalStoredProcedure(_ *sql.Context, externalProcedure sql.ExternalStoredProcedureDetails) (*plan.Procedure, error) {
    91  	funcVal := reflect.ValueOf(externalProcedure.Function)
    92  	funcType := funcVal.Type()
    93  	if funcType.Kind() != reflect.Func {
    94  		return nil, sql.ErrExternalProcedureNonFunction.New(externalProcedure.Function)
    95  	}
    96  	if funcType.NumIn() == 0 {
    97  		return nil, sql.ErrExternalProcedureMissingContextParam.New()
    98  	}
    99  	if funcType.NumOut() != 2 {
   100  		return nil, sql.ErrExternalProcedureReturnTypes.New()
   101  	}
   102  	if funcType.In(0) != ctxType {
   103  		return nil, sql.ErrExternalProcedureMissingContextParam.New()
   104  	}
   105  	if funcType.Out(0) != rowIterType {
   106  		return nil, sql.ErrExternalProcedureFirstReturn.New()
   107  	}
   108  	if funcType.Out(1) != errorType {
   109  		return nil, sql.ErrExternalProcedureSecondReturn.New()
   110  	}
   111  	funcIsVariadic := funcType.IsVariadic()
   112  
   113  	paramDefinitions := make([]plan.ProcedureParam, funcType.NumIn()-1)
   114  	paramReferences := make([]*expression.ProcedureParam, len(paramDefinitions))
   115  	for i := 0; i < len(paramDefinitions); i++ {
   116  		funcParamType := funcType.In(i + 1)
   117  		paramName := "A" + strconv.FormatInt(int64(i), 10)
   118  		paramIsVariadic := false
   119  		if funcIsVariadic && i == len(paramDefinitions)-1 {
   120  			paramIsVariadic = true
   121  			funcParamType = funcParamType.Elem()
   122  			if funcParamType.Kind() == reflect.Ptr {
   123  				return nil, sql.ErrExternalProcedurePointerVariadic.New()
   124  			}
   125  		}
   126  
   127  		if sqlType, ok := externalStoredProcedureTypes[funcParamType]; ok {
   128  			paramDefinitions[i] = plan.ProcedureParam{
   129  				Direction: plan.ProcedureParamDirection_In,
   130  				Name:      paramName,
   131  				Type:      sqlType,
   132  				Variadic:  paramIsVariadic,
   133  			}
   134  			paramReferences[i] = expression.NewProcedureParam(paramName, sqlType)
   135  		} else if sqlType, ok = externalStoredProcedurePointerTypes[funcParamType]; ok {
   136  			paramDefinitions[i] = plan.ProcedureParam{
   137  				Direction: plan.ProcedureParamDirection_Inout,
   138  				Name:      paramName,
   139  				Type:      sqlType,
   140  				Variadic:  paramIsVariadic,
   141  			}
   142  			paramReferences[i] = expression.NewProcedureParam(paramName, sqlType)
   143  		} else {
   144  			return nil, sql.ErrExternalProcedureInvalidParamType.New(funcParamType.String())
   145  		}
   146  	}
   147  
   148  	procedure := plan.NewProcedure(
   149  		externalProcedure.Name,
   150  		"root",
   151  		paramDefinitions,
   152  		plan.ProcedureSecurityContext_Definer,
   153  		"External stored procedure",
   154  		nil,
   155  		externalProcedure.FakeCreateProcedureStmt(),
   156  		&plan.ExternalProcedure{
   157  			ExternalStoredProcedureDetails: externalProcedure,
   158  			ParamDefinitions:               paramDefinitions,
   159  			Params:                         paramReferences,
   160  		},
   161  		time.Unix(1, 0),
   162  		time.Unix(1, 0),
   163  	)
   164  	return procedure, nil
   165  }