vitess.io/vitess@v0.16.2/go/vt/vtgate/tx_conn.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package vtgate
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"sync"
    23  
    24  	"vitess.io/vitess/go/vt/concurrency"
    25  	"vitess.io/vitess/go/vt/dtids"
    26  	"vitess.io/vitess/go/vt/log"
    27  	"vitess.io/vitess/go/vt/sqlparser"
    28  	"vitess.io/vitess/go/vt/vterrors"
    29  	"vitess.io/vitess/go/vt/vttablet/queryservice"
    30  
    31  	querypb "vitess.io/vitess/go/vt/proto/query"
    32  	topodatapb "vitess.io/vitess/go/vt/proto/topodata"
    33  	vtgatepb "vitess.io/vitess/go/vt/proto/vtgate"
    34  	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
    35  )
    36  
    37  // TxConn is used for executing transactional requests.
    38  type TxConn struct {
    39  	tabletGateway *TabletGateway
    40  	mode          vtgatepb.TransactionMode
    41  }
    42  
    43  // NewTxConn builds a new TxConn.
    44  func NewTxConn(gw *TabletGateway, txMode vtgatepb.TransactionMode) *TxConn {
    45  	return &TxConn{
    46  		tabletGateway: gw,
    47  		mode:          txMode,
    48  	}
    49  }
    50  
    51  var txAccessModeToEOTxAccessMode = map[sqlparser.TxAccessMode]querypb.ExecuteOptions_TransactionAccessMode{
    52  	sqlparser.WithConsistentSnapshot: querypb.ExecuteOptions_CONSISTENT_SNAPSHOT,
    53  	sqlparser.ReadWrite:              querypb.ExecuteOptions_READ_WRITE,
    54  	sqlparser.ReadOnly:               querypb.ExecuteOptions_READ_ONLY,
    55  }
    56  
    57  // Begin begins a new transaction. If one is already in progress, it commits it
    58  // and starts a new one.
    59  func (txc *TxConn) Begin(ctx context.Context, session *SafeSession, txAccessModes []sqlparser.TxAccessMode) error {
    60  	if session.InTransaction() {
    61  		if err := txc.Commit(ctx, session); err != nil {
    62  			return err
    63  		}
    64  	}
    65  	if len(txAccessModes) > 0 {
    66  		options := session.GetOrCreateOptions()
    67  		for _, txAccessMode := range txAccessModes {
    68  			accessMode, ok := txAccessModeToEOTxAccessMode[txAccessMode]
    69  			if !ok {
    70  				return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "[BUG] invalid transaction characteristic: %s", txAccessMode.ToString())
    71  			}
    72  			options.TransactionAccessMode = append(options.TransactionAccessMode, accessMode)
    73  		}
    74  	}
    75  	session.Session.InTransaction = true
    76  	return nil
    77  }
    78  
    79  // Commit commits the current transaction. The type of commit can be
    80  // best effort or 2pc depending on the session setting.
    81  func (txc *TxConn) Commit(ctx context.Context, session *SafeSession) error {
    82  	defer session.ResetTx()
    83  	if !session.InTransaction() {
    84  		return nil
    85  	}
    86  
    87  	twopc := false
    88  	switch session.TransactionMode {
    89  	case vtgatepb.TransactionMode_TWOPC:
    90  		twopc = true
    91  	case vtgatepb.TransactionMode_UNSPECIFIED:
    92  		twopc = txc.mode == vtgatepb.TransactionMode_TWOPC
    93  	}
    94  
    95  	if twopc {
    96  		return txc.commit2PC(ctx, session)
    97  	}
    98  	return txc.commitNormal(ctx, session)
    99  }
   100  
   101  func (txc *TxConn) queryService(alias *topodatapb.TabletAlias) (queryservice.QueryService, error) {
   102  	if alias == nil {
   103  		return txc.tabletGateway, nil
   104  	}
   105  	return txc.tabletGateway.QueryServiceByAlias(alias, nil)
   106  }
   107  
   108  func (txc *TxConn) commitShard(ctx context.Context, s *vtgatepb.Session_ShardSession, logging *executeLogger) error {
   109  	if s.TransactionId == 0 {
   110  		return nil
   111  	}
   112  	var qs queryservice.QueryService
   113  	var err error
   114  	qs, err = txc.queryService(s.TabletAlias)
   115  	if err != nil {
   116  		return err
   117  	}
   118  	reservedID, err := qs.Commit(ctx, s.Target, s.TransactionId)
   119  	if err != nil {
   120  		return err
   121  	}
   122  	s.TransactionId = 0
   123  	s.ReservedId = reservedID
   124  	logging.log(nil, s.Target, nil, "commit", false, nil)
   125  	return nil
   126  }
   127  
   128  func (txc *TxConn) commitNormal(ctx context.Context, session *SafeSession) error {
   129  	if err := txc.runSessions(ctx, session.PreSessions, session.logging, txc.commitShard); err != nil {
   130  		_ = txc.Release(ctx, session)
   131  		return err
   132  	}
   133  
   134  	// Retain backward compatibility on commit order for the normal session.
   135  	for _, shardSession := range session.ShardSessions {
   136  		if err := txc.commitShard(ctx, shardSession, session.logging); err != nil {
   137  			_ = txc.Release(ctx, session)
   138  			return err
   139  		}
   140  	}
   141  
   142  	if err := txc.runSessions(ctx, session.PostSessions, session.logging, txc.commitShard); err != nil {
   143  		// If last commit fails, there will be nothing to rollback.
   144  		session.RecordWarning(&querypb.QueryWarning{Message: fmt.Sprintf("post-operation transaction had an error: %v", err)})
   145  		// With reserved connection we should release them.
   146  		if session.InReservedConn() {
   147  			_ = txc.Release(ctx, session)
   148  		}
   149  	}
   150  	return nil
   151  }
   152  
   153  // commit2PC will not used the pinned tablets - to make sure we use the current source, we need to use the gateway's queryservice
   154  func (txc *TxConn) commit2PC(ctx context.Context, session *SafeSession) error {
   155  	if len(session.PreSessions) != 0 || len(session.PostSessions) != 0 {
   156  		_ = txc.Rollback(ctx, session)
   157  		return vterrors.New(vtrpcpb.Code_FAILED_PRECONDITION, "pre or post actions not allowed for 2PC commits")
   158  	}
   159  
   160  	// If the number of participants is one or less, then it's a normal commit.
   161  	if len(session.ShardSessions) <= 1 {
   162  		return txc.commitNormal(ctx, session)
   163  	}
   164  
   165  	participants := make([]*querypb.Target, 0, len(session.ShardSessions)-1)
   166  	for _, s := range session.ShardSessions[1:] {
   167  		participants = append(participants, s.Target)
   168  	}
   169  	mmShard := session.ShardSessions[0]
   170  	dtid := dtids.New(mmShard)
   171  	err := txc.tabletGateway.CreateTransaction(ctx, mmShard.Target, dtid, participants)
   172  	if err != nil {
   173  		// Normal rollback is safe because nothing was prepared yet.
   174  		_ = txc.Rollback(ctx, session)
   175  		return err
   176  	}
   177  
   178  	err = txc.runSessions(ctx, session.ShardSessions[1:], session.logging, func(ctx context.Context, s *vtgatepb.Session_ShardSession, logging *executeLogger) error {
   179  		return txc.tabletGateway.Prepare(ctx, s.Target, s.TransactionId, dtid)
   180  	})
   181  	if err != nil {
   182  		// TODO(sougou): Perform a more fine-grained cleanup
   183  		// including unprepared transactions.
   184  		if resumeErr := txc.Resolve(ctx, dtid); resumeErr != nil {
   185  			log.Warningf("Rollback failed after Prepare failure: %v", resumeErr)
   186  		}
   187  		// Return the original error even if the previous operation fails.
   188  		return err
   189  	}
   190  
   191  	err = txc.tabletGateway.StartCommit(ctx, mmShard.Target, mmShard.TransactionId, dtid)
   192  	if err != nil {
   193  		return err
   194  	}
   195  
   196  	err = txc.runSessions(ctx, session.ShardSessions[1:], session.logging, func(ctx context.Context, s *vtgatepb.Session_ShardSession, logging *executeLogger) error {
   197  		return txc.tabletGateway.CommitPrepared(ctx, s.Target, dtid)
   198  	})
   199  	if err != nil {
   200  		return err
   201  	}
   202  
   203  	return txc.tabletGateway.ConcludeTransaction(ctx, mmShard.Target, dtid)
   204  }
   205  
   206  // Rollback rolls back the current transaction. There are no retries on this operation.
   207  func (txc *TxConn) Rollback(ctx context.Context, session *SafeSession) error {
   208  	if !session.InTransaction() {
   209  		return nil
   210  	}
   211  	defer session.ResetTx()
   212  
   213  	allsessions := append(session.PreSessions, session.ShardSessions...)
   214  	allsessions = append(allsessions, session.PostSessions...)
   215  
   216  	err := txc.runSessions(ctx, allsessions, session.logging, func(ctx context.Context, s *vtgatepb.Session_ShardSession, logging *executeLogger) error {
   217  		if s.TransactionId == 0 {
   218  			return nil
   219  		}
   220  		qs, err := txc.queryService(s.TabletAlias)
   221  		if err != nil {
   222  			return err
   223  		}
   224  		reservedID, err := qs.Rollback(ctx, s.Target, s.TransactionId)
   225  		if err != nil {
   226  			return err
   227  		}
   228  		s.TransactionId = 0
   229  		s.ReservedId = reservedID
   230  		logging.log(nil, s.Target, nil, "rollback", false, nil)
   231  		return nil
   232  	})
   233  	if err != nil {
   234  		session.RecordWarning(&querypb.QueryWarning{Message: fmt.Sprintf("rollback encountered an error and connection to all shard for this session is released: %v", err)})
   235  		if session.InReservedConn() {
   236  			_ = txc.Release(ctx, session)
   237  		}
   238  	}
   239  	return err
   240  }
   241  
   242  // Release releases the reserved connection and/or rollbacks the transaction
   243  func (txc *TxConn) Release(ctx context.Context, session *SafeSession) error {
   244  	if !session.InTransaction() && !session.InReservedConn() {
   245  		return nil
   246  	}
   247  	defer session.Reset()
   248  
   249  	allsessions := append(session.PreSessions, session.ShardSessions...)
   250  	allsessions = append(allsessions, session.PostSessions...)
   251  
   252  	return txc.runSessions(ctx, allsessions, session.logging, func(ctx context.Context, s *vtgatepb.Session_ShardSession, logging *executeLogger) error {
   253  		if s.ReservedId == 0 && s.TransactionId == 0 {
   254  			return nil
   255  		}
   256  		qs, err := txc.queryService(s.TabletAlias)
   257  		if err != nil {
   258  			return err
   259  		}
   260  		err = qs.Release(ctx, s.Target, s.TransactionId, s.ReservedId)
   261  		if err != nil {
   262  			return err
   263  		}
   264  		s.TransactionId = 0
   265  		s.ReservedId = 0
   266  		return nil
   267  	})
   268  }
   269  
   270  // ReleaseLock releases the reserved connection used for locking.
   271  func (txc *TxConn) ReleaseLock(ctx context.Context, session *SafeSession) error {
   272  	if !session.InLockSession() {
   273  		return nil
   274  	}
   275  	defer session.ResetLock()
   276  
   277  	session.ClearAdvisoryLock()
   278  	ls := session.LockSession
   279  	if ls.ReservedId == 0 {
   280  		return nil
   281  	}
   282  	qs, err := txc.queryService(ls.TabletAlias)
   283  	if err != nil {
   284  		return err
   285  	}
   286  	return qs.Release(ctx, ls.Target, 0, ls.ReservedId)
   287  }
   288  
   289  // ReleaseAll releases all the shard sessions and lock session.
   290  func (txc *TxConn) ReleaseAll(ctx context.Context, session *SafeSession) error {
   291  	if !session.InTransaction() && !session.InReservedConn() && !session.InLockSession() {
   292  		return nil
   293  	}
   294  	defer session.ResetAll()
   295  
   296  	allsessions := append(session.PreSessions, session.ShardSessions...)
   297  	allsessions = append(allsessions, session.PostSessions...)
   298  	if session.LockSession != nil {
   299  		allsessions = append(allsessions, session.LockSession)
   300  	}
   301  
   302  	return txc.runSessions(ctx, allsessions, session.logging, func(ctx context.Context, s *vtgatepb.Session_ShardSession, loggging *executeLogger) error {
   303  		if s.ReservedId == 0 && s.TransactionId == 0 {
   304  			return nil
   305  		}
   306  		qs, err := txc.queryService(s.TabletAlias)
   307  		if err != nil {
   308  			return err
   309  		}
   310  		err = qs.Release(ctx, s.Target, s.TransactionId, s.ReservedId)
   311  		if err != nil {
   312  			return err
   313  		}
   314  		s.TransactionId = 0
   315  		s.ReservedId = 0
   316  		return nil
   317  	})
   318  }
   319  
   320  // Resolve resolves the specified 2PC transaction.
   321  func (txc *TxConn) Resolve(ctx context.Context, dtid string) error {
   322  	mmShard, err := dtids.ShardSession(dtid)
   323  	if err != nil {
   324  		return err
   325  	}
   326  
   327  	transaction, err := txc.tabletGateway.ReadTransaction(ctx, mmShard.Target, dtid)
   328  	if err != nil {
   329  		return err
   330  	}
   331  	if transaction == nil || transaction.Dtid == "" {
   332  		// It was already resolved.
   333  		return nil
   334  	}
   335  	switch transaction.State {
   336  	case querypb.TransactionState_PREPARE:
   337  		// If state is PREPARE, make a decision to rollback and
   338  		// fallthrough to the rollback workflow.
   339  		qs, err := txc.queryService(mmShard.TabletAlias)
   340  		if err != nil {
   341  			return err
   342  		}
   343  		if err := qs.SetRollback(ctx, mmShard.Target, transaction.Dtid, mmShard.TransactionId); err != nil {
   344  			return err
   345  		}
   346  		fallthrough
   347  	case querypb.TransactionState_ROLLBACK:
   348  		if err := txc.resumeRollback(ctx, mmShard.Target, transaction); err != nil {
   349  			return err
   350  		}
   351  	case querypb.TransactionState_COMMIT:
   352  		if err := txc.resumeCommit(ctx, mmShard.Target, transaction); err != nil {
   353  			return err
   354  		}
   355  	default:
   356  		// Should never happen.
   357  		return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "invalid state: %v", transaction.State)
   358  	}
   359  	return nil
   360  }
   361  
   362  func (txc *TxConn) resumeRollback(ctx context.Context, target *querypb.Target, transaction *querypb.TransactionMetadata) error {
   363  	err := txc.runTargets(transaction.Participants, func(t *querypb.Target) error {
   364  		return txc.tabletGateway.RollbackPrepared(ctx, t, transaction.Dtid, 0)
   365  	})
   366  	if err != nil {
   367  		return err
   368  	}
   369  	return txc.tabletGateway.ConcludeTransaction(ctx, target, transaction.Dtid)
   370  }
   371  
   372  func (txc *TxConn) resumeCommit(ctx context.Context, target *querypb.Target, transaction *querypb.TransactionMetadata) error {
   373  	err := txc.runTargets(transaction.Participants, func(t *querypb.Target) error {
   374  		return txc.tabletGateway.CommitPrepared(ctx, t, transaction.Dtid)
   375  	})
   376  	if err != nil {
   377  		return err
   378  	}
   379  	return txc.tabletGateway.ConcludeTransaction(ctx, target, transaction.Dtid)
   380  }
   381  
   382  // runSessions executes the action for all shardSessions in parallel and returns a consolidated error.
   383  func (txc *TxConn) runSessions(ctx context.Context, shardSessions []*vtgatepb.Session_ShardSession, logging *executeLogger, action func(context.Context, *vtgatepb.Session_ShardSession, *executeLogger) error) error {
   384  	// Fastpath.
   385  	if len(shardSessions) == 1 {
   386  		return action(ctx, shardSessions[0], logging)
   387  	}
   388  
   389  	allErrors := new(concurrency.AllErrorRecorder)
   390  	var wg sync.WaitGroup
   391  	for _, s := range shardSessions {
   392  		wg.Add(1)
   393  		go func(s *vtgatepb.Session_ShardSession) {
   394  			defer wg.Done()
   395  			if err := action(ctx, s, logging); err != nil {
   396  				allErrors.RecordError(err)
   397  			}
   398  		}(s)
   399  	}
   400  	wg.Wait()
   401  	return allErrors.AggrError(vterrors.Aggregate)
   402  }
   403  
   404  // runTargets executes the action for all targets in parallel and returns a consolildated error.
   405  // Flow is identical to runSessions.
   406  func (txc *TxConn) runTargets(targets []*querypb.Target, action func(*querypb.Target) error) error {
   407  	if len(targets) == 1 {
   408  		return action(targets[0])
   409  	}
   410  	allErrors := new(concurrency.AllErrorRecorder)
   411  	var wg sync.WaitGroup
   412  	for _, t := range targets {
   413  		wg.Add(1)
   414  		go func(t *querypb.Target) {
   415  			defer wg.Done()
   416  			if err := action(t); err != nil {
   417  				allErrors.RecordError(err)
   418  			}
   419  		}(t)
   420  	}
   421  	wg.Wait()
   422  	return allErrors.AggrError(vterrors.Aggregate)
   423  }