vitess.io/vitess@v0.16.2/go/vt/vttablet/tabletserver/messager/message_manager_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package messager
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  	"io"
    24  	"reflect"
    25  	"runtime"
    26  	"sync"
    27  	"testing"
    28  	"time"
    29  
    30  	"vitess.io/vitess/go/vt/vtgate/evalengine"
    31  
    32  	"vitess.io/vitess/go/test/utils"
    33  
    34  	"github.com/stretchr/testify/assert"
    35  
    36  	"vitess.io/vitess/go/sqltypes"
    37  	"vitess.io/vitess/go/sync2"
    38  	"vitess.io/vitess/go/vt/sqlparser"
    39  	"vitess.io/vitess/go/vt/vttablet/tabletserver/schema"
    40  	"vitess.io/vitess/go/vt/vttablet/tabletserver/tabletenv"
    41  
    42  	binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata"
    43  	querypb "vitess.io/vitess/go/vt/proto/query"
    44  )
    45  
    46  var (
    47  	testFields = []*querypb.Field{{
    48  		Name: "id",
    49  		Type: sqltypes.VarBinary,
    50  	}, {
    51  		Name: "message",
    52  		Type: sqltypes.VarBinary,
    53  	}}
    54  
    55  	testDBFields = []*querypb.Field{
    56  		{Type: sqltypes.Int64},
    57  		{Type: sqltypes.Int64},
    58  		{Type: sqltypes.Int64},
    59  		{Type: sqltypes.Int64},
    60  		{Type: sqltypes.Int64},
    61  		{Type: sqltypes.VarBinary},
    62  	}
    63  )
    64  
    65  func newMMTable() *schema.Table {
    66  	return &schema.Table{
    67  		Name: sqlparser.NewIdentifierCS("foo"),
    68  		Type: schema.Message,
    69  		MessageInfo: &schema.MessageInfo{
    70  			Fields:             testFields,
    71  			AckWaitDuration:    1 * time.Second,
    72  			PurgeAfterDuration: 3 * time.Second,
    73  			MinBackoff:         1 * time.Second,
    74  			BatchSize:          1,
    75  			CacheSize:          10,
    76  			PollInterval:       1 * time.Second,
    77  		},
    78  	}
    79  }
    80  
    81  func newMMTableWithBackoff() *schema.Table {
    82  	return &schema.Table{
    83  		Name: sqlparser.NewIdentifierCS("foo"),
    84  		Type: schema.Message,
    85  		MessageInfo: &schema.MessageInfo{
    86  			Fields:             testFields,
    87  			AckWaitDuration:    10 * time.Second,
    88  			PurgeAfterDuration: 3 * time.Second,
    89  			MinBackoff:         1 * time.Second,
    90  			MaxBackoff:         4 * time.Second,
    91  			BatchSize:          1,
    92  			CacheSize:          10,
    93  			PollInterval:       1 * time.Second,
    94  		},
    95  	}
    96  }
    97  
    98  func newMMRow(id int64) *querypb.Row {
    99  	return sqltypes.RowToProto3([]sqltypes.Value{
   100  		sqltypes.NewInt64(1),
   101  		sqltypes.NewInt64(1),
   102  		sqltypes.NewInt64(0),
   103  		sqltypes.NULL,
   104  		sqltypes.NewInt64(id),
   105  		sqltypes.NewVarBinary(fmt.Sprintf("%v", id)),
   106  	})
   107  }
   108  
   109  type testReceiver struct {
   110  	rcv   func(*sqltypes.Result) error
   111  	count sync2.AtomicInt64
   112  	ch    chan *sqltypes.Result
   113  }
   114  
   115  func newTestReceiver(size int) *testReceiver {
   116  	tr := &testReceiver{
   117  		ch: make(chan *sqltypes.Result, size),
   118  	}
   119  	tr.rcv = func(qr *sqltypes.Result) error {
   120  		tr.count.Add(1)
   121  		tr.ch <- qr
   122  		return nil
   123  	}
   124  	return tr
   125  }
   126  
   127  func (tr *testReceiver) WaitForCount(n int) {
   128  	for {
   129  		runtime.Gosched()
   130  		time.Sleep(10 * time.Millisecond)
   131  		if tr.count.Get() == int64(n) {
   132  			return
   133  		}
   134  	}
   135  }
   136  
   137  func TestReceiverCancel(t *testing.T) {
   138  	mm := newMessageManager(newFakeTabletServer(), newFakeVStreamer(), newMMTable(), sync2.NewSemaphore(1, 0))
   139  	mm.Open()
   140  	defer mm.Close()
   141  
   142  	r1 := newTestReceiver(0)
   143  	ctx, cancel := context.WithCancel(context.Background())
   144  	go cancel()
   145  	_ = mm.Subscribe(ctx, r1.rcv)
   146  
   147  	// r1 should eventually be unsubscribed.
   148  	for i := 0; i < 10; i++ {
   149  		runtime.Gosched()
   150  		time.Sleep(10 * time.Millisecond)
   151  		if len(mm.receivers) != 0 {
   152  			continue
   153  		}
   154  		return
   155  	}
   156  	t.Errorf("receivers were not cleared: %d", len(mm.receivers))
   157  }
   158  
   159  func TestMessageManagerState(t *testing.T) {
   160  	mm := newMessageManager(newFakeTabletServer(), newFakeVStreamer(), newMMTable(), sync2.NewSemaphore(1, 0))
   161  	// Do it twice
   162  	for i := 0; i < 2; i++ {
   163  		mm.Open()
   164  		// Idempotence.
   165  		mm.Open()
   166  		// The yield is for making sure runSend starts up
   167  		// and waits on cond.
   168  		runtime.Gosched()
   169  		mm.Close()
   170  		// Idempotence.
   171  		mm.Close()
   172  	}
   173  }
   174  
   175  func TestMessageManagerAdd(t *testing.T) {
   176  	ti := newMMTable()
   177  	ti.MessageInfo.CacheSize = 1
   178  	mm := newMessageManager(newFakeTabletServer(), newFakeVStreamer(), ti, sync2.NewSemaphore(1, 0))
   179  	mm.Open()
   180  	defer mm.Close()
   181  
   182  	row1 := &MessageRow{
   183  		Row: []sqltypes.Value{sqltypes.NewVarBinary("1")},
   184  	}
   185  	if mm.Add(row1) {
   186  		t.Error("Add(no receivers): true, want false")
   187  	}
   188  
   189  	r1 := newTestReceiver(0)
   190  	go func() { <-r1.ch }()
   191  	mm.Subscribe(context.Background(), r1.rcv)
   192  
   193  	if !mm.Add(row1) {
   194  		t.Error("Add(1 receiver): false, want true")
   195  	}
   196  	// Make sure message is enqueued.
   197  	r1.WaitForCount(2)
   198  	// This will fill up the cache.
   199  	mm.Add(&MessageRow{Row: []sqltypes.Value{sqltypes.NewVarBinary("2")}})
   200  
   201  	// The third add has to fail.
   202  	if mm.Add(&MessageRow{Row: []sqltypes.Value{sqltypes.NewVarBinary("3")}}) {
   203  		t.Error("Add(cache full): true, want false")
   204  	}
   205  }
   206  
   207  func TestMessageManagerSend(t *testing.T) {
   208  	tsv := newFakeTabletServer()
   209  	mm := newMessageManager(tsv, newFakeVStreamer(), newMMTable(), sync2.NewSemaphore(1, 0))
   210  	mm.Open()
   211  	defer mm.Close()
   212  
   213  	r1 := newTestReceiver(1)
   214  	mm.Subscribe(context.Background(), r1.rcv)
   215  
   216  	want := &sqltypes.Result{
   217  		Fields: testFields,
   218  	}
   219  	if got := <-r1.ch; !got.Equal(want) {
   220  		t.Errorf("Received: %v, want %v", got, want)
   221  	}
   222  	// Set the channel to verify call to Postpone.
   223  	// Make it buffered so the thread doesn't block on repeated calls.
   224  	ch := make(chan string, 20)
   225  	tsv.SetChannel(ch)
   226  	mm.Add(&MessageRow{Row: []sqltypes.Value{sqltypes.NewVarBinary("1"), sqltypes.NULL}})
   227  	want = &sqltypes.Result{
   228  		Rows: [][]sqltypes.Value{{
   229  			sqltypes.NewVarBinary("1"),
   230  			sqltypes.NULL,
   231  		}},
   232  	}
   233  	if got := <-r1.ch; !got.Equal(want) {
   234  		t.Errorf("Received: %v, want %v", got, want)
   235  	}
   236  
   237  	// Ensure Postpone got called.
   238  	if got, want := <-ch, "postpone"; got != want {
   239  		t.Errorf("Postpone: %s, want %v", got, want)
   240  	}
   241  
   242  	// Verify item has been removed from cache.
   243  	// Need to obtain lock to prevent data race.
   244  	// It may take some time for this to happen.
   245  	inQueue := true
   246  	inFlight := true
   247  	for i := 0; i < 10; i++ {
   248  		mm.cache.mu.Lock()
   249  		if _, ok := mm.cache.inQueue["1"]; !ok {
   250  			inQueue = false
   251  		}
   252  		if _, ok := mm.cache.inFlight["1"]; !ok {
   253  			inFlight = false
   254  		}
   255  		mm.cache.mu.Unlock()
   256  		if inQueue || inFlight {
   257  			runtime.Gosched()
   258  			time.Sleep(10 * time.Millisecond)
   259  			continue
   260  		}
   261  		break
   262  	}
   263  	assert.False(t, inQueue)
   264  	assert.False(t, inFlight)
   265  
   266  	// Test that mm stops sending to a canceled receiver.
   267  	r2 := newTestReceiver(1)
   268  	ctx, cancel := context.WithCancel(context.Background())
   269  	mm.Subscribe(ctx, r2.rcv)
   270  	<-r2.ch
   271  
   272  	mm.Add(&MessageRow{Row: []sqltypes.Value{sqltypes.NewVarBinary("2")}})
   273  	mm.Add(&MessageRow{Row: []sqltypes.Value{sqltypes.NewVarBinary("3")}})
   274  	// Send should be round-robin.
   275  	<-r1.ch
   276  	<-r2.ch
   277  
   278  	// Cancel and wait for it to take effect.
   279  	cancel()
   280  	for i := 0; i < 10; i++ {
   281  		runtime.Gosched()
   282  		time.Sleep(10 * time.Millisecond)
   283  		mm.mu.Lock()
   284  		if len(mm.receivers) != 1 {
   285  			mm.mu.Unlock()
   286  			continue
   287  		}
   288  		mm.mu.Unlock()
   289  		break
   290  	}
   291  
   292  	mm.Add(&MessageRow{Row: []sqltypes.Value{sqltypes.NewVarBinary("4")}})
   293  	mm.Add(&MessageRow{Row: []sqltypes.Value{sqltypes.NewVarBinary("5")}})
   294  	mm.Add(&MessageRow{Row: []sqltypes.Value{sqltypes.NewVarBinary("6")}})
   295  	// Only r1 should be receiving.
   296  	<-r1.ch
   297  	<-r1.ch
   298  	<-r1.ch
   299  }
   300  
   301  func TestMessageManagerPostponeThrottle(t *testing.T) {
   302  	tsv := newFakeTabletServer()
   303  	mm := newMessageManager(tsv, newFakeVStreamer(), newMMTable(), sync2.NewSemaphore(1, 0))
   304  	mm.Open()
   305  	defer mm.Close()
   306  
   307  	r1 := newTestReceiver(1)
   308  	mm.Subscribe(context.Background(), r1.rcv)
   309  	<-r1.ch
   310  
   311  	// Set the channel to verify call to Postpone.
   312  	ch := make(chan string)
   313  	tsv.SetChannel(ch)
   314  	tsv.postponeCount.Set(0)
   315  
   316  	mm.Add(&MessageRow{Row: []sqltypes.Value{sqltypes.NewVarBinary("1"), sqltypes.NULL}})
   317  	// Once we receive, mm will obtain the single semaphore and call postpone.
   318  	// Postpone will wait on the unbuffered ch.
   319  	<-r1.ch
   320  
   321  	// Set up a second subsriber, add a message.
   322  	r2 := newTestReceiver(1)
   323  	mm.Subscribe(context.Background(), r2.rcv)
   324  	<-r2.ch
   325  
   326  	// Wait.
   327  	for i := 0; i < 2; i++ {
   328  		runtime.Gosched()
   329  		time.Sleep(10 * time.Millisecond)
   330  	}
   331  	// postponeCount should be 1. Verify for two iterations.
   332  	if got, want := tsv.postponeCount.Get(), int64(1); got != want {
   333  		t.Errorf("tsv.postponeCount: %d, want %d", got, want)
   334  	}
   335  
   336  	// Receive on this channel will allow the next postpone to go through.
   337  	<-ch
   338  	// Wait.
   339  	for i := 0; i < 2; i++ {
   340  		runtime.Gosched()
   341  		time.Sleep(10 * time.Millisecond)
   342  	}
   343  	if got, want := tsv.postponeCount.Get(), int64(1); got != want {
   344  		t.Errorf("tsv.postponeCount: %d, want %d", got, want)
   345  	}
   346  	<-ch
   347  }
   348  
   349  func TestMessageManagerSendError(t *testing.T) {
   350  	tsv := newFakeTabletServer()
   351  	mm := newMessageManager(tsv, newFakeVStreamer(), newMMTable(), sync2.NewSemaphore(1, 0))
   352  	mm.Open()
   353  	defer mm.Close()
   354  	ctx := context.Background()
   355  
   356  	ch := make(chan *sqltypes.Result)
   357  	go func() { <-ch }()
   358  	fieldSent := false
   359  	mm.Subscribe(ctx, func(qr *sqltypes.Result) error {
   360  		ch <- qr
   361  		if !fieldSent {
   362  			fieldSent = true
   363  			return nil
   364  		}
   365  		return errors.New("intentional error")
   366  	})
   367  
   368  	postponech := make(chan string, 20)
   369  	tsv.SetChannel(postponech)
   370  	mm.Add(&MessageRow{Row: []sqltypes.Value{sqltypes.NewVarBinary("1"), sqltypes.NULL}})
   371  	<-ch
   372  
   373  	// Ensure Postpone got called.
   374  	if got, want := <-postponech, "postpone"; got != want {
   375  		t.Errorf("Postpone: %s, want %v", got, want)
   376  	}
   377  }
   378  
   379  func TestMessageManagerFieldSendError(t *testing.T) {
   380  	mm := newMessageManager(newFakeTabletServer(), newFakeVStreamer(), newMMTable(), sync2.NewSemaphore(1, 0))
   381  	mm.Open()
   382  	defer mm.Close()
   383  	ctx := context.Background()
   384  
   385  	ch := make(chan *sqltypes.Result)
   386  	go func() { <-ch }()
   387  	done := mm.Subscribe(ctx, func(qr *sqltypes.Result) error {
   388  		ch <- qr
   389  		return errors.New("non-eof")
   390  	})
   391  
   392  	// This should not hang because a field send error must terminate
   393  	// subscription.
   394  	<-done
   395  }
   396  
   397  func TestMessageManagerBatchSend(t *testing.T) {
   398  	ti := newMMTable()
   399  	ti.MessageInfo.BatchSize = 2
   400  	mm := newMessageManager(newFakeTabletServer(), newFakeVStreamer(), ti, sync2.NewSemaphore(1, 0))
   401  	mm.Open()
   402  	defer mm.Close()
   403  
   404  	r1 := newTestReceiver(1)
   405  	mm.Subscribe(context.Background(), r1.rcv)
   406  	<-r1.ch
   407  
   408  	row1 := &MessageRow{
   409  		Row: []sqltypes.Value{sqltypes.NewVarBinary("1"), sqltypes.NULL},
   410  	}
   411  	mm.Add(row1)
   412  	want := &sqltypes.Result{
   413  		Rows: [][]sqltypes.Value{{
   414  			sqltypes.NewVarBinary("1"),
   415  			sqltypes.NULL,
   416  		}},
   417  	}
   418  	if got := <-r1.ch; !got.Equal(want) {
   419  		t.Errorf("Received: %v, want %v", got, row1)
   420  	}
   421  	mm.mu.Lock()
   422  	mm.cache.Add(&MessageRow{Row: []sqltypes.Value{sqltypes.NewVarBinary("2"), sqltypes.NULL}})
   423  	mm.cache.Add(&MessageRow{Row: []sqltypes.Value{sqltypes.NewVarBinary("3"), sqltypes.NULL}})
   424  	mm.cond.Broadcast()
   425  	mm.mu.Unlock()
   426  	want = &sqltypes.Result{
   427  		Rows: [][]sqltypes.Value{{
   428  			sqltypes.NewVarBinary("2"),
   429  			sqltypes.NULL,
   430  		}, {
   431  			sqltypes.NewVarBinary("3"),
   432  			sqltypes.NULL,
   433  		}},
   434  	}
   435  	if got := <-r1.ch; !got.Equal(want) {
   436  		t.Errorf("Received: %+v, want %+v", got, row1)
   437  	}
   438  }
   439  
   440  func TestMessageManagerStreamerSimple(t *testing.T) {
   441  	fvs := newFakeVStreamer()
   442  	fvs.setStreamerResponse([][]*binlogdatapb.VEvent{{{
   443  		// Event set 1.
   444  		Type: binlogdatapb.VEventType_GTID,
   445  		Gtid: "MySQL56/33333333-3333-3333-3333-333333333333:1-100",
   446  	}, {
   447  		Type: binlogdatapb.VEventType_OTHER,
   448  	}}, {{
   449  		// Event set 2.
   450  		Type: binlogdatapb.VEventType_FIELD,
   451  		FieldEvent: &binlogdatapb.FieldEvent{
   452  			TableName: "foo",
   453  			Fields:    testDBFields,
   454  		},
   455  	}}, {{
   456  		// Event set 3.
   457  		Type: binlogdatapb.VEventType_ROW,
   458  		RowEvent: &binlogdatapb.RowEvent{
   459  			TableName: "foo",
   460  			RowChanges: []*binlogdatapb.RowChange{{
   461  				After: newMMRow(1),
   462  			}},
   463  		},
   464  	}, {
   465  		Type: binlogdatapb.VEventType_GTID,
   466  		Gtid: "MySQL56/33333333-3333-3333-3333-333333333333:1-101",
   467  	}, {
   468  		Type: binlogdatapb.VEventType_COMMIT,
   469  	}}})
   470  	mm := newMessageManager(newFakeTabletServer(), fvs, newMMTable(), sync2.NewSemaphore(1, 0))
   471  	mm.Open()
   472  	defer mm.Close()
   473  
   474  	r1 := newTestReceiver(1)
   475  	mm.Subscribe(context.Background(), r1.rcv)
   476  	<-r1.ch
   477  
   478  	want := &sqltypes.Result{
   479  		Rows: [][]sqltypes.Value{{
   480  			sqltypes.NewInt64(1),
   481  			sqltypes.NewVarBinary("1"),
   482  		}},
   483  	}
   484  	if got := <-r1.ch; !got.Equal(want) {
   485  		t.Errorf("Received: %v, want %v", got, want)
   486  	}
   487  }
   488  
   489  func TestMessageManagerStreamerAndPoller(t *testing.T) {
   490  	fvs := newFakeVStreamer()
   491  	fvs.setPollerResponse([]*binlogdatapb.VStreamResultsResponse{{
   492  		Fields: testDBFields,
   493  		Gtid:   "MySQL56/33333333-3333-3333-3333-333333333333:1-100",
   494  	}})
   495  	mm := newMessageManager(newFakeTabletServer(), fvs, newMMTable(), sync2.NewSemaphore(1, 0))
   496  	mm.Open()
   497  	defer mm.Close()
   498  
   499  	r1 := newTestReceiver(1)
   500  	mm.Subscribe(context.Background(), r1.rcv)
   501  	<-r1.ch
   502  
   503  	for {
   504  		runtime.Gosched()
   505  		time.Sleep(10 * time.Millisecond)
   506  		pos := mm.getLastPollPosition()
   507  		if pos != nil {
   508  			break
   509  		}
   510  	}
   511  
   512  	fvs.setStreamerResponse([][]*binlogdatapb.VEvent{{{
   513  		// Event set 1: field info.
   514  		Type: binlogdatapb.VEventType_FIELD,
   515  		FieldEvent: &binlogdatapb.FieldEvent{
   516  			TableName: "foo",
   517  			Fields:    testDBFields,
   518  		},
   519  	}}, {{
   520  		// Event set 2: GTID won't be known till the first GTID event.
   521  		// Row will not be added.
   522  		Type: binlogdatapb.VEventType_ROW,
   523  		RowEvent: &binlogdatapb.RowEvent{
   524  			TableName: "foo",
   525  			RowChanges: []*binlogdatapb.RowChange{{
   526  				After: newMMRow(1),
   527  			}},
   528  		},
   529  	}, {
   530  		Type: binlogdatapb.VEventType_GTID,
   531  		Gtid: "MySQL56/33333333-3333-3333-3333-333333333333:1-99",
   532  	}, {
   533  		Type: binlogdatapb.VEventType_COMMIT,
   534  	}}, {{
   535  		// Event set 3: GTID will be known, but <= last poll.
   536  		// Row will not be added.
   537  		Type: binlogdatapb.VEventType_ROW,
   538  		RowEvent: &binlogdatapb.RowEvent{
   539  			TableName: "foo",
   540  			RowChanges: []*binlogdatapb.RowChange{{
   541  				After: newMMRow(2),
   542  			}},
   543  		},
   544  	}, {
   545  		Type: binlogdatapb.VEventType_GTID,
   546  		Gtid: "MySQL56/33333333-3333-3333-3333-333333333333:1-100",
   547  	}, {
   548  		Type: binlogdatapb.VEventType_COMMIT,
   549  	}}, {{
   550  		// Event set 3: GTID will be > last poll.
   551  		// Row will be added.
   552  		Type: binlogdatapb.VEventType_ROW,
   553  		RowEvent: &binlogdatapb.RowEvent{
   554  			TableName: "foo",
   555  			RowChanges: []*binlogdatapb.RowChange{{
   556  				After: newMMRow(3),
   557  			}},
   558  		},
   559  	}, {
   560  		Type: binlogdatapb.VEventType_GTID,
   561  		Gtid: "MySQL56/33333333-3333-3333-3333-333333333333:1-101",
   562  	}, {
   563  		Type: binlogdatapb.VEventType_COMMIT,
   564  	}}})
   565  
   566  	want := &sqltypes.Result{
   567  		Rows: [][]sqltypes.Value{{
   568  			sqltypes.NewInt64(3),
   569  			sqltypes.NewVarBinary("3"),
   570  		}},
   571  	}
   572  	if got := <-r1.ch; !got.Equal(want) {
   573  		t.Errorf("Received: %v, want %v", got, want)
   574  	}
   575  }
   576  
   577  func TestMessageManagerPoller(t *testing.T) {
   578  	ti := newMMTable()
   579  	ti.MessageInfo.BatchSize = 2
   580  	ti.MessageInfo.PollInterval = 20 * time.Second
   581  	fvs := newFakeVStreamer()
   582  	fvs.setPollerResponse([]*binlogdatapb.VStreamResultsResponse{{
   583  		Fields: testDBFields,
   584  		Gtid:   "MySQL56/33333333-3333-3333-3333-333333333333:1-100",
   585  	}, {
   586  		Rows: []*querypb.Row{
   587  			newMMRow(1),
   588  			newMMRow(2),
   589  			newMMRow(3),
   590  		},
   591  	}})
   592  	mm := newMessageManager(newFakeTabletServer(), fvs, ti, sync2.NewSemaphore(1, 0))
   593  	mm.Open()
   594  	defer mm.Close()
   595  
   596  	ctx, cancel := context.WithCancel(context.Background())
   597  	r1 := newTestReceiver(1)
   598  	mm.Subscribe(ctx, r1.rcv)
   599  	<-r1.ch
   600  
   601  	want := [][]sqltypes.Value{{
   602  		sqltypes.NewInt64(1),
   603  		sqltypes.NewVarBinary("1"),
   604  	}, {
   605  		sqltypes.NewInt64(2),
   606  		sqltypes.NewVarBinary("2"),
   607  	}, {
   608  		sqltypes.NewInt64(3),
   609  		sqltypes.NewVarBinary("3"),
   610  	}}
   611  	var got [][]sqltypes.Value
   612  	// We should get it in 2 iterations.
   613  	for i := 0; i < 2; i++ {
   614  		qr := <-r1.ch
   615  		got = append(got, qr.Rows...)
   616  	}
   617  	for _, gotrow := range got {
   618  		found := false
   619  		for _, wantrow := range want {
   620  			if reflect.DeepEqual(gotrow, wantrow) {
   621  				found = true
   622  				break
   623  			}
   624  		}
   625  		if !found {
   626  			t.Errorf("row: %v not found in %v", gotrow, want)
   627  		}
   628  	}
   629  
   630  	// If there are no receivers, nothing should fire.
   631  	cancel()
   632  	runtime.Gosched()
   633  	select {
   634  	case row := <-r1.ch:
   635  		t.Errorf("Expecting no value, got: %v", row)
   636  	default:
   637  	}
   638  }
   639  
   640  // TestMessagesPending1 tests for the case where you can't
   641  // add items because the cache is full.
   642  func TestMessagesPending1(t *testing.T) {
   643  	// Set a large polling interval.
   644  	ti := newMMTable()
   645  	ti.MessageInfo.CacheSize = 2
   646  	ti.MessageInfo.PollInterval = 30 * time.Second
   647  	fvs := newFakeVStreamer()
   648  	mm := newMessageManager(newFakeTabletServer(), fvs, ti, sync2.NewSemaphore(1, 0))
   649  	mm.Open()
   650  	defer mm.Close()
   651  
   652  	r1 := newTestReceiver(0)
   653  	go func() { <-r1.ch }()
   654  	mm.Subscribe(context.Background(), r1.rcv)
   655  
   656  	mm.Add(&MessageRow{Row: []sqltypes.Value{sqltypes.NewVarBinary("1")}})
   657  	// Make sure the first message is enqueued.
   658  	r1.WaitForCount(2)
   659  	// This will fill up the cache.
   660  	assert.True(t, mm.Add(&MessageRow{Row: []sqltypes.Value{sqltypes.NewVarBinary("2")}}))
   661  	assert.True(t, mm.Add(&MessageRow{Row: []sqltypes.Value{sqltypes.NewVarBinary("3")}}))
   662  	// This will fail and messagesPending will be set to true.
   663  	assert.False(t, mm.Add(&MessageRow{Row: []sqltypes.Value{sqltypes.NewVarBinary("4")}}))
   664  
   665  	fvs.setPollerResponse([]*binlogdatapb.VStreamResultsResponse{{
   666  		Fields: testDBFields,
   667  		Gtid:   "MySQL56/33333333-3333-3333-3333-333333333333:1-100",
   668  	}, {
   669  		Rows: []*querypb.Row{newMMRow(1)},
   670  	}})
   671  
   672  	// Now, let's pull more than 3 items. It should
   673  	// trigger the poller, and there should be no wait.
   674  	start := time.Now()
   675  	for i := 0; i < 4; i++ {
   676  		<-r1.ch
   677  	}
   678  	if d := time.Since(start); d > 15*time.Second {
   679  		t.Errorf("pending work trigger did not happen. Duration: %v", d)
   680  	}
   681  }
   682  
   683  // TestMessagesPending2 tests for the case where
   684  // there are more pending items than the cache size.
   685  func TestMessagesPending2(t *testing.T) {
   686  	// Set a large polling interval.
   687  	ti := newMMTable()
   688  	ti.MessageInfo.CacheSize = 1
   689  	ti.MessageInfo.PollInterval = 30 * time.Second
   690  	fvs := newFakeVStreamer()
   691  	fvs.setPollerResponse([]*binlogdatapb.VStreamResultsResponse{{
   692  		Fields: testDBFields,
   693  		Gtid:   "MySQL56/33333333-3333-3333-3333-333333333333:1-100",
   694  	}, {
   695  		Rows: []*querypb.Row{newMMRow(1)},
   696  	}})
   697  	mm := newMessageManager(newFakeTabletServer(), fvs, ti, sync2.NewSemaphore(1, 0))
   698  	mm.Open()
   699  	defer mm.Close()
   700  
   701  	r1 := newTestReceiver(0)
   702  	go func() { <-r1.ch }()
   703  	mm.Subscribe(context.Background(), r1.rcv)
   704  
   705  	// Now, let's pull more than 1 item. It should
   706  	// trigger the poller every time cache gets empty.
   707  	start := time.Now()
   708  	for i := 0; i < 3; i++ {
   709  		<-r1.ch
   710  	}
   711  	if d := time.Since(start); d > 15*time.Second {
   712  		t.Errorf("pending work trigger did not happen. Duration: %v", d)
   713  	}
   714  }
   715  
   716  func TestMessageManagerPurge(t *testing.T) {
   717  	tsv := newFakeTabletServer()
   718  
   719  	// Make a buffered channel so the thread doesn't block on repeated calls.
   720  	ch := make(chan string, 20)
   721  	tsv.SetChannel(ch)
   722  
   723  	ti := newMMTable()
   724  	ti.MessageInfo.PollInterval = 1 * time.Millisecond
   725  	mm := newMessageManager(tsv, newFakeVStreamer(), ti, sync2.NewSemaphore(1, 0))
   726  	mm.Open()
   727  	defer mm.Close()
   728  	// Ensure Purge got called.
   729  	if got, want := <-ch, "purge"; got != want {
   730  		t.Errorf("Purge: %s, want %v", got, want)
   731  	}
   732  }
   733  
   734  func TestMMGenerate(t *testing.T) {
   735  	mm := newMessageManager(newFakeTabletServer(), newFakeVStreamer(), newMMTable(), sync2.NewSemaphore(1, 0))
   736  	mm.Open()
   737  	defer mm.Close()
   738  	query, bv := mm.GenerateAckQuery([]string{"1", "2"})
   739  	wantQuery := "update foo set time_acked = :time_acked, time_next = null where id in ::ids and time_acked is null"
   740  	if query != wantQuery {
   741  		t.Errorf("GenerateAckQuery query: %s, want %s", query, wantQuery)
   742  	}
   743  	bvv, _ := sqltypes.BindVariableToValue(bv["time_acked"])
   744  	gotAcked, _ := evalengine.ToInt64(bvv)
   745  	wantAcked := time.Now().UnixNano()
   746  	if wantAcked-gotAcked > 10e9 {
   747  		t.Errorf("gotAcked: %d, should be with 10s of %d", gotAcked, wantAcked)
   748  	}
   749  	gotids := bv["ids"]
   750  	wantids := sqltypes.TestBindVariable([]any{[]byte{'1'}, []byte{'2'}})
   751  	utils.MustMatch(t, wantids, gotids, "did not match")
   752  
   753  	query, bv = mm.GeneratePostponeQuery([]string{"1", "2"})
   754  	wantQuery = "update foo set time_next = :time_now + :wait_time + IF(FLOOR((:min_backoff<<ifnull(epoch, 0)) * :jitter) < :min_backoff, :min_backoff, FLOOR((:min_backoff<<ifnull(epoch, 0)) * :jitter)), epoch = ifnull(epoch, 0)+1 where id in ::ids and time_acked is null"
   755  	if query != wantQuery {
   756  		t.Errorf("GeneratePostponeQuery query: %s, want %s", query, wantQuery)
   757  	}
   758  	if _, ok := bv["time_now"]; !ok {
   759  		t.Errorf("time_now is absent in %v", bv)
   760  	} else {
   761  		// time_now cannot be compared.
   762  		delete(bv, "time_now")
   763  	}
   764  	if _, ok := bv["jitter"]; !ok {
   765  		t.Errorf("jitter is absent in %v", bv)
   766  	} else {
   767  		// jitter cannot be compared.
   768  		delete(bv, "jitter")
   769  	}
   770  	wantbv := map[string]*querypb.BindVariable{
   771  		"wait_time":   sqltypes.Int64BindVariable(1e9),
   772  		"min_backoff": sqltypes.Int64BindVariable(1e9),
   773  		"ids":         wantids,
   774  	}
   775  	utils.MustMatch(t, wantbv, bv, "did not match")
   776  
   777  	query, bv = mm.GeneratePurgeQuery(3)
   778  	wantQuery = "delete from foo where time_acked < :time_acked limit 500"
   779  	if query != wantQuery {
   780  		t.Errorf("GeneratePurgeQuery query: %s, want %s", query, wantQuery)
   781  	}
   782  	wantbv = map[string]*querypb.BindVariable{
   783  		"time_acked": sqltypes.Int64BindVariable(3),
   784  	}
   785  	if !reflect.DeepEqual(bv, wantbv) {
   786  		t.Errorf("gotid: %v, want %v", bv, wantbv)
   787  	}
   788  }
   789  
   790  func TestMMGenerateWithBackoff(t *testing.T) {
   791  	mm := newMessageManager(newFakeTabletServer(), newFakeVStreamer(), newMMTableWithBackoff(), sync2.NewSemaphore(1, 0))
   792  	mm.Open()
   793  	defer mm.Close()
   794  
   795  	wantids := sqltypes.TestBindVariable([]any{[]byte{'1'}, []byte{'2'}})
   796  
   797  	query, bv := mm.GeneratePostponeQuery([]string{"1", "2"})
   798  	wantQuery := "update foo set time_next = :time_now + :wait_time + IF(FLOOR((:min_backoff<<ifnull(epoch, 0)) * :jitter) < :min_backoff, :min_backoff, IF(FLOOR((:min_backoff<<ifnull(epoch, 0)) * :jitter) > :max_backoff, :max_backoff, FLOOR((:min_backoff<<ifnull(epoch, 0)) * :jitter))), epoch = ifnull(epoch, 0)+1 where id in ::ids and time_acked is null"
   799  	if query != wantQuery {
   800  		t.Errorf("GeneratePostponeQuery query: %s, want %s", query, wantQuery)
   801  	}
   802  	if _, ok := bv["time_now"]; !ok {
   803  		t.Errorf("time_now is absent in %v", bv)
   804  	} else {
   805  		// time_now cannot be compared.
   806  		delete(bv, "time_now")
   807  	}
   808  	if _, ok := bv["jitter"]; !ok {
   809  		t.Errorf("jitter is absent in %v", bv)
   810  	} else {
   811  		// jitter cannot be compared.
   812  		delete(bv, "jitter")
   813  	}
   814  	wantbv := map[string]*querypb.BindVariable{
   815  		"wait_time":   sqltypes.Int64BindVariable(1e10),
   816  		"min_backoff": sqltypes.Int64BindVariable(1e9),
   817  		"max_backoff": sqltypes.Int64BindVariable(4e9),
   818  		"ids":         wantids,
   819  	}
   820  	if !reflect.DeepEqual(bv, wantbv) {
   821  		t.Errorf("gotid: %v, want %v", bv, wantbv)
   822  	}
   823  }
   824  
   825  type fakeTabletServer struct {
   826  	tabletenv.Env
   827  	postponeCount sync2.AtomicInt64
   828  	purgeCount    sync2.AtomicInt64
   829  
   830  	mu sync.Mutex
   831  	ch chan string
   832  }
   833  
   834  func newFakeTabletServer() *fakeTabletServer {
   835  	config := tabletenv.NewDefaultConfig()
   836  	return &fakeTabletServer{
   837  		Env: tabletenv.NewEnv(config, "MessagerTest"),
   838  	}
   839  }
   840  
   841  func (fts *fakeTabletServer) CheckMySQL() {}
   842  
   843  func (fts *fakeTabletServer) SetChannel(ch chan string) {
   844  	fts.mu.Lock()
   845  	fts.ch = ch
   846  	fts.mu.Unlock()
   847  }
   848  
   849  func (fts *fakeTabletServer) PostponeMessages(ctx context.Context, target *querypb.Target, gen QueryGenerator, ids []string) (count int64, err error) {
   850  	fts.postponeCount.Add(1)
   851  	fts.mu.Lock()
   852  	ch := fts.ch
   853  	fts.mu.Unlock()
   854  	if ch != nil {
   855  		ch <- "postpone"
   856  	}
   857  	return 0, nil
   858  }
   859  
   860  func (fts *fakeTabletServer) PurgeMessages(ctx context.Context, target *querypb.Target, gen QueryGenerator, timeCutoff int64) (count int64, err error) {
   861  	fts.purgeCount.Add(1)
   862  	fts.mu.Lock()
   863  	ch := fts.ch
   864  	fts.mu.Unlock()
   865  	if ch != nil {
   866  		ch <- "purge"
   867  	}
   868  	return 0, nil
   869  }
   870  
   871  type fakeVStreamer struct {
   872  	streamInvocations sync2.AtomicInt64
   873  	mu                sync.Mutex
   874  	streamerResponse  [][]*binlogdatapb.VEvent
   875  	pollerResponse    []*binlogdatapb.VStreamResultsResponse
   876  }
   877  
   878  func newFakeVStreamer() *fakeVStreamer { return &fakeVStreamer{} }
   879  
   880  func (fv *fakeVStreamer) setStreamerResponse(sr [][]*binlogdatapb.VEvent) {
   881  	fv.mu.Lock()
   882  	defer fv.mu.Unlock()
   883  	fv.streamerResponse = sr
   884  }
   885  
   886  func (fv *fakeVStreamer) setPollerResponse(pr []*binlogdatapb.VStreamResultsResponse) {
   887  	fv.mu.Lock()
   888  	defer fv.mu.Unlock()
   889  	fv.pollerResponse = pr
   890  }
   891  
   892  func (fv *fakeVStreamer) Stream(ctx context.Context, startPos string, tablePKs []*binlogdatapb.TableLastPK, filter *binlogdatapb.Filter, send func([]*binlogdatapb.VEvent) error) error {
   893  	fv.streamInvocations.Add(1)
   894  	for {
   895  		fv.mu.Lock()
   896  		sr := fv.streamerResponse
   897  		fv.streamerResponse = nil
   898  		fv.mu.Unlock()
   899  		for _, r := range sr {
   900  			if err := send(r); err != nil {
   901  				return err
   902  			}
   903  		}
   904  		select {
   905  		case <-ctx.Done():
   906  			return io.EOF
   907  		default:
   908  		}
   909  		runtime.Gosched()
   910  		time.Sleep(10 * time.Millisecond)
   911  	}
   912  }
   913  
   914  func (fv *fakeVStreamer) StreamResults(ctx context.Context, query string, send func(*binlogdatapb.VStreamResultsResponse) error) error {
   915  	fv.mu.Lock()
   916  	defer fv.mu.Unlock()
   917  	for _, r := range fv.pollerResponse {
   918  		if err := send(r); err != nil {
   919  			return err
   920  		}
   921  	}
   922  	return nil
   923  }