github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/sql/txn_state_test.go (about)

     1  // Copyright 2017 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package sql
    12  
    13  import (
    14  	"context"
    15  	"fmt"
    16  	"strings"
    17  	"testing"
    18  	"time"
    19  
    20  	"github.com/cockroachdb/cockroach/pkg/kv"
    21  	"github.com/cockroachdb/cockroach/pkg/roachpb"
    22  	"github.com/cockroachdb/cockroach/pkg/settings/cluster"
    23  	"github.com/cockroachdb/cockroach/pkg/sql/sem/tree"
    24  	"github.com/cockroachdb/cockroach/pkg/testutils"
    25  	"github.com/cockroachdb/cockroach/pkg/util/fsm"
    26  	"github.com/cockroachdb/cockroach/pkg/util/hlc"
    27  	"github.com/cockroachdb/cockroach/pkg/util/leaktest"
    28  	"github.com/cockroachdb/cockroach/pkg/util/metric"
    29  	"github.com/cockroachdb/cockroach/pkg/util/mon"
    30  	"github.com/cockroachdb/cockroach/pkg/util/timeutil"
    31  	"github.com/cockroachdb/cockroach/pkg/util/tracing"
    32  	"github.com/cockroachdb/errors"
    33  	"github.com/gogo/protobuf/proto"
    34  	opentracing "github.com/opentracing/opentracing-go"
    35  )
    36  
    37  var noRewindExpected = CmdPos(-1)
    38  
    39  type testContext struct {
    40  	manualClock *hlc.ManualClock
    41  	clock       *hlc.Clock
    42  	mockDB      *kv.DB
    43  	mon         mon.BytesMonitor
    44  	tracer      opentracing.Tracer
    45  	// ctx is mimicking the spirit of a client connection's context
    46  	ctx      context.Context
    47  	settings *cluster.Settings
    48  }
    49  
    50  func makeTestContext() testContext {
    51  	manual := hlc.NewManualClock(123)
    52  	clock := hlc.NewClock(manual.UnixNano, time.Nanosecond)
    53  	factory := kv.MakeMockTxnSenderFactory(
    54  		func(context.Context, *roachpb.Transaction, roachpb.BatchRequest,
    55  		) (*roachpb.BatchResponse, *roachpb.Error) {
    56  			return nil, nil
    57  		})
    58  
    59  	settings := cluster.MakeTestingClusterSettings()
    60  	ambient := testutils.MakeAmbientCtx()
    61  	return testContext{
    62  		manualClock: manual,
    63  		clock:       clock,
    64  		mockDB:      kv.NewDB(ambient, factory, clock),
    65  		mon: mon.MakeMonitor(
    66  			"test root mon",
    67  			mon.MemoryResource,
    68  			nil,  /* curCount */
    69  			nil,  /* maxHist */
    70  			-1,   /* increment */
    71  			1000, /* noteworthy */
    72  			settings,
    73  		),
    74  		tracer:   ambient.Tracer,
    75  		ctx:      context.Background(),
    76  		settings: settings,
    77  	}
    78  }
    79  
    80  // createOpenState returns a txnState initialized with an open txn.
    81  func (tc *testContext) createOpenState(typ txnType) (fsm.State, *txnState) {
    82  	sp := tc.tracer.StartSpan("createOpenState")
    83  	ctx := opentracing.ContextWithSpan(tc.ctx, sp)
    84  	ctx, cancel := context.WithCancel(ctx)
    85  
    86  	txnStateMon := mon.MakeMonitor("test mon",
    87  		mon.MemoryResource,
    88  		nil,  /* curCount */
    89  		nil,  /* maxHist */
    90  		-1,   /* increment */
    91  		1000, /* noteworthy */
    92  		cluster.MakeTestingClusterSettings(),
    93  	)
    94  	txnStateMon.Start(tc.ctx, &tc.mon, mon.BoundAccount{})
    95  
    96  	ts := txnState{
    97  		Ctx:           ctx,
    98  		connCtx:       tc.ctx,
    99  		sp:            sp,
   100  		cancel:        cancel,
   101  		sqlTimestamp:  timeutil.Now(),
   102  		priority:      roachpb.NormalUserPriority,
   103  		mon:           &txnStateMon,
   104  		txnAbortCount: metric.NewCounter(MetaTxnAbort),
   105  	}
   106  	ts.mu.txn = kv.NewTxn(ctx, tc.mockDB, roachpb.NodeID(1) /* gatewayNodeID */)
   107  
   108  	state := stateOpen{
   109  		ImplicitTxn: fsm.FromBool(typ == implicitTxn),
   110  	}
   111  	return state, &ts
   112  }
   113  
   114  // createAbortedState returns a txnState initialized with an aborted txn.
   115  func (tc *testContext) createAbortedState() (fsm.State, *txnState) {
   116  	_, ts := tc.createOpenState(explicitTxn)
   117  	return stateAborted{}, ts
   118  }
   119  
   120  func (tc *testContext) createCommitWaitState() (fsm.State, *txnState, error) {
   121  	_, ts := tc.createOpenState(explicitTxn)
   122  	// Commit the KV txn, simulating what the execution layer is doing.
   123  	if err := ts.mu.txn.Commit(ts.Ctx); err != nil {
   124  		return nil, nil, err
   125  	}
   126  	s := stateCommitWait{}
   127  	return s, ts, nil
   128  }
   129  
   130  func (tc *testContext) createNoTxnState() (fsm.State, *txnState) {
   131  	txnStateMon := mon.MakeMonitor("test mon",
   132  		mon.MemoryResource,
   133  		nil,  /* curCount */
   134  		nil,  /* maxHist */
   135  		-1,   /* increment */
   136  		1000, /* noteworthy */
   137  		cluster.MakeTestingClusterSettings(),
   138  	)
   139  	ts := txnState{mon: &txnStateMon, connCtx: tc.ctx}
   140  	return stateNoTxn{}, &ts
   141  }
   142  
   143  // checkAdv returns an error if adv does not match all the expected fields.
   144  //
   145  // Pass noRewindExpected for expRewPos if a rewind is not expected.
   146  func checkAdv(adv advanceInfo, expCode advanceCode, expRewPos CmdPos, expEv txnEvent) error {
   147  	if adv.code != expCode {
   148  		return errors.Errorf("expected code: %s, but got: %s (%+v)", expCode, adv.code, adv)
   149  	}
   150  	if expRewPos == noRewindExpected {
   151  		if adv.rewCap != (rewindCapability{}) {
   152  			return errors.Errorf("expected not rewind, but got: %+v", adv)
   153  		}
   154  	} else {
   155  		if adv.rewCap.rewindPos != expRewPos {
   156  			return errors.Errorf("expected rewind to %d, but got: %+v", expRewPos, adv)
   157  		}
   158  	}
   159  	if expEv != adv.txnEvent {
   160  		return errors.Errorf("expected txnEvent: %s, got: %s", expEv, adv.txnEvent)
   161  	}
   162  	return nil
   163  }
   164  
   165  // expKVTxn is used with checkTxn to check that fields on a client.Txn
   166  // correspond to expectations. Any field left nil will not be checked.
   167  type expKVTxn struct {
   168  	debugName    *string
   169  	userPriority *roachpb.UserPriority
   170  	// For the timestamps we just check the physical part. The logical part is
   171  	// incremented every time the clock is read and so it's unpredictable.
   172  	tsNanos     *int64
   173  	origTSNanos *int64
   174  	maxTSNanos  *int64
   175  }
   176  
   177  func checkTxn(txn *kv.Txn, exp expKVTxn) error {
   178  	if txn == nil {
   179  		return errors.Errorf("expected a KV txn but found an uninitialized txn")
   180  	}
   181  	if exp.debugName != nil && !strings.HasPrefix(txn.DebugName(), *exp.debugName+" (") {
   182  		return errors.Errorf("expected DebugName: %s, but got: %s",
   183  			*exp.debugName, txn.DebugName())
   184  	}
   185  	if exp.userPriority != nil && *exp.userPriority != txn.UserPriority() {
   186  		return errors.Errorf("expected UserPriority: %s, but got: %s",
   187  			*exp.userPriority, txn.UserPriority())
   188  	}
   189  	proto := txn.TestingCloneTxn()
   190  	if exp.tsNanos != nil && *exp.tsNanos != proto.WriteTimestamp.WallTime {
   191  		return errors.Errorf("expected Timestamp: %d, but got: %s",
   192  			*exp.tsNanos, proto.WriteTimestamp)
   193  	}
   194  	if origTimestamp := txn.ReadTimestamp(); exp.origTSNanos != nil &&
   195  		*exp.origTSNanos != origTimestamp.WallTime {
   196  		return errors.Errorf("expected DeprecatedOrigTimestamp: %d, but got: %s",
   197  			*exp.origTSNanos, origTimestamp)
   198  	}
   199  	if exp.maxTSNanos != nil && *exp.maxTSNanos != proto.MaxTimestamp.WallTime {
   200  		return errors.Errorf("expected MaxTimestamp: %d, but got: %s",
   201  			*exp.maxTSNanos, proto.MaxTimestamp)
   202  	}
   203  	return nil
   204  }
   205  
   206  func TestTransitions(t *testing.T) {
   207  	defer leaktest.AfterTest(t)()
   208  
   209  	ctx := context.Background()
   210  	dummyRewCap := rewindCapability{rewindPos: CmdPos(12)}
   211  	testCon := makeTestContext()
   212  	tranCtx := transitionCtx{
   213  		db:             testCon.mockDB,
   214  		nodeIDOrZero:   roachpb.NodeID(5),
   215  		clock:          testCon.clock,
   216  		tracer:         tracing.NewTracer(),
   217  		connMon:        &testCon.mon,
   218  		sessionTracing: &SessionTracing{},
   219  		settings:       testCon.settings,
   220  	}
   221  
   222  	type expAdvance struct {
   223  		expCode advanceCode
   224  		expEv   txnEvent
   225  	}
   226  
   227  	txnName := sqlTxnName
   228  	now := testCon.clock.Now()
   229  	pri := roachpb.NormalUserPriority
   230  	maxTS := testCon.clock.Now().Add(testCon.clock.MaxOffset().Nanoseconds(), 0 /* logical */)
   231  	type test struct {
   232  		name string
   233  
   234  		// A function used to init the txnState to the desired state before the
   235  		// transition. The returned State and txnState are to be used to initialize
   236  		// a Machine.
   237  		init func() (fsm.State, *txnState, error)
   238  
   239  		// The event to deliver to the state machine.
   240  		ev fsm.Event
   241  		// evPayload, if not nil, is the payload to be delivered with the event.
   242  		evPayload fsm.EventPayload
   243  		// evFun, if specified, replaces ev and allows a test to create an event
   244  		// that depends on the transactionState.
   245  		evFun func(ts *txnState) (fsm.Event, fsm.EventPayload)
   246  
   247  		// The expected state of the fsm after the transition.
   248  		expState fsm.State
   249  
   250  		// The expected advance instructions resulting from the transition.
   251  		expAdv expAdvance
   252  
   253  		// If nil, the kv txn is expected to be nil. Otherwise, the txn's fields are
   254  		// compared.
   255  		expTxn *expKVTxn
   256  	}
   257  	tests := []test{
   258  		//
   259  		// Tests starting from the NoTxn state.
   260  		//
   261  		{
   262  			// Start an implicit txn from NoTxn.
   263  			name: "NoTxn->Starting (implicit txn)",
   264  			init: func() (fsm.State, *txnState, error) {
   265  				s, ts := testCon.createNoTxnState()
   266  				return s, ts, nil
   267  			},
   268  			ev: eventTxnStart{ImplicitTxn: fsm.True},
   269  			evPayload: makeEventTxnStartPayload(pri, tree.ReadWrite, timeutil.Now(),
   270  				nil /* historicalTimestamp */, tranCtx),
   271  			expState: stateOpen{ImplicitTxn: fsm.True},
   272  			expAdv: expAdvance{
   273  				// We expect to stayInPlace; upon starting a txn the statement is
   274  				// executed again, this time in state Open.
   275  				expCode: stayInPlace,
   276  				expEv:   txnStart,
   277  			},
   278  			expTxn: &expKVTxn{
   279  				debugName:    &txnName,
   280  				userPriority: &pri,
   281  				tsNanos:      &now.WallTime,
   282  				origTSNanos:  &now.WallTime,
   283  				maxTSNanos:   &maxTS.WallTime,
   284  			},
   285  		},
   286  		{
   287  			// Start an explicit txn from NoTxn.
   288  			name: "NoTxn->Starting (explicit txn)",
   289  			init: func() (fsm.State, *txnState, error) {
   290  				s, ts := testCon.createNoTxnState()
   291  				return s, ts, nil
   292  			},
   293  			ev: eventTxnStart{ImplicitTxn: fsm.False},
   294  			evPayload: makeEventTxnStartPayload(pri, tree.ReadWrite, timeutil.Now(),
   295  				nil /* historicalTimestamp */, tranCtx),
   296  			expState: stateOpen{ImplicitTxn: fsm.False},
   297  			expAdv: expAdvance{
   298  				expCode: advanceOne,
   299  				expEv:   txnStart,
   300  			},
   301  			expTxn: &expKVTxn{
   302  				debugName:    &txnName,
   303  				userPriority: &pri,
   304  				tsNanos:      &now.WallTime,
   305  				origTSNanos:  &now.WallTime,
   306  				maxTSNanos:   &maxTS.WallTime,
   307  			},
   308  		},
   309  		//
   310  		// Tests starting from the Open state.
   311  		//
   312  		{
   313  			// Finish an implicit txn.
   314  			name: "Open (implicit) -> NoTxn",
   315  			init: func() (fsm.State, *txnState, error) {
   316  				s, ts := testCon.createOpenState(implicitTxn)
   317  				// We commit the KV transaction, as that's done by the layer below
   318  				// txnState.
   319  				if err := ts.mu.txn.Commit(ts.Ctx); err != nil {
   320  					return nil, nil, err
   321  				}
   322  				return s, ts, nil
   323  			},
   324  			ev:        eventTxnFinish{},
   325  			evPayload: eventTxnFinishPayload{commit: true},
   326  			expState:  stateNoTxn{},
   327  			expAdv: expAdvance{
   328  				expCode: advanceOne,
   329  				expEv:   txnCommit,
   330  			},
   331  			expTxn: nil,
   332  		},
   333  		{
   334  			// Finish an explicit txn.
   335  			name: "Open (explicit) -> NoTxn",
   336  			init: func() (fsm.State, *txnState, error) {
   337  				s, ts := testCon.createOpenState(explicitTxn)
   338  				// We commit the KV transaction, as that's done by the layer below
   339  				// txnState.
   340  				if err := ts.mu.txn.Commit(ts.Ctx); err != nil {
   341  					return nil, nil, err
   342  				}
   343  				return s, ts, nil
   344  			},
   345  			ev:        eventTxnFinish{},
   346  			evPayload: eventTxnFinishPayload{commit: true},
   347  			expState:  stateNoTxn{},
   348  			expAdv: expAdvance{
   349  				expCode: advanceOne,
   350  				expEv:   txnCommit,
   351  			},
   352  			expTxn: nil,
   353  		},
   354  		{
   355  			// Get a retriable error while we can auto-retry.
   356  			name: "Open + auto-retry",
   357  			init: func() (fsm.State, *txnState, error) {
   358  				s, ts := testCon.createOpenState(explicitTxn)
   359  				return s, ts, nil
   360  			},
   361  			evFun: func(ts *txnState) (fsm.Event, fsm.EventPayload) {
   362  				b := eventRetriableErrPayload{
   363  					err:    ts.mu.txn.PrepareRetryableError(ctx, "test retriable err"),
   364  					rewCap: dummyRewCap,
   365  				}
   366  				return eventRetriableErr{CanAutoRetry: fsm.True, IsCommit: fsm.False}, b
   367  			},
   368  			expState: stateOpen{ImplicitTxn: fsm.False},
   369  			expAdv: expAdvance{
   370  				expCode: rewind,
   371  				expEv:   txnRestart,
   372  			},
   373  			// Expect non-nil txn.
   374  			expTxn: &expKVTxn{},
   375  		},
   376  		{
   377  			// Like the above test - get a retriable error while we can auto-retry,
   378  			// except this time the error is on a COMMIT. This shouldn't make any
   379  			// difference; we should still auto-retry like the above.
   380  			name: "Open + auto-retry (COMMIT)",
   381  			init: func() (fsm.State, *txnState, error) {
   382  				s, ts := testCon.createOpenState(explicitTxn)
   383  				return s, ts, nil
   384  			},
   385  			evFun: func(ts *txnState) (fsm.Event, fsm.EventPayload) {
   386  				b := eventRetriableErrPayload{
   387  					err:    ts.mu.txn.PrepareRetryableError(ctx, "test retriable err"),
   388  					rewCap: dummyRewCap,
   389  				}
   390  				return eventRetriableErr{CanAutoRetry: fsm.True, IsCommit: fsm.True}, b
   391  			},
   392  			expState: stateOpen{ImplicitTxn: fsm.False},
   393  			expAdv: expAdvance{
   394  				expCode: rewind,
   395  				expEv:   txnRestart,
   396  			},
   397  			// Expect non-nil txn.
   398  			expTxn: &expKVTxn{},
   399  		},
   400  		{
   401  			// Get a retriable error when we can no longer auto-retry, but the client
   402  			// is doing client-side retries.
   403  			name: "Open + client retry",
   404  			init: func() (fsm.State, *txnState, error) {
   405  				s, ts := testCon.createOpenState(explicitTxn)
   406  				return s, ts, nil
   407  			},
   408  			evFun: func(ts *txnState) (fsm.Event, fsm.EventPayload) {
   409  				b := eventRetriableErrPayload{
   410  					err:    ts.mu.txn.PrepareRetryableError(ctx, "test retriable err"),
   411  					rewCap: rewindCapability{},
   412  				}
   413  				return eventRetriableErr{CanAutoRetry: fsm.False, IsCommit: fsm.False}, b
   414  			},
   415  			expState: stateAborted{},
   416  			expAdv: expAdvance{
   417  				expCode: skipBatch,
   418  				expEv:   noEvent,
   419  			},
   420  			// Expect non-nil txn.
   421  			expTxn: &expKVTxn{},
   422  		},
   423  		{
   424  			// Like the above (a retriable error when we can no longer auto-retry, but
   425  			// the client is doing client-side retries) except the retriable error
   426  			// comes from a COMMIT statement. This means that the client didn't
   427  			// properly respect the client-directed retries protocol (it should've
   428  			// done a RELEASE such that COMMIT couldn't get retriable errors), and so
   429  			// we can't go to RestartWait.
   430  			name: "Open + client retry + error on COMMIT",
   431  			init: func() (fsm.State, *txnState, error) {
   432  				s, ts := testCon.createOpenState(explicitTxn)
   433  				return s, ts, nil
   434  			},
   435  			evFun: func(ts *txnState) (fsm.Event, fsm.EventPayload) {
   436  				b := eventRetriableErrPayload{
   437  					err:    ts.mu.txn.PrepareRetryableError(ctx, "test retriable err"),
   438  					rewCap: rewindCapability{},
   439  				}
   440  				return eventRetriableErr{CanAutoRetry: fsm.False, IsCommit: fsm.True}, b
   441  			},
   442  			expState: stateNoTxn{},
   443  			expAdv: expAdvance{
   444  				expCode: skipBatch,
   445  				expEv:   txnRollback,
   446  			},
   447  			// Expect nil txn.
   448  			expTxn: nil,
   449  		},
   450  		{
   451  			// An error on COMMIT leaves us in NoTxn, not in Aborted.
   452  			name: "Open + non-retriable error on COMMIT",
   453  			init: func() (fsm.State, *txnState, error) {
   454  				s, ts := testCon.createOpenState(explicitTxn)
   455  				return s, ts, nil
   456  			},
   457  			ev:        eventNonRetriableErr{IsCommit: fsm.True},
   458  			evPayload: eventNonRetriableErrPayload{err: fmt.Errorf("test non-retriable err")},
   459  			expState:  stateNoTxn{},
   460  			expAdv: expAdvance{
   461  				expCode: skipBatch,
   462  				expEv:   txnRollback,
   463  			},
   464  			// Expect nil txn.
   465  			expTxn: nil,
   466  		},
   467  		{
   468  			// Like the above, but this time with an implicit txn: we get a retriable
   469  			// error, but we can't auto-retry. We expect to go to NoTxn.
   470  			name: "Open + useless retriable error (implicit)",
   471  			init: func() (fsm.State, *txnState, error) {
   472  				s, ts := testCon.createOpenState(implicitTxn)
   473  				return s, ts, nil
   474  			},
   475  			evFun: func(ts *txnState) (fsm.Event, fsm.EventPayload) {
   476  				b := eventRetriableErrPayload{
   477  					err:    ts.mu.txn.PrepareRetryableError(ctx, "test retriable err"),
   478  					rewCap: rewindCapability{},
   479  				}
   480  				return eventRetriableErr{CanAutoRetry: fsm.False, IsCommit: fsm.False}, b
   481  			},
   482  			expState: stateNoTxn{},
   483  			expAdv: expAdvance{
   484  				expCode: skipBatch,
   485  				expEv:   txnRollback,
   486  			},
   487  			// Expect the txn to have been cleared.
   488  			expTxn: nil,
   489  		},
   490  		{
   491  			// We get a non-retriable error.
   492  			name: "Open + non-retriable error",
   493  			init: func() (fsm.State, *txnState, error) {
   494  				s, ts := testCon.createOpenState(explicitTxn)
   495  				return s, ts, nil
   496  			},
   497  			ev:        eventNonRetriableErr{IsCommit: fsm.False},
   498  			evPayload: eventNonRetriableErrPayload{err: fmt.Errorf("test non-retriable err")},
   499  			expState:  stateAborted{},
   500  			expAdv: expAdvance{
   501  				expCode: skipBatch,
   502  				expEv:   noEvent,
   503  			},
   504  			expTxn: &expKVTxn{},
   505  		},
   506  		{
   507  			// We go to CommitWait (after a RELEASE SAVEPOINT).
   508  			name: "Open->CommitWait",
   509  			init: func() (fsm.State, *txnState, error) {
   510  				s, ts := testCon.createOpenState(explicitTxn)
   511  				// Simulate what execution does before generating this event.
   512  				err := ts.mu.txn.Commit(ts.Ctx)
   513  				return s, ts, err
   514  			},
   515  			ev:       eventTxnReleased{},
   516  			expState: stateCommitWait{},
   517  			expAdv: expAdvance{
   518  				expCode: advanceOne,
   519  				expEv:   txnCommit,
   520  			},
   521  			expTxn: &expKVTxn{},
   522  		},
   523  		{
   524  			// Restarting from Open via ROLLBACK TO SAVEPOINT.
   525  			name: "Open + restart",
   526  			init: func() (fsm.State, *txnState, error) {
   527  				s, ts := testCon.createOpenState(explicitTxn)
   528  				return s, ts, nil
   529  			},
   530  			ev:       eventTxnRestart{},
   531  			expState: stateOpen{ImplicitTxn: fsm.False},
   532  			expAdv: expAdvance{
   533  				expCode: advanceOne,
   534  				expEv:   txnRestart,
   535  			},
   536  			// We would like to test that the transaction's epoch bumped if the txn
   537  			// performed any operations, but it's not easy to do the test.
   538  			expTxn: &expKVTxn{},
   539  		},
   540  		//
   541  		// Tests starting from the Aborted state.
   542  		//
   543  		{
   544  			// The txn finished, such as after a ROLLBACK.
   545  			name: "Aborted->NoTxn",
   546  			init: func() (fsm.State, *txnState, error) {
   547  				s, ts := testCon.createAbortedState()
   548  				return s, ts, nil
   549  			},
   550  			ev:        eventTxnFinish{},
   551  			evPayload: eventTxnFinishPayload{commit: false},
   552  			expState:  stateNoTxn{},
   553  			expAdv: expAdvance{
   554  				expCode: advanceOne,
   555  				expEv:   txnRollback,
   556  			},
   557  			expTxn: nil,
   558  		},
   559  		{
   560  			// The txn is starting again (ROLLBACK TO SAVEPOINT <not cockroach_restart> while in Aborted).
   561  			name: "Aborted->Open",
   562  			init: func() (fsm.State, *txnState, error) {
   563  				s, ts := testCon.createAbortedState()
   564  				return s, ts, nil
   565  			},
   566  			ev:       eventSavepointRollback{},
   567  			expState: stateOpen{ImplicitTxn: fsm.False},
   568  			expAdv: expAdvance{
   569  				expCode: advanceOne,
   570  				expEv:   noEvent,
   571  			},
   572  			expTxn: &expKVTxn{},
   573  		},
   574  		{
   575  			// The txn is starting again (ROLLBACK TO SAVEPOINT cockroach_restart while in Aborted).
   576  			name: "Aborted->Restart",
   577  			init: func() (fsm.State, *txnState, error) {
   578  				s, ts := testCon.createAbortedState()
   579  				return s, ts, nil
   580  			},
   581  			ev:       eventTxnRestart{},
   582  			expState: stateOpen{ImplicitTxn: fsm.False},
   583  			expAdv: expAdvance{
   584  				expCode: advanceOne,
   585  				expEv:   txnRestart,
   586  			},
   587  			expTxn: &expKVTxn{
   588  				userPriority: &pri,
   589  				tsNanos:      &now.WallTime,
   590  				origTSNanos:  &now.WallTime,
   591  				maxTSNanos:   &maxTS.WallTime,
   592  			},
   593  		},
   594  		{
   595  			// The txn is starting again (e.g. ROLLBACK TO SAVEPOINT while in Aborted).
   596  			// Verify that the historical timestamp from the evPayload is propagated
   597  			// to the expTxn.
   598  			name: "Aborted->Starting (historical)",
   599  			init: func() (fsm.State, *txnState, error) {
   600  				s, ts := testCon.createAbortedState()
   601  				return s, ts, nil
   602  			},
   603  			ev:       eventTxnRestart{},
   604  			expState: stateOpen{ImplicitTxn: fsm.False},
   605  			expAdv: expAdvance{
   606  				expCode: advanceOne,
   607  				expEv:   txnRestart,
   608  			},
   609  			expTxn: &expKVTxn{
   610  				tsNanos: proto.Int64(now.WallTime),
   611  			},
   612  		},
   613  		//
   614  		// Tests starting from the CommitWait state.
   615  		//
   616  		{
   617  			name: "CommitWait->NoTxn",
   618  			init: func() (fsm.State, *txnState, error) {
   619  				return testCon.createCommitWaitState()
   620  			},
   621  			ev:        eventTxnFinish{},
   622  			evPayload: eventTxnFinishPayload{commit: true},
   623  			expState:  stateNoTxn{},
   624  			expAdv: expAdvance{
   625  				expCode: advanceOne,
   626  				expEv:   txnCommit,
   627  			},
   628  			expTxn: nil,
   629  		},
   630  		{
   631  			name: "CommitWait + err",
   632  			init: func() (fsm.State, *txnState, error) {
   633  				return testCon.createCommitWaitState()
   634  			},
   635  			ev:        eventNonRetriableErr{IsCommit: fsm.False},
   636  			evPayload: eventNonRetriableErrPayload{err: fmt.Errorf("test non-retriable err")},
   637  			expState:  stateCommitWait{},
   638  			expAdv: expAdvance{
   639  				expCode: skipBatch,
   640  			},
   641  			expTxn: &expKVTxn{},
   642  		},
   643  	}
   644  
   645  	for _, tc := range tests {
   646  		t.Run(tc.name, func(t *testing.T) {
   647  			// Get the initial state.
   648  			s, ts, err := tc.init()
   649  			if err != nil {
   650  				t.Fatal(err)
   651  			}
   652  			machine := fsm.MakeMachine(TxnStateTransitions, s, ts)
   653  
   654  			// Perform the test's transition.
   655  			ev := tc.ev
   656  			payload := tc.evPayload
   657  			if tc.evFun != nil {
   658  				ev, payload = tc.evFun(ts)
   659  			}
   660  			if err := machine.ApplyWithPayload(ctx, ev, payload); err != nil {
   661  				t.Fatal(err)
   662  			}
   663  
   664  			// Check that we moved to the right high-level state.
   665  			if state := machine.CurState(); state != tc.expState {
   666  				t.Fatalf("expected state %#v, got: %#v", tc.expState, state)
   667  			}
   668  
   669  			// Check the resulting advanceInfo.
   670  			adv := ts.consumeAdvanceInfo()
   671  			expRewPos := noRewindExpected
   672  			if tc.expAdv.expCode == rewind {
   673  				expRewPos = dummyRewCap.rewindPos
   674  			}
   675  			if err := checkAdv(
   676  				adv, tc.expAdv.expCode, expRewPos, tc.expAdv.expEv,
   677  			); err != nil {
   678  				t.Fatal(err)
   679  			}
   680  
   681  			// Check that the KV txn is in the expected state.
   682  			if tc.expTxn == nil {
   683  				if ts.mu.txn != nil {
   684  					t.Fatalf("expected no txn, got: %+v", ts.mu.txn)
   685  				}
   686  			} else {
   687  				if err := checkTxn(ts.mu.txn, *tc.expTxn); err != nil {
   688  					t.Fatal(err)
   689  				}
   690  			}
   691  		})
   692  	}
   693  }