github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/transaction_iters.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package rowexec
    16  
    17  import (
    18  	"io"
    19  
    20  	"gopkg.in/src-d/go-errors.v1"
    21  
    22  	"github.com/dolthub/go-mysql-server/sql"
    23  	"github.com/dolthub/go-mysql-server/sql/plan"
    24  )
    25  
    26  type rowFunc func(ctx *sql.Context) (sql.Row, error)
    27  
    28  type lazyRowIter struct {
    29  	next rowFunc
    30  }
    31  
    32  func (i *lazyRowIter) Next(ctx *sql.Context) (sql.Row, error) {
    33  	if i.next != nil {
    34  		res, err := i.next(ctx)
    35  		i.next = nil
    36  		return res, err
    37  	}
    38  	return nil, io.EOF
    39  }
    40  
    41  func (i *lazyRowIter) Close(ctx *sql.Context) error {
    42  	return nil
    43  }
    44  
    45  // ErrTableNotLockable is returned whenever a lockable table can't be found.
    46  var ErrTableNotLockable = errors.NewKind("table %s is not lockable")
    47  
    48  func getLockable(node sql.Node) (sql.Lockable, error) {
    49  	switch node := node.(type) {
    50  	case *plan.ResolvedTable:
    51  		return getLockableTable(node.Table)
    52  	case sql.TableWrapper:
    53  		return getLockableTable(node.Underlying())
    54  	default:
    55  		return nil, ErrTableNotLockable.New("unknown")
    56  	}
    57  }
    58  
    59  func getLockableTable(table sql.Table) (sql.Lockable, error) {
    60  	switch t := table.(type) {
    61  	case sql.Lockable:
    62  		return t, nil
    63  	case sql.TableWrapper:
    64  		return getLockableTable(t.Underlying())
    65  	default:
    66  		return nil, ErrTableNotLockable.New(t.Name())
    67  	}
    68  }
    69  
    70  // transactionCommittingIter is a simple RowIter wrapper to allow the engine to conditionally commit a transaction
    71  // during the Close() operation
    72  type transactionCommittingIter struct {
    73  	childIter           sql.RowIter
    74  	transactionDatabase string
    75  }
    76  
    77  func (t transactionCommittingIter) Next(ctx *sql.Context) (sql.Row, error) {
    78  	return t.childIter.Next(ctx)
    79  }
    80  
    81  func (t transactionCommittingIter) Close(ctx *sql.Context) error {
    82  	var err error
    83  	if t.childIter != nil {
    84  		err = t.childIter.Close(ctx)
    85  	}
    86  	if err != nil {
    87  		return err
    88  	}
    89  
    90  	tx := ctx.GetTransaction()
    91  	// TODO: In the future we should ensure that analyzer supports implicit commits instead of directly
    92  	// accessing autocommit here.
    93  	// cc. https://dev.mysql.com/doc/refman/8.0/en/implicit-commit.html
    94  	autocommit, err := plan.IsSessionAutocommit(ctx)
    95  	if err != nil {
    96  		return err
    97  	}
    98  
    99  	commitTransaction := ((tx != nil) && !ctx.GetIgnoreAutoCommit()) && autocommit
   100  	if commitTransaction {
   101  		ts, ok := ctx.Session.(sql.TransactionSession)
   102  		if !ok {
   103  			return nil
   104  		}
   105  
   106  		ctx.GetLogger().Tracef("committing transaction %s", tx)
   107  		if err := ts.CommitTransaction(ctx, tx); err != nil {
   108  			return err
   109  		}
   110  
   111  		// Clearing out the current transaction will tell us to start a new one the next time this session queries
   112  		ctx.SetTransaction(nil)
   113  	}
   114  
   115  	return nil
   116  }