github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/transaction.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package rowexec
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"github.com/dolthub/vitess/go/mysql"
    21  
    22  	"github.com/dolthub/go-mysql-server/sql"
    23  	"github.com/dolthub/go-mysql-server/sql/plan"
    24  	"github.com/dolthub/go-mysql-server/sql/types"
    25  )
    26  
    27  func (b *BaseBuilder) buildRollbackSavepoint(ctx *sql.Context, n *plan.RollbackSavepoint, row sql.Row) (sql.RowIter, error) {
    28  	ts, ok := ctx.Session.(sql.TransactionSession)
    29  	if !ok {
    30  		return sql.RowsToRowIter(), nil
    31  	}
    32  
    33  	transaction := ctx.GetTransaction()
    34  
    35  	if transaction == nil {
    36  		return sql.RowsToRowIter(), nil
    37  	}
    38  
    39  	err := ts.RollbackToSavepoint(ctx, transaction, n.Name)
    40  	if err != nil {
    41  		return nil, err
    42  	}
    43  
    44  	return sql.RowsToRowIter(), nil
    45  }
    46  
    47  func (b *BaseBuilder) buildReleaseSavepoint(ctx *sql.Context, n *plan.ReleaseSavepoint, row sql.Row) (sql.RowIter, error) {
    48  	ts, ok := ctx.Session.(sql.TransactionSession)
    49  	if !ok {
    50  		return sql.RowsToRowIter(), nil
    51  	}
    52  
    53  	transaction := ctx.GetTransaction()
    54  
    55  	if transaction == nil {
    56  		return sql.RowsToRowIter(), nil
    57  	}
    58  
    59  	err := ts.ReleaseSavepoint(ctx, transaction, n.Name)
    60  	if err != nil {
    61  		return nil, err
    62  	}
    63  
    64  	return sql.RowsToRowIter(), nil
    65  }
    66  
    67  func (b *BaseBuilder) buildCreateSavepoint(ctx *sql.Context, n *plan.CreateSavepoint, row sql.Row) (sql.RowIter, error) {
    68  	ts, ok := ctx.Session.(sql.TransactionSession)
    69  	if !ok {
    70  		return sql.RowsToRowIter(), nil
    71  	}
    72  
    73  	transaction := ctx.GetTransaction()
    74  
    75  	if transaction == nil {
    76  		return sql.RowsToRowIter(), nil
    77  	}
    78  
    79  	err := ts.CreateSavepoint(ctx, transaction, n.Name)
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  
    84  	return sql.RowsToRowIter(), nil
    85  }
    86  
    87  func (b *BaseBuilder) buildStartTransaction(ctx *sql.Context, n *plan.StartTransaction, row sql.Row) (sql.RowIter, error) {
    88  	ts, ok := ctx.Session.(sql.TransactionSession)
    89  	if !ok {
    90  		return sql.RowsToRowIter(), nil
    91  	}
    92  
    93  	currentTx := ctx.GetTransaction()
    94  	// A START TRANSACTION statement commits any pending work before beginning a new tx
    95  	// TODO: this work is wasted in the case that START TRANSACTION is the first statement after COMMIT
    96  	//  an isDirty method on the transaction would allow us to avoid this
    97  	if currentTx != nil {
    98  		err := ts.CommitTransaction(ctx, currentTx)
    99  		if err != nil {
   100  			return nil, err
   101  		}
   102  	}
   103  
   104  	transaction, err := ts.StartTransaction(ctx, n.TransChar)
   105  	if err != nil {
   106  		return nil, err
   107  	}
   108  
   109  	ctx.SetTransaction(transaction)
   110  	// until this transaction is committed or rolled back, don't begin or commit any transactions automatically
   111  	ctx.SetIgnoreAutoCommit(true)
   112  
   113  	return sql.RowsToRowIter(), nil
   114  }
   115  
   116  func (b *BaseBuilder) buildStartReplica(ctx *sql.Context, n *plan.StartReplica, row sql.Row) (sql.RowIter, error) {
   117  	if n.ReplicaController == nil {
   118  		return nil, plan.ErrNoReplicationController.New()
   119  	}
   120  
   121  	err := n.ReplicaController.StartReplica(ctx)
   122  	return sql.RowsToRowIter(), err
   123  }
   124  
   125  func (b *BaseBuilder) buildUnlockTables(ctx *sql.Context, n *plan.UnlockTables, row sql.Row) (sql.RowIter, error) {
   126  	span, ctx := ctx.Span("plan.UnlockTables")
   127  	defer span.End()
   128  
   129  	if err := n.Catalog.UnlockTables(ctx, ctx.ID()); err != nil {
   130  		return nil, err
   131  	}
   132  
   133  	return sql.RowsToRowIter(), nil
   134  }
   135  
   136  func (b *BaseBuilder) buildCommit(ctx *sql.Context, n *plan.Commit, row sql.Row) (sql.RowIter, error) {
   137  	ts, ok := ctx.Session.(sql.TransactionSession)
   138  	if !ok {
   139  		return sql.RowsToRowIter(), nil
   140  	}
   141  
   142  	transaction := ctx.GetTransaction()
   143  
   144  	if transaction == nil {
   145  		return sql.RowsToRowIter(), nil
   146  	}
   147  
   148  	err := ts.CommitTransaction(ctx, transaction)
   149  	if err != nil {
   150  		return nil, err
   151  	}
   152  
   153  	ctx.SetIgnoreAutoCommit(false)
   154  	ctx.SetTransaction(nil)
   155  
   156  	return sql.RowsToRowIter(), nil
   157  }
   158  
   159  func (b *BaseBuilder) buildNoopTriggerRollback(ctx *sql.Context, n *plan.NoopTriggerRollback, row sql.Row) (sql.RowIter, error) {
   160  	return b.buildNodeExec(ctx, n.Child, row)
   161  
   162  }
   163  
   164  func (b *BaseBuilder) buildKill(ctx *sql.Context, n *plan.Kill, row sql.Row) (sql.RowIter, error) {
   165  	return &lazyRowIter{
   166  		func(ctx *sql.Context) (sql.Row, error) {
   167  			ctx.ProcessList.Kill(n.ConnID)
   168  			if n.Kt == plan.KillType_Connection {
   169  				ctx.KillConnection(n.ConnID)
   170  			}
   171  			return sql.NewRow(types.NewOkResult(0)), nil
   172  		},
   173  	}, nil
   174  }
   175  
   176  func (b *BaseBuilder) buildResetReplica(ctx *sql.Context, n *plan.ResetReplica, row sql.Row) (sql.RowIter, error) {
   177  	if n.ReplicaController == nil {
   178  		return nil, plan.ErrNoReplicationController.New()
   179  	}
   180  
   181  	err := n.ReplicaController.ResetReplica(ctx, n.All)
   182  	return sql.RowsToRowIter(), err
   183  }
   184  
   185  func (b *BaseBuilder) buildRollback(ctx *sql.Context, n *plan.Rollback, row sql.Row) (sql.RowIter, error) {
   186  	ts, ok := ctx.Session.(sql.TransactionSession)
   187  	if !ok {
   188  		return sql.RowsToRowIter(), nil
   189  	}
   190  
   191  	transaction := ctx.GetTransaction()
   192  
   193  	if transaction == nil {
   194  		return sql.RowsToRowIter(), nil
   195  	}
   196  
   197  	err := ts.Rollback(ctx, transaction)
   198  	if err != nil {
   199  		return nil, err
   200  	}
   201  
   202  	// Like Commit, Rollback ends the current transaction and a new one begins with the next statement
   203  	ctx.SetIgnoreAutoCommit(false)
   204  	ctx.SetTransaction(nil)
   205  
   206  	return sql.RowsToRowIter(), nil
   207  }
   208  
   209  func (b *BaseBuilder) buildChangeReplicationSource(ctx *sql.Context, n *plan.ChangeReplicationSource, row sql.Row) (sql.RowIter, error) {
   210  	if n.ReplicaController == nil {
   211  		return nil, plan.ErrNoReplicationController.New()
   212  	}
   213  
   214  	err := n.ReplicaController.SetReplicationSourceOptions(ctx, n.Options)
   215  	return sql.RowsToRowIter(), err
   216  }
   217  
   218  func (b *BaseBuilder) buildLockTables(ctx *sql.Context, n *plan.LockTables, row sql.Row) (sql.RowIter, error) {
   219  	span, ctx := ctx.Span("plan.LockTables")
   220  	defer span.End()
   221  
   222  	for _, l := range n.Locks {
   223  		lockable, err := getLockable(l.Table)
   224  		if err != nil {
   225  			// If a table is not lockable, just skip it
   226  			ctx.Warn(0, err.Error())
   227  			continue
   228  		}
   229  
   230  		if err := lockable.Lock(ctx, l.Write); err != nil {
   231  			ctx.Error(0, "unable to lock table: %s", err)
   232  		} else {
   233  			n.Catalog.LockTable(ctx, lockable.Name())
   234  		}
   235  	}
   236  
   237  	return sql.RowsToRowIter(), nil
   238  }
   239  
   240  func (b *BaseBuilder) buildSignal(ctx *sql.Context, n *plan.Signal, row sql.Row) (sql.RowIter, error) {
   241  	//TODO: implement CLASS_ORIGIN
   242  	//TODO: implement SUBCLASS_ORIGIN
   243  	//TODO: implement CONSTRAINT_CATALOG
   244  	//TODO: implement CONSTRAINT_SCHEMA
   245  	//TODO: implement CONSTRAINT_NAME
   246  	//TODO: implement CATALOG_NAME
   247  	//TODO: implement SCHEMA_NAME
   248  	//TODO: implement TABLE_NAME
   249  	//TODO: implement COLUMN_NAME
   250  	//TODO: implement CURSOR_NAME
   251  	if n.SqlStateValue[0:2] == "01" {
   252  		//TODO: implement warnings
   253  		return nil, fmt.Errorf("warnings not yet implemented")
   254  	} else {
   255  
   256  		messageItem := n.Info[plan.SignalConditionItemName_MessageText]
   257  		strValue := messageItem.StrValue
   258  		if messageItem.ExprVal != nil {
   259  			exprResult, err := messageItem.ExprVal.Eval(ctx, nil)
   260  			if err != nil {
   261  				return nil, err
   262  			}
   263  			s, ok := exprResult.(string)
   264  			if !ok {
   265  				return nil, fmt.Errorf("message text expression did not evaluate to a string")
   266  			}
   267  			strValue = s
   268  		}
   269  
   270  		return nil, mysql.NewSQLError(
   271  			int(n.Info[plan.SignalConditionItemName_MysqlErrno].IntValue),
   272  			n.SqlStateValue,
   273  			strValue,
   274  		)
   275  	}
   276  }
   277  
   278  func (b *BaseBuilder) buildStopReplica(ctx *sql.Context, n *plan.StopReplica, row sql.Row) (sql.RowIter, error) {
   279  	if n.ReplicaController == nil {
   280  		return nil, plan.ErrNoReplicationController.New()
   281  	}
   282  
   283  	err := n.ReplicaController.StopReplica(ctx)
   284  	return sql.RowsToRowIter(), err
   285  }
   286  
   287  func (b *BaseBuilder) buildChangeReplicationFilter(ctx *sql.Context, n *plan.ChangeReplicationFilter, row sql.Row) (sql.RowIter, error) {
   288  	if n.ReplicaController == nil {
   289  		return nil, plan.ErrNoReplicationController.New()
   290  	}
   291  
   292  	err := n.ReplicaController.SetReplicationFilterOptions(ctx, n.Options)
   293  	return sql.RowsToRowIter(), err
   294  }
   295  
   296  func (b *BaseBuilder) buildExecuteQuery(ctx *sql.Context, n *plan.ExecuteQuery, row sql.Row) (sql.RowIter, error) {
   297  	return nil, fmt.Errorf("%T does not have an execution iterator", n)
   298  }
   299  
   300  func (b *BaseBuilder) buildUse(ctx *sql.Context, n *plan.Use, row sql.Row) (sql.RowIter, error) {
   301  	return n.RowIter(ctx, row)
   302  }
   303  
   304  func (b *BaseBuilder) buildTransactionCommittingNode(ctx *sql.Context, n *plan.TransactionCommittingNode, row sql.Row) (sql.RowIter, error) {
   305  	iter, err := b.Build(ctx, n.Child(), row)
   306  	if err != nil {
   307  		return nil, err
   308  	}
   309  	return transactionCommittingIter{childIter: iter}, nil
   310  }