github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/proc_iters.go (about) 1 // Copyright 2023 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package rowexec 16 17 import ( 18 "errors" 19 "fmt" 20 "io" 21 "strings" 22 23 "github.com/dolthub/vitess/go/mysql" 24 25 "github.com/dolthub/go-mysql-server/sql" 26 "github.com/dolthub/go-mysql-server/sql/expression" 27 "github.com/dolthub/go-mysql-server/sql/plan" 28 ) 29 30 // ifElseIter is the row iterator for *IfElseBlock. 31 type ifElseIter struct { 32 branchIter sql.RowIter 33 sch sql.Schema 34 branchNode sql.Node 35 } 36 37 var _ plan.BlockRowIter = (*ifElseIter)(nil) 38 39 // Next implements the sql.RowIter interface. 40 func (i *ifElseIter) Next(ctx *sql.Context) (sql.Row, error) { 41 if err := startTransaction(ctx); err != nil { 42 return nil, err 43 } 44 45 return i.branchIter.Next(ctx) 46 } 47 48 // Close implements the sql.RowIter interface. 49 func (i *ifElseIter) Close(ctx *sql.Context) error { 50 return i.branchIter.Close(ctx) 51 } 52 53 // RepresentingNode implements the sql.BlockRowIter interface. 54 func (i *ifElseIter) RepresentingNode() sql.Node { 55 return i.branchNode 56 } 57 58 // Schema implements the sql.BlockRowIter interface. 59 func (i *ifElseIter) Schema() sql.Schema { 60 return i.sch 61 } 62 63 // beginEndIter is the sql.RowIter of *BeginEndBlock. 64 type beginEndIter struct { 65 *plan.BeginEndBlock 66 rowIter sql.RowIter 67 } 68 69 var _ sql.RowIter = (*beginEndIter)(nil) 70 71 // Next implements the interface sql.RowIter. 72 func (b *beginEndIter) Next(ctx *sql.Context) (sql.Row, error) { 73 if err := startTransaction(ctx); err != nil { 74 return nil, err 75 } 76 77 row, err := b.rowIter.Next(ctx) 78 if err != nil { 79 if controlFlow, ok := err.(loopError); ok && strings.ToLower(controlFlow.Label) == strings.ToLower(b.Label) { 80 if controlFlow.IsExit { 81 err = nil 82 } else { 83 err = fmt.Errorf("encountered ITERATE on BEGIN...END, which should should have been caught by the analyzer") 84 } 85 } 86 if nErr := b.Pref.PopScope(ctx); nErr != nil && err == io.EOF { 87 err = nErr 88 } 89 if errors.Is(err, expression.FetchEOF) { 90 err = io.EOF 91 } 92 return nil, err 93 } 94 return row, nil 95 } 96 97 // Close implements the interface sql.RowIter. 98 func (b *beginEndIter) Close(ctx *sql.Context) error { 99 return b.rowIter.Close(ctx) 100 } 101 102 // callIter is the row iterator for *Call. 103 type callIter struct { 104 call *plan.Call 105 innerIter sql.RowIter 106 } 107 108 // Next implements the sql.RowIter interface. 109 func (iter *callIter) Next(ctx *sql.Context) (sql.Row, error) { 110 return iter.innerIter.Next(ctx) 111 } 112 113 // Close implements the sql.RowIter interface. 114 func (iter *callIter) Close(ctx *sql.Context) error { 115 err := iter.innerIter.Close(ctx) 116 if err != nil { 117 return err 118 } 119 err = iter.call.Pref.CloseAllCursors(ctx) 120 if err != nil { 121 return err 122 } 123 124 // Set all user and system variables from INOUT and OUT params 125 for i, param := range iter.call.Procedure.Params { 126 if param.Direction == plan.ProcedureParamDirection_Inout || 127 (param.Direction == plan.ProcedureParamDirection_Out && iter.call.Pref.VariableHasBeenSet(param.Name)) { 128 val, err := iter.call.Pref.GetVariableValue(param.Name) 129 if err != nil { 130 return err 131 } 132 133 typ := iter.call.Pref.GetVariableType(param.Name) 134 135 switch callParam := iter.call.Params[i].(type) { 136 case *expression.UserVar: 137 err = ctx.SetUserVariable(ctx, callParam.Name, val, typ) 138 if err != nil { 139 return err 140 } 141 case *expression.SystemVar: 142 // This should have been caught by the analyzer, so a major bug exists somewhere 143 return fmt.Errorf("unable to set `%s` as it is a system variable", callParam.Name) 144 case *expression.ProcedureParam: 145 err = callParam.Set(val, param.Type) 146 if err != nil { 147 return err 148 } 149 } 150 } else if param.Direction == plan.ProcedureParamDirection_Out { // VariableHasBeenSet was false 151 // For OUT only, if a var was not set within the procedure body, then we set the vars to nil. 152 // If the var had a value before the call then it is basically removed. 153 switch callParam := iter.call.Params[i].(type) { 154 case *expression.UserVar: 155 err = ctx.SetUserVariable(ctx, callParam.Name, nil, iter.call.Pref.GetVariableType(param.Name)) 156 if err != nil { 157 return err 158 } 159 case *expression.SystemVar: 160 // This should have been caught by the analyzer, so a major bug exists somewhere 161 return fmt.Errorf("unable to set `%s` as it is a system variable", callParam.Name) 162 case *expression.ProcedureParam: 163 err := callParam.Set(nil, param.Type) 164 if err != nil { 165 return err 166 } 167 } 168 } 169 } 170 return nil 171 } 172 173 type elseCaseErrorIter struct{} 174 175 var _ sql.RowIter = elseCaseErrorIter{} 176 177 // Next implements the interface sql.RowIter. 178 func (e elseCaseErrorIter) Next(ctx *sql.Context) (sql.Row, error) { 179 return nil, mysql.NewSQLError(1339, "20000", "Case not found for CASE statement") 180 } 181 182 // Close implements the interface sql.RowIter. 183 func (e elseCaseErrorIter) Close(context *sql.Context) error { 184 return nil 185 } 186 187 // openIter is the sql.RowIter of *Open. 188 type openIter struct { 189 pRef *expression.ProcedureReference 190 name string 191 row sql.Row 192 b *BaseBuilder 193 } 194 195 var _ sql.RowIter = (*openIter)(nil) 196 197 // Next implements the interface sql.RowIter. 198 func (o *openIter) Next(ctx *sql.Context) (sql.Row, error) { 199 if err := o.openCursor(ctx, o.pRef, o.name, o.row); err != nil { 200 return nil, err 201 } 202 return nil, io.EOF 203 } 204 205 func (o *openIter) openCursor(ctx *sql.Context, ref *expression.ProcedureReference, name string, row sql.Row) error { 206 lowerName := strings.ToLower(name) 207 scope := ref.InnermostScope 208 for scope != nil { 209 if cursorRefVal, ok := scope.Cursors[lowerName]; ok { 210 if cursorRefVal.RowIter != nil { 211 return sql.ErrCursorAlreadyOpen.New(name) 212 } 213 var err error 214 cursorRefVal.RowIter, err = o.b.buildNodeExec(ctx, cursorRefVal.SelectStmt, row) 215 return err 216 } 217 scope = scope.Parent 218 } 219 return fmt.Errorf("cannot find cursor `%s`", name) 220 } 221 222 // Close implements the interface sql.RowIter. 223 func (o *openIter) Close(ctx *sql.Context) error { 224 return nil 225 } 226 227 // closeIter is the sql.RowIter of *Close. 228 type closeIter struct { 229 pRef *expression.ProcedureReference 230 name string 231 } 232 233 var _ sql.RowIter = (*closeIter)(nil) 234 235 // Next implements the interface sql.RowIter. 236 func (c *closeIter) Next(ctx *sql.Context) (sql.Row, error) { 237 if err := c.pRef.CloseCursor(ctx, c.name); err != nil { 238 return nil, err 239 } 240 return nil, io.EOF 241 } 242 243 // Close implements the interface sql.RowIter. 244 func (c *closeIter) Close(ctx *sql.Context) error { 245 return nil 246 } 247 248 // loopError is an error used to control a loop's flow. 249 type loopError struct { 250 Label string 251 IsExit bool 252 } 253 254 var _ error = loopError{} 255 256 // Error implements the interface error. As long as the analysis step is implemented correctly, this should never be seen. 257 func (l loopError) Error() string { 258 option := "exited" 259 if !l.IsExit { 260 option = "continued" 261 } 262 return fmt.Sprintf("should have %s the loop `%s` but it was somehow not found in the call stack", option, l.Label) 263 } 264 265 // loopAcquireRowIter is a helper function for LOOP that conditionally acquires a new sql.RowIter. If a loop exit is 266 // encountered, `exitIter` determines whether to return an empty iterator or an io.EOF error. 267 func (b *BaseBuilder) loopAcquireRowIter(ctx *sql.Context, row sql.Row, label string, block *plan.Block, exitIter bool) (sql.RowIter, error) { 268 blockIter, err := b.buildBlock(ctx, block, row) 269 if controlFlow, ok := err.(loopError); ok && strings.ToLower(controlFlow.Label) == strings.ToLower(label) { 270 if controlFlow.IsExit { 271 if exitIter { 272 return sql.RowsToRowIter(), nil 273 } else { 274 return nil, io.EOF 275 } 276 } else { 277 err = io.EOF 278 } 279 } 280 if err == io.EOF { 281 blockIter = sql.RowsToRowIter() 282 err = nil 283 } 284 return blockIter, err 285 } 286 287 // leaveIter is the sql.RowIter of *Leave. 288 type leaveIter struct { 289 Label string 290 } 291 292 var _ sql.RowIter = (*leaveIter)(nil) 293 294 // Next implements the interface sql.RowIter. 295 func (l *leaveIter) Next(ctx *sql.Context) (sql.Row, error) { 296 return nil, loopError{ 297 Label: l.Label, 298 IsExit: true, 299 } 300 } 301 302 // Close implements the interface sql.RowIter. 303 func (l *leaveIter) Close(ctx *sql.Context) error { 304 return nil 305 } 306 307 // iterateIter is the sql.RowIter of *Iterate. 308 type iterateIter struct { 309 Label string 310 } 311 312 var _ sql.RowIter = (*iterateIter)(nil) 313 314 // Next implements the interface sql.RowIter. 315 func (i *iterateIter) Next(ctx *sql.Context) (sql.Row, error) { 316 return nil, loopError{ 317 Label: i.Label, 318 IsExit: false, 319 } 320 } 321 322 // Close implements the interface sql.RowIter. 323 func (i *iterateIter) Close(ctx *sql.Context) error { 324 return nil 325 } 326 327 // startTransaction begins a new transaction if necessary, e.g. if a statement in a stored procedure committed the 328 // current one 329 func startTransaction(ctx *sql.Context) error { 330 if ctx.GetTransaction() == nil { 331 ts, ok := ctx.Session.(sql.TransactionSession) 332 if ok { 333 tx, err := ts.StartTransaction(ctx, sql.ReadWrite) 334 if err != nil { 335 return err 336 } 337 338 ctx.SetTransaction(tx) 339 } 340 } 341 342 return nil 343 }