github.com/dolthub/go-mysql-server@v0.18.0/sql/plan/transaction.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package plan 16 17 import ( 18 "fmt" 19 20 "github.com/dolthub/go-mysql-server/sql" 21 ) 22 23 // transactionNode implements all the no-op methods of sql.Node 24 type transactionNode struct{} 25 26 func (transactionNode) Children() []sql.Node { 27 return nil 28 } 29 30 // CheckPrivileges implements the interface sql.Node. 31 func (transactionNode) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { 32 return true 33 } 34 35 // CollationCoercibility implements the interface sql.CollationCoercible. 36 func (*transactionNode) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 37 return sql.Collation_binary, 7 38 } 39 40 // Resolved implements the sql.Node interface. 41 func (transactionNode) Resolved() bool { 42 return true 43 } 44 45 func (transactionNode) IsReadOnly() bool { 46 return true 47 } 48 49 // Schema implements the sql.Node interface. 50 func (transactionNode) Schema() sql.Schema { 51 return nil 52 } 53 54 // StartTransaction explicitly starts a transaction. Transactions also start before any statement execution that 55 // doesn't have a transaction. Starting a transaction implicitly commits any in-progress one. 56 type StartTransaction struct { 57 transactionNode 58 TransChar sql.TransactionCharacteristic 59 } 60 61 var _ sql.Node = (*StartTransaction)(nil) 62 var _ sql.CollationCoercible = (*StartTransaction)(nil) 63 64 // NewStartTransaction creates a new StartTransaction node. 65 func NewStartTransaction(transactionChar sql.TransactionCharacteristic) *StartTransaction { 66 return &StartTransaction{ 67 TransChar: transactionChar, 68 } 69 } 70 71 // RowIter implements the sql.Node interface. 72 func (s *StartTransaction) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) { 73 ts, ok := ctx.Session.(sql.TransactionSession) 74 if !ok { 75 return sql.RowsToRowIter(), nil 76 } 77 78 currentTx := ctx.GetTransaction() 79 // A START TRANSACTION statement commits any pending work before beginning a new tx 80 // TODO: this work is wasted in the case that START TRANSACTION is the first statement after COMMIT 81 // an isDirty method on the transaction would allow us to avoid this 82 if currentTx != nil { 83 err := ts.CommitTransaction(ctx, currentTx) 84 if err != nil { 85 return nil, err 86 } 87 } 88 89 transaction, err := ts.StartTransaction(ctx, s.TransChar) 90 if err != nil { 91 return nil, err 92 } 93 94 ctx.SetTransaction(transaction) 95 // until this transaction is committed or rolled back, don't begin or commit any transactions automatically 96 ctx.SetIgnoreAutoCommit(true) 97 98 return sql.RowsToRowIter(), nil 99 } 100 101 func (s *StartTransaction) String() string { 102 return "Start Transaction" 103 } 104 105 // WithChildren implements the Node interface. 106 func (s *StartTransaction) WithChildren(children ...sql.Node) (sql.Node, error) { 107 if len(children) != 0 { 108 return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 0) 109 } 110 111 return s, nil 112 } 113 114 // Commit commits the changes performed in a transaction. For sessions that don't implement sql.TransactionSession, 115 // this operation is a no-op. 116 type Commit struct { 117 transactionNode 118 } 119 120 var _ sql.Node = (*Commit)(nil) 121 var _ sql.CollationCoercible = (*Commit)(nil) 122 123 // NewCommit creates a new Commit node. 124 func NewCommit() *Commit { 125 return &Commit{} 126 } 127 128 // RowIter implements the sql.Node interface. 129 func (c *Commit) RowIter(ctx *sql.Context, _ sql.Row) (sql.RowIter, error) { 130 ts, ok := ctx.Session.(sql.TransactionSession) 131 if !ok { 132 return sql.RowsToRowIter(), nil 133 } 134 135 transaction := ctx.GetTransaction() 136 137 if transaction == nil { 138 return sql.RowsToRowIter(), nil 139 } 140 141 err := ts.CommitTransaction(ctx, transaction) 142 if err != nil { 143 return nil, err 144 } 145 146 ctx.SetIgnoreAutoCommit(false) 147 ctx.SetTransaction(nil) 148 149 return sql.RowsToRowIter(), nil 150 } 151 152 func (*Commit) String() string { return "COMMIT" } 153 154 // WithChildren implements the Node interface. 155 func (c *Commit) WithChildren(children ...sql.Node) (sql.Node, error) { 156 if len(children) != 0 { 157 return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 0) 158 } 159 160 return c, nil 161 } 162 163 // Rollback undoes the changes performed in the current transaction. For compatibility, sessions that don't implement 164 // sql.TransactionSession treat this as a no-op. 165 type Rollback struct { 166 transactionNode 167 } 168 169 var _ sql.Node = (*Rollback)(nil) 170 var _ sql.CollationCoercible = (*Rollback)(nil) 171 172 // NewRollback creates a new Rollback node. 173 func NewRollback() *Rollback { 174 return &Rollback{} 175 } 176 177 // RowIter implements the sql.Node interface. 178 func (r *Rollback) RowIter(ctx *sql.Context, _ sql.Row) (sql.RowIter, error) { 179 ts, ok := ctx.Session.(sql.TransactionSession) 180 if !ok { 181 return sql.RowsToRowIter(), nil 182 } 183 184 transaction := ctx.GetTransaction() 185 186 if transaction == nil { 187 return sql.RowsToRowIter(), nil 188 } 189 190 err := ts.Rollback(ctx, transaction) 191 if err != nil { 192 return nil, err 193 } 194 195 // Like Commit, Rollback ends the current transaction and a new one begins with the next statement 196 ctx.SetIgnoreAutoCommit(false) 197 ctx.SetTransaction(nil) 198 199 return sql.RowsToRowIter(), nil 200 } 201 202 func (*Rollback) String() string { return "ROLLBACK" } 203 204 // WithChildren implements the Node interface. 205 func (r *Rollback) WithChildren(children ...sql.Node) (sql.Node, error) { 206 if len(children) != 0 { 207 return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 0) 208 } 209 210 return r, nil 211 } 212 213 // CreateSavepoint creates a savepoint with the given name. For sessions that don't implement sql.TransactionSession, 214 // this is a no-op. 215 type CreateSavepoint struct { 216 transactionNode 217 Name string 218 } 219 220 var _ sql.Node = (*CreateSavepoint)(nil) 221 var _ sql.CollationCoercible = (*CreateSavepoint)(nil) 222 223 // NewCreateSavepoint creates a new CreateSavepoint node. 224 func NewCreateSavepoint(name string) *CreateSavepoint { 225 return &CreateSavepoint{Name: name} 226 } 227 228 // RowIter implements the sql.Node interface. 229 func (c *CreateSavepoint) RowIter(ctx *sql.Context, _ sql.Row) (sql.RowIter, error) { 230 ts, ok := ctx.Session.(sql.TransactionSession) 231 if !ok { 232 return sql.RowsToRowIter(), nil 233 } 234 235 transaction := ctx.GetTransaction() 236 237 if transaction == nil { 238 return sql.RowsToRowIter(), nil 239 } 240 241 err := ts.CreateSavepoint(ctx, transaction, c.Name) 242 if err != nil { 243 return nil, err 244 } 245 246 return sql.RowsToRowIter(), nil 247 } 248 249 func (c *CreateSavepoint) String() string { return fmt.Sprintf("SAVEPOINT %s", c.Name) } 250 251 // WithChildren implements the Node interface. 252 func (c *CreateSavepoint) WithChildren(children ...sql.Node) (sql.Node, error) { 253 if len(children) != 0 { 254 return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 0) 255 } 256 257 return c, nil 258 } 259 260 // RollbackSavepoint rolls back the current transaction to the given savepoint. For sessions that don't implement 261 // sql.TransactionSession, this is a no-op. 262 type RollbackSavepoint struct { 263 transactionNode 264 Name string 265 } 266 267 var _ sql.Node = (*RollbackSavepoint)(nil) 268 var _ sql.CollationCoercible = (*RollbackSavepoint)(nil) 269 270 // NewRollbackSavepoint creates a new RollbackSavepoint node. 271 func NewRollbackSavepoint(name string) *RollbackSavepoint { 272 return &RollbackSavepoint{ 273 Name: name, 274 } 275 } 276 277 // RowIter implements the sql.Node interface. 278 func (r *RollbackSavepoint) RowIter(ctx *sql.Context, _ sql.Row) (sql.RowIter, error) { 279 ts, ok := ctx.Session.(sql.TransactionSession) 280 if !ok { 281 return sql.RowsToRowIter(), nil 282 } 283 284 transaction := ctx.GetTransaction() 285 286 if transaction == nil { 287 return sql.RowsToRowIter(), nil 288 } 289 290 err := ts.RollbackToSavepoint(ctx, transaction, r.Name) 291 if err != nil { 292 return nil, err 293 } 294 295 return sql.RowsToRowIter(), nil 296 } 297 298 func (r *RollbackSavepoint) String() string { return fmt.Sprintf("ROLLBACK TO SAVEPOINT %s", r.Name) } 299 300 // WithChildren implements the Node interface. 301 func (r *RollbackSavepoint) WithChildren(children ...sql.Node) (sql.Node, error) { 302 if len(children) != 0 { 303 return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 0) 304 } 305 306 return r, nil 307 } 308 309 // ReleaseSavepoint releases the given savepoint. For sessions that don't implement sql.TransactionSession, this is a 310 // no-op. 311 type ReleaseSavepoint struct { 312 transactionNode 313 Name string 314 } 315 316 var _ sql.Node = (*ReleaseSavepoint)(nil) 317 var _ sql.CollationCoercible = (*ReleaseSavepoint)(nil) 318 319 // NewReleaseSavepoint creates a new ReleaseSavepoint node. 320 func NewReleaseSavepoint(name string) *ReleaseSavepoint { 321 return &ReleaseSavepoint{ 322 Name: name, 323 } 324 } 325 326 // RowIter implements the sql.Node interface. 327 func (r *ReleaseSavepoint) RowIter(ctx *sql.Context, _ sql.Row) (sql.RowIter, error) { 328 ts, ok := ctx.Session.(sql.TransactionSession) 329 if !ok { 330 return sql.RowsToRowIter(), nil 331 } 332 333 transaction := ctx.GetTransaction() 334 335 if transaction == nil { 336 return sql.RowsToRowIter(), nil 337 } 338 339 err := ts.ReleaseSavepoint(ctx, transaction, r.Name) 340 if err != nil { 341 return nil, err 342 } 343 344 return sql.RowsToRowIter(), nil 345 } 346 347 func (r *ReleaseSavepoint) String() string { return fmt.Sprintf("RELEASE SAVEPOINT %s", r.Name) } 348 349 // WithChildren implements the Node interface. 350 func (r *ReleaseSavepoint) WithChildren(children ...sql.Node) (sql.Node, error) { 351 if len(children) != 0 { 352 return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 0) 353 } 354 355 return r, nil 356 }