github.com/dolthub/go-mysql-server@v0.18.0/sql/plan/transaction.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package plan
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"github.com/dolthub/go-mysql-server/sql"
    21  )
    22  
    23  // transactionNode implements all the no-op methods of sql.Node
    24  type transactionNode struct{}
    25  
    26  func (transactionNode) Children() []sql.Node {
    27  	return nil
    28  }
    29  
    30  // CheckPrivileges implements the interface sql.Node.
    31  func (transactionNode) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
    32  	return true
    33  }
    34  
    35  // CollationCoercibility implements the interface sql.CollationCoercible.
    36  func (*transactionNode) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
    37  	return sql.Collation_binary, 7
    38  }
    39  
    40  // Resolved implements the sql.Node interface.
    41  func (transactionNode) Resolved() bool {
    42  	return true
    43  }
    44  
    45  func (transactionNode) IsReadOnly() bool {
    46  	return true
    47  }
    48  
    49  // Schema implements the sql.Node interface.
    50  func (transactionNode) Schema() sql.Schema {
    51  	return nil
    52  }
    53  
    54  // StartTransaction explicitly starts a transaction. Transactions also start before any statement execution that
    55  // doesn't have a transaction. Starting a transaction implicitly commits any in-progress one.
    56  type StartTransaction struct {
    57  	transactionNode
    58  	TransChar sql.TransactionCharacteristic
    59  }
    60  
    61  var _ sql.Node = (*StartTransaction)(nil)
    62  var _ sql.CollationCoercible = (*StartTransaction)(nil)
    63  
    64  // NewStartTransaction creates a new StartTransaction node.
    65  func NewStartTransaction(transactionChar sql.TransactionCharacteristic) *StartTransaction {
    66  	return &StartTransaction{
    67  		TransChar: transactionChar,
    68  	}
    69  }
    70  
    71  // RowIter implements the sql.Node interface.
    72  func (s *StartTransaction) RowIter(ctx *sql.Context, row sql.Row) (sql.RowIter, error) {
    73  	ts, ok := ctx.Session.(sql.TransactionSession)
    74  	if !ok {
    75  		return sql.RowsToRowIter(), nil
    76  	}
    77  
    78  	currentTx := ctx.GetTransaction()
    79  	// A START TRANSACTION statement commits any pending work before beginning a new tx
    80  	// TODO: this work is wasted in the case that START TRANSACTION is the first statement after COMMIT
    81  	//  an isDirty method on the transaction would allow us to avoid this
    82  	if currentTx != nil {
    83  		err := ts.CommitTransaction(ctx, currentTx)
    84  		if err != nil {
    85  			return nil, err
    86  		}
    87  	}
    88  
    89  	transaction, err := ts.StartTransaction(ctx, s.TransChar)
    90  	if err != nil {
    91  		return nil, err
    92  	}
    93  
    94  	ctx.SetTransaction(transaction)
    95  	// until this transaction is committed or rolled back, don't begin or commit any transactions automatically
    96  	ctx.SetIgnoreAutoCommit(true)
    97  
    98  	return sql.RowsToRowIter(), nil
    99  }
   100  
   101  func (s *StartTransaction) String() string {
   102  	return "Start Transaction"
   103  }
   104  
   105  // WithChildren implements the Node interface.
   106  func (s *StartTransaction) WithChildren(children ...sql.Node) (sql.Node, error) {
   107  	if len(children) != 0 {
   108  		return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 0)
   109  	}
   110  
   111  	return s, nil
   112  }
   113  
   114  // Commit commits the changes performed in a transaction. For sessions that don't implement sql.TransactionSession,
   115  // this operation is a no-op.
   116  type Commit struct {
   117  	transactionNode
   118  }
   119  
   120  var _ sql.Node = (*Commit)(nil)
   121  var _ sql.CollationCoercible = (*Commit)(nil)
   122  
   123  // NewCommit creates a new Commit node.
   124  func NewCommit() *Commit {
   125  	return &Commit{}
   126  }
   127  
   128  // RowIter implements the sql.Node interface.
   129  func (c *Commit) RowIter(ctx *sql.Context, _ sql.Row) (sql.RowIter, error) {
   130  	ts, ok := ctx.Session.(sql.TransactionSession)
   131  	if !ok {
   132  		return sql.RowsToRowIter(), nil
   133  	}
   134  
   135  	transaction := ctx.GetTransaction()
   136  
   137  	if transaction == nil {
   138  		return sql.RowsToRowIter(), nil
   139  	}
   140  
   141  	err := ts.CommitTransaction(ctx, transaction)
   142  	if err != nil {
   143  		return nil, err
   144  	}
   145  
   146  	ctx.SetIgnoreAutoCommit(false)
   147  	ctx.SetTransaction(nil)
   148  
   149  	return sql.RowsToRowIter(), nil
   150  }
   151  
   152  func (*Commit) String() string { return "COMMIT" }
   153  
   154  // WithChildren implements the Node interface.
   155  func (c *Commit) WithChildren(children ...sql.Node) (sql.Node, error) {
   156  	if len(children) != 0 {
   157  		return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 0)
   158  	}
   159  
   160  	return c, nil
   161  }
   162  
   163  // Rollback undoes the changes performed in the current transaction. For compatibility, sessions that don't implement
   164  // sql.TransactionSession treat this as a no-op.
   165  type Rollback struct {
   166  	transactionNode
   167  }
   168  
   169  var _ sql.Node = (*Rollback)(nil)
   170  var _ sql.CollationCoercible = (*Rollback)(nil)
   171  
   172  // NewRollback creates a new Rollback node.
   173  func NewRollback() *Rollback {
   174  	return &Rollback{}
   175  }
   176  
   177  // RowIter implements the sql.Node interface.
   178  func (r *Rollback) RowIter(ctx *sql.Context, _ sql.Row) (sql.RowIter, error) {
   179  	ts, ok := ctx.Session.(sql.TransactionSession)
   180  	if !ok {
   181  		return sql.RowsToRowIter(), nil
   182  	}
   183  
   184  	transaction := ctx.GetTransaction()
   185  
   186  	if transaction == nil {
   187  		return sql.RowsToRowIter(), nil
   188  	}
   189  
   190  	err := ts.Rollback(ctx, transaction)
   191  	if err != nil {
   192  		return nil, err
   193  	}
   194  
   195  	// Like Commit, Rollback ends the current transaction and a new one begins with the next statement
   196  	ctx.SetIgnoreAutoCommit(false)
   197  	ctx.SetTransaction(nil)
   198  
   199  	return sql.RowsToRowIter(), nil
   200  }
   201  
   202  func (*Rollback) String() string { return "ROLLBACK" }
   203  
   204  // WithChildren implements the Node interface.
   205  func (r *Rollback) WithChildren(children ...sql.Node) (sql.Node, error) {
   206  	if len(children) != 0 {
   207  		return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 0)
   208  	}
   209  
   210  	return r, nil
   211  }
   212  
   213  // CreateSavepoint creates a savepoint with the given name. For sessions that don't implement sql.TransactionSession,
   214  // this is a no-op.
   215  type CreateSavepoint struct {
   216  	transactionNode
   217  	Name string
   218  }
   219  
   220  var _ sql.Node = (*CreateSavepoint)(nil)
   221  var _ sql.CollationCoercible = (*CreateSavepoint)(nil)
   222  
   223  // NewCreateSavepoint creates a new CreateSavepoint node.
   224  func NewCreateSavepoint(name string) *CreateSavepoint {
   225  	return &CreateSavepoint{Name: name}
   226  }
   227  
   228  // RowIter implements the sql.Node interface.
   229  func (c *CreateSavepoint) RowIter(ctx *sql.Context, _ sql.Row) (sql.RowIter, error) {
   230  	ts, ok := ctx.Session.(sql.TransactionSession)
   231  	if !ok {
   232  		return sql.RowsToRowIter(), nil
   233  	}
   234  
   235  	transaction := ctx.GetTransaction()
   236  
   237  	if transaction == nil {
   238  		return sql.RowsToRowIter(), nil
   239  	}
   240  
   241  	err := ts.CreateSavepoint(ctx, transaction, c.Name)
   242  	if err != nil {
   243  		return nil, err
   244  	}
   245  
   246  	return sql.RowsToRowIter(), nil
   247  }
   248  
   249  func (c *CreateSavepoint) String() string { return fmt.Sprintf("SAVEPOINT %s", c.Name) }
   250  
   251  // WithChildren implements the Node interface.
   252  func (c *CreateSavepoint) WithChildren(children ...sql.Node) (sql.Node, error) {
   253  	if len(children) != 0 {
   254  		return nil, sql.ErrInvalidChildrenNumber.New(c, len(children), 0)
   255  	}
   256  
   257  	return c, nil
   258  }
   259  
   260  // RollbackSavepoint rolls back the current transaction to the given savepoint. For sessions that don't implement
   261  // sql.TransactionSession, this is a no-op.
   262  type RollbackSavepoint struct {
   263  	transactionNode
   264  	Name string
   265  }
   266  
   267  var _ sql.Node = (*RollbackSavepoint)(nil)
   268  var _ sql.CollationCoercible = (*RollbackSavepoint)(nil)
   269  
   270  // NewRollbackSavepoint creates a new RollbackSavepoint node.
   271  func NewRollbackSavepoint(name string) *RollbackSavepoint {
   272  	return &RollbackSavepoint{
   273  		Name: name,
   274  	}
   275  }
   276  
   277  // RowIter implements the sql.Node interface.
   278  func (r *RollbackSavepoint) RowIter(ctx *sql.Context, _ sql.Row) (sql.RowIter, error) {
   279  	ts, ok := ctx.Session.(sql.TransactionSession)
   280  	if !ok {
   281  		return sql.RowsToRowIter(), nil
   282  	}
   283  
   284  	transaction := ctx.GetTransaction()
   285  
   286  	if transaction == nil {
   287  		return sql.RowsToRowIter(), nil
   288  	}
   289  
   290  	err := ts.RollbackToSavepoint(ctx, transaction, r.Name)
   291  	if err != nil {
   292  		return nil, err
   293  	}
   294  
   295  	return sql.RowsToRowIter(), nil
   296  }
   297  
   298  func (r *RollbackSavepoint) String() string { return fmt.Sprintf("ROLLBACK TO SAVEPOINT %s", r.Name) }
   299  
   300  // WithChildren implements the Node interface.
   301  func (r *RollbackSavepoint) WithChildren(children ...sql.Node) (sql.Node, error) {
   302  	if len(children) != 0 {
   303  		return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 0)
   304  	}
   305  
   306  	return r, nil
   307  }
   308  
   309  // ReleaseSavepoint releases the given savepoint. For sessions that don't implement sql.TransactionSession, this is a
   310  // no-op.
   311  type ReleaseSavepoint struct {
   312  	transactionNode
   313  	Name string
   314  }
   315  
   316  var _ sql.Node = (*ReleaseSavepoint)(nil)
   317  var _ sql.CollationCoercible = (*ReleaseSavepoint)(nil)
   318  
   319  // NewReleaseSavepoint creates a new ReleaseSavepoint node.
   320  func NewReleaseSavepoint(name string) *ReleaseSavepoint {
   321  	return &ReleaseSavepoint{
   322  		Name: name,
   323  	}
   324  }
   325  
   326  // RowIter implements the sql.Node interface.
   327  func (r *ReleaseSavepoint) RowIter(ctx *sql.Context, _ sql.Row) (sql.RowIter, error) {
   328  	ts, ok := ctx.Session.(sql.TransactionSession)
   329  	if !ok {
   330  		return sql.RowsToRowIter(), nil
   331  	}
   332  
   333  	transaction := ctx.GetTransaction()
   334  
   335  	if transaction == nil {
   336  		return sql.RowsToRowIter(), nil
   337  	}
   338  
   339  	err := ts.ReleaseSavepoint(ctx, transaction, r.Name)
   340  	if err != nil {
   341  		return nil, err
   342  	}
   343  
   344  	return sql.RowsToRowIter(), nil
   345  }
   346  
   347  func (r *ReleaseSavepoint) String() string { return fmt.Sprintf("RELEASE SAVEPOINT %s", r.Name) }
   348  
   349  // WithChildren implements the Node interface.
   350  func (r *ReleaseSavepoint) WithChildren(children ...sql.Node) (sql.Node, error) {
   351  	if len(children) != 0 {
   352  		return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), 0)
   353  	}
   354  
   355  	return r, nil
   356  }