github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/procedurereference.go (about) 1 // Copyright 2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package expression 16 17 import ( 18 "errors" 19 "fmt" 20 "strings" 21 22 "github.com/dolthub/go-mysql-server/sql" 23 "github.com/dolthub/go-mysql-server/sql/types" 24 ) 25 26 // ProcedureReference contains the state for a single CALL statement of a stored procedure. 27 type ProcedureReference struct { 28 InnermostScope *procedureScope 29 height int 30 } 31 type procedureScope struct { 32 Parent *procedureScope 33 variables map[string]*procedureVariableReferenceValue 34 Cursors map[string]*procedureCursorReferenceValue 35 Handlers []*procedureHandlerReferenceValue 36 } 37 type procedureVariableReferenceValue struct { 38 Name string 39 Value interface{} 40 SqlType sql.Type 41 HasBeenSet bool 42 } 43 type procedureCursorReferenceValue struct { 44 Name string 45 SelectStmt sql.Node 46 RowIter sql.RowIter 47 } 48 type procedureHandlerReferenceValue struct { 49 Stmt sql.Node 50 IsExit bool 51 Action DeclareHandlerAction 52 Cond HandlerCondition 53 ScopeHeight int 54 } 55 56 // ProcedureReferencable indicates that a sql.Node takes a *ProcedureReference returns a new copy with the reference set. 57 type ProcedureReferencable interface { 58 WithParamReference(pRef *ProcedureReference) sql.Node 59 } 60 61 // InitializeVariable sets the initial value for the variable. 62 func (ppr *ProcedureReference) InitializeVariable(name string, sqlType sql.Type, val interface{}) error { 63 convertedVal, _, err := sqlType.Convert(val) 64 if err != nil { 65 return err 66 } 67 lowerName := strings.ToLower(name) 68 ppr.InnermostScope.variables[lowerName] = &procedureVariableReferenceValue{ 69 Name: lowerName, 70 Value: convertedVal, 71 SqlType: sqlType, 72 HasBeenSet: false, 73 } 74 return nil 75 } 76 77 // InitializeCursor sets the initial state for the cursor. 78 func (ppr *ProcedureReference) InitializeCursor(name string, selectStmt sql.Node) { 79 lowerName := strings.ToLower(name) 80 ppr.InnermostScope.Cursors[lowerName] = &procedureCursorReferenceValue{ 81 Name: lowerName, 82 SelectStmt: selectStmt, 83 RowIter: nil, 84 } 85 } 86 87 // InitializeHandler sets the given handler's statement. 88 func (ppr *ProcedureReference) InitializeHandler(stmt sql.Node, action DeclareHandlerAction, cond HandlerCondition) { 89 ppr.InnermostScope.Handlers = append(ppr.InnermostScope.Handlers, &procedureHandlerReferenceValue{ 90 Stmt: stmt, 91 Cond: cond, 92 Action: action, 93 ScopeHeight: ppr.height, 94 }) 95 } 96 97 // GetVariableValue returns the value of the given parameter. 98 func (ppr *ProcedureReference) GetVariableValue(name string) (interface{}, error) { 99 lowerName := strings.ToLower(name) 100 scope := ppr.InnermostScope 101 for scope != nil { 102 if varRefVal, ok := scope.variables[lowerName]; ok { 103 return varRefVal.Value, nil 104 } 105 scope = scope.Parent 106 } 107 return nil, fmt.Errorf("cannot find value for parameter `%s`", name) 108 } 109 110 // GetVariableType returns the type of the given parameter. Returns the NULL type if the type cannot be found. 111 func (ppr *ProcedureReference) GetVariableType(name string) sql.Type { 112 if ppr == nil { 113 return types.Null 114 } 115 lowerName := strings.ToLower(name) 116 scope := ppr.InnermostScope 117 for scope != nil { 118 if varRefVal, ok := scope.variables[lowerName]; ok { 119 return varRefVal.SqlType 120 } 121 scope = scope.Parent 122 } 123 return types.Null 124 } 125 126 // SetVariable updates the value of the given parameter. 127 func (ppr *ProcedureReference) SetVariable(name string, val interface{}, valType sql.Type) error { 128 lowerName := strings.ToLower(name) 129 scope := ppr.InnermostScope 130 for scope != nil { 131 if varRefVal, ok := scope.variables[lowerName]; ok { 132 //TODO: do some actual type checking using the given value's type 133 val, _, err := varRefVal.SqlType.Convert(val) 134 if err != nil { 135 return err 136 } 137 varRefVal.Value = val 138 varRefVal.HasBeenSet = true 139 return nil 140 } 141 scope = scope.Parent 142 } 143 return fmt.Errorf("cannot find value for parameter `%s`", name) 144 } 145 146 // VariableHasBeenSet returns whether the parameter has had its value altered from the initial value. 147 func (ppr *ProcedureReference) VariableHasBeenSet(name string) bool { 148 lowerName := strings.ToLower(name) 149 scope := ppr.InnermostScope 150 for scope != nil { 151 if varRefVal, ok := scope.variables[lowerName]; ok { 152 return varRefVal.HasBeenSet 153 } 154 scope = scope.Parent 155 } 156 return false 157 } 158 159 // CloseCursor closes the designated cursor. 160 func (ppr *ProcedureReference) CloseCursor(ctx *sql.Context, name string) error { 161 lowerName := strings.ToLower(name) 162 scope := ppr.InnermostScope 163 for scope != nil { 164 if cursorRefVal, ok := scope.Cursors[lowerName]; ok { 165 if cursorRefVal.RowIter == nil { 166 return sql.ErrCursorNotOpen.New(name) 167 } 168 err := cursorRefVal.RowIter.Close(ctx) 169 cursorRefVal.RowIter = nil 170 return err 171 } 172 scope = scope.Parent 173 } 174 return fmt.Errorf("cannot find cursor `%s`", name) 175 } 176 177 // FetchCursor returns the next row from the designated cursor. 178 func (ppr *ProcedureReference) FetchCursor(ctx *sql.Context, name string) (sql.Row, sql.Schema, error) { 179 lowerName := strings.ToLower(name) 180 scope := ppr.InnermostScope 181 for scope != nil { 182 if cursorRefVal, ok := scope.Cursors[lowerName]; ok { 183 if cursorRefVal.RowIter == nil { 184 return nil, nil, sql.ErrCursorNotOpen.New(name) 185 } 186 row, err := cursorRefVal.RowIter.Next(ctx) 187 return row, cursorRefVal.SelectStmt.Schema(), err 188 } 189 scope = scope.Parent 190 } 191 return nil, nil, fmt.Errorf("cannot find cursor `%s`", name) 192 } 193 194 // PushScope creates a new scope inside the current one. 195 func (ppr *ProcedureReference) PushScope() { 196 ppr.InnermostScope = &procedureScope{ 197 Parent: ppr.InnermostScope, 198 variables: make(map[string]*procedureVariableReferenceValue), 199 Cursors: make(map[string]*procedureCursorReferenceValue), 200 Handlers: nil, 201 } 202 ppr.height++ 203 } 204 205 // PopScope removes the innermost scope, returning to its parent. Also closes all open cursors. 206 func (ppr *ProcedureReference) PopScope(ctx *sql.Context) error { 207 var err error 208 if ppr.InnermostScope == nil { 209 return fmt.Errorf("attempted to pop an empty scope") 210 } 211 for _, cursorRefVal := range ppr.InnermostScope.Cursors { 212 if cursorRefVal.RowIter != nil { 213 nErr := cursorRefVal.RowIter.Close(ctx) 214 cursorRefVal.RowIter = nil 215 if err == nil { 216 err = nErr 217 } 218 } 219 } 220 ppr.InnermostScope = ppr.InnermostScope.Parent 221 ppr.height-- 222 return nil 223 } 224 225 // CloseAllCursors closes all cursors that are still open. 226 func (ppr *ProcedureReference) CloseAllCursors(ctx *sql.Context) error { 227 var err error 228 scope := ppr.InnermostScope 229 for scope != nil { 230 for _, cursorRefVal := range scope.Cursors { 231 if cursorRefVal.RowIter != nil { 232 nErr := cursorRefVal.RowIter.Close(ctx) 233 cursorRefVal.RowIter = nil 234 if err == nil { 235 err = nErr 236 } 237 } 238 } 239 scope = scope.Parent 240 } 241 return err 242 } 243 244 // CurrentHeight returns the current height of the scope stack. 245 func (ppr *ProcedureReference) CurrentHeight() int { 246 return ppr.height 247 } 248 249 func NewProcedureReference() *ProcedureReference { 250 return &ProcedureReference{ 251 InnermostScope: &procedureScope{ 252 Parent: nil, 253 variables: make(map[string]*procedureVariableReferenceValue), 254 Cursors: make(map[string]*procedureCursorReferenceValue), 255 Handlers: nil, 256 }, 257 height: 0, 258 } 259 } 260 261 // ProcedureParam represents the parameter of a stored procedure or stored function. 262 type ProcedureParam struct { 263 name string 264 pRef *ProcedureReference 265 typ sql.Type 266 hasBeenSet bool 267 } 268 269 var _ sql.Expression = (*ProcedureParam)(nil) 270 var _ sql.CollationCoercible = (*ProcedureParam)(nil) 271 272 // NewProcedureParam creates a new ProcedureParam expression. 273 func NewProcedureParam(name string, typ sql.Type) *ProcedureParam { 274 return &ProcedureParam{name: strings.ToLower(name), typ: typ} 275 } 276 277 // Children implements the sql.Expression interface. 278 func (*ProcedureParam) Children() []sql.Expression { 279 return nil 280 } 281 282 // Resolved implements the sql.Expression interface. 283 func (*ProcedureParam) Resolved() bool { 284 return true 285 } 286 287 // IsNullable implements the sql.Expression interface. 288 func (*ProcedureParam) IsNullable() bool { 289 return false 290 } 291 292 // Type implements the sql.Expression interface. 293 func (pp *ProcedureParam) Type() sql.Type { 294 return pp.typ 295 } 296 297 // CollationCoercibility implements the sql.CollationCoercible interface. 298 func (pp *ProcedureParam) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 299 collation, _ = pp.pRef.GetVariableType(pp.name).CollationCoercibility(ctx) 300 return collation, 2 301 } 302 303 // Name implements the Nameable interface. 304 func (pp *ProcedureParam) Name() string { 305 return pp.name 306 } 307 308 // String implements the sql.Expression interface. 309 func (pp *ProcedureParam) String() string { 310 return pp.name 311 } 312 313 // Eval implements the sql.Expression interface. 314 func (pp *ProcedureParam) Eval(ctx *sql.Context, r sql.Row) (interface{}, error) { 315 return pp.pRef.GetVariableValue(pp.name) 316 } 317 318 // WithChildren implements the sql.Expression interface. 319 func (pp *ProcedureParam) WithChildren(children ...sql.Expression) (sql.Expression, error) { 320 if len(children) != 0 { 321 return nil, sql.ErrInvalidChildrenNumber.New(pp, len(children), 0) 322 } 323 return pp, nil 324 } 325 326 // WithParamReference returns a new *ProcedureParam containing the given *ProcedureReference. 327 func (pp *ProcedureParam) WithParamReference(pRef *ProcedureReference) *ProcedureParam { 328 npp := *pp 329 npp.pRef = pRef 330 return &npp 331 } 332 333 // Set sets the value of this procedure parameter to the given value. 334 func (pp *ProcedureParam) Set(val interface{}, valType sql.Type) error { 335 return pp.pRef.SetVariable(pp.name, val, valType) 336 } 337 338 // UnresolvedProcedureParam represents an unresolved parameter of a stored procedure or stored function. 339 type UnresolvedProcedureParam struct { 340 name string 341 } 342 343 var _ sql.Expression = (*UnresolvedProcedureParam)(nil) 344 var _ sql.CollationCoercible = (*UnresolvedProcedureParam)(nil) 345 346 // NewUnresolvedProcedureParam creates a new UnresolvedProcedureParam expression. 347 func NewUnresolvedProcedureParam(name string) *UnresolvedProcedureParam { 348 return &UnresolvedProcedureParam{name: strings.ToLower(name)} 349 } 350 351 // Children implements the sql.Expression interface. 352 func (*UnresolvedProcedureParam) Children() []sql.Expression { 353 return nil 354 } 355 356 // Resolved implements the sql.Expression interface. 357 func (*UnresolvedProcedureParam) Resolved() bool { 358 return false 359 } 360 361 // IsNullable implements the sql.Expression interface. 362 func (*UnresolvedProcedureParam) IsNullable() bool { 363 return false 364 } 365 366 // Type implements the sql.Expression interface. 367 func (*UnresolvedProcedureParam) Type() sql.Type { 368 return types.Null 369 } 370 371 // CollationCoercibility implements the interface sql.CollationCoercible. 372 func (*UnresolvedProcedureParam) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 373 return sql.Collation_binary, 7 374 } 375 376 // Name implements the Nameable interface. 377 func (upp *UnresolvedProcedureParam) Name() string { 378 return upp.name 379 } 380 381 // String implements the sql.Expression interface. 382 func (upp *UnresolvedProcedureParam) String() string { 383 return upp.name 384 } 385 386 // Eval implements the sql.Expression interface. 387 func (upp *UnresolvedProcedureParam) Eval(ctx *sql.Context, r sql.Row) (interface{}, error) { 388 return nil, fmt.Errorf("attempted to use unresolved procedure param '%s'", upp.name) 389 } 390 391 // WithChildren implements the sql.Expression interface. 392 func (upp *UnresolvedProcedureParam) WithChildren(children ...sql.Expression) (sql.Expression, error) { 393 if len(children) != 0 { 394 return nil, sql.ErrInvalidChildrenNumber.New(upp, len(children), 0) 395 } 396 return upp, nil 397 } 398 399 // FetchEOF is a special EOF error that lets the loop implementation 400 // differentiate between this io.EOF 401 var FetchEOF = errors.New("exhausted fetch iterator") 402 403 type HandlerConditionType uint8 404 405 const ( 406 HandlerConditionUnknown HandlerConditionType = iota 407 HandlerConditionNotFound 408 HandlerConditionSqlException 409 ) 410 411 type HandlerCondition struct { 412 SqlStatePrefix string 413 Type HandlerConditionType 414 } 415 416 type DeclareHandlerAction byte 417 418 const ( 419 DeclareHandlerAction_Continue DeclareHandlerAction = iota 420 DeclareHandlerAction_Exit 421 DeclareHandlerAction_Undo 422 ) 423 424 func (c *HandlerCondition) Matches(err error) bool { 425 if errors.Is(err, FetchEOF) { 426 return c.Type == HandlerConditionNotFound 427 } else { 428 return c.Type == HandlerConditionSqlException 429 } 430 }