vitess.io/vitess@v0.16.2/go/vt/wrangler/traffic_switcher_env_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package wrangler
    18  
    19  import (
    20  	"fmt"
    21  	"sync"
    22  	"testing"
    23  	"time"
    24  
    25  	"vitess.io/vitess/go/sync2"
    26  	"vitess.io/vitess/go/vt/log"
    27  
    28  	"vitess.io/vitess/go/mysql/fakesqldb"
    29  
    30  	"context"
    31  
    32  	"vitess.io/vitess/go/mysql"
    33  	"vitess.io/vitess/go/sqltypes"
    34  	"vitess.io/vitess/go/vt/binlog/binlogplayer"
    35  	"vitess.io/vitess/go/vt/key"
    36  	"vitess.io/vitess/go/vt/logutil"
    37  	binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata"
    38  	topodatapb "vitess.io/vitess/go/vt/proto/topodata"
    39  	"vitess.io/vitess/go/vt/proto/vschema"
    40  	vschemapb "vitess.io/vitess/go/vt/proto/vschema"
    41  	"vitess.io/vitess/go/vt/topo"
    42  	"vitess.io/vitess/go/vt/topo/memorytopo"
    43  	"vitess.io/vitess/go/vt/topotools"
    44  	"vitess.io/vitess/go/vt/vttablet/tabletmanager/vreplication"
    45  	"vitess.io/vitess/go/vt/vttablet/tmclient"
    46  )
    47  
    48  const (
    49  	streamInfoQuery    = "select id, source, message, cell, tablet_types, workflow_type, workflow_sub_type, defer_secondary_keys from _vt.vreplication where workflow='%s' and db_name='vt_%s'"
    50  	streamExtInfoQuery = "select id, source, pos, stop_pos, max_replication_lag, state, db_name, time_updated, transaction_timestamp, time_heartbeat, time_throttled, component_throttled, message, tags, workflow_type, workflow_sub_type, defer_secondary_keys from _vt.vreplication where db_name = 'vt_%s' and workflow = '%s'"
    51  	copyStateQuery     = "select table_name, lastpk from _vt.copy_state where vrepl_id = %d and id in (select max(id) from _vt.copy_state where vrepl_id = %d group by vrepl_id, table_name)"
    52  )
    53  
    54  var (
    55  	streamInfoKs         = fmt.Sprintf(streamInfoQuery, "test", "ks")
    56  	reverseStreamInfoKs1 = fmt.Sprintf(streamInfoQuery, "test_reverse", "ks1")
    57  	streamInfoKs2        = fmt.Sprintf(streamInfoQuery, "test", "ks2")
    58  
    59  	streamExtInfoKs2        = fmt.Sprintf(streamExtInfoQuery, "ks2", "test")
    60  	reverseStreamExtInfoKs2 = fmt.Sprintf(streamExtInfoQuery, "ks2", "test_reverse")
    61  	reverseStreamExtInfoKs1 = fmt.Sprintf(streamExtInfoQuery, "ks1", "test_reverse")
    62  	streamExtInfoKs         = fmt.Sprintf(streamExtInfoQuery, "ks", "test")
    63  )
    64  
    65  type testMigraterEnv struct {
    66  	ts              *topo.Server
    67  	wr              *Wrangler
    68  	sourcePrimaries []*fakeTablet
    69  	targetPrimaries []*fakeTablet
    70  	dbSourceClients []*fakeDBClient
    71  	dbTargetClients []*fakeDBClient
    72  	allDBClients    []*fakeDBClient
    73  	targetKeyspace  string
    74  	sourceShards    []string
    75  	targetShards    []string
    76  	sourceKeyRanges []*topodatapb.KeyRange
    77  	targetKeyRanges []*topodatapb.KeyRange
    78  	tmeDB           *fakesqldb.DB
    79  	mu              sync.Mutex
    80  }
    81  
    82  // testShardMigraterEnv has some convenience functions for adding expected queries.
    83  // They are approximate and should be only used to test other features like stream migration.
    84  // Use explicit queries for testing the actual shard migration.
    85  type testShardMigraterEnv struct {
    86  	testMigraterEnv
    87  }
    88  
    89  // tablet picker requires these to be set, otherwise it errors out. also the values need to match an existing
    90  // tablet, otherwise it sleeps until it retries, causing tests to timeout and hence break
    91  // we set these for each new migater env to be the first source shard
    92  // the tests don't depend on which tablet is picked, so this works for now
    93  type testTabletPickerChoice struct {
    94  	keyspace string
    95  	shard    string
    96  }
    97  
    98  var tpChoice *testTabletPickerChoice
    99  
   100  func newTestTableMigrater(ctx context.Context, t *testing.T) *testMigraterEnv {
   101  	return newTestTableMigraterCustom(ctx, t, []string{"-40", "40-"}, []string{"-80", "80-"}, "select * %s")
   102  }
   103  
   104  // newTestTableMigraterCustom creates a customized test tablet migrater.
   105  // fmtQuery should be of the form: 'select a, b %s group by a'.
   106  // The test will Sprintf a from clause and where clause as needed.
   107  func newTestTableMigraterCustom(ctx context.Context, t *testing.T, sourceShards, targetShards []string, fmtQuery string) *testMigraterEnv {
   108  	tme := &testMigraterEnv{}
   109  	tme.ts = memorytopo.NewServer("cell1", "cell2")
   110  	tme.wr = New(logutil.NewConsoleLogger(), tme.ts, tmclient.NewTabletManagerClient())
   111  	tme.wr.sem = sync2.NewSemaphore(1, 1)
   112  	tme.sourceShards = sourceShards
   113  	tme.targetShards = targetShards
   114  	tme.tmeDB = fakesqldb.New(t)
   115  	tabletID := 10
   116  	for _, shard := range sourceShards {
   117  		tme.sourcePrimaries = append(tme.sourcePrimaries, newFakeTablet(t, tme.wr, "cell1", uint32(tabletID), topodatapb.TabletType_PRIMARY, tme.tmeDB, TabletKeyspaceShard(t, "ks1", shard)))
   118  		tabletID += 10
   119  
   120  		_, sourceKeyRange, err := topo.ValidateShardName(shard)
   121  		if err != nil {
   122  			t.Fatal(err)
   123  		}
   124  		tme.sourceKeyRanges = append(tme.sourceKeyRanges, sourceKeyRange)
   125  	}
   126  	tpChoiceTablet := tme.sourcePrimaries[0].Tablet
   127  	tpChoice = &testTabletPickerChoice{
   128  		keyspace: tpChoiceTablet.Keyspace,
   129  		shard:    tpChoiceTablet.Shard,
   130  	}
   131  	for _, shard := range targetShards {
   132  		tme.targetPrimaries = append(tme.targetPrimaries, newFakeTablet(t, tme.wr, "cell1", uint32(tabletID), topodatapb.TabletType_PRIMARY, tme.tmeDB, TabletKeyspaceShard(t, "ks2", shard)))
   133  		tabletID += 10
   134  
   135  		_, targetKeyRange, err := topo.ValidateShardName(shard)
   136  		if err != nil {
   137  			t.Fatal(err)
   138  		}
   139  		tme.targetKeyRanges = append(tme.targetKeyRanges, targetKeyRange)
   140  	}
   141  
   142  	vs := &vschemapb.Keyspace{
   143  		Sharded: true,
   144  		Vindexes: map[string]*vschemapb.Vindex{
   145  			"hash": {
   146  				Type: "hash",
   147  			},
   148  		},
   149  		Tables: map[string]*vschemapb.Table{
   150  			"t1": {
   151  				ColumnVindexes: []*vschemapb.ColumnVindex{{
   152  					Column: "c1",
   153  					Name:   "hash",
   154  				}},
   155  			},
   156  			"t2": {
   157  				ColumnVindexes: []*vschemapb.ColumnVindex{{
   158  					Column: "c1",
   159  					Name:   "hash",
   160  				}},
   161  			},
   162  		},
   163  	}
   164  	if len(sourceShards) != 1 {
   165  		if err := tme.ts.SaveVSchema(ctx, "ks1", vs); err != nil {
   166  			t.Fatal(err)
   167  		}
   168  	}
   169  	if len(targetShards) != 1 {
   170  		if err := tme.ts.SaveVSchema(ctx, "ks2", vs); err != nil {
   171  			t.Fatal(err)
   172  		}
   173  	}
   174  	if err := tme.ts.RebuildSrvVSchema(ctx, nil); err != nil {
   175  		t.Fatal(err)
   176  	}
   177  	err := topotools.RebuildKeyspace(ctx, logutil.NewConsoleLogger(), tme.ts, "ks1", []string{"cell1"}, false)
   178  	if err != nil {
   179  		t.Fatal(err)
   180  	}
   181  	err = topotools.RebuildKeyspace(ctx, logutil.NewConsoleLogger(), tme.ts, "ks2", []string{"cell1"}, false)
   182  	if err != nil {
   183  		t.Fatal(err)
   184  	}
   185  
   186  	tme.startTablets(t)
   187  	tme.createDBClients(ctx, t)
   188  	tme.setPrimaryPositions()
   189  	now := time.Now().Unix()
   190  	for i, targetShard := range targetShards {
   191  		var streamInfoRows []string
   192  		var streamExtInfoRows []string
   193  		for j, sourceShard := range sourceShards {
   194  			bls := &binlogdatapb.BinlogSource{
   195  				Keyspace: "ks1",
   196  				Shard:    sourceShard,
   197  				Filter: &binlogdatapb.Filter{
   198  					Rules: []*binlogdatapb.Rule{{
   199  						Match:  "t1",
   200  						Filter: fmt.Sprintf(fmtQuery, fmt.Sprintf("from t1 where in_keyrange('%s')", targetShard)),
   201  					}, {
   202  						Match:  "t2",
   203  						Filter: fmt.Sprintf(fmtQuery, fmt.Sprintf("from t2 where in_keyrange('%s')", targetShard)),
   204  					}},
   205  				},
   206  			}
   207  			streamInfoRows = append(streamInfoRows, fmt.Sprintf("%d|%v|||", j+1, bls))
   208  			streamExtInfoRows = append(streamExtInfoRows, fmt.Sprintf("%d|||||Running|vt_ks1|%d|%d|0|0||||0", j+1, now, now))
   209  			tme.dbTargetClients[i].addInvariant(fmt.Sprintf(copyStateQuery, j+1, j+1), noResult)
   210  		}
   211  		tme.dbTargetClients[i].addInvariant(streamInfoKs2, sqltypes.MakeTestResult(sqltypes.MakeTestFields(
   212  			"id|source|message|cell|tablet_types",
   213  			"int64|varchar|varchar|varchar|varchar"),
   214  			streamInfoRows...))
   215  		tme.dbTargetClients[i].addInvariant(streamExtInfoKs2, sqltypes.MakeTestResult(sqltypes.MakeTestFields(
   216  			"id|source|pos|stop_pos|max_replication_lag|state|db_name|time_updated|transaction_timestamp|time_heartbeat|time_throttled|component_throttled|message|tags|workflow_type|workflow_sub_type|defer_secondary_keys",
   217  			"int64|varchar|int64|int64|int64|varchar|varchar|int64|int64|int64|int64|int64|varchar|varchar|int64|int64|int64"),
   218  			streamExtInfoRows...))
   219  		tme.dbTargetClients[i].addInvariant(reverseStreamExtInfoKs2, sqltypes.MakeTestResult(sqltypes.MakeTestFields(
   220  			"id|source|pos|stop_pos|max_replication_lag|state|db_name|time_updated|transaction_timestamp|time_heartbeat|time_throttled|component_throttled|message|tags|workflow_type|workflow_sub_type|defer_secondary_keys",
   221  			"int64|varchar|int64|int64|int64|varchar|varchar|int64|int64|int64|int64|int64|varchar|varchar|int64|int64|int64"),
   222  			streamExtInfoRows...))
   223  	}
   224  
   225  	for i, sourceShard := range sourceShards {
   226  		var streamInfoRows []string
   227  		for j, targetShard := range targetShards {
   228  			bls := &binlogdatapb.BinlogSource{
   229  				Keyspace: "ks2",
   230  				Shard:    targetShard,
   231  				Filter: &binlogdatapb.Filter{
   232  					Rules: []*binlogdatapb.Rule{{
   233  						Match:  "t1",
   234  						Filter: fmt.Sprintf(fmtQuery, fmt.Sprintf("from t1 where in_keyrange('%s')", sourceShard)),
   235  					}, {
   236  						Match:  "t2",
   237  						Filter: fmt.Sprintf(fmtQuery, fmt.Sprintf("from t2 where in_keyrange('%s')", sourceShard)),
   238  					}},
   239  				},
   240  			}
   241  			streamInfoRows = append(streamInfoRows, fmt.Sprintf("%d|%v|||", j+1, bls))
   242  			tme.dbTargetClients[i].addInvariant(fmt.Sprintf(copyStateQuery, j+1, j+1), noResult)
   243  		}
   244  		tme.dbSourceClients[i].addInvariant(reverseStreamInfoKs1, sqltypes.MakeTestResult(sqltypes.MakeTestFields(
   245  			"id|source|message|cell|tablet_types",
   246  			"int64|varchar|varchar|varchar|varchar"),
   247  			streamInfoRows...),
   248  		)
   249  	}
   250  
   251  	if err := topotools.SaveRoutingRules(ctx, tme.wr.ts, map[string][]string{
   252  		"t1":     {"ks1.t1"},
   253  		"ks2.t1": {"ks1.t1"},
   254  		"t2":     {"ks1.t2"},
   255  		"ks2.t2": {"ks1.t2"},
   256  	}); err != nil {
   257  		t.Fatal(err)
   258  	}
   259  	if err := tme.ts.RebuildSrvVSchema(ctx, nil); err != nil {
   260  		t.Fatal(err)
   261  	}
   262  
   263  	tme.targetKeyspace = "ks2"
   264  	return tme
   265  }
   266  
   267  func newTestShardMigrater(ctx context.Context, t *testing.T, sourceShards, targetShards []string) *testShardMigraterEnv {
   268  	tme := &testShardMigraterEnv{}
   269  	tme.ts = memorytopo.NewServer("cell1", "cell2")
   270  	tme.wr = New(logutil.NewConsoleLogger(), tme.ts, tmclient.NewTabletManagerClient())
   271  	tme.sourceShards = sourceShards
   272  	tme.targetShards = targetShards
   273  	tme.tmeDB = fakesqldb.New(t)
   274  	tme.wr.sem = sync2.NewSemaphore(1, 0)
   275  
   276  	tabletID := 10
   277  	for _, shard := range sourceShards {
   278  		tme.sourcePrimaries = append(tme.sourcePrimaries, newFakeTablet(t, tme.wr, "cell1", uint32(tabletID), topodatapb.TabletType_PRIMARY, tme.tmeDB, TabletKeyspaceShard(t, "ks", shard)))
   279  		tabletID += 10
   280  
   281  		_, sourceKeyRange, err := topo.ValidateShardName(shard)
   282  		if err != nil {
   283  			t.Fatal(err)
   284  		}
   285  		tme.sourceKeyRanges = append(tme.sourceKeyRanges, sourceKeyRange)
   286  	}
   287  	tpChoiceTablet := tme.sourcePrimaries[0].Tablet
   288  	tpChoice = &testTabletPickerChoice{
   289  		keyspace: tpChoiceTablet.Keyspace,
   290  		shard:    tpChoiceTablet.Shard,
   291  	}
   292  
   293  	for _, shard := range targetShards {
   294  		tme.targetPrimaries = append(tme.targetPrimaries, newFakeTablet(t, tme.wr, "cell1", uint32(tabletID), topodatapb.TabletType_PRIMARY, tme.tmeDB, TabletKeyspaceShard(t, "ks", shard)))
   295  		tabletID += 10
   296  
   297  		_, targetKeyRange, err := topo.ValidateShardName(shard)
   298  		if err != nil {
   299  			t.Fatal(err)
   300  		}
   301  		tme.targetKeyRanges = append(tme.targetKeyRanges, targetKeyRange)
   302  	}
   303  
   304  	vs := &vschemapb.Keyspace{
   305  		Sharded: true,
   306  		Vindexes: map[string]*vschema.Vindex{
   307  			"thash": {
   308  				Type: "hash",
   309  			},
   310  		},
   311  		Tables: map[string]*vschema.Table{
   312  			"t1": {
   313  				ColumnVindexes: []*vschema.ColumnVindex{{
   314  					Columns: []string{"c1"},
   315  					Name:    "thash",
   316  				}},
   317  			},
   318  			"t2": {
   319  				ColumnVindexes: []*vschema.ColumnVindex{{
   320  					Columns: []string{"c1"},
   321  					Name:    "thash",
   322  				}},
   323  			},
   324  			"t3": {
   325  				ColumnVindexes: []*vschema.ColumnVindex{{
   326  					Columns: []string{"c1"},
   327  					Name:    "thash",
   328  				}},
   329  			},
   330  		},
   331  	}
   332  	if err := tme.ts.SaveVSchema(ctx, "ks", vs); err != nil {
   333  		t.Fatal(err)
   334  	}
   335  	if err := tme.ts.RebuildSrvVSchema(ctx, nil); err != nil {
   336  		t.Fatal(err)
   337  	}
   338  	err := topotools.RebuildKeyspace(ctx, logutil.NewConsoleLogger(), tme.ts, "ks", nil, false)
   339  	if err != nil {
   340  		t.Fatal(err)
   341  	}
   342  
   343  	tme.startTablets(t)
   344  	tme.createDBClients(ctx, t)
   345  	tme.setPrimaryPositions()
   346  	now := time.Now().Unix()
   347  	for i, targetShard := range targetShards {
   348  		var rows, rowsRdOnly []string
   349  		var streamExtInfoRows []string
   350  		for j, sourceShard := range sourceShards {
   351  			if !key.KeyRangesIntersect(tme.targetKeyRanges[i], tme.sourceKeyRanges[j]) {
   352  				continue
   353  			}
   354  			bls := &binlogdatapb.BinlogSource{
   355  				Keyspace: "ks",
   356  				Shard:    sourceShard,
   357  				Filter: &binlogdatapb.Filter{
   358  					Rules: []*binlogdatapb.Rule{{
   359  						Match:  "/.*",
   360  						Filter: targetShard,
   361  					}},
   362  				},
   363  			}
   364  			rows = append(rows, fmt.Sprintf("%d|%v||||0|0|0", j+1, bls))
   365  			rowsRdOnly = append(rows, fmt.Sprintf("%d|%v|||RDONLY|0|0|0", j+1, bls))
   366  			streamExtInfoRows = append(streamExtInfoRows, fmt.Sprintf("%d|||||Running|vt_ks1|%d|%d|0|0|||", j+1, now, now))
   367  			tme.dbTargetClients[i].addInvariant(fmt.Sprintf(copyStateQuery, j+1, j+1), noResult)
   368  		}
   369  		tme.dbTargetClients[i].addInvariant(streamInfoKs, sqltypes.MakeTestResult(sqltypes.MakeTestFields(
   370  			"id|source|message|cell|tablet_types|workflow_type|workflow_sub_type|defer_secondary_keys",
   371  			"int64|varchar|varchar|varchar|varchar|int64|int64|int64"),
   372  			rows...),
   373  		)
   374  		tme.dbTargetClients[i].addInvariant(streamInfoKs+"-rdonly", sqltypes.MakeTestResult(sqltypes.MakeTestFields(
   375  			"id|source|message|cell|tablet_types|workflow_type|workflow_sub_type|defer_secondary_keys",
   376  			"int64|varchar|varchar|varchar|varchar|int64|int64|int64"),
   377  			rowsRdOnly...),
   378  		)
   379  		tme.dbTargetClients[i].addInvariant(streamExtInfoKs, sqltypes.MakeTestResult(sqltypes.MakeTestFields(
   380  			"id|source|pos|stop_pos|max_replication_lag|state|db_name|time_updated|transaction_timestamp|time_heartbeat|time_throttled|component_throttled|message|tags",
   381  			"int64|varchar|int64|int64|int64|varchar|varchar|int64|int64|int64|int64|varchar|varchar|varchar"),
   382  			streamExtInfoRows...))
   383  	}
   384  
   385  	tme.targetKeyspace = "ks"
   386  	for i, dbclient := range tme.dbSourceClients {
   387  		var streamExtInfoRows []string
   388  		dbclient.addInvariant(streamInfoKs, &sqltypes.Result{})
   389  		for j := range targetShards {
   390  			streamExtInfoRows = append(streamExtInfoRows, fmt.Sprintf("%d|||||Running|vt_ks|%d|%d|0|0|||", j+1, now, now))
   391  			tme.dbSourceClients[i].addInvariant(fmt.Sprintf(copyStateQuery, j+1, j+1), noResult)
   392  		}
   393  		tme.dbSourceClients[i].addInvariant(streamExtInfoKs, sqltypes.MakeTestResult(sqltypes.MakeTestFields(
   394  			"id|source|pos|stop_pos|max_replication_lag|state|db_name|time_updated|transaction_timestamp|time_heartbeat|time_throttled|component_throttled|message|tags",
   395  			"int64|varchar|int64|int64|int64|varchar|varchar|int64|int64|int64|int64|varchar|varchar|varchar"),
   396  			streamExtInfoRows...))
   397  	}
   398  	return tme
   399  }
   400  
   401  func (tme *testMigraterEnv) startTablets(t *testing.T) {
   402  	tme.mu.Lock()
   403  	defer tme.mu.Unlock()
   404  	allPrimarys := append(tme.sourcePrimaries, tme.targetPrimaries...)
   405  	for _, primary := range allPrimarys {
   406  		primary.StartActionLoop(t, tme.wr)
   407  	}
   408  	// Wait for the shard record primaries to be set.
   409  	for _, primary := range allPrimarys {
   410  		primaryFound := false
   411  		for i := 0; i < 10; i++ {
   412  			si, err := tme.wr.ts.GetShard(context.Background(), primary.Tablet.Keyspace, primary.Tablet.Shard)
   413  			if err != nil {
   414  				t.Fatal(err)
   415  			}
   416  			if si.PrimaryAlias != nil {
   417  				primaryFound = true
   418  				break
   419  			}
   420  			time.Sleep(10 * time.Millisecond)
   421  		}
   422  		if !primaryFound {
   423  			t.Fatalf("shard primary did not get updated for tablet: %v", primary)
   424  		}
   425  	}
   426  }
   427  
   428  func (tme *testMigraterEnv) stopTablets(t *testing.T) {
   429  	for _, primary := range tme.sourcePrimaries {
   430  		primary.StopActionLoop(t)
   431  	}
   432  	for _, primary := range tme.targetPrimaries {
   433  		primary.StopActionLoop(t)
   434  	}
   435  }
   436  
   437  func (tme *testMigraterEnv) createDBClients(ctx context.Context, t *testing.T) {
   438  	tme.mu.Lock()
   439  	defer tme.mu.Unlock()
   440  	for _, primary := range tme.sourcePrimaries {
   441  		dbclient := newFakeDBClient(primary.Tablet.Alias.String())
   442  		tme.dbSourceClients = append(tme.dbSourceClients, dbclient)
   443  		dbClientFactory := func() binlogplayer.DBClient { return dbclient }
   444  		// Replace existing engine with a new one
   445  		primary.TM.VREngine = vreplication.NewTestEngine(tme.ts, "", primary.FakeMysqlDaemon, dbClientFactory, dbClientFactory, dbclient.DBName(), nil)
   446  		primary.TM.VREngine.Open(ctx)
   447  	}
   448  	for _, primary := range tme.targetPrimaries {
   449  		log.Infof("Adding as targetPrimary %s", primary.Tablet.Alias)
   450  		dbclient := newFakeDBClient(primary.Tablet.Alias.String())
   451  		tme.dbTargetClients = append(tme.dbTargetClients, dbclient)
   452  		dbClientFactory := func() binlogplayer.DBClient { return dbclient }
   453  		// Replace existing engine with a new one
   454  		primary.TM.VREngine = vreplication.NewTestEngine(tme.ts, "", primary.FakeMysqlDaemon, dbClientFactory, dbClientFactory, dbclient.DBName(), nil)
   455  		primary.TM.VREngine.Open(ctx)
   456  	}
   457  	tme.allDBClients = append(tme.dbSourceClients, tme.dbTargetClients...)
   458  }
   459  
   460  func (tme *testMigraterEnv) setPrimaryPositions() {
   461  	for _, primary := range tme.sourcePrimaries {
   462  		primary.FakeMysqlDaemon.CurrentPrimaryPosition = mysql.Position{
   463  			GTIDSet: mysql.MariadbGTIDSet{
   464  				5: mysql.MariadbGTID{
   465  					Domain:   5,
   466  					Server:   456,
   467  					Sequence: 892,
   468  				},
   469  			},
   470  		}
   471  	}
   472  	for _, primary := range tme.targetPrimaries {
   473  		primary.FakeMysqlDaemon.CurrentPrimaryPosition = mysql.Position{
   474  			GTIDSet: mysql.MariadbGTIDSet{
   475  				5: mysql.MariadbGTID{
   476  					Domain:   5,
   477  					Server:   456,
   478  					Sequence: 893,
   479  				},
   480  			},
   481  		}
   482  	}
   483  }
   484  
   485  func (tme *testMigraterEnv) expectNoPreviousJournals() {
   486  	// validate that no previous journals exist
   487  	for _, dbclient := range tme.dbSourceClients {
   488  		dbclient.addQueryRE(tsCheckJournals, &sqltypes.Result{}, nil)
   489  	}
   490  }
   491  
   492  func (tme *testMigraterEnv) expectNoPreviousReverseJournals() {
   493  	// validate that no previous journals exist
   494  	for _, dbclient := range tme.dbTargetClients {
   495  		dbclient.addQueryRE(tsCheckJournals, &sqltypes.Result{}, nil)
   496  	}
   497  }
   498  
   499  func (tme *testShardMigraterEnv) forAllStreams(f func(i, j int)) {
   500  	for i := range tme.targetShards {
   501  		for j := range tme.sourceShards {
   502  			if !key.KeyRangesIntersect(tme.targetKeyRanges[i], tme.sourceKeyRanges[j]) {
   503  				continue
   504  			}
   505  			f(i, j)
   506  		}
   507  	}
   508  }
   509  
   510  func (tme *testShardMigraterEnv) expectCheckJournals() {
   511  	for _, dbclient := range tme.dbSourceClients {
   512  		dbclient.addQueryRE("select val from _vt.resharding_journal where id=.*", &sqltypes.Result{}, nil)
   513  	}
   514  }
   515  
   516  func (tme *testShardMigraterEnv) expectWaitForCatchup() {
   517  	state := sqltypes.MakeTestResult(sqltypes.MakeTestFields(
   518  		"pos|state|message",
   519  		"varchar|varchar|varchar"),
   520  		"MariaDB/5-456-892|Running",
   521  	)
   522  	tme.forAllStreams(func(i, j int) {
   523  		tme.dbTargetClients[i].addQuery(fmt.Sprintf("select pos, state, message from _vt.vreplication where id=%d", j+1), state, nil)
   524  
   525  		// mi.waitForCatchup-> mi.wr.tmc.VReplicationExec('stopped for cutover')
   526  		tme.dbTargetClients[i].addQuery(fmt.Sprintf("select id from _vt.vreplication where id = %d", j+1), &sqltypes.Result{Rows: [][]sqltypes.Value{{sqltypes.NewInt64(int64(j + 1))}}}, nil)
   527  		tme.dbTargetClients[i].addQuery(fmt.Sprintf("update _vt.vreplication set state = 'Stopped', message = 'stopped for cutover' where id in (%d)", j+1), &sqltypes.Result{}, nil)
   528  		tme.dbTargetClients[i].addQuery(fmt.Sprintf("select * from _vt.vreplication where id = %d", j+1), stoppedResult(j+1), nil)
   529  	})
   530  }
   531  
   532  func (tme *testShardMigraterEnv) expectDeleteReverseVReplication() {
   533  	// NOTE: this is not a faithful reproduction of what should happen.
   534  	// The ids returned are not accurate.
   535  	for _, dbclient := range tme.dbSourceClients {
   536  		dbclient.addQuery("select id from _vt.vreplication where db_name = 'vt_ks' and workflow = 'test_reverse'", resultid12, nil)
   537  		dbclient.addQuery("delete from _vt.vreplication where id in (1, 2)", &sqltypes.Result{}, nil)
   538  		dbclient.addQuery("delete from _vt.copy_state where vrepl_id in (1, 2)", &sqltypes.Result{}, nil)
   539  		dbclient.addQuery("delete from _vt.post_copy_action where vrepl_id in (1, 2)", &sqltypes.Result{}, nil)
   540  	}
   541  }
   542  
   543  func (tme *testShardMigraterEnv) expectCreateReverseVReplication() {
   544  	tme.expectDeleteReverseVReplication()
   545  	tme.forAllStreams(func(i, j int) {
   546  		tme.dbSourceClients[j].addQueryRE(fmt.Sprintf("insert into _vt.vreplication.*%s.*%s.*MariaDB/5-456-893.*Stopped", tme.targetShards[i], key.KeyRangeString(tme.sourceKeyRanges[j])), &sqltypes.Result{InsertID: uint64(j + 1)}, nil)
   547  		tme.dbSourceClients[j].addQuery(fmt.Sprintf("select * from _vt.vreplication where id = %d", j+1), stoppedResult(j+1), nil)
   548  	})
   549  }
   550  
   551  func (tme *testShardMigraterEnv) expectCreateJournals() {
   552  	for _, dbclient := range tme.dbSourceClients {
   553  		dbclient.addQueryRE("insert into _vt.resharding_journal.*", &sqltypes.Result{}, nil)
   554  	}
   555  }
   556  
   557  func (tme *testShardMigraterEnv) expectStartReverseVReplication() {
   558  	// NOTE: this is not a faithful reproduction of what should happen.
   559  	// The ids returned are not accurate.
   560  	for _, dbclient := range tme.dbSourceClients {
   561  		dbclient.addQuery("select id from _vt.vreplication where db_name = 'vt_ks'", resultid34, nil)
   562  		dbclient.addQuery("update _vt.vreplication set state = 'Running', message = '' where id in (3, 4)", &sqltypes.Result{}, nil)
   563  		dbclient.addQuery("select * from _vt.vreplication where id = 3", runningResult(3), nil)
   564  		dbclient.addQuery("select * from _vt.vreplication where id = 4", runningResult(4), nil)
   565  	}
   566  }
   567  
   568  func (tme *testShardMigraterEnv) expectFrozenTargetVReplication() {
   569  	// NOTE: this is not a faithful reproduction of what should happen.
   570  	// The ids returned are not accurate.
   571  	for _, dbclient := range tme.dbTargetClients {
   572  		dbclient.addQuery("select id from _vt.vreplication where db_name = 'vt_ks' and workflow = 'test'", resultid12, nil)
   573  		dbclient.addQuery("update _vt.vreplication set message = 'FROZEN' where id in (1, 2)", &sqltypes.Result{}, nil)
   574  		dbclient.addQuery("select * from _vt.vreplication where id = 1", stoppedResult(1), nil)
   575  		dbclient.addQuery("select * from _vt.vreplication where id = 2", stoppedResult(2), nil)
   576  	}
   577  }
   578  
   579  func (tme *testShardMigraterEnv) expectDeleteTargetVReplication() {
   580  	// NOTE: this is not a faithful reproduction of what should happen.
   581  	// The ids returned are not accurate.
   582  	for _, dbclient := range tme.dbTargetClients {
   583  		dbclient.addQuery("select id from _vt.vreplication where db_name = 'vt_ks' and workflow = 'test'", resultid12, nil)
   584  		dbclient.addQuery("delete from _vt.vreplication where id in (1, 2)", &sqltypes.Result{}, nil)
   585  		dbclient.addQuery("delete from _vt.copy_state where vrepl_id in (1, 2)", &sqltypes.Result{}, nil)
   586  		dbclient.addQuery("delete from _vt.post_copy_action where vrepl_id in (1, 2)", &sqltypes.Result{}, nil)
   587  	}
   588  }
   589  
   590  func (tme *testShardMigraterEnv) expectCancelMigration() {
   591  	for _, dbclient := range tme.dbTargetClients {
   592  		dbclient.addQuery("select id from _vt.vreplication where db_name = 'vt_ks' and workflow = 'test'", &sqltypes.Result{}, nil)
   593  	}
   594  	for _, dbclient := range tme.dbSourceClients {
   595  		dbclient.addQuery("select id from _vt.vreplication where db_name = 'vt_ks' and workflow != 'test_reverse'", &sqltypes.Result{}, nil)
   596  	}
   597  	tme.expectDeleteReverseVReplication()
   598  }
   599  
   600  func (tme *testShardMigraterEnv) expectNoPreviousJournals() {
   601  	// validate that no previous journals exist
   602  	for _, dbclient := range tme.dbSourceClients {
   603  		dbclient.addQueryRE(tsCheckJournals, &sqltypes.Result{}, nil)
   604  	}
   605  }
   606  
   607  func (tme *testMigraterEnv) close(t *testing.T) {
   608  	tme.mu.Lock()
   609  	defer tme.mu.Unlock()
   610  	tme.stopTablets(t)
   611  	for _, dbclient := range tme.dbSourceClients {
   612  		dbclient.Close()
   613  	}
   614  	for _, dbclient := range tme.dbTargetClients {
   615  		dbclient.Close()
   616  	}
   617  	tme.tmeDB.CloseAllConnections()
   618  	tme.ts.Close()
   619  	tme.wr.tmc.Close()
   620  	tme.wr = nil
   621  }