vitess.io/vitess@v0.16.2/go/vt/vtgate/vtgate_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package vtgate
    18  
    19  import (
    20  	"context"
    21  	"strings"
    22  	"testing"
    23  
    24  	"github.com/stretchr/testify/assert"
    25  	"google.golang.org/protobuf/proto"
    26  
    27  	"vitess.io/vitess/go/test/utils"
    28  
    29  	"github.com/stretchr/testify/require"
    30  
    31  	"vitess.io/vitess/go/sqltypes"
    32  	"vitess.io/vitess/go/vt/discovery"
    33  	"vitess.io/vitess/go/vt/vterrors"
    34  	"vitess.io/vitess/go/vt/vttablet/sandboxconn"
    35  
    36  	querypb "vitess.io/vitess/go/vt/proto/query"
    37  	topodatapb "vitess.io/vitess/go/vt/proto/topodata"
    38  	vtgatepb "vitess.io/vitess/go/vt/proto/vtgate"
    39  	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
    40  )
    41  
    42  // This file uses the sandbox_test framework.
    43  
    44  var hcVTGateTest *discovery.FakeHealthCheck
    45  
    46  var executeOptions = &querypb.ExecuteOptions{
    47  	IncludedFields: querypb.ExecuteOptions_TYPE_ONLY,
    48  }
    49  
    50  var primarySession *vtgatepb.Session
    51  
    52  func init() {
    53  	createSandbox(KsTestUnsharded).VSchema = `
    54  {
    55  	"sharded": false,
    56  	"tables": {
    57  		"t1": {}
    58  	}
    59  }
    60  `
    61  	createSandbox(KsTestBadVSchema).VSchema = `
    62  {
    63  	"sharded": true,
    64  	"tables": {
    65  		"t2": {
    66  			"auto_increment": {
    67  				"column": "id",
    68  				"sequence": "id_seq"
    69  			}
    70  		}
    71  	}
    72  }
    73  `
    74  	hcVTGateTest = discovery.NewFakeHealthCheck(nil)
    75  	transactionMode = "MULTI"
    76  	Init(context.Background(), hcVTGateTest, newSandboxForCells([]string{"aa"}), "aa", nil, querypb.ExecuteOptions_Gen4)
    77  
    78  	mysqlServerPort = 0
    79  	mysqlAuthServerImpl = "none"
    80  	initMySQLProtocol()
    81  }
    82  
    83  func TestVTGateExecute(t *testing.T) {
    84  	createSandbox(KsTestUnsharded)
    85  	hcVTGateTest.Reset()
    86  	sbc := hcVTGateTest.AddTestTablet("aa", "1.1.1.1", 1001, KsTestUnsharded, "0", topodatapb.TabletType_PRIMARY, true, 1, nil)
    87  	_, qr, err := rpcVTGate.Execute(
    88  		context.Background(),
    89  		&vtgatepb.Session{
    90  			Autocommit:   true,
    91  			TargetString: "@primary",
    92  			Options:      executeOptions,
    93  		},
    94  		"select id from t1",
    95  		nil,
    96  	)
    97  	if err != nil {
    98  		t.Errorf("want nil, got %v", err)
    99  	}
   100  	want := *sandboxconn.SingleRowResult
   101  	want.StatusFlags = 0 // VTGate result set does not contain status flags in sqltypes.Result
   102  	utils.MustMatch(t, &want, qr)
   103  	if !proto.Equal(sbc.Options[0], executeOptions) {
   104  		t.Errorf("got ExecuteOptions \n%+v, want \n%+v", sbc.Options[0], executeOptions)
   105  	}
   106  }
   107  
   108  func TestVTGateExecuteWithKeyspaceShard(t *testing.T) {
   109  	createSandbox(KsTestUnsharded)
   110  	hcVTGateTest.Reset()
   111  	hcVTGateTest.AddTestTablet("aa", "1.1.1.1", 1001, KsTestUnsharded, "0", topodatapb.TabletType_PRIMARY, true, 1, nil)
   112  
   113  	// Valid keyspace.
   114  	_, qr, err := rpcVTGate.Execute(
   115  		context.Background(),
   116  		&vtgatepb.Session{
   117  			TargetString: KsTestUnsharded,
   118  		},
   119  		"select id from none",
   120  		nil,
   121  	)
   122  	if err != nil {
   123  		t.Errorf("want nil, got %v", err)
   124  	}
   125  	wantQr := *sandboxconn.SingleRowResult
   126  	wantQr.StatusFlags = 0 // VTGate result set does not contain status flags in sqltypes.Result
   127  	utils.MustMatch(t, &wantQr, qr)
   128  
   129  	// Invalid keyspace.
   130  	_, _, err = rpcVTGate.Execute(
   131  		context.Background(),
   132  		&vtgatepb.Session{
   133  			TargetString: "invalid_keyspace",
   134  		},
   135  		"select id from none",
   136  		nil,
   137  	)
   138  	want := "VT05003: unknown database 'invalid_keyspace' in vschema"
   139  	assert.EqualError(t, err, want)
   140  
   141  	// Valid keyspace/shard.
   142  	_, qr, err = rpcVTGate.Execute(
   143  		context.Background(),
   144  		&vtgatepb.Session{
   145  			TargetString: KsTestUnsharded + ":0@primary",
   146  		},
   147  		"select id from none",
   148  		nil,
   149  	)
   150  	if err != nil {
   151  		t.Errorf("want nil, got %v", err)
   152  	}
   153  	utils.MustMatch(t, &wantQr, qr)
   154  
   155  	// Invalid keyspace/shard.
   156  	_, _, err = rpcVTGate.Execute(
   157  		context.Background(),
   158  		&vtgatepb.Session{
   159  			TargetString: KsTestUnsharded + ":noshard@primary",
   160  		},
   161  		"select id from none",
   162  		nil,
   163  	)
   164  	require.Error(t, err)
   165  	require.Contains(t, err.Error(), `no healthy tablet available for 'keyspace:"TestUnsharded" shard:"noshard" tablet_type:PRIMARY`)
   166  }
   167  
   168  func TestVTGateStreamExecute(t *testing.T) {
   169  	ks := KsTestUnsharded
   170  	shard := "0"
   171  	createSandbox(ks)
   172  	hcVTGateTest.Reset()
   173  	sbc := hcVTGateTest.AddTestTablet("aa", "1.1.1.1", 1001, ks, shard, topodatapb.TabletType_PRIMARY, true, 1, nil)
   174  	var qrs []*sqltypes.Result
   175  	err := rpcVTGate.StreamExecute(
   176  		context.Background(),
   177  		&vtgatepb.Session{
   178  			TargetString: "@primary",
   179  			Options:      executeOptions,
   180  		},
   181  		"select id from t1",
   182  		nil,
   183  		func(r *sqltypes.Result) error {
   184  			qrs = append(qrs, r)
   185  			return nil
   186  		},
   187  	)
   188  	require.NoError(t, err)
   189  	want := []*sqltypes.Result{{
   190  		Fields: sandboxconn.StreamRowResult.Fields,
   191  	}, {
   192  		Rows: sandboxconn.StreamRowResult.Rows,
   193  	}}
   194  	utils.MustMatch(t, want, qrs)
   195  	if !proto.Equal(sbc.Options[0], executeOptions) {
   196  		t.Errorf("got ExecuteOptions \n%+v, want \n%+v", sbc.Options[0], executeOptions)
   197  	}
   198  }
   199  
   200  func TestVTGateBindVarError(t *testing.T) {
   201  	ks := KsTestUnsharded
   202  	createSandbox(ks)
   203  	hcVTGateTest.Reset()
   204  	ctx := context.Background()
   205  	session := &vtgatepb.Session{}
   206  	bindVars := map[string]*querypb.BindVariable{
   207  		"v": {
   208  			Type:  querypb.Type_EXPRESSION,
   209  			Value: []byte("1"),
   210  		},
   211  	}
   212  	want := "v: invalid type specified for MakeValue: EXPRESSION"
   213  
   214  	tcases := []struct {
   215  		name string
   216  		f    func() error
   217  	}{{
   218  		name: "Execute",
   219  		f: func() error {
   220  			_, _, err := rpcVTGate.Execute(ctx, session, "", bindVars)
   221  			return err
   222  		},
   223  	}, {
   224  		name: "ExecuteBatch",
   225  		f: func() error {
   226  			_, _, err := rpcVTGate.ExecuteBatch(ctx, session, []string{""}, []map[string]*querypb.BindVariable{bindVars})
   227  			return err
   228  		},
   229  	}, {
   230  		name: "StreamExecute",
   231  		f: func() error {
   232  			return rpcVTGate.StreamExecute(ctx, session, "", bindVars, func(_ *sqltypes.Result) error { return nil })
   233  		},
   234  	}}
   235  	for _, tcase := range tcases {
   236  		if err := tcase.f(); err == nil || !strings.Contains(err.Error(), want) {
   237  			t.Errorf("%v error: %v, must contain %s", tcase.name, err, want)
   238  		}
   239  	}
   240  }
   241  
   242  func testErrorPropagation(t *testing.T, sbcs []*sandboxconn.SandboxConn, before func(sbc *sandboxconn.SandboxConn), after func(sbc *sandboxconn.SandboxConn), expected vtrpcpb.Code) {
   243  
   244  	// Execute
   245  	for _, sbc := range sbcs {
   246  		before(sbc)
   247  	}
   248  	_, _, err := rpcVTGate.Execute(
   249  		context.Background(),
   250  		primarySession,
   251  		"select id from t1",
   252  		nil,
   253  	)
   254  	if err == nil {
   255  		t.Errorf("error %v not propagated for Execute", expected)
   256  	} else {
   257  		ec := vterrors.Code(err)
   258  		if ec != expected {
   259  			t.Errorf("unexpected error, got code %v err %v, want %v", ec, err, expected)
   260  		}
   261  	}
   262  	for _, sbc := range sbcs {
   263  		after(sbc)
   264  	}
   265  
   266  	// StreamExecute
   267  	for _, sbc := range sbcs {
   268  		before(sbc)
   269  	}
   270  	err = rpcVTGate.StreamExecute(
   271  		context.Background(),
   272  		primarySession,
   273  		"select id from t1",
   274  		nil,
   275  		func(r *sqltypes.Result) error {
   276  			return nil
   277  		},
   278  	)
   279  	if err == nil {
   280  		t.Errorf("error %v not propagated for StreamExecute", expected)
   281  	} else {
   282  		ec := vterrors.Code(err)
   283  		if ec != expected {
   284  			t.Errorf("unexpected error, got %v want %v: %v", ec, expected, err)
   285  		}
   286  	}
   287  	for _, sbc := range sbcs {
   288  		after(sbc)
   289  	}
   290  }
   291  
   292  // TestErrorPropagation tests an error returned by sandboxconn is
   293  // properly propagated through vtgate layers.  We need both a primary
   294  // tablet and a rdonly tablet because we don't control the routing of
   295  // Commit.
   296  func TestErrorPropagation(t *testing.T) {
   297  	createSandbox(KsTestUnsharded)
   298  	hcVTGateTest.Reset()
   299  	// create a new session each time so that ShardSessions don't get re-used across tests
   300  	primarySession = &vtgatepb.Session{
   301  		TargetString: "@primary",
   302  	}
   303  
   304  	sbcm := hcVTGateTest.AddTestTablet("aa", "1.1.1.1", 1001, KsTestUnsharded, "0", topodatapb.TabletType_PRIMARY, true, 1, nil)
   305  	sbcrdonly := hcVTGateTest.AddTestTablet("aa", "1.1.1.2", 1001, KsTestUnsharded, "0", topodatapb.TabletType_RDONLY, true, 1, nil)
   306  	sbcs := []*sandboxconn.SandboxConn{
   307  		sbcm,
   308  		sbcrdonly,
   309  	}
   310  
   311  	testErrorPropagation(t, sbcs, func(sbc *sandboxconn.SandboxConn) {
   312  		sbc.MustFailCodes[vtrpcpb.Code_CANCELED] = 20
   313  	}, func(sbc *sandboxconn.SandboxConn) {
   314  		sbc.MustFailCodes[vtrpcpb.Code_CANCELED] = 0
   315  	}, vtrpcpb.Code_CANCELED)
   316  
   317  	testErrorPropagation(t, sbcs, func(sbc *sandboxconn.SandboxConn) {
   318  		sbc.MustFailCodes[vtrpcpb.Code_UNKNOWN] = 20
   319  	}, func(sbc *sandboxconn.SandboxConn) {
   320  		sbc.MustFailCodes[vtrpcpb.Code_UNKNOWN] = 0
   321  	}, vtrpcpb.Code_UNKNOWN)
   322  
   323  	testErrorPropagation(t, sbcs, func(sbc *sandboxconn.SandboxConn) {
   324  		sbc.MustFailCodes[vtrpcpb.Code_INVALID_ARGUMENT] = 20
   325  	}, func(sbc *sandboxconn.SandboxConn) {
   326  		sbc.MustFailCodes[vtrpcpb.Code_INVALID_ARGUMENT] = 0
   327  	}, vtrpcpb.Code_INVALID_ARGUMENT)
   328  
   329  	testErrorPropagation(t, sbcs, func(sbc *sandboxconn.SandboxConn) {
   330  		sbc.MustFailCodes[vtrpcpb.Code_DEADLINE_EXCEEDED] = 20
   331  	}, func(sbc *sandboxconn.SandboxConn) {
   332  		sbc.MustFailCodes[vtrpcpb.Code_DEADLINE_EXCEEDED] = 0
   333  	}, vtrpcpb.Code_DEADLINE_EXCEEDED)
   334  
   335  	testErrorPropagation(t, sbcs, func(sbc *sandboxconn.SandboxConn) {
   336  		sbc.MustFailCodes[vtrpcpb.Code_ALREADY_EXISTS] = 20
   337  	}, func(sbc *sandboxconn.SandboxConn) {
   338  		sbc.MustFailCodes[vtrpcpb.Code_ALREADY_EXISTS] = 0
   339  	}, vtrpcpb.Code_ALREADY_EXISTS)
   340  
   341  	testErrorPropagation(t, sbcs, func(sbc *sandboxconn.SandboxConn) {
   342  		sbc.MustFailCodes[vtrpcpb.Code_PERMISSION_DENIED] = 20
   343  	}, func(sbc *sandboxconn.SandboxConn) {
   344  		sbc.MustFailCodes[vtrpcpb.Code_PERMISSION_DENIED] = 0
   345  	}, vtrpcpb.Code_PERMISSION_DENIED)
   346  
   347  	testErrorPropagation(t, sbcs, func(sbc *sandboxconn.SandboxConn) {
   348  		sbc.MustFailCodes[vtrpcpb.Code_RESOURCE_EXHAUSTED] = 20
   349  	}, func(sbc *sandboxconn.SandboxConn) {
   350  		sbc.MustFailCodes[vtrpcpb.Code_RESOURCE_EXHAUSTED] = 0
   351  	}, vtrpcpb.Code_RESOURCE_EXHAUSTED)
   352  
   353  	testErrorPropagation(t, sbcs, func(sbc *sandboxconn.SandboxConn) {
   354  		sbc.MustFailCodes[vtrpcpb.Code_FAILED_PRECONDITION] = 20
   355  	}, func(sbc *sandboxconn.SandboxConn) {
   356  		sbc.MustFailCodes[vtrpcpb.Code_FAILED_PRECONDITION] = 0
   357  	}, vtrpcpb.Code_FAILED_PRECONDITION)
   358  
   359  	testErrorPropagation(t, sbcs, func(sbc *sandboxconn.SandboxConn) {
   360  		sbc.MustFailCodes[vtrpcpb.Code_ABORTED] = 20
   361  	}, func(sbc *sandboxconn.SandboxConn) {
   362  		sbc.MustFailCodes[vtrpcpb.Code_ABORTED] = 0
   363  	}, vtrpcpb.Code_ABORTED)
   364  
   365  	testErrorPropagation(t, sbcs, func(sbc *sandboxconn.SandboxConn) {
   366  		sbc.MustFailCodes[vtrpcpb.Code_INTERNAL] = 20
   367  	}, func(sbc *sandboxconn.SandboxConn) {
   368  		sbc.MustFailCodes[vtrpcpb.Code_INTERNAL] = 0
   369  	}, vtrpcpb.Code_INTERNAL)
   370  
   371  	testErrorPropagation(t, sbcs, func(sbc *sandboxconn.SandboxConn) {
   372  		sbc.MustFailCodes[vtrpcpb.Code_UNAVAILABLE] = 20
   373  	}, func(sbc *sandboxconn.SandboxConn) {
   374  		sbc.MustFailCodes[vtrpcpb.Code_UNAVAILABLE] = 0
   375  	}, vtrpcpb.Code_UNAVAILABLE)
   376  
   377  	testErrorPropagation(t, sbcs, func(sbc *sandboxconn.SandboxConn) {
   378  		sbc.MustFailCodes[vtrpcpb.Code_UNAUTHENTICATED] = 20
   379  	}, func(sbc *sandboxconn.SandboxConn) {
   380  		sbc.MustFailCodes[vtrpcpb.Code_UNAUTHENTICATED] = 0
   381  	}, vtrpcpb.Code_UNAUTHENTICATED)
   382  }
   383  
   384  // This test makes sure that if we start a transaction and hit a critical
   385  // error, a rollback is issued.
   386  func TestErrorIssuesRollback(t *testing.T) {
   387  	createSandbox(KsTestUnsharded)
   388  	hcVTGateTest.Reset()
   389  	sbc := hcVTGateTest.AddTestTablet("aa", "1.1.1.1", 1001, KsTestUnsharded, "0", topodatapb.TabletType_PRIMARY, true, 1, nil)
   390  
   391  	// Start a transaction, send one statement.
   392  	// Simulate an error that should trigger a rollback:
   393  	// vtrpcpb.Code_ABORTED case.
   394  	session, _, err := rpcVTGate.Execute(
   395  		context.Background(),
   396  		&vtgatepb.Session{},
   397  		"begin",
   398  		nil,
   399  	)
   400  	if err != nil {
   401  		t.Fatalf("cannot start a transaction: %v", err)
   402  	}
   403  	session, _, err = rpcVTGate.Execute(
   404  		context.Background(),
   405  		session,
   406  		"select id from t1",
   407  		nil,
   408  	)
   409  	if err != nil {
   410  		t.Fatalf("want nil, got %v", err)
   411  	}
   412  	if sbc.RollbackCount.Get() != 0 {
   413  		t.Errorf("want 0, got %d", sbc.RollbackCount.Get())
   414  	}
   415  	sbc.MustFailCodes[vtrpcpb.Code_ABORTED] = 20
   416  	_, _, err = rpcVTGate.Execute(
   417  		context.Background(),
   418  		session,
   419  		"select id from t1",
   420  		nil,
   421  	)
   422  	if err == nil {
   423  		t.Fatalf("want error but got nil")
   424  	}
   425  	if sbc.RollbackCount.Get() != 1 {
   426  		t.Errorf("want 1, got %d", sbc.RollbackCount.Get())
   427  	}
   428  	sbc.RollbackCount.Set(0)
   429  	sbc.MustFailCodes[vtrpcpb.Code_ABORTED] = 0
   430  
   431  	// Start a transaction, send one statement.
   432  	// Simulate an error that should trigger a rollback:
   433  	// vtrpcpb.ErrorCode_RESOURCE_EXHAUSTED case.
   434  	session, _, err = rpcVTGate.Execute(
   435  		context.Background(),
   436  		&vtgatepb.Session{},
   437  		"begin",
   438  		nil,
   439  	)
   440  	if err != nil {
   441  		t.Fatalf("cannot start a transaction: %v", err)
   442  	}
   443  	session, _, err = rpcVTGate.Execute(
   444  		context.Background(),
   445  		session,
   446  		"select id from t1",
   447  		nil,
   448  	)
   449  	if err != nil {
   450  		t.Fatalf("want nil, got %v", err)
   451  	}
   452  	if sbc.RollbackCount.Get() != 0 {
   453  		t.Errorf("want 0, got %d", sbc.RollbackCount.Get())
   454  	}
   455  	sbc.MustFailCodes[vtrpcpb.Code_RESOURCE_EXHAUSTED] = 20
   456  	_, _, err = rpcVTGate.Execute(
   457  		context.Background(),
   458  		session,
   459  		"select id from t1",
   460  		nil,
   461  	)
   462  	if err == nil {
   463  		t.Fatalf("want error but got nil")
   464  	}
   465  	if sbc.RollbackCount.Get() != 1 {
   466  		t.Errorf("want 1, got %d", sbc.RollbackCount.Get())
   467  	}
   468  	sbc.RollbackCount.Set(0)
   469  	sbc.MustFailCodes[vtrpcpb.Code_RESOURCE_EXHAUSTED] = 0
   470  
   471  	// Start a transaction, send one statement.
   472  	// Simulate an error that should *not* trigger a rollback:
   473  	// vtrpcpb.Code_ALREADY_EXISTS case.
   474  	session, _, err = rpcVTGate.Execute(
   475  		context.Background(),
   476  		&vtgatepb.Session{},
   477  		"begin",
   478  		nil,
   479  	)
   480  	if err != nil {
   481  		t.Fatalf("cannot start a transaction: %v", err)
   482  	}
   483  	session, _, err = rpcVTGate.Execute(
   484  		context.Background(),
   485  		session,
   486  		"select id from t1",
   487  		nil,
   488  	)
   489  	if err != nil {
   490  		t.Fatalf("want nil, got %v", err)
   491  	}
   492  	if sbc.RollbackCount.Get() != 0 {
   493  		t.Errorf("want 0, got %d", sbc.RollbackCount.Get())
   494  	}
   495  	sbc.MustFailCodes[vtrpcpb.Code_ALREADY_EXISTS] = 20
   496  	_, _, err = rpcVTGate.Execute(
   497  		context.Background(),
   498  		session,
   499  		"select id from t1",
   500  		nil,
   501  	)
   502  	if err == nil {
   503  		t.Fatalf("want error but got nil")
   504  	}
   505  	if sbc.RollbackCount.Get() != 0 {
   506  		t.Errorf("want 0, got %d", sbc.RollbackCount.Get())
   507  	}
   508  	sbc.MustFailCodes[vtrpcpb.Code_ALREADY_EXISTS] = 0
   509  }
   510  
   511  var shardedVSchema = `
   512  {
   513  	"sharded": true,
   514  	"vindexes": {
   515  		"hash_index": {
   516  			"type": "hash"
   517  		}
   518  	},
   519  	"tables": {
   520  		"sp_tbl": {
   521  			"column_vindexes": [
   522  				{
   523  					"column": "user_id",
   524  					"name": "hash_index"
   525  				}
   526  			]
   527  		}
   528  	}
   529  }
   530  `
   531  
   532  func TestMultiInternalSavepointVtGate(t *testing.T) {
   533  	s := createSandbox(KsTestSharded)
   534  	s.ShardSpec = "-40-80-"
   535  	s.VSchema = shardedVSchema
   536  	srvSchema := getSandboxSrvVSchema()
   537  	rpcVTGate.executor.vm.VSchemaUpdate(srvSchema, nil)
   538  	hcVTGateTest.Reset()
   539  
   540  	sbc1 := hcVTGateTest.AddTestTablet("aa", "1.1.1.1", 1, KsTestSharded, "-40", topodatapb.TabletType_PRIMARY, true, 1, nil)
   541  	sbc2 := hcVTGateTest.AddTestTablet("aa", "1.1.1.1", 2, KsTestSharded, "40-80", topodatapb.TabletType_PRIMARY, true, 1, nil)
   542  	sbc3 := hcVTGateTest.AddTestTablet("aa", "1.1.1.1", 3, KsTestSharded, "80-", topodatapb.TabletType_PRIMARY, true, 1, nil)
   543  
   544  	logChan := QueryLogger.Subscribe("Test")
   545  	defer QueryLogger.Unsubscribe(logChan)
   546  
   547  	session := &vtgatepb.Session{Autocommit: true}
   548  	require.True(t, session.GetAutocommit())
   549  	require.False(t, session.InTransaction)
   550  
   551  	var err error
   552  	session, _, err = rpcVTGate.Execute(context.Background(), session, "begin", nil)
   553  	require.NoError(t, err)
   554  	require.True(t, session.GetAutocommit())
   555  	require.True(t, session.InTransaction)
   556  
   557  	// this query goes to multiple shards so internal savepoint will be created.
   558  	session, _, err = rpcVTGate.Execute(context.Background(), session, "insert into sp_tbl(user_id) values (1), (3)", nil)
   559  	require.NoError(t, err)
   560  	require.True(t, session.GetAutocommit())
   561  	require.True(t, session.InTransaction)
   562  
   563  	wantQ := []*querypb.BoundQuery{{
   564  		Sql:           "savepoint x",
   565  		BindVariables: map[string]*querypb.BindVariable{},
   566  	}, {
   567  		Sql: "insert into sp_tbl(user_id) values (:_user_id_0)",
   568  		BindVariables: map[string]*querypb.BindVariable{
   569  			"_user_id_0": sqltypes.Int64BindVariable(1),
   570  			"_user_id_1": sqltypes.Int64BindVariable(3),
   571  			"vtg1":       sqltypes.Int64BindVariable(1),
   572  			"vtg2":       sqltypes.Int64BindVariable(3),
   573  		},
   574  	}}
   575  	assertQueriesWithSavepoint(t, sbc1, wantQ)
   576  	wantQ[1].Sql = "insert into sp_tbl(user_id) values (:_user_id_1)"
   577  	assertQueriesWithSavepoint(t, sbc2, wantQ)
   578  	assert.Len(t, sbc3.Queries, 0)
   579  	// internal savepoint should be removed.
   580  	assert.Len(t, session.Savepoints, 0)
   581  	sbc1.Queries = nil
   582  	sbc2.Queries = nil
   583  
   584  	// multi shard so new savepoint will be created.
   585  	session, _, err = rpcVTGate.Execute(context.Background(), session, "insert into sp_tbl(user_id) values (2), (4)", nil)
   586  	require.NoError(t, err)
   587  	wantQ = []*querypb.BoundQuery{{
   588  		Sql:           "savepoint x",
   589  		BindVariables: map[string]*querypb.BindVariable{},
   590  	}, {
   591  		Sql: "insert into sp_tbl(user_id) values (:_user_id_1)",
   592  		BindVariables: map[string]*querypb.BindVariable{
   593  			"_user_id_0": sqltypes.Int64BindVariable(2),
   594  			"_user_id_1": sqltypes.Int64BindVariable(4),
   595  			"vtg1":       sqltypes.Int64BindVariable(2),
   596  			"vtg2":       sqltypes.Int64BindVariable(4),
   597  		},
   598  	}}
   599  	assertQueriesWithSavepoint(t, sbc3, wantQ)
   600  	// internal savepoint should be removed.
   601  	assert.Len(t, session.Savepoints, 0)
   602  	sbc2.Queries = nil
   603  	sbc3.Queries = nil
   604  
   605  	// single shard so no savepoint will be created and neither any old savepoint will be executed
   606  	_, _, err = rpcVTGate.Execute(context.Background(), session, "insert into sp_tbl(user_id) values (5)", nil)
   607  	require.NoError(t, err)
   608  	wantQ = []*querypb.BoundQuery{{
   609  		Sql: "insert into sp_tbl(user_id) values (:_user_id_0)",
   610  		BindVariables: map[string]*querypb.BindVariable{
   611  			"_user_id_0": sqltypes.Int64BindVariable(5),
   612  			"vtg1":       sqltypes.Int64BindVariable(5),
   613  		},
   614  	}}
   615  	assertQueriesWithSavepoint(t, sbc2, wantQ)
   616  
   617  	testQueryLog(t, logChan, "Execute", "BEGIN", "begin", 0)
   618  	testQueryLog(t, logChan, "MarkSavepoint", "SAVEPOINT", "savepoint x", 0)
   619  	testQueryLog(t, logChan, "Execute", "INSERT", "insert into sp_tbl(user_id) values (:vtg1), (:vtg2)", 2)
   620  	testQueryLog(t, logChan, "MarkSavepoint", "SAVEPOINT", "savepoint y", 2)
   621  	testQueryLog(t, logChan, "Execute", "INSERT", "insert into sp_tbl(user_id) values (:vtg1), (:vtg2)", 2)
   622  	testQueryLog(t, logChan, "Execute", "INSERT", "insert into sp_tbl(user_id) values (:vtg1)", 1)
   623  }