vitess.io/vitess@v0.16.2/go/vt/vtgate/vstream_manager.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package vtgate
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"io"
    23  	"strings"
    24  	"sync"
    25  	"time"
    26  
    27  	"vitess.io/vitess/go/vt/discovery"
    28  	querypb "vitess.io/vitess/go/vt/proto/query"
    29  	"vitess.io/vitess/go/vt/topo"
    30  
    31  	vtgatepb "vitess.io/vitess/go/vt/proto/vtgate"
    32  
    33  	"google.golang.org/protobuf/proto"
    34  
    35  	"vitess.io/vitess/go/vt/log"
    36  	binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata"
    37  	topodatapb "vitess.io/vitess/go/vt/proto/topodata"
    38  	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
    39  	"vitess.io/vitess/go/vt/srvtopo"
    40  	"vitess.io/vitess/go/vt/vterrors"
    41  )
    42  
    43  // vstreamManager manages vstream requests.
    44  type vstreamManager struct {
    45  	resolver *srvtopo.Resolver
    46  	toposerv srvtopo.Server
    47  	cell     string
    48  }
    49  
    50  // maxSkewTimeoutSeconds is the maximum allowed skew between two streams when the MinimizeSkew flag is set
    51  const maxSkewTimeoutSeconds = 10 * 60
    52  
    53  // vstream contains the metadata for one VStream request.
    54  type vstream struct {
    55  	// mu protects parts of vgtid, the semantics of a send, and journaler.
    56  	// Once streaming begins, the Gtid within each ShardGtid will be updated on each event.
    57  	// Also, the list of ShardGtids can change on a journaling event.
    58  	// All other parts of vgtid can be read without a lock.
    59  	// The lock is also held to ensure that all grouped events are sent together.
    60  	// This can happen if vstreamer breaks up large transactions into smaller chunks.
    61  	mu        sync.Mutex
    62  	vgtid     *binlogdatapb.VGtid
    63  	send      func(events []*binlogdatapb.VEvent) error
    64  	journaler map[int64]*journalEvent
    65  
    66  	// err can only be set once.
    67  	// errMu protects err by ensuring its value is read or written by only one goroutine at a time.
    68  	once  sync.Once
    69  	err   error
    70  	errMu sync.Mutex
    71  
    72  	// Other input parameters
    73  	tabletType topodatapb.TabletType
    74  	filter     *binlogdatapb.Filter
    75  	resolver   *srvtopo.Resolver
    76  	optCells   string
    77  
    78  	cancel context.CancelFunc
    79  	wg     sync.WaitGroup
    80  
    81  	// this flag is set by the client, default false
    82  	// if true skew detection is enabled and we align the streams so that they receive events from
    83  	// about the same time as each other. Note that there is no exact ordering of events across shards
    84  	minimizeSkew bool
    85  
    86  	// this flag is set by the client, default false
    87  	// if true when a reshard is detected the client will send the corresponding journal event to the client
    88  	// default behavior is to automatically migrate the resharded streams from the old to the new shards
    89  	stopOnReshard bool
    90  
    91  	// mutex used to synchronize access to skew detection parameters
    92  	skewMu sync.Mutex
    93  	// channel is created whenever there is a skew detected. closing it implies the current skew has been fixed
    94  	skewCh chan bool
    95  	// if a skew lasts for this long, we timeout the vstream call. currently hardcoded
    96  	skewTimeoutSeconds int64
    97  	// the slow streamId which is causing the skew. streamId is of the form <keyspace>.<shard>
    98  	laggard string
    99  	// transaction timestamp of the slowest stream
   100  	lowestTS int64
   101  	// the timestamp of the most recent event, keyed by streamId. streamId is of the form <keyspace>.<shard>
   102  	timestamps map[string]int64
   103  
   104  	// the shard map tracking the copy completion, keyed by streamId. streamId is of the form <keyspace>.<shard>
   105  	copyCompletedShard map[string]struct{}
   106  
   107  	vsm *vstreamManager
   108  
   109  	eventCh           chan []*binlogdatapb.VEvent
   110  	heartbeatInterval uint32
   111  	ts                *topo.Server
   112  }
   113  
   114  type journalEvent struct {
   115  	journal      *binlogdatapb.Journal
   116  	participants map[*binlogdatapb.ShardGtid]bool
   117  	done         chan struct{}
   118  }
   119  
   120  func newVStreamManager(resolver *srvtopo.Resolver, serv srvtopo.Server, cell string) *vstreamManager {
   121  	return &vstreamManager{
   122  		resolver: resolver,
   123  		toposerv: serv,
   124  		cell:     cell,
   125  	}
   126  }
   127  
   128  func (vsm *vstreamManager) VStream(ctx context.Context, tabletType topodatapb.TabletType, vgtid *binlogdatapb.VGtid,
   129  	filter *binlogdatapb.Filter, flags *vtgatepb.VStreamFlags, send func(events []*binlogdatapb.VEvent) error) error {
   130  	vgtid, filter, flags, err := vsm.resolveParams(ctx, tabletType, vgtid, filter, flags)
   131  	if err != nil {
   132  		return err
   133  	}
   134  	ts, err := vsm.toposerv.GetTopoServer()
   135  	if err != nil {
   136  		return err
   137  	}
   138  	if ts == nil {
   139  		log.Errorf("unable to get topo server in VStream()")
   140  		return fmt.Errorf("unable to get topo server")
   141  	}
   142  	vs := &vstream{
   143  		vgtid:              vgtid,
   144  		tabletType:         tabletType,
   145  		optCells:           flags.Cells,
   146  		filter:             filter,
   147  		send:               send,
   148  		resolver:           vsm.resolver,
   149  		journaler:          make(map[int64]*journalEvent),
   150  		minimizeSkew:       flags.GetMinimizeSkew(),
   151  		stopOnReshard:      flags.GetStopOnReshard(),
   152  		skewTimeoutSeconds: maxSkewTimeoutSeconds,
   153  		timestamps:         make(map[string]int64),
   154  		vsm:                vsm,
   155  		eventCh:            make(chan []*binlogdatapb.VEvent),
   156  		heartbeatInterval:  flags.GetHeartbeatInterval(),
   157  		ts:                 ts,
   158  		copyCompletedShard: make(map[string]struct{}),
   159  	}
   160  	return vs.stream(ctx)
   161  }
   162  
   163  // resolveParams provides defaults for the inputs if they're not specified.
   164  func (vsm *vstreamManager) resolveParams(ctx context.Context, tabletType topodatapb.TabletType, vgtid *binlogdatapb.VGtid,
   165  	filter *binlogdatapb.Filter, flags *vtgatepb.VStreamFlags) (*binlogdatapb.VGtid, *binlogdatapb.Filter, *vtgatepb.VStreamFlags, error) {
   166  
   167  	if filter == nil {
   168  		filter = &binlogdatapb.Filter{
   169  			Rules: []*binlogdatapb.Rule{{
   170  				Match: "/.*",
   171  			}},
   172  		}
   173  	}
   174  
   175  	if flags == nil {
   176  		flags = &vtgatepb.VStreamFlags{}
   177  	}
   178  	if vgtid == nil || len(vgtid.ShardGtids) == 0 {
   179  		return nil, nil, nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "vgtid must have at least one value with a starting position")
   180  	}
   181  	// To fetch from all keyspaces, the input must contain a single ShardGtid
   182  	// that has an empty keyspace, and the Gtid must be "current". In the
   183  	// future, we'll allow the Gtid to be empty which will also support
   184  	// copying of existing data.
   185  	if len(vgtid.ShardGtids) == 1 && vgtid.ShardGtids[0].Keyspace == "" {
   186  		if vgtid.ShardGtids[0].Gtid != "current" {
   187  			return nil, nil, nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "for an empty keyspace, the Gtid value must be 'current': %v", vgtid)
   188  		}
   189  		keyspaces, err := vsm.toposerv.GetSrvKeyspaceNames(ctx, vsm.cell, false)
   190  		if err != nil {
   191  			return nil, nil, nil, err
   192  		}
   193  		newvgtid := &binlogdatapb.VGtid{}
   194  		for _, keyspace := range keyspaces {
   195  			newvgtid.ShardGtids = append(newvgtid.ShardGtids, &binlogdatapb.ShardGtid{
   196  				Keyspace: keyspace,
   197  				Gtid:     "current",
   198  			})
   199  		}
   200  		vgtid = newvgtid
   201  	}
   202  	newvgtid := &binlogdatapb.VGtid{}
   203  	for _, sgtid := range vgtid.ShardGtids {
   204  		if sgtid.Shard == "" {
   205  			if sgtid.Gtid != "current" {
   206  				return nil, nil, nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "if shards are unspecified, the Gtid value must be 'current': %v", vgtid)
   207  			}
   208  			// TODO(sougou): this should work with the new Migrate workflow
   209  			_, _, allShards, err := vsm.resolver.GetKeyspaceShards(ctx, sgtid.Keyspace, tabletType)
   210  			if err != nil {
   211  				return nil, nil, nil, err
   212  			}
   213  			for _, shard := range allShards {
   214  				newvgtid.ShardGtids = append(newvgtid.ShardGtids, &binlogdatapb.ShardGtid{
   215  					Keyspace: sgtid.Keyspace,
   216  					Shard:    shard.Name,
   217  					Gtid:     sgtid.Gtid,
   218  				})
   219  			}
   220  		} else {
   221  			newvgtid.ShardGtids = append(newvgtid.ShardGtids, sgtid)
   222  		}
   223  	}
   224  
   225  	//TODO add tablepk validations
   226  
   227  	return newvgtid, filter, flags, nil
   228  }
   229  
   230  func (vsm *vstreamManager) RecordStreamDelay() {
   231  	vstreamSkewDelayCount.Add(1)
   232  }
   233  
   234  func (vsm *vstreamManager) GetTotalStreamDelay() int64 {
   235  	return vstreamSkewDelayCount.Get()
   236  }
   237  
   238  func (vs *vstream) stream(ctx context.Context) error {
   239  	ctx, vs.cancel = context.WithCancel(ctx)
   240  	defer vs.cancel()
   241  
   242  	go vs.sendEvents(ctx)
   243  
   244  	// Make a copy first, because the ShardGtids list can change once streaming starts.
   245  	copylist := append(([]*binlogdatapb.ShardGtid)(nil), vs.vgtid.ShardGtids...)
   246  	for _, sgtid := range copylist {
   247  		vs.startOneStream(ctx, sgtid)
   248  	}
   249  	vs.wg.Wait()
   250  
   251  	return vs.getError()
   252  }
   253  
   254  func (vs *vstream) sendEvents(ctx context.Context) {
   255  	var heartbeat <-chan time.Time
   256  	var resetHeartbeat func()
   257  
   258  	if vs.heartbeatInterval == 0 {
   259  		heartbeat = make(chan time.Time)
   260  		resetHeartbeat = func() {}
   261  	} else {
   262  		d := time.Duration(vs.heartbeatInterval) * time.Second
   263  		timer := time.NewTicker(d)
   264  		defer timer.Stop()
   265  
   266  		heartbeat = timer.C
   267  		resetHeartbeat = func() { timer.Reset(d) }
   268  	}
   269  
   270  	send := func(evs []*binlogdatapb.VEvent) error {
   271  		if err := vs.send(evs); err != nil {
   272  			vs.once.Do(func() {
   273  				vs.setError(err)
   274  			})
   275  			return err
   276  		}
   277  		return nil
   278  	}
   279  	for {
   280  		select {
   281  		case <-ctx.Done():
   282  			vs.once.Do(func() {
   283  				vs.setError(fmt.Errorf("context canceled"))
   284  			})
   285  			return
   286  		case evs := <-vs.eventCh:
   287  			if err := send(evs); err != nil {
   288  				vs.once.Do(func() {
   289  					vs.setError(err)
   290  				})
   291  				return
   292  			}
   293  			resetHeartbeat()
   294  		case t := <-heartbeat:
   295  			now := t.UnixNano()
   296  			evs := []*binlogdatapb.VEvent{{
   297  				Type:        binlogdatapb.VEventType_HEARTBEAT,
   298  				Timestamp:   now / 1e9,
   299  				CurrentTime: now,
   300  			}}
   301  			if err := send(evs); err != nil {
   302  				vs.once.Do(func() {
   303  					vs.setError(err)
   304  				})
   305  				return
   306  			}
   307  		}
   308  	}
   309  }
   310  
   311  // startOneStream sets up one shard stream.
   312  func (vs *vstream) startOneStream(ctx context.Context, sgtid *binlogdatapb.ShardGtid) {
   313  	vs.wg.Add(1)
   314  	go func() {
   315  		defer vs.wg.Done()
   316  		err := vs.streamFromTablet(ctx, sgtid)
   317  
   318  		// Set the error on exit. First one wins.
   319  		if err != nil {
   320  			log.Errorf("Error in vstream for %+v: %s", sgtid, err)
   321  			vs.once.Do(func() {
   322  				vs.setError(err)
   323  				vs.cancel()
   324  			})
   325  		}
   326  	}()
   327  }
   328  
   329  // MaxSkew is the threshold for a skew to be detected. Since MySQL timestamps are in seconds we account for
   330  // two round-offs: one for the actual event and another while accounting for the clock skew
   331  const MaxSkew = int64(2)
   332  
   333  // computeSkew sets the timestamp of the current event for the calling stream, accounts for a clock skew
   334  // and declares that a skew has arisen if the streams are too far apart
   335  func (vs *vstream) computeSkew(streamID string, event *binlogdatapb.VEvent) bool {
   336  	vs.skewMu.Lock()
   337  	defer vs.skewMu.Unlock()
   338  	// account for skew between this vtgate and the source mysql server
   339  	secondsInThePast := event.CurrentTime/1e9 - event.Timestamp
   340  	vs.timestamps[streamID] = time.Now().Unix() - secondsInThePast
   341  
   342  	var minTs, maxTs int64
   343  	var laggardStream string
   344  
   345  	if len(vs.timestamps) <= 1 {
   346  		return false
   347  	}
   348  	for k, ts := range vs.timestamps {
   349  		if ts < minTs || minTs == 0 {
   350  			minTs = ts
   351  			laggardStream = k
   352  		}
   353  		if ts > maxTs {
   354  			maxTs = ts
   355  		}
   356  	}
   357  	if vs.laggard != "" { // we are skewed, check if this event has fixed the skew
   358  		if (maxTs - minTs) <= MaxSkew {
   359  			vs.laggard = ""
   360  			close(vs.skewCh)
   361  		}
   362  	} else {
   363  		if (maxTs - minTs) > MaxSkew { // check if we are skewed due to this event
   364  			log.Infof("Skew found, laggard is %s, %+v", laggardStream, vs.timestamps)
   365  			vs.laggard = laggardStream
   366  			vs.skewCh = make(chan bool)
   367  		}
   368  	}
   369  	return vs.mustPause(streamID)
   370  }
   371  
   372  // mustPause returns true if a skew exists and the stream calling this is not the slowest one
   373  func (vs *vstream) mustPause(streamID string) bool {
   374  	switch vs.laggard {
   375  	case "":
   376  		return false
   377  	case streamID:
   378  		// current stream is the laggard, not pausing
   379  		return false
   380  	}
   381  
   382  	if (vs.timestamps[streamID] - vs.lowestTS) <= MaxSkew {
   383  		// current stream is not the laggard, but the skew is still within the limit
   384  		return false
   385  	}
   386  	vs.vsm.RecordStreamDelay()
   387  	return true
   388  }
   389  
   390  // alignStreams is called by each individual shard's stream before an event is sent to the client or after each heartbeat.
   391  // It checks for skew (if the minimizeSkew option is set). If skew is present this stream is delayed until the skew is fixed
   392  // The faster stream detects the skew and waits. The slower stream resets the skew when it catches up.
   393  func (vs *vstream) alignStreams(ctx context.Context, event *binlogdatapb.VEvent, keyspace, shard string) error {
   394  	if !vs.minimizeSkew || event.Timestamp == 0 {
   395  		return nil
   396  	}
   397  	streamID := fmt.Sprintf("%s/%s", keyspace, shard)
   398  	for {
   399  		mustPause := vs.computeSkew(streamID, event)
   400  		if event.Type == binlogdatapb.VEventType_HEARTBEAT {
   401  			return nil
   402  		}
   403  		if !mustPause {
   404  			return nil
   405  		}
   406  		select {
   407  		case <-ctx.Done():
   408  			return ctx.Err()
   409  		case <-time.After(time.Duration(vs.skewTimeoutSeconds) * time.Second):
   410  			log.Errorf("timed out while waiting for skew to reduce: %s", streamID)
   411  			return fmt.Errorf("timed out while waiting for skew to reduce: %s", streamID)
   412  		case <-vs.skewCh:
   413  			// once skew is fixed the channel is closed and all waiting streams "wake up"
   414  		}
   415  	}
   416  }
   417  
   418  func (vs *vstream) getCells() []string {
   419  	var cells []string
   420  	if vs.optCells != "" {
   421  		for _, cell := range strings.Split(strings.TrimSpace(vs.optCells), ",") {
   422  			cells = append(cells, strings.TrimSpace(cell))
   423  		}
   424  	}
   425  
   426  	if len(cells) == 0 {
   427  		// use the vtgate's cell by default
   428  		cells = append(cells, vs.vsm.cell)
   429  	}
   430  	return cells
   431  }
   432  
   433  // streamFromTablet streams from one shard. If transactions come in separate chunks, they are grouped and sent.
   434  func (vs *vstream) streamFromTablet(ctx context.Context, sgtid *binlogdatapb.ShardGtid) error {
   435  	// journalDone is assigned a channel when a journal event is encountered.
   436  	// It will be closed when all journal events converge.
   437  	var journalDone chan struct{}
   438  
   439  	errCount := 0
   440  	for {
   441  		select {
   442  		case <-ctx.Done():
   443  			return ctx.Err()
   444  		case <-journalDone:
   445  			// Unreachable.
   446  			// This can happen if a server misbehaves and does not end
   447  			// the stream after we return an error.
   448  			return nil
   449  		default:
   450  		}
   451  
   452  		var eventss [][]*binlogdatapb.VEvent
   453  		var err error
   454  		cells := vs.getCells()
   455  		tp, err := discovery.NewTabletPicker(vs.ts, cells, sgtid.Keyspace, sgtid.Shard, vs.tabletType.String())
   456  		if err != nil {
   457  			log.Errorf(err.Error())
   458  			return err
   459  		}
   460  		tablet, err := tp.PickForStreaming(ctx)
   461  		if err != nil {
   462  			log.Errorf(err.Error())
   463  			return err
   464  		}
   465  		log.Infof("Picked tablet %s for for %s/%s/%s/%s", tablet.Alias.String(), strings.Join(cells, ","),
   466  			sgtid.Keyspace, sgtid.Shard, vs.tabletType.String())
   467  		target := &querypb.Target{
   468  			Keyspace:   sgtid.Keyspace,
   469  			Shard:      sgtid.Shard,
   470  			TabletType: vs.tabletType,
   471  			Cell:       vs.vsm.cell,
   472  		}
   473  		tabletConn, err := vs.vsm.resolver.GetGateway().QueryServiceByAlias(tablet.Alias, target)
   474  		if err != nil {
   475  			log.Errorf(err.Error())
   476  			return err
   477  		}
   478  
   479  		errCh := make(chan error, 1)
   480  		go func() {
   481  			_ = tabletConn.StreamHealth(ctx, func(shr *querypb.StreamHealthResponse) error {
   482  				var err error
   483  				if ctx.Err() != nil {
   484  					err = fmt.Errorf("context has ended")
   485  				} else if shr == nil || shr.RealtimeStats == nil || shr.Target == nil {
   486  					err = fmt.Errorf("health check failed")
   487  				} else if vs.tabletType != shr.Target.TabletType {
   488  					err = fmt.Errorf("tablet type has changed from %s to %s, restarting vstream",
   489  						vs.tabletType, shr.Target.TabletType)
   490  				} else if shr.RealtimeStats.HealthError != "" {
   491  					err = fmt.Errorf("tablet %s is no longer healthy: %s, restarting vstream",
   492  						tablet.Alias, shr.RealtimeStats.HealthError)
   493  				}
   494  				if err != nil {
   495  					errCh <- err
   496  					return err
   497  				}
   498  				return nil
   499  			})
   500  		}()
   501  
   502  		log.Infof("Starting to vstream from %s", tablet.Alias.String())
   503  		// Safe to access sgtid.Gtid here (because it can't change until streaming begins).
   504  		req := &binlogdatapb.VStreamRequest{
   505  			Target:       target,
   506  			Position:     sgtid.Gtid,
   507  			Filter:       vs.filter,
   508  			TableLastPKs: sgtid.TablePKs,
   509  		}
   510  		err = tabletConn.VStream(ctx, req, func(events []*binlogdatapb.VEvent) error {
   511  			// We received a valid event. Reset error count.
   512  			errCount = 0
   513  
   514  			select {
   515  			case <-ctx.Done():
   516  				return ctx.Err()
   517  			case streamErr := <-errCh:
   518  				log.Warningf("Tablet state changed: %s, attempting to restart", streamErr)
   519  				return vterrors.New(vtrpcpb.Code_UNAVAILABLE, streamErr.Error())
   520  			case <-journalDone:
   521  				// Unreachable.
   522  				// This can happen if a server misbehaves and does not end
   523  				// the stream after we return an error.
   524  				return io.EOF
   525  			default:
   526  			}
   527  
   528  			sendevents := make([]*binlogdatapb.VEvent, 0, len(events))
   529  			for _, event := range events {
   530  				switch event.Type {
   531  				case binlogdatapb.VEventType_FIELD:
   532  					// Update table names and send.
   533  					// If we're streaming from multiple keyspaces, this will disambiguate
   534  					// duplicate table names.
   535  					ev := proto.Clone(event).(*binlogdatapb.VEvent)
   536  					ev.FieldEvent.TableName = sgtid.Keyspace + "." + ev.FieldEvent.TableName
   537  					sendevents = append(sendevents, ev)
   538  				case binlogdatapb.VEventType_ROW:
   539  					// Update table names and send.
   540  					ev := proto.Clone(event).(*binlogdatapb.VEvent)
   541  					ev.RowEvent.TableName = sgtid.Keyspace + "." + ev.RowEvent.TableName
   542  					sendevents = append(sendevents, ev)
   543  				case binlogdatapb.VEventType_COMMIT, binlogdatapb.VEventType_DDL, binlogdatapb.VEventType_OTHER:
   544  					sendevents = append(sendevents, event)
   545  					eventss = append(eventss, sendevents)
   546  
   547  					if err := vs.alignStreams(ctx, event, sgtid.Keyspace, sgtid.Shard); err != nil {
   548  						return err
   549  					}
   550  
   551  					if err := vs.sendAll(ctx, sgtid, eventss); err != nil {
   552  						return err
   553  					}
   554  					eventss = nil
   555  					sendevents = nil
   556  				case binlogdatapb.VEventType_COPY_COMPLETED:
   557  					sendevents = append(sendevents, event)
   558  					if fullyCopied, doneEvent := vs.isCopyFullyCompleted(ctx, sgtid, event); fullyCopied {
   559  						sendevents = append(sendevents, doneEvent)
   560  					}
   561  					eventss = append(eventss, sendevents)
   562  
   563  					if err := vs.alignStreams(ctx, event, sgtid.Keyspace, sgtid.Shard); err != nil {
   564  						return err
   565  					}
   566  
   567  					if err := vs.sendAll(ctx, sgtid, eventss); err != nil {
   568  						return err
   569  					}
   570  					eventss = nil
   571  					sendevents = nil
   572  				case binlogdatapb.VEventType_HEARTBEAT:
   573  					// Remove all heartbeat events for now.
   574  					// Otherwise they can accumulate indefinitely if there are no real events.
   575  					// TODO(sougou): figure out a model for this.
   576  					if err := vs.alignStreams(ctx, event, sgtid.Keyspace, sgtid.Shard); err != nil {
   577  						return err
   578  					}
   579  
   580  				case binlogdatapb.VEventType_JOURNAL:
   581  					journal := event.Journal
   582  					// Journal events are not sent to clients by default, but only when StopOnReshard is set
   583  					if vs.stopOnReshard && journal.MigrationType == binlogdatapb.MigrationType_SHARDS {
   584  						sendevents = append(sendevents, event)
   585  						eventss = append(eventss, sendevents)
   586  						if err := vs.sendAll(ctx, sgtid, eventss); err != nil {
   587  							return err
   588  						}
   589  						eventss = nil
   590  						sendevents = nil
   591  					}
   592  					je, err := vs.getJournalEvent(ctx, sgtid, journal)
   593  					if err != nil {
   594  						return err
   595  					}
   596  					if je != nil {
   597  						// Wait till all other participants converge and return EOF.
   598  						journalDone = je.done
   599  						select {
   600  						case <-ctx.Done():
   601  							return ctx.Err()
   602  						case <-journalDone:
   603  							return io.EOF
   604  						}
   605  					}
   606  				default:
   607  					sendevents = append(sendevents, event)
   608  				}
   609  			}
   610  			if len(sendevents) != 0 {
   611  				eventss = append(eventss, sendevents)
   612  			}
   613  			return nil
   614  		})
   615  		// If stream was ended (by a journal event), return nil without checking for error.
   616  		select {
   617  		case <-journalDone:
   618  			return nil
   619  		default:
   620  		}
   621  		if err == nil {
   622  			// Unreachable.
   623  			err = vterrors.Errorf(vtrpcpb.Code_UNKNOWN, "vstream ended unexpectedly")
   624  		}
   625  		if vterrors.Code(err) != vtrpcpb.Code_FAILED_PRECONDITION && vterrors.Code(err) != vtrpcpb.Code_UNAVAILABLE {
   626  			log.Errorf("vstream for %s/%s error: %v", sgtid.Keyspace, sgtid.Shard, err)
   627  			return err
   628  		}
   629  		errCount++
   630  		if errCount >= 3 {
   631  			log.Errorf("vstream for %s/%s had three consecutive failures: %v", sgtid.Keyspace, sgtid.Shard, err)
   632  			return err
   633  		}
   634  		log.Infof("vstream for %s/%s error, retrying: %v", sgtid.Keyspace, sgtid.Shard, err)
   635  	}
   636  }
   637  
   638  // sendAll sends a group of events together while holding the lock.
   639  func (vs *vstream) sendAll(ctx context.Context, sgtid *binlogdatapb.ShardGtid, eventss [][]*binlogdatapb.VEvent) error {
   640  	vs.mu.Lock()
   641  	defer vs.mu.Unlock()
   642  
   643  	// Send all chunks while holding the lock.
   644  	for _, events := range eventss {
   645  		if err := vs.getError(); err != nil {
   646  			return err
   647  		}
   648  		// convert all gtids to vgtids. This should be done here while holding the lock.
   649  		for j, event := range events {
   650  			if event.Type == binlogdatapb.VEventType_GTID {
   651  				// Update the VGtid and send that instead.
   652  				sgtid.Gtid = event.Gtid
   653  				events[j] = &binlogdatapb.VEvent{
   654  					Type:     binlogdatapb.VEventType_VGTID,
   655  					Vgtid:    proto.Clone(vs.vgtid).(*binlogdatapb.VGtid),
   656  					Keyspace: event.Keyspace,
   657  					Shard:    event.Shard,
   658  				}
   659  			} else if event.Type == binlogdatapb.VEventType_LASTPK {
   660  				var foundIndex = -1
   661  				eventTablePK := event.LastPKEvent.TableLastPK
   662  				for idx, pk := range sgtid.TablePKs {
   663  					if pk.TableName == eventTablePK.TableName {
   664  						foundIndex = idx
   665  						break
   666  					}
   667  				}
   668  				if foundIndex == -1 {
   669  					if !event.LastPKEvent.Completed {
   670  						sgtid.TablePKs = append(sgtid.TablePKs, eventTablePK)
   671  					}
   672  				} else {
   673  					if event.LastPKEvent.Completed {
   674  						// remove tablepk from sgtid
   675  						sgtid.TablePKs[foundIndex] = sgtid.TablePKs[len(sgtid.TablePKs)-1]
   676  						sgtid.TablePKs[len(sgtid.TablePKs)-1] = nil
   677  						sgtid.TablePKs = sgtid.TablePKs[:len(sgtid.TablePKs)-1]
   678  					} else {
   679  						sgtid.TablePKs[foundIndex] = eventTablePK
   680  					}
   681  				}
   682  				events[j] = &binlogdatapb.VEvent{
   683  					Type:     binlogdatapb.VEventType_VGTID,
   684  					Vgtid:    proto.Clone(vs.vgtid).(*binlogdatapb.VGtid),
   685  					Keyspace: event.Keyspace,
   686  					Shard:    event.Shard,
   687  				}
   688  			}
   689  		}
   690  		select {
   691  		case <-ctx.Done():
   692  			return nil
   693  		case vs.eventCh <- events:
   694  		}
   695  	}
   696  	return nil
   697  }
   698  
   699  // isCopyFullyCompleted returns true if all stream has received a copy_completed event.
   700  // If true, it will also return a new copy_completed event that needs to be sent.
   701  // This new event represents the completion of all the copy operations.
   702  func (vs *vstream) isCopyFullyCompleted(ctx context.Context, sgtid *binlogdatapb.ShardGtid, event *binlogdatapb.VEvent) (bool, *binlogdatapb.VEvent) {
   703  	vs.mu.Lock()
   704  	defer vs.mu.Unlock()
   705  
   706  	vs.copyCompletedShard[fmt.Sprintf("%s/%s", event.Keyspace, event.Shard)] = struct{}{}
   707  
   708  	for _, shard := range vs.vgtid.ShardGtids {
   709  		if _, ok := vs.copyCompletedShard[fmt.Sprintf("%s/%s", shard.Keyspace, shard.Shard)]; !ok {
   710  			return false, nil
   711  		}
   712  	}
   713  	return true, &binlogdatapb.VEvent{
   714  		Type: binlogdatapb.VEventType_COPY_COMPLETED,
   715  	}
   716  }
   717  
   718  func (vs *vstream) getError() error {
   719  	vs.errMu.Lock()
   720  	defer vs.errMu.Unlock()
   721  	return vs.err
   722  }
   723  
   724  func (vs *vstream) setError(err error) {
   725  	vs.errMu.Lock()
   726  	defer vs.errMu.Unlock()
   727  	vs.err = err
   728  }
   729  
   730  // getJournalEvent returns a journalEvent. The caller has to wait on its done channel.
   731  // Once it closes, the caller has to return (end their stream).
   732  // The function has three parts:
   733  // Part 1: For the first stream that encounters an event, it creates a journal event.
   734  // Part 2: Every stream joins the journalEvent. If all have not joined, the journalEvent
   735  // is returned to the caller.
   736  // Part 3: If all streams have joined, then new streams are created to replace existing
   737  // streams, the done channel is closed and returned. This section is executed exactly
   738  // once after the last stream joins.
   739  func (vs *vstream) getJournalEvent(ctx context.Context, sgtid *binlogdatapb.ShardGtid, journal *binlogdatapb.Journal) (*journalEvent, error) {
   740  	if journal.MigrationType == binlogdatapb.MigrationType_TABLES {
   741  		// We cannot support table migrations yet because there is no
   742  		// good model for it yet. For example, what if a table is migrated
   743  		// out of the current keyspace we're streaming from.
   744  		return nil, nil
   745  	}
   746  
   747  	vs.mu.Lock()
   748  	defer vs.mu.Unlock()
   749  
   750  	je, ok := vs.journaler[journal.Id]
   751  	if !ok {
   752  		log.Infof("Journal event received: %v", journal)
   753  		// Identify the list of ShardGtids that match the participants of the journal.
   754  		je = &journalEvent{
   755  			journal:      journal,
   756  			participants: make(map[*binlogdatapb.ShardGtid]bool),
   757  			done:         make(chan struct{}),
   758  		}
   759  		const (
   760  			undecided = iota
   761  			matchAll
   762  			matchNone
   763  		)
   764  		// We start off as undecided. Once we transition to
   765  		// matchAll or matchNone, we have to stay in that state.
   766  		mode := undecided
   767  	nextParticipant:
   768  		for _, jks := range journal.Participants {
   769  			for _, inner := range vs.vgtid.ShardGtids {
   770  				if inner.Keyspace == jks.Keyspace && inner.Shard == jks.Shard {
   771  					switch mode {
   772  					case undecided, matchAll:
   773  						mode = matchAll
   774  						je.participants[inner] = false
   775  					case matchNone:
   776  						return nil, fmt.Errorf("not all journaling participants are in the stream: journal: %v, stream: %v", journal.Participants, vs.vgtid.ShardGtids)
   777  					}
   778  					continue nextParticipant
   779  				}
   780  			}
   781  			switch mode {
   782  			case undecided, matchNone:
   783  				mode = matchNone
   784  			case matchAll:
   785  				return nil, fmt.Errorf("not all journaling participants are in the stream: journal: %v, stream: %v", journal.Participants, vs.vgtid.ShardGtids)
   786  			}
   787  		}
   788  		if mode == matchNone {
   789  			// Unreachable. Journal events are only added to participants.
   790  			// But if we do receive such an event, the right action will be to ignore it.
   791  			return nil, nil
   792  		}
   793  		vs.journaler[journal.Id] = je
   794  	}
   795  
   796  	if _, ok := je.participants[sgtid]; !ok {
   797  		// Unreachable. See above.
   798  		return nil, nil
   799  	}
   800  	je.participants[sgtid] = true
   801  
   802  	for _, waiting := range je.participants {
   803  		if !waiting {
   804  			// Some participants are yet to join the wait.
   805  			return je, nil
   806  		}
   807  	}
   808  
   809  	if !vs.stopOnReshard { // stop streaming from current shards and start streaming the new shards
   810  		// All participants are waiting. Replace old shard gtids with new ones.
   811  		newsgtids := make([]*binlogdatapb.ShardGtid, 0, len(vs.vgtid.ShardGtids)-len(je.participants)+len(je.journal.ShardGtids))
   812  		log.Infof("Removing shard gtids: %v", je.participants)
   813  		for _, cursgtid := range vs.vgtid.ShardGtids {
   814  			if je.participants[cursgtid] {
   815  				continue
   816  			}
   817  			newsgtids = append(newsgtids, cursgtid)
   818  		}
   819  
   820  		log.Infof("Adding shard gtids: %v", je.journal.ShardGtids)
   821  		for _, sgtid := range je.journal.ShardGtids {
   822  			newsgtids = append(newsgtids, sgtid)
   823  			// It's ok to start the streams even though ShardGtids are not updated yet.
   824  			// This is because we're still holding the lock.
   825  			vs.startOneStream(ctx, sgtid)
   826  		}
   827  		vs.vgtid.ShardGtids = newsgtids
   828  	}
   829  	close(je.done)
   830  	return je, nil
   831  }