vitess.io/vitess@v0.16.2/go/vt/vtgate/tx_conn_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package vtgate
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"testing"
    23  
    24  	"github.com/stretchr/testify/assert"
    25  
    26  	"vitess.io/vitess/go/test/utils"
    27  
    28  	"github.com/stretchr/testify/require"
    29  
    30  	"vitess.io/vitess/go/vt/discovery"
    31  	"vitess.io/vitess/go/vt/key"
    32  	"vitess.io/vitess/go/vt/srvtopo"
    33  	"vitess.io/vitess/go/vt/vterrors"
    34  	"vitess.io/vitess/go/vt/vttablet/sandboxconn"
    35  
    36  	querypb "vitess.io/vitess/go/vt/proto/query"
    37  	topodatapb "vitess.io/vitess/go/vt/proto/topodata"
    38  	vtgatepb "vitess.io/vitess/go/vt/proto/vtgate"
    39  	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
    40  )
    41  
    42  var queries = []*querypb.BoundQuery{{Sql: "query1"}}
    43  var twoQueries = []*querypb.BoundQuery{{Sql: "query1"}, {Sql: "query1"}}
    44  
    45  func TestTxConnBegin(t *testing.T) {
    46  	sc, sbc0, _, rss0, _, _ := newTestTxConnEnv(t, "TestTxConn")
    47  	session := &vtgatepb.Session{}
    48  
    49  	// begin
    50  	safeSession := NewSafeSession(session)
    51  	err := sc.txConn.Begin(ctx, safeSession, nil)
    52  	require.NoError(t, err)
    53  	wantSession := vtgatepb.Session{InTransaction: true}
    54  	utils.MustMatch(t, &wantSession, session, "Session")
    55  	_, errors := sc.ExecuteMultiShard(ctx, nil, rss0, queries, safeSession, false, false)
    56  	require.Empty(t, errors)
    57  
    58  	// Begin again should cause a commit and a new begin.
    59  	require.NoError(t,
    60  		sc.txConn.Begin(ctx, safeSession, nil))
    61  	utils.MustMatch(t, &wantSession, session, "Session")
    62  	assert.EqualValues(t, 1, sbc0.CommitCount.Get(), "sbc0.CommitCount")
    63  }
    64  
    65  func TestTxConnCommitFailure(t *testing.T) {
    66  	sc, sbc0, sbc1, rss0, rss1, rss01 := newTestTxConnEnv(t, "TestTxConn")
    67  	sc.txConn.mode = vtgatepb.TransactionMode_MULTI
    68  
    69  	// Sequence the executes to ensure commit order
    70  	session := NewSafeSession(&vtgatepb.Session{InTransaction: true})
    71  	sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
    72  	wantSession := vtgatepb.Session{
    73  		InTransaction: true,
    74  		ShardSessions: []*vtgatepb.Session_ShardSession{{
    75  			Target: &querypb.Target{
    76  				Keyspace:   "TestTxConn",
    77  				Shard:      "0",
    78  				TabletType: topodatapb.TabletType_PRIMARY,
    79  			},
    80  			TransactionId: 1,
    81  			TabletAlias:   sbc0.Tablet().Alias,
    82  		}},
    83  	}
    84  	utils.MustMatch(t, &wantSession, session.Session, "Session")
    85  	sc.ExecuteMultiShard(ctx, nil, rss01, twoQueries, session, false, false)
    86  	wantSession = vtgatepb.Session{
    87  		InTransaction: true,
    88  		ShardSessions: []*vtgatepb.Session_ShardSession{{
    89  			Target: &querypb.Target{
    90  				Keyspace:   "TestTxConn",
    91  				Shard:      "0",
    92  				TabletType: topodatapb.TabletType_PRIMARY,
    93  			},
    94  			TransactionId: 1,
    95  			TabletAlias:   sbc0.Tablet().Alias,
    96  		}, {
    97  			Target: &querypb.Target{
    98  				Keyspace:   "TestTxConn",
    99  				Shard:      "1",
   100  				TabletType: topodatapb.TabletType_PRIMARY,
   101  			},
   102  			TransactionId: 1,
   103  			TabletAlias:   sbc1.Tablet().Alias,
   104  		}},
   105  	}
   106  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   107  
   108  	sbc1.MustFailCodes[vtrpcpb.Code_DEADLINE_EXCEEDED] = 1
   109  
   110  	expectErr := NewShardError(vterrors.New(
   111  		vtrpcpb.Code_DEADLINE_EXCEEDED,
   112  		fmt.Sprintf("%v error", vtrpcpb.Code_DEADLINE_EXCEEDED)),
   113  		rss1[0].Target)
   114  
   115  	require.ErrorContains(t, sc.txConn.Commit(ctx, session), expectErr.Error())
   116  	wantSession = vtgatepb.Session{}
   117  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   118  	assert.EqualValues(t, 1, sbc0.CommitCount.Get(), "sbc0.CommitCount")
   119  	assert.EqualValues(t, 1, sbc1.CommitCount.Get(), "sbc1.CommitCount")
   120  }
   121  
   122  func TestTxConnCommitSuccess(t *testing.T) {
   123  	sc, sbc0, sbc1, rss0, _, rss01 := newTestTxConnEnv(t, "TestTxConn")
   124  	sc.txConn.mode = vtgatepb.TransactionMode_MULTI
   125  
   126  	// Sequence the executes to ensure commit order
   127  	session := NewSafeSession(&vtgatepb.Session{InTransaction: true})
   128  	sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
   129  	wantSession := vtgatepb.Session{
   130  		InTransaction: true,
   131  		ShardSessions: []*vtgatepb.Session_ShardSession{{
   132  			Target: &querypb.Target{
   133  				Keyspace:   "TestTxConn",
   134  				Shard:      "0",
   135  				TabletType: topodatapb.TabletType_PRIMARY,
   136  			},
   137  			TransactionId: 1,
   138  			TabletAlias:   sbc0.Tablet().Alias,
   139  		}},
   140  	}
   141  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   142  	sc.ExecuteMultiShard(ctx, nil, rss01, twoQueries, session, false, false)
   143  	wantSession = vtgatepb.Session{
   144  		InTransaction: true,
   145  		ShardSessions: []*vtgatepb.Session_ShardSession{{
   146  			Target: &querypb.Target{
   147  				Keyspace:   "TestTxConn",
   148  				Shard:      "0",
   149  				TabletType: topodatapb.TabletType_PRIMARY,
   150  			},
   151  			TransactionId: 1,
   152  			TabletAlias:   sbc0.Tablet().Alias,
   153  		}, {
   154  			Target: &querypb.Target{
   155  				Keyspace:   "TestTxConn",
   156  				Shard:      "1",
   157  				TabletType: topodatapb.TabletType_PRIMARY,
   158  			},
   159  			TransactionId: 1,
   160  			TabletAlias:   sbc1.Tablet().Alias,
   161  		}},
   162  	}
   163  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   164  
   165  	require.NoError(t,
   166  		sc.txConn.Commit(ctx, session))
   167  	wantSession = vtgatepb.Session{}
   168  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   169  	assert.EqualValues(t, 1, sbc0.CommitCount.Get(), "sbc0.CommitCount")
   170  	assert.EqualValues(t, 1, sbc1.CommitCount.Get(), "sbc1.CommitCount")
   171  }
   172  
   173  func TestTxConnReservedCommitSuccess(t *testing.T) {
   174  	sc, sbc0, sbc1, rss0, _, rss01 := newTestTxConnEnv(t, "TestTxConn")
   175  	sc.txConn.mode = vtgatepb.TransactionMode_MULTI
   176  
   177  	// Sequence the executes to ensure commit order
   178  	session := NewSafeSession(&vtgatepb.Session{InTransaction: true, InReservedConn: true})
   179  	sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
   180  	wantSession := vtgatepb.Session{
   181  		InTransaction:  true,
   182  		InReservedConn: true,
   183  		ShardSessions: []*vtgatepb.Session_ShardSession{{
   184  			Target: &querypb.Target{
   185  				Keyspace:   "TestTxConn",
   186  				Shard:      "0",
   187  				TabletType: topodatapb.TabletType_PRIMARY,
   188  			},
   189  			TransactionId: 1,
   190  			ReservedId:    1,
   191  			TabletAlias:   sbc0.Tablet().Alias,
   192  		}},
   193  	}
   194  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   195  	sc.ExecuteMultiShard(ctx, nil, rss01, twoQueries, session, false, false)
   196  	wantSession = vtgatepb.Session{
   197  		InTransaction:  true,
   198  		InReservedConn: true,
   199  		ShardSessions: []*vtgatepb.Session_ShardSession{{
   200  			Target: &querypb.Target{
   201  				Keyspace:   "TestTxConn",
   202  				Shard:      "0",
   203  				TabletType: topodatapb.TabletType_PRIMARY,
   204  			},
   205  			TransactionId: 1,
   206  			ReservedId:    1,
   207  			TabletAlias:   sbc0.Tablet().Alias,
   208  		}, {
   209  			Target: &querypb.Target{
   210  				Keyspace:   "TestTxConn",
   211  				Shard:      "1",
   212  				TabletType: topodatapb.TabletType_PRIMARY,
   213  			},
   214  			TransactionId: 1,
   215  			ReservedId:    1,
   216  			TabletAlias:   sbc1.Tablet().Alias,
   217  		}},
   218  	}
   219  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   220  
   221  	require.NoError(t,
   222  		sc.txConn.Commit(ctx, session))
   223  	wantSession = vtgatepb.Session{
   224  		InReservedConn: true,
   225  		ShardSessions: []*vtgatepb.Session_ShardSession{{
   226  			Target: &querypb.Target{
   227  				Keyspace:   "TestTxConn",
   228  				Shard:      "0",
   229  				TabletType: topodatapb.TabletType_PRIMARY,
   230  			},
   231  			ReservedId:  2,
   232  			TabletAlias: sbc0.Tablet().Alias,
   233  		}, {
   234  			Target: &querypb.Target{
   235  				Keyspace:   "TestTxConn",
   236  				Shard:      "1",
   237  				TabletType: topodatapb.TabletType_PRIMARY,
   238  			},
   239  			ReservedId:  2,
   240  			TabletAlias: sbc1.Tablet().Alias,
   241  		}},
   242  	}
   243  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   244  	assert.EqualValues(t, 1, sbc0.CommitCount.Get(), "sbc0.CommitCount")
   245  	assert.EqualValues(t, 1, sbc1.CommitCount.Get(), "sbc1.CommitCount")
   246  
   247  	require.NoError(t,
   248  		sc.txConn.Release(ctx, session))
   249  	wantSession = vtgatepb.Session{InReservedConn: true}
   250  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   251  	assert.EqualValues(t, 1, sbc0.ReleaseCount.Get(), "sbc0.ReleaseCount")
   252  	assert.EqualValues(t, 1, sbc1.ReleaseCount.Get(), "sbc1.ReleaseCount")
   253  }
   254  
   255  func TestTxConnReservedOn2ShardTxOn1ShardAndCommit(t *testing.T) {
   256  	keyspace := "TestTxConn"
   257  	sc, sbc0, sbc1, rss0, rss1, _ := newTestTxConnEnv(t, keyspace)
   258  	sc.txConn.mode = vtgatepb.TransactionMode_MULTI
   259  
   260  	// Sequence the executes to ensure shard session order
   261  	session := NewSafeSession(&vtgatepb.Session{InReservedConn: true})
   262  
   263  	// this will create reserved connections against all tablets
   264  	_, errs := sc.ExecuteMultiShard(ctx, nil, rss1, queries, session, false, false)
   265  	require.Empty(t, errs)
   266  	_, errs = sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
   267  	require.Empty(t, errs)
   268  
   269  	wantSession := vtgatepb.Session{
   270  		InReservedConn: true,
   271  		ShardSessions: []*vtgatepb.Session_ShardSession{{
   272  			Target: &querypb.Target{
   273  				Keyspace:   keyspace,
   274  				Shard:      "1",
   275  				TabletType: topodatapb.TabletType_PRIMARY,
   276  			},
   277  			ReservedId:  1,
   278  			TabletAlias: sbc1.Tablet().Alias,
   279  		}, {
   280  			Target: &querypb.Target{
   281  				Keyspace:   keyspace,
   282  				Shard:      "0",
   283  				TabletType: topodatapb.TabletType_PRIMARY,
   284  			},
   285  			ReservedId:  1,
   286  			TabletAlias: sbc0.Tablet().Alias,
   287  		}},
   288  	}
   289  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   290  
   291  	session.Session.InTransaction = true
   292  
   293  	// start a transaction against rss0
   294  	sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
   295  	wantSession = vtgatepb.Session{
   296  		InTransaction:  true,
   297  		InReservedConn: true,
   298  		ShardSessions: []*vtgatepb.Session_ShardSession{{
   299  			Target: &querypb.Target{
   300  				Keyspace:   keyspace,
   301  				Shard:      "1",
   302  				TabletType: topodatapb.TabletType_PRIMARY,
   303  			},
   304  			ReservedId:  1,
   305  			TabletAlias: sbc1.Tablet().Alias,
   306  		}, {
   307  			Target: &querypb.Target{
   308  				Keyspace:   keyspace,
   309  				Shard:      "0",
   310  				TabletType: topodatapb.TabletType_PRIMARY,
   311  			},
   312  			TransactionId: 1,
   313  			ReservedId:    1,
   314  			TabletAlias:   sbc0.Tablet().Alias,
   315  		}},
   316  	}
   317  
   318  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   319  
   320  	require.NoError(t,
   321  		sc.txConn.Commit(ctx, session))
   322  
   323  	wantSession = vtgatepb.Session{
   324  		InReservedConn: true,
   325  		ShardSessions: []*vtgatepb.Session_ShardSession{{
   326  			Target: &querypb.Target{
   327  				Keyspace:   keyspace,
   328  				Shard:      "1",
   329  				TabletType: topodatapb.TabletType_PRIMARY,
   330  			},
   331  			ReservedId:  1,
   332  			TabletAlias: sbc1.Tablet().Alias,
   333  		}, {
   334  			Target: &querypb.Target{
   335  				Keyspace:   keyspace,
   336  				Shard:      "0",
   337  				TabletType: topodatapb.TabletType_PRIMARY,
   338  			},
   339  			ReservedId:  2,
   340  			TabletAlias: sbc0.Tablet().Alias,
   341  		}},
   342  	}
   343  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   344  	assert.EqualValues(t, 1, sbc0.CommitCount.Get(), "sbc0.CommitCount")
   345  	assert.EqualValues(t, 0, sbc1.CommitCount.Get(), "sbc1.CommitCount")
   346  }
   347  
   348  func TestTxConnReservedOn2ShardTxOn1ShardAndRollback(t *testing.T) {
   349  	keyspace := "TestTxConn"
   350  	sc, sbc0, sbc1, rss0, rss1, _ := newTestTxConnEnv(t, keyspace)
   351  	sc.txConn.mode = vtgatepb.TransactionMode_MULTI
   352  
   353  	// Sequence the executes to ensure shard session order
   354  	session := NewSafeSession(&vtgatepb.Session{InReservedConn: true})
   355  
   356  	// this will create reserved connections against all tablets
   357  	_, errs := sc.ExecuteMultiShard(ctx, nil, rss1, queries, session, false, false)
   358  	require.Empty(t, errs)
   359  	_, errs = sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
   360  	require.Empty(t, errs)
   361  
   362  	wantSession := vtgatepb.Session{
   363  		InReservedConn: true,
   364  		ShardSessions: []*vtgatepb.Session_ShardSession{{
   365  			Target: &querypb.Target{
   366  				Keyspace:   keyspace,
   367  				Shard:      "1",
   368  				TabletType: topodatapb.TabletType_PRIMARY,
   369  			},
   370  			ReservedId:  1,
   371  			TabletAlias: sbc1.Tablet().Alias,
   372  		}, {
   373  			Target: &querypb.Target{
   374  				Keyspace:   keyspace,
   375  				Shard:      "0",
   376  				TabletType: topodatapb.TabletType_PRIMARY,
   377  			},
   378  			ReservedId:  1,
   379  			TabletAlias: sbc0.Tablet().Alias,
   380  		}},
   381  	}
   382  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   383  
   384  	session.Session.InTransaction = true
   385  
   386  	// start a transaction against rss0
   387  	sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
   388  	wantSession = vtgatepb.Session{
   389  		InTransaction:  true,
   390  		InReservedConn: true,
   391  		ShardSessions: []*vtgatepb.Session_ShardSession{{
   392  			Target: &querypb.Target{
   393  				Keyspace:   keyspace,
   394  				Shard:      "1",
   395  				TabletType: topodatapb.TabletType_PRIMARY,
   396  			},
   397  			ReservedId:  1,
   398  			TabletAlias: sbc1.Tablet().Alias,
   399  		}, {
   400  			Target: &querypb.Target{
   401  				Keyspace:   keyspace,
   402  				Shard:      "0",
   403  				TabletType: topodatapb.TabletType_PRIMARY,
   404  			},
   405  			TransactionId: 1,
   406  			ReservedId:    1,
   407  			TabletAlias:   sbc0.Tablet().Alias,
   408  		}},
   409  	}
   410  
   411  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   412  
   413  	require.NoError(t,
   414  		sc.txConn.Rollback(ctx, session))
   415  
   416  	wantSession = vtgatepb.Session{
   417  		InReservedConn: true,
   418  		ShardSessions: []*vtgatepb.Session_ShardSession{{
   419  			Target: &querypb.Target{
   420  				Keyspace:   keyspace,
   421  				Shard:      "1",
   422  				TabletType: topodatapb.TabletType_PRIMARY,
   423  			},
   424  			ReservedId:  1,
   425  			TabletAlias: sbc1.Tablet().Alias,
   426  		}, {
   427  			Target: &querypb.Target{
   428  				Keyspace:   keyspace,
   429  				Shard:      "0",
   430  				TabletType: topodatapb.TabletType_PRIMARY,
   431  			},
   432  			ReservedId:  2,
   433  			TabletAlias: sbc0.Tablet().Alias,
   434  		}},
   435  	}
   436  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   437  	assert.EqualValues(t, 1, sbc0.RollbackCount.Get(), "sbc0.RollbackCount")
   438  	assert.EqualValues(t, 0, sbc1.RollbackCount.Get(), "sbc1.RollbackCount")
   439  }
   440  
   441  func TestTxConnCommitOrderFailure1(t *testing.T) {
   442  	sc, sbc0, sbc1, rss0, rss1, _ := newTestTxConnEnv(t, "TestTxConn")
   443  	sc.txConn.mode = vtgatepb.TransactionMode_MULTI
   444  
   445  	queries := []*querypb.BoundQuery{{Sql: "query1"}}
   446  
   447  	// Sequence the executes to ensure commit order
   448  	session := NewSafeSession(&vtgatepb.Session{InTransaction: true})
   449  	sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
   450  
   451  	session.SetCommitOrder(vtgatepb.CommitOrder_PRE)
   452  	sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
   453  
   454  	session.SetCommitOrder(vtgatepb.CommitOrder_POST)
   455  	sc.ExecuteMultiShard(ctx, nil, rss1, queries, session, false, false)
   456  
   457  	sbc0.MustFailCodes[vtrpcpb.Code_INVALID_ARGUMENT] = 1
   458  	err := sc.txConn.Commit(ctx, session)
   459  	require.Error(t, err)
   460  	assert.Contains(t, err.Error(), "INVALID_ARGUMENT error", "commit error")
   461  
   462  	wantSession := vtgatepb.Session{}
   463  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   464  	assert.EqualValues(t, 1, sbc0.CommitCount.Get(), "sbc0.CommitCount")
   465  	// first commit failed so we don't try to commit the second shard
   466  	assert.EqualValues(t, 0, sbc1.CommitCount.Get(), "sbc1.CommitCount")
   467  	// When the commit fails, we try to clean up by issuing a rollback
   468  	assert.EqualValues(t, 2, sbc0.ReleaseCount.Get(), "sbc0.ReleaseCount")
   469  	assert.EqualValues(t, 1, sbc1.ReleaseCount.Get(), "sbc1.ReleaseCount")
   470  }
   471  
   472  func TestTxConnCommitOrderFailure2(t *testing.T) {
   473  	sc, sbc0, sbc1, rss0, rss1, _ := newTestTxConnEnv(t, "TestTxConn")
   474  	sc.txConn.mode = vtgatepb.TransactionMode_MULTI
   475  
   476  	queries := []*querypb.BoundQuery{{
   477  		Sql: "query1",
   478  	}}
   479  
   480  	// Sequence the executes to ensure commit order
   481  	session := NewSafeSession(&vtgatepb.Session{InTransaction: true})
   482  	sc.ExecuteMultiShard(context.Background(), nil, rss1, queries, session, false, false)
   483  
   484  	session.SetCommitOrder(vtgatepb.CommitOrder_PRE)
   485  	sc.ExecuteMultiShard(context.Background(), nil, rss0, queries, session, false, false)
   486  
   487  	session.SetCommitOrder(vtgatepb.CommitOrder_POST)
   488  	sc.ExecuteMultiShard(context.Background(), nil, rss1, queries, session, false, false)
   489  
   490  	sbc1.MustFailCodes[vtrpcpb.Code_INVALID_ARGUMENT] = 1
   491  	err := sc.txConn.Commit(ctx, session)
   492  	require.Error(t, err)
   493  	assert.Contains(t, err.Error(), "INVALID_ARGUMENT error", "Commit")
   494  
   495  	wantSession := vtgatepb.Session{}
   496  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   497  	assert.EqualValues(t, 1, sbc0.CommitCount.Get(), "sbc0.CommitCount")
   498  	assert.EqualValues(t, 1, sbc1.CommitCount.Get(), "sbc1.CommitCount")
   499  	// When the commit fails, we try to clean up by issuing a rollback
   500  	assert.EqualValues(t, 0, sbc0.ReleaseCount.Get(), "sbc0.ReleaseCount")
   501  	assert.EqualValues(t, 2, sbc1.ReleaseCount.Get(), "sbc1.ReleaseCount")
   502  }
   503  
   504  func TestTxConnCommitOrderFailure3(t *testing.T) {
   505  	sc, sbc0, sbc1, rss0, rss1, _ := newTestTxConnEnv(t, "TestTxConn")
   506  	sc.txConn.mode = vtgatepb.TransactionMode_MULTI
   507  
   508  	queries := []*querypb.BoundQuery{{
   509  		Sql: "query1",
   510  	}}
   511  
   512  	// Sequence the executes to ensure commit order
   513  	session := NewSafeSession(&vtgatepb.Session{InTransaction: true})
   514  	sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
   515  
   516  	session.SetCommitOrder(vtgatepb.CommitOrder_PRE)
   517  	sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
   518  
   519  	session.SetCommitOrder(vtgatepb.CommitOrder_POST)
   520  	sc.ExecuteMultiShard(ctx, nil, rss1, queries, session, false, false)
   521  
   522  	sbc1.MustFailCodes[vtrpcpb.Code_INVALID_ARGUMENT] = 1
   523  	require.NoError(t,
   524  		sc.txConn.Commit(ctx, session))
   525  
   526  	// The last failed commit must generate a warning.
   527  	expectErr := NewShardError(vterrors.New(
   528  		vtrpcpb.Code_INVALID_ARGUMENT,
   529  		fmt.Sprintf("%v error", vtrpcpb.Code_INVALID_ARGUMENT)),
   530  		rss1[0].Target)
   531  
   532  	wantSession := vtgatepb.Session{
   533  		Warnings: []*querypb.QueryWarning{{
   534  			Message: fmt.Sprintf("post-operation transaction had an error: %v", expectErr),
   535  		}},
   536  	}
   537  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   538  	assert.EqualValues(t, 2, sbc0.CommitCount.Get(), "sbc0.CommitCount")
   539  	assert.EqualValues(t, 1, sbc1.CommitCount.Get(), "sbc1.CommitCount")
   540  	assert.EqualValues(t, 0, sbc0.RollbackCount.Get(), "sbc0.RollbackCount")
   541  	assert.EqualValues(t, 0, sbc1.RollbackCount.Get(), "sbc1.RollbackCount")
   542  }
   543  
   544  func TestTxConnCommitOrderSuccess(t *testing.T) {
   545  	sc, sbc0, sbc1, rss0, rss1, _ := newTestTxConnEnv(t, "TestTxConn")
   546  	sc.txConn.mode = vtgatepb.TransactionMode_MULTI
   547  
   548  	queries := []*querypb.BoundQuery{{
   549  		Sql: "query1",
   550  	}}
   551  
   552  	// Sequence the executes to ensure commit order
   553  	session := NewSafeSession(&vtgatepb.Session{InTransaction: true})
   554  	sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
   555  	wantSession := vtgatepb.Session{
   556  		InTransaction: true,
   557  		ShardSessions: []*vtgatepb.Session_ShardSession{{
   558  			Target: &querypb.Target{
   559  				Keyspace:   "TestTxConn",
   560  				Shard:      "0",
   561  				TabletType: topodatapb.TabletType_PRIMARY,
   562  			},
   563  			TransactionId: 1,
   564  			TabletAlias:   sbc0.Tablet().Alias,
   565  		}},
   566  	}
   567  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   568  
   569  	session.SetCommitOrder(vtgatepb.CommitOrder_PRE)
   570  	sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
   571  	wantSession = vtgatepb.Session{
   572  		InTransaction: true,
   573  		PreSessions: []*vtgatepb.Session_ShardSession{{
   574  			Target: &querypb.Target{
   575  				Keyspace:   "TestTxConn",
   576  				Shard:      "0",
   577  				TabletType: topodatapb.TabletType_PRIMARY,
   578  			},
   579  			TransactionId: 2,
   580  			TabletAlias:   sbc0.Tablet().Alias,
   581  		}},
   582  		ShardSessions: []*vtgatepb.Session_ShardSession{{
   583  			Target: &querypb.Target{
   584  				Keyspace:   "TestTxConn",
   585  				Shard:      "0",
   586  				TabletType: topodatapb.TabletType_PRIMARY,
   587  			},
   588  			TransactionId: 1,
   589  			TabletAlias:   sbc0.Tablet().Alias,
   590  		}},
   591  	}
   592  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   593  
   594  	session.SetCommitOrder(vtgatepb.CommitOrder_POST)
   595  	sc.ExecuteMultiShard(ctx, nil, rss1, queries, session, false, false)
   596  	wantSession = vtgatepb.Session{
   597  		InTransaction: true,
   598  		PreSessions: []*vtgatepb.Session_ShardSession{{
   599  			Target: &querypb.Target{
   600  				Keyspace:   "TestTxConn",
   601  				Shard:      "0",
   602  				TabletType: topodatapb.TabletType_PRIMARY,
   603  			},
   604  			TransactionId: 2,
   605  			TabletAlias:   sbc0.Tablet().Alias,
   606  		}},
   607  		ShardSessions: []*vtgatepb.Session_ShardSession{{
   608  			Target: &querypb.Target{
   609  				Keyspace:   "TestTxConn",
   610  				Shard:      "0",
   611  				TabletType: topodatapb.TabletType_PRIMARY,
   612  			},
   613  			TransactionId: 1,
   614  			TabletAlias:   sbc0.Tablet().Alias,
   615  		}},
   616  		PostSessions: []*vtgatepb.Session_ShardSession{{
   617  			Target: &querypb.Target{
   618  				Keyspace:   "TestTxConn",
   619  				Shard:      "1",
   620  				TabletType: topodatapb.TabletType_PRIMARY,
   621  			},
   622  			TransactionId: 1,
   623  			TabletAlias:   sbc1.Tablet().Alias,
   624  		}},
   625  	}
   626  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   627  
   628  	// Ensure nothing changes if we reuse a transaction.
   629  	sc.ExecuteMultiShard(ctx, nil, rss1, queries, session, false, false)
   630  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   631  
   632  	require.NoError(t,
   633  		sc.txConn.Commit(ctx, session))
   634  	wantSession = vtgatepb.Session{}
   635  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   636  	assert.EqualValues(t, 2, sbc0.CommitCount.Get(), "sbc0.CommitCount")
   637  	assert.EqualValues(t, 1, sbc1.CommitCount.Get(), "sbc1.CommitCount")
   638  }
   639  
   640  func TestTxConnReservedCommitOrderSuccess(t *testing.T) {
   641  	sc, sbc0, sbc1, rss0, rss1, _ := newTestTxConnEnv(t, "TestTxConn")
   642  	sc.txConn.mode = vtgatepb.TransactionMode_MULTI
   643  
   644  	queries := []*querypb.BoundQuery{{
   645  		Sql: "query1",
   646  	}}
   647  
   648  	// Sequence the executes to ensure commit order
   649  	session := NewSafeSession(&vtgatepb.Session{InTransaction: true, InReservedConn: true})
   650  	sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
   651  	wantSession := vtgatepb.Session{
   652  		InTransaction:  true,
   653  		InReservedConn: true,
   654  		ShardSessions: []*vtgatepb.Session_ShardSession{{
   655  			Target: &querypb.Target{
   656  				Keyspace:   "TestTxConn",
   657  				Shard:      "0",
   658  				TabletType: topodatapb.TabletType_PRIMARY,
   659  			},
   660  			TransactionId: 1,
   661  			ReservedId:    1,
   662  			TabletAlias:   sbc0.Tablet().Alias,
   663  		}},
   664  	}
   665  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   666  
   667  	session.SetCommitOrder(vtgatepb.CommitOrder_PRE)
   668  	sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
   669  	wantSession = vtgatepb.Session{
   670  		InTransaction:  true,
   671  		InReservedConn: true,
   672  		PreSessions: []*vtgatepb.Session_ShardSession{{
   673  			Target: &querypb.Target{
   674  				Keyspace:   "TestTxConn",
   675  				Shard:      "0",
   676  				TabletType: topodatapb.TabletType_PRIMARY,
   677  			},
   678  			TransactionId: 2,
   679  			ReservedId:    2,
   680  			TabletAlias:   sbc0.Tablet().Alias,
   681  		}},
   682  		ShardSessions: []*vtgatepb.Session_ShardSession{{
   683  			Target: &querypb.Target{
   684  				Keyspace:   "TestTxConn",
   685  				Shard:      "0",
   686  				TabletType: topodatapb.TabletType_PRIMARY,
   687  			},
   688  			TransactionId: 1,
   689  			ReservedId:    1,
   690  			TabletAlias:   sbc0.Tablet().Alias,
   691  		}},
   692  	}
   693  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   694  
   695  	session.SetCommitOrder(vtgatepb.CommitOrder_POST)
   696  	sc.ExecuteMultiShard(ctx, nil, rss1, queries, session, false, false)
   697  	wantSession = vtgatepb.Session{
   698  		InTransaction:  true,
   699  		InReservedConn: true,
   700  		PreSessions: []*vtgatepb.Session_ShardSession{{
   701  			Target: &querypb.Target{
   702  				Keyspace:   "TestTxConn",
   703  				Shard:      "0",
   704  				TabletType: topodatapb.TabletType_PRIMARY,
   705  			},
   706  			TransactionId: 2,
   707  			ReservedId:    2,
   708  			TabletAlias:   sbc0.Tablet().Alias,
   709  		}},
   710  		ShardSessions: []*vtgatepb.Session_ShardSession{{
   711  			Target: &querypb.Target{
   712  				Keyspace:   "TestTxConn",
   713  				Shard:      "0",
   714  				TabletType: topodatapb.TabletType_PRIMARY,
   715  			},
   716  			TransactionId: 1,
   717  			ReservedId:    1,
   718  			TabletAlias:   sbc0.Tablet().Alias,
   719  		}},
   720  		PostSessions: []*vtgatepb.Session_ShardSession{{
   721  			Target: &querypb.Target{
   722  				Keyspace:   "TestTxConn",
   723  				Shard:      "1",
   724  				TabletType: topodatapb.TabletType_PRIMARY,
   725  			},
   726  			TransactionId: 1,
   727  			ReservedId:    1,
   728  			TabletAlias:   sbc1.Tablet().Alias,
   729  		}},
   730  	}
   731  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   732  
   733  	// Ensure nothing changes if we reuse a transaction.
   734  	sc.ExecuteMultiShard(ctx, nil, rss1, queries, session, false, false)
   735  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   736  
   737  	require.NoError(t,
   738  		sc.txConn.Commit(ctx, session))
   739  	wantSession = vtgatepb.Session{
   740  		InReservedConn: true,
   741  		PreSessions: []*vtgatepb.Session_ShardSession{{
   742  			Target: &querypb.Target{
   743  				Keyspace:   "TestTxConn",
   744  				Shard:      "0",
   745  				TabletType: topodatapb.TabletType_PRIMARY,
   746  			},
   747  			ReservedId:  3,
   748  			TabletAlias: sbc0.Tablet().Alias,
   749  		}},
   750  		ShardSessions: []*vtgatepb.Session_ShardSession{{
   751  			Target: &querypb.Target{
   752  				Keyspace:   "TestTxConn",
   753  				Shard:      "0",
   754  				TabletType: topodatapb.TabletType_PRIMARY,
   755  			},
   756  			ReservedId:  4,
   757  			TabletAlias: sbc0.Tablet().Alias,
   758  		}},
   759  		PostSessions: []*vtgatepb.Session_ShardSession{{
   760  			Target: &querypb.Target{
   761  				Keyspace:   "TestTxConn",
   762  				Shard:      "1",
   763  				TabletType: topodatapb.TabletType_PRIMARY,
   764  			},
   765  			ReservedId:  2,
   766  			TabletAlias: sbc1.Tablet().Alias,
   767  		}},
   768  	}
   769  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   770  	assert.EqualValues(t, 2, sbc0.CommitCount.Get(), "sbc0.CommitCount")
   771  	assert.EqualValues(t, 1, sbc1.CommitCount.Get(), "sbc1.CommitCount")
   772  
   773  	require.NoError(t,
   774  		sc.txConn.Release(ctx, session))
   775  	wantSession = vtgatepb.Session{InReservedConn: true}
   776  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   777  	assert.EqualValues(t, 2, sbc0.ReleaseCount.Get(), "sbc0.ReleaseCount")
   778  	assert.EqualValues(t, 1, sbc1.ReleaseCount.Get(), "sbc1.ReleaseCount")
   779  }
   780  
   781  func TestTxConnCommit2PC(t *testing.T) {
   782  	sc, sbc0, sbc1, rss0, _, rss01 := newTestTxConnEnv(t, "TestTxConnCommit2PC")
   783  
   784  	session := NewSafeSession(&vtgatepb.Session{InTransaction: true})
   785  	sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
   786  	sc.ExecuteMultiShard(ctx, nil, rss01, twoQueries, session, false, false)
   787  	session.TransactionMode = vtgatepb.TransactionMode_TWOPC
   788  	require.NoError(t,
   789  		sc.txConn.Commit(ctx, session))
   790  	assert.EqualValues(t, 1, sbc0.CreateTransactionCount.Get(), "sbc0.CreateTransactionCount")
   791  	assert.EqualValues(t, 1, sbc1.PrepareCount.Get(), "sbc1.PrepareCount")
   792  	assert.EqualValues(t, 1, sbc0.StartCommitCount.Get(), "sbc0.StartCommitCount")
   793  	assert.EqualValues(t, 1, sbc1.CommitPreparedCount.Get(), "sbc1.CommitPreparedCount")
   794  	assert.EqualValues(t, 1, sbc0.ConcludeTransactionCount.Get(), "sbc0.ConcludeTransactionCount")
   795  }
   796  
   797  func TestTxConnCommit2PCOneParticipant(t *testing.T) {
   798  	sc, sbc0, _, rss0, _, _ := newTestTxConnEnv(t, "TestTxConnCommit2PCOneParticipant")
   799  	session := NewSafeSession(&vtgatepb.Session{InTransaction: true})
   800  	sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
   801  	session.TransactionMode = vtgatepb.TransactionMode_TWOPC
   802  	require.NoError(t,
   803  		sc.txConn.Commit(ctx, session))
   804  	assert.EqualValues(t, 1, sbc0.CommitCount.Get(), "sbc0.CommitCount")
   805  }
   806  
   807  func TestTxConnCommit2PCCreateTransactionFail(t *testing.T) {
   808  	sc, sbc0, sbc1, rss0, rss1, _ := newTestTxConnEnv(t, "TestTxConnCommit2PCCreateTransactionFail")
   809  
   810  	session := NewSafeSession(&vtgatepb.Session{InTransaction: true})
   811  	sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
   812  	sc.ExecuteMultiShard(ctx, nil, rss1, queries, session, false, false)
   813  
   814  	sbc0.MustFailCreateTransaction = 1
   815  	session.TransactionMode = vtgatepb.TransactionMode_TWOPC
   816  	err := sc.txConn.Commit(ctx, session)
   817  	want := "error: err"
   818  	require.Error(t, err)
   819  	assert.Contains(t, err.Error(), want, "Commit")
   820  	assert.EqualValues(t, 1, sbc0.CreateTransactionCount.Get(), "sbc0.CreateTransactionCount")
   821  	assert.EqualValues(t, 1, sbc0.RollbackCount.Get(), "sbc0.RollbackCount")
   822  	assert.EqualValues(t, 1, sbc1.RollbackCount.Get(), "sbc1.RollbackCount")
   823  	assert.EqualValues(t, 0, sbc1.PrepareCount.Get(), "sbc1.PrepareCount")
   824  	assert.EqualValues(t, 0, sbc0.StartCommitCount.Get(), "sbc0.StartCommitCount")
   825  	assert.EqualValues(t, 0, sbc1.CommitPreparedCount.Get(), "sbc1.CommitPreparedCount")
   826  	assert.EqualValues(t, 0, sbc0.ConcludeTransactionCount.Get(), "sbc0.ConcludeTransactionCount")
   827  }
   828  
   829  func TestTxConnCommit2PCPrepareFail(t *testing.T) {
   830  	sc, sbc0, sbc1, rss0, _, rss01 := newTestTxConnEnv(t, "TestTxConnCommit2PCPrepareFail")
   831  
   832  	session := NewSafeSession(&vtgatepb.Session{InTransaction: true})
   833  	sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
   834  	sc.ExecuteMultiShard(ctx, nil, rss01, twoQueries, session, false, false)
   835  
   836  	sbc1.MustFailPrepare = 1
   837  	session.TransactionMode = vtgatepb.TransactionMode_TWOPC
   838  	err := sc.txConn.Commit(ctx, session)
   839  	want := "error: err"
   840  	require.Error(t, err)
   841  	assert.Contains(t, err.Error(), want, "Commit")
   842  	assert.EqualValues(t, 1, sbc0.CreateTransactionCount.Get(), "sbc0.CreateTransactionCount")
   843  	assert.EqualValues(t, 1, sbc1.PrepareCount.Get(), "sbc1.PrepareCount")
   844  	assert.EqualValues(t, 0, sbc0.StartCommitCount.Get(), "sbc0.StartCommitCount")
   845  	assert.EqualValues(t, 0, sbc1.CommitPreparedCount.Get(), "sbc1.CommitPreparedCount")
   846  	assert.EqualValues(t, 0, sbc0.ConcludeTransactionCount.Get(), "sbc0.ConcludeTransactionCount")
   847  }
   848  
   849  func TestTxConnCommit2PCStartCommitFail(t *testing.T) {
   850  	sc, sbc0, sbc1, rss0, _, rss01 := newTestTxConnEnv(t, "TestTxConnCommit2PCStartCommitFail")
   851  
   852  	session := NewSafeSession(&vtgatepb.Session{InTransaction: true})
   853  	sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
   854  	sc.ExecuteMultiShard(ctx, nil, rss01, twoQueries, session, false, false)
   855  
   856  	sbc0.MustFailStartCommit = 1
   857  	session.TransactionMode = vtgatepb.TransactionMode_TWOPC
   858  	err := sc.txConn.Commit(ctx, session)
   859  	want := "error: err"
   860  	require.Error(t, err)
   861  	assert.Contains(t, err.Error(), want, "Commit")
   862  	assert.EqualValues(t, 1, sbc0.CreateTransactionCount.Get(), "sbc0.CreateTransactionCount")
   863  	assert.EqualValues(t, 1, sbc1.PrepareCount.Get(), "sbc1.PrepareCount")
   864  	assert.EqualValues(t, 1, sbc0.StartCommitCount.Get(), "sbc0.StartCommitCount")
   865  	assert.EqualValues(t, 0, sbc1.CommitPreparedCount.Get(), "sbc1.CommitPreparedCount")
   866  	assert.EqualValues(t, 0, sbc0.ConcludeTransactionCount.Get(), "sbc0.ConcludeTransactionCount")
   867  }
   868  
   869  func TestTxConnCommit2PCCommitPreparedFail(t *testing.T) {
   870  	sc, sbc0, sbc1, rss0, _, rss01 := newTestTxConnEnv(t, "TestTxConnCommit2PCCommitPreparedFail")
   871  
   872  	session := NewSafeSession(&vtgatepb.Session{InTransaction: true})
   873  	sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
   874  	sc.ExecuteMultiShard(ctx, nil, rss01, twoQueries, session, false, false)
   875  
   876  	sbc1.MustFailCommitPrepared = 1
   877  	session.TransactionMode = vtgatepb.TransactionMode_TWOPC
   878  	err := sc.txConn.Commit(ctx, session)
   879  	want := "error: err"
   880  	require.Error(t, err)
   881  	assert.Contains(t, err.Error(), want, "Commit")
   882  	assert.EqualValues(t, 1, sbc0.CreateTransactionCount.Get(), "sbc0.CreateTransactionCount")
   883  	assert.EqualValues(t, 1, sbc1.PrepareCount.Get(), "sbc1.PrepareCount")
   884  	assert.EqualValues(t, 1, sbc0.StartCommitCount.Get(), "sbc0.StartCommitCount")
   885  	assert.EqualValues(t, 1, sbc1.CommitPreparedCount.Get(), "sbc1.CommitPreparedCount")
   886  	assert.EqualValues(t, 0, sbc0.ConcludeTransactionCount.Get(), "sbc0.ConcludeTransactionCount")
   887  }
   888  
   889  func TestTxConnCommit2PCConcludeTransactionFail(t *testing.T) {
   890  	sc, sbc0, sbc1, rss0, _, rss01 := newTestTxConnEnv(t, "TestTxConnCommit2PCConcludeTransactionFail")
   891  
   892  	session := NewSafeSession(&vtgatepb.Session{InTransaction: true})
   893  	sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
   894  	sc.ExecuteMultiShard(ctx, nil, rss01, twoQueries, session, false, false)
   895  
   896  	sbc0.MustFailConcludeTransaction = 1
   897  	session.TransactionMode = vtgatepb.TransactionMode_TWOPC
   898  	err := sc.txConn.Commit(ctx, session)
   899  	want := "error: err"
   900  	require.Error(t, err)
   901  	assert.Contains(t, err.Error(), want, "Commit")
   902  	assert.EqualValues(t, 1, sbc0.CreateTransactionCount.Get(), "sbc0.CreateTransactionCount")
   903  	assert.EqualValues(t, 1, sbc1.PrepareCount.Get(), "sbc1.PrepareCount")
   904  	assert.EqualValues(t, 1, sbc0.StartCommitCount.Get(), "sbc0.StartCommitCount")
   905  	assert.EqualValues(t, 1, sbc1.CommitPreparedCount.Get(), "sbc1.CommitPreparedCount")
   906  	assert.EqualValues(t, 1, sbc0.ConcludeTransactionCount.Get(), "sbc0.ConcludeTransactionCount")
   907  }
   908  
   909  func TestTxConnRollback(t *testing.T) {
   910  	sc, sbc0, sbc1, rss0, _, rss01 := newTestTxConnEnv(t, "TxConnRollback")
   911  
   912  	session := NewSafeSession(&vtgatepb.Session{InTransaction: true})
   913  	sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
   914  	sc.ExecuteMultiShard(ctx, nil, rss01, twoQueries, session, false, false)
   915  	require.NoError(t,
   916  		sc.txConn.Rollback(ctx, session))
   917  	wantSession := vtgatepb.Session{}
   918  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   919  	assert.EqualValues(t, 1, sbc0.RollbackCount.Get(), "sbc0.RollbackCount")
   920  	assert.EqualValues(t, 1, sbc1.RollbackCount.Get(), "sbc1.RollbackCount")
   921  }
   922  
   923  func TestTxConnReservedRollback(t *testing.T) {
   924  	sc, sbc0, sbc1, rss0, _, rss01 := newTestTxConnEnv(t, "TxConnReservedRollback")
   925  
   926  	session := NewSafeSession(&vtgatepb.Session{InTransaction: true, InReservedConn: true})
   927  	sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
   928  	sc.ExecuteMultiShard(ctx, nil, rss01, twoQueries, session, false, false)
   929  	require.NoError(t,
   930  		sc.txConn.Rollback(ctx, session))
   931  	wantSession := vtgatepb.Session{
   932  		InReservedConn: true,
   933  		ShardSessions: []*vtgatepb.Session_ShardSession{{
   934  			Target: &querypb.Target{
   935  				Keyspace:   "TxConnReservedRollback",
   936  				Shard:      "0",
   937  				TabletType: topodatapb.TabletType_PRIMARY,
   938  			},
   939  			ReservedId:  2,
   940  			TabletAlias: sbc0.Tablet().Alias,
   941  		}, {
   942  			Target: &querypb.Target{
   943  				Keyspace:   "TxConnReservedRollback",
   944  				Shard:      "1",
   945  				TabletType: topodatapb.TabletType_PRIMARY,
   946  			},
   947  			ReservedId:  2,
   948  			TabletAlias: sbc1.Tablet().Alias,
   949  		}},
   950  	}
   951  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   952  	assert.EqualValues(t, 1, sbc0.RollbackCount.Get(), "sbc0.RollbackCount")
   953  	assert.EqualValues(t, 1, sbc1.RollbackCount.Get(), "sbc1.RollbackCount")
   954  	assert.EqualValues(t, 0, sbc0.ReleaseCount.Get(), "sbc0.ReleaseCount")
   955  	assert.EqualValues(t, 0, sbc1.ReleaseCount.Get(), "sbc1.ReleaseCount")
   956  }
   957  
   958  func TestTxConnReservedRollbackFailure(t *testing.T) {
   959  	sc, sbc0, sbc1, rss0, rss1, rss01 := newTestTxConnEnv(t, "TxConnReservedRollback")
   960  
   961  	session := NewSafeSession(&vtgatepb.Session{InTransaction: true, InReservedConn: true})
   962  	sc.ExecuteMultiShard(ctx, nil, rss0, queries, session, false, false)
   963  	sc.ExecuteMultiShard(ctx, nil, rss01, twoQueries, session, false, false)
   964  
   965  	sbc1.MustFailCodes[vtrpcpb.Code_INVALID_ARGUMENT] = 1
   966  	assert.Error(t,
   967  		sc.txConn.Rollback(ctx, session))
   968  
   969  	expectErr := NewShardError(vterrors.New(
   970  		vtrpcpb.Code_INVALID_ARGUMENT,
   971  		fmt.Sprintf("%v error", vtrpcpb.Code_INVALID_ARGUMENT)),
   972  		rss1[0].Target)
   973  
   974  	wantSession := vtgatepb.Session{
   975  		InReservedConn: true,
   976  		Warnings: []*querypb.QueryWarning{{
   977  			Message: fmt.Sprintf("rollback encountered an error and connection to all shard for this session is released: %v", expectErr),
   978  		}},
   979  	}
   980  	utils.MustMatch(t, &wantSession, session.Session, "Session")
   981  	assert.EqualValues(t, 1, sbc0.RollbackCount.Get(), "sbc0.RollbackCount")
   982  	assert.EqualValues(t, 1, sbc1.RollbackCount.Get(), "sbc1.RollbackCount")
   983  	assert.EqualValues(t, 1, sbc0.ReleaseCount.Get(), "sbc0.ReleaseCount")
   984  	assert.EqualValues(t, 1, sbc1.ReleaseCount.Get(), "sbc1.ReleaseCount")
   985  }
   986  
   987  func TestTxConnResolveOnPrepare(t *testing.T) {
   988  	sc, sbc0, sbc1, _, _, _ := newTestTxConnEnv(t, "TestTxConn")
   989  
   990  	dtid := "TestTxConn:0:1234"
   991  	sbc0.ReadTransactionResults = []*querypb.TransactionMetadata{{
   992  		Dtid:  dtid,
   993  		State: querypb.TransactionState_PREPARE,
   994  		Participants: []*querypb.Target{{
   995  			Keyspace:   "TestTxConn",
   996  			Shard:      "1",
   997  			TabletType: topodatapb.TabletType_PRIMARY,
   998  		}},
   999  	}}
  1000  	err := sc.txConn.Resolve(ctx, dtid)
  1001  	require.NoError(t, err)
  1002  	assert.EqualValues(t, 1, sbc0.SetRollbackCount.Get(), "sbc0.SetRollbackCount")
  1003  	assert.EqualValues(t, 1, sbc1.RollbackPreparedCount.Get(), "sbc1.RollbackPreparedCount")
  1004  	assert.EqualValues(t, 0, sbc1.CommitPreparedCount.Get(), "sbc1.CommitPreparedCount")
  1005  	assert.EqualValues(t, 1, sbc0.ConcludeTransactionCount.Get(), "sbc0.ConcludeTransactionCount")
  1006  }
  1007  
  1008  func TestTxConnResolveOnRollback(t *testing.T) {
  1009  	sc, sbc0, sbc1, _, _, _ := newTestTxConnEnv(t, "TestTxConn")
  1010  
  1011  	dtid := "TestTxConn:0:1234"
  1012  	sbc0.ReadTransactionResults = []*querypb.TransactionMetadata{{
  1013  		Dtid:  dtid,
  1014  		State: querypb.TransactionState_ROLLBACK,
  1015  		Participants: []*querypb.Target{{
  1016  			Keyspace:   "TestTxConn",
  1017  			Shard:      "1",
  1018  			TabletType: topodatapb.TabletType_PRIMARY,
  1019  		}},
  1020  	}}
  1021  	require.NoError(t,
  1022  		sc.txConn.Resolve(ctx, dtid))
  1023  	assert.EqualValues(t, 0, sbc0.SetRollbackCount.Get(), "sbc0.SetRollbackCount")
  1024  	assert.EqualValues(t, 1, sbc1.RollbackPreparedCount.Get(), "sbc1.RollbackPreparedCount")
  1025  	assert.EqualValues(t, 0, sbc1.CommitPreparedCount.Get(), "sbc1.CommitPreparedCount")
  1026  	assert.EqualValues(t, 1, sbc0.ConcludeTransactionCount.Get(), "sbc0.ConcludeTransactionCount")
  1027  }
  1028  
  1029  func TestTxConnResolveOnCommit(t *testing.T) {
  1030  	sc, sbc0, sbc1, _, _, _ := newTestTxConnEnv(t, "TestTxConn")
  1031  
  1032  	dtid := "TestTxConn:0:1234"
  1033  	sbc0.ReadTransactionResults = []*querypb.TransactionMetadata{{
  1034  		Dtid:  dtid,
  1035  		State: querypb.TransactionState_COMMIT,
  1036  		Participants: []*querypb.Target{{
  1037  			Keyspace:   "TestTxConn",
  1038  			Shard:      "1",
  1039  			TabletType: topodatapb.TabletType_PRIMARY,
  1040  		}},
  1041  	}}
  1042  	require.NoError(t,
  1043  		sc.txConn.Resolve(ctx, dtid))
  1044  	assert.EqualValues(t, 0, sbc0.SetRollbackCount.Get(), "sbc0.SetRollbackCount")
  1045  	assert.EqualValues(t, 0, sbc1.RollbackPreparedCount.Get(), "sbc1.RollbackPreparedCount")
  1046  	assert.EqualValues(t, 1, sbc1.CommitPreparedCount.Get(), "sbc1.CommitPreparedCount")
  1047  	assert.EqualValues(t, 1, sbc0.ConcludeTransactionCount.Get(), "sbc0.ConcludeTransactionCount")
  1048  }
  1049  
  1050  func TestTxConnResolveInvalidDTID(t *testing.T) {
  1051  	sc, _, _, _, _, _ := newTestTxConnEnv(t, "TestTxConn")
  1052  
  1053  	err := sc.txConn.Resolve(ctx, "abcd")
  1054  	want := "invalid parts in dtid: abcd"
  1055  	require.EqualError(t, err, want, "Resolve")
  1056  }
  1057  
  1058  func TestTxConnResolveReadTransactionFail(t *testing.T) {
  1059  	sc, sbc0, _, _, _, _ := newTestTxConnEnv(t, "TestTxConn")
  1060  
  1061  	dtid := "TestTxConn:0:1234"
  1062  	sbc0.MustFailCodes[vtrpcpb.Code_INVALID_ARGUMENT] = 1
  1063  	err := sc.txConn.Resolve(ctx, dtid)
  1064  	want := "INVALID_ARGUMENT error"
  1065  	require.Error(t, err)
  1066  	assert.Contains(t, err.Error(), want, "Resolve")
  1067  }
  1068  
  1069  func TestTxConnResolveInternalError(t *testing.T) {
  1070  	sc, sbc0, _, _, _, _ := newTestTxConnEnv(t, "TestTxConn")
  1071  
  1072  	dtid := "TestTxConn:0:1234"
  1073  	sbc0.ReadTransactionResults = []*querypb.TransactionMetadata{{
  1074  		Dtid:  dtid,
  1075  		State: querypb.TransactionState_UNKNOWN,
  1076  		Participants: []*querypb.Target{{
  1077  			Keyspace:   "TestTxConn",
  1078  			Shard:      "1",
  1079  			TabletType: topodatapb.TabletType_PRIMARY,
  1080  		}},
  1081  	}}
  1082  	err := sc.txConn.Resolve(ctx, dtid)
  1083  	want := "invalid state: UNKNOWN"
  1084  	require.Error(t, err)
  1085  	assert.Contains(t, err.Error(), want, "Resolve")
  1086  }
  1087  
  1088  func TestTxConnResolveSetRollbackFail(t *testing.T) {
  1089  	sc, sbc0, sbc1, _, _, _ := newTestTxConnEnv(t, "TestTxConn")
  1090  
  1091  	dtid := "TestTxConn:0:1234"
  1092  	sbc0.ReadTransactionResults = []*querypb.TransactionMetadata{{
  1093  		Dtid:  dtid,
  1094  		State: querypb.TransactionState_PREPARE,
  1095  		Participants: []*querypb.Target{{
  1096  			Keyspace:   "TestTxConn",
  1097  			Shard:      "1",
  1098  			TabletType: topodatapb.TabletType_PRIMARY,
  1099  		}},
  1100  	}}
  1101  	sbc0.MustFailSetRollback = 1
  1102  	err := sc.txConn.Resolve(ctx, dtid)
  1103  	want := "error: err"
  1104  	require.Error(t, err)
  1105  	assert.Contains(t, err.Error(), want, "Resolve")
  1106  	assert.EqualValues(t, 1, sbc0.SetRollbackCount.Get(), "sbc0.SetRollbackCount")
  1107  	assert.EqualValues(t, 0, sbc1.RollbackPreparedCount.Get(), "sbc1.RollbackPreparedCount")
  1108  	assert.EqualValues(t, 0, sbc1.CommitPreparedCount.Get(), "sbc1.CommitPreparedCount")
  1109  	assert.EqualValues(t, 0, sbc0.ConcludeTransactionCount.Get(), "sbc0.ConcludeTransactionCount")
  1110  }
  1111  
  1112  func TestTxConnResolveRollbackPreparedFail(t *testing.T) {
  1113  	sc, sbc0, sbc1, _, _, _ := newTestTxConnEnv(t, "TestTxConn")
  1114  
  1115  	dtid := "TestTxConn:0:1234"
  1116  	sbc0.ReadTransactionResults = []*querypb.TransactionMetadata{{
  1117  		Dtid:  dtid,
  1118  		State: querypb.TransactionState_ROLLBACK,
  1119  		Participants: []*querypb.Target{{
  1120  			Keyspace:   "TestTxConn",
  1121  			Shard:      "1",
  1122  			TabletType: topodatapb.TabletType_PRIMARY,
  1123  		}},
  1124  	}}
  1125  	sbc1.MustFailRollbackPrepared = 1
  1126  	err := sc.txConn.Resolve(ctx, dtid)
  1127  	want := "error: err"
  1128  	require.Error(t, err)
  1129  	assert.Contains(t, err.Error(), want, "Resolve")
  1130  	assert.EqualValues(t, 0, sbc0.SetRollbackCount.Get(), "sbc0.SetRollbackCount")
  1131  	assert.EqualValues(t, 1, sbc1.RollbackPreparedCount.Get(), "sbc1.RollbackPreparedCount")
  1132  	assert.EqualValues(t, 0, sbc1.CommitPreparedCount.Get(), "sbc1.CommitPreparedCount")
  1133  	assert.EqualValues(t, 0, sbc0.ConcludeTransactionCount.Get(), "sbc0.ConcludeTransactionCount")
  1134  }
  1135  
  1136  func TestTxConnResolveCommitPreparedFail(t *testing.T) {
  1137  	sc, sbc0, sbc1, _, _, _ := newTestTxConnEnv(t, "TestTxConn")
  1138  
  1139  	dtid := "TestTxConn:0:1234"
  1140  	sbc0.ReadTransactionResults = []*querypb.TransactionMetadata{{
  1141  		Dtid:  dtid,
  1142  		State: querypb.TransactionState_COMMIT,
  1143  		Participants: []*querypb.Target{{
  1144  			Keyspace:   "TestTxConn",
  1145  			Shard:      "1",
  1146  			TabletType: topodatapb.TabletType_PRIMARY,
  1147  		}},
  1148  	}}
  1149  	sbc1.MustFailCommitPrepared = 1
  1150  	err := sc.txConn.Resolve(ctx, dtid)
  1151  	want := "error: err"
  1152  	require.Error(t, err)
  1153  	assert.Contains(t, err.Error(), want, "Resolve")
  1154  	assert.EqualValues(t, 0, sbc0.SetRollbackCount.Get(), "sbc0.SetRollbackCount")
  1155  	assert.EqualValues(t, 0, sbc1.RollbackPreparedCount.Get(), "sbc1.RollbackPreparedCount")
  1156  	assert.EqualValues(t, 1, sbc1.CommitPreparedCount.Get(), "sbc1.CommitPreparedCount")
  1157  	assert.EqualValues(t, 0, sbc0.ConcludeTransactionCount.Get(), "sbc0.ConcludeTransactionCount")
  1158  }
  1159  
  1160  func TestTxConnResolveConcludeTransactionFail(t *testing.T) {
  1161  	sc, sbc0, sbc1, _, _, _ := newTestTxConnEnv(t, "TestTxConn")
  1162  
  1163  	dtid := "TestTxConn:0:1234"
  1164  	sbc0.ReadTransactionResults = []*querypb.TransactionMetadata{{
  1165  		Dtid:  dtid,
  1166  		State: querypb.TransactionState_COMMIT,
  1167  		Participants: []*querypb.Target{{
  1168  			Keyspace:   "TestTxConn",
  1169  			Shard:      "1",
  1170  			TabletType: topodatapb.TabletType_PRIMARY,
  1171  		}},
  1172  	}}
  1173  	sbc0.MustFailConcludeTransaction = 1
  1174  	err := sc.txConn.Resolve(ctx, dtid)
  1175  	want := "error: err"
  1176  	require.Error(t, err)
  1177  	assert.Contains(t, err.Error(), want, "Resolve")
  1178  	assert.EqualValues(t, 0, sbc0.SetRollbackCount.Get(), "sbc0.SetRollbackCount")
  1179  	assert.EqualValues(t, 0, sbc1.RollbackPreparedCount.Get(), "sbc1.RollbackPreparedCount")
  1180  	assert.EqualValues(t, 1, sbc1.CommitPreparedCount.Get(), "sbc1.CommitPreparedCount")
  1181  	assert.EqualValues(t, 1, sbc0.ConcludeTransactionCount.Get(), "sbc0.ConcludeTransactionCount")
  1182  }
  1183  
  1184  func TestTxConnMultiGoSessions(t *testing.T) {
  1185  	txc := &TxConn{}
  1186  
  1187  	input := []*vtgatepb.Session_ShardSession{{
  1188  		Target: &querypb.Target{
  1189  			Keyspace: "0",
  1190  		},
  1191  	}}
  1192  	err := txc.runSessions(ctx, input, nil, func(ctx context.Context, s *vtgatepb.Session_ShardSession, logger *executeLogger) error {
  1193  		return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "err %s", s.Target.Keyspace)
  1194  	})
  1195  	want := "err 0"
  1196  	require.EqualError(t, err, want, "runSessions(1)")
  1197  
  1198  	input = []*vtgatepb.Session_ShardSession{{
  1199  		Target: &querypb.Target{
  1200  			Keyspace: "0",
  1201  		},
  1202  	}, {
  1203  		Target: &querypb.Target{
  1204  			Keyspace: "1",
  1205  		},
  1206  	}}
  1207  	err = txc.runSessions(ctx, input, nil, func(ctx context.Context, s *vtgatepb.Session_ShardSession, logger *executeLogger) error {
  1208  		return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "err %s", s.Target.Keyspace)
  1209  	})
  1210  	want = "err 0\nerr 1"
  1211  	require.EqualError(t, err, want, "runSessions(2)")
  1212  	wantCode := vtrpcpb.Code_INTERNAL
  1213  	assert.Equal(t, wantCode, vterrors.Code(err), "error code")
  1214  
  1215  	err = txc.runSessions(ctx, input, nil, func(ctx context.Context, s *vtgatepb.Session_ShardSession, logger *executeLogger) error {
  1216  		return nil
  1217  	})
  1218  	require.NoError(t, err)
  1219  }
  1220  
  1221  func TestTxConnMultiGoTargets(t *testing.T) {
  1222  	txc := &TxConn{}
  1223  	input := []*querypb.Target{{
  1224  		Keyspace: "0",
  1225  	}}
  1226  	err := txc.runTargets(input, func(t *querypb.Target) error {
  1227  		return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "err %s", t.Keyspace)
  1228  	})
  1229  	want := "err 0"
  1230  	require.EqualError(t, err, want, "runTargets(1)")
  1231  
  1232  	input = []*querypb.Target{{
  1233  		Keyspace: "0",
  1234  	}, {
  1235  		Keyspace: "1",
  1236  	}}
  1237  	err = txc.runTargets(input, func(t *querypb.Target) error {
  1238  		return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "err %s", t.Keyspace)
  1239  	})
  1240  	want = "err 0\nerr 1"
  1241  	require.EqualError(t, err, want, "runTargets(2)")
  1242  	wantCode := vtrpcpb.Code_INTERNAL
  1243  	assert.Equal(t, wantCode, vterrors.Code(err), "error code")
  1244  
  1245  	err = txc.runTargets(input, func(t *querypb.Target) error {
  1246  		return nil
  1247  	})
  1248  	require.NoError(t, err)
  1249  }
  1250  
  1251  func TestTxConnAccessModeReset(t *testing.T) {
  1252  	sc, _, _, _, _, _ := newTestTxConnEnv(t, "TestTxConn")
  1253  
  1254  	tcases := []struct {
  1255  		name string
  1256  		f    func(ctx context.Context, session *SafeSession) error
  1257  	}{{
  1258  		name: "begin-commit",
  1259  		f:    sc.txConn.Commit,
  1260  	}, {
  1261  		name: "begin-rollback",
  1262  		f:    sc.txConn.Rollback,
  1263  	}, {
  1264  		name: "begin-release",
  1265  		f:    sc.txConn.Release,
  1266  	}, {
  1267  		name: "begin-releaseAll",
  1268  		f:    sc.txConn.ReleaseAll,
  1269  	}}
  1270  
  1271  	for _, tcase := range tcases {
  1272  		t.Run(tcase.name, func(t *testing.T) {
  1273  			safeSession := NewSafeSession(&vtgatepb.Session{
  1274  				Options: &querypb.ExecuteOptions{
  1275  					TransactionAccessMode: []querypb.ExecuteOptions_TransactionAccessMode{querypb.ExecuteOptions_READ_ONLY},
  1276  				},
  1277  			})
  1278  
  1279  			// begin transaction
  1280  			require.NoError(t,
  1281  				sc.txConn.Begin(ctx, safeSession, nil))
  1282  
  1283  			// resolve transaction
  1284  			require.NoError(t,
  1285  				tcase.f(ctx, safeSession))
  1286  
  1287  			// check that the access mode is reset
  1288  			require.Nil(t, safeSession.Session.Options.TransactionAccessMode)
  1289  		})
  1290  	}
  1291  }
  1292  
  1293  func newTestTxConnEnv(t *testing.T, name string) (sc *ScatterConn, sbc0, sbc1 *sandboxconn.SandboxConn, rss0, rss1, rss01 []*srvtopo.ResolvedShard) {
  1294  	t.Helper()
  1295  	createSandbox(name)
  1296  	hc := discovery.NewFakeHealthCheck(nil)
  1297  	sc = newTestScatterConn(hc, newSandboxForCells([]string{"aa"}), "aa")
  1298  	sbc0 = hc.AddTestTablet("aa", "0", 1, name, "0", topodatapb.TabletType_PRIMARY, true, 1, nil)
  1299  	sbc1 = hc.AddTestTablet("aa", "1", 1, name, "1", topodatapb.TabletType_PRIMARY, true, 1, nil)
  1300  	res := srvtopo.NewResolver(newSandboxForCells([]string{"aa"}), sc.gateway, "aa")
  1301  	var err error
  1302  	rss0, err = res.ResolveDestination(ctx, name, topodatapb.TabletType_PRIMARY, key.DestinationShard("0"))
  1303  	require.NoError(t, err)
  1304  	rss1, err = res.ResolveDestination(ctx, name, topodatapb.TabletType_PRIMARY, key.DestinationShard("1"))
  1305  	require.NoError(t, err)
  1306  	rss01, err = res.ResolveDestination(ctx, name, topodatapb.TabletType_PRIMARY, key.DestinationShards([]string{"0", "1"}))
  1307  	require.NoError(t, err)
  1308  	return sc, sbc0, sbc1, rss0, rss1, rss01
  1309  }