vitess.io/vitess@v0.16.2/go/test/endtoend/messaging/message_test.go (about)

     1  /*
     2  Copyright 2020 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package messaging
    18  
    19  import (
    20  	"context"
    21  	"encoding/json"
    22  	"fmt"
    23  	"io"
    24  	"net/http"
    25  	"testing"
    26  	"time"
    27  
    28  	"github.com/stretchr/testify/assert"
    29  	"github.com/stretchr/testify/require"
    30  
    31  	"vitess.io/vitess/go/mysql"
    32  	"vitess.io/vitess/go/sqltypes"
    33  	"vitess.io/vitess/go/test/endtoend/cluster"
    34  	"vitess.io/vitess/go/test/endtoend/utils"
    35  	cmp "vitess.io/vitess/go/test/utils"
    36  	"vitess.io/vitess/go/vt/log"
    37  	"vitess.io/vitess/go/vt/proto/query"
    38  	querypb "vitess.io/vitess/go/vt/proto/query"
    39  	"vitess.io/vitess/go/vt/proto/topodata"
    40  	"vitess.io/vitess/go/vt/vtgate/evalengine"
    41  	"vitess.io/vitess/go/vt/vtgate/vtgateconn"
    42  )
    43  
    44  var testMessage = "{\"message\":\"hello world\"}"
    45  var testShardedMessagef = "{\"message\": \"hello world\", \"id\": %d}"
    46  
    47  var createMessage = `
    48  create table vitess_message(
    49  	# required columns
    50  	id bigint NOT NULL COMMENT 'often an event id, can also be auto-increment or a sequence',
    51  	priority tinyint NOT NULL DEFAULT '50' COMMENT 'lower number priorities process first',
    52  	epoch bigint NOT NULL DEFAULT '0' COMMENT 'Vitess increments this each time it sends a message, and is used for incremental backoff doubling',
    53  	time_next bigint DEFAULT 0 COMMENT 'the earliest time the message will be sent in epoch nanoseconds. Must be null if time_acked is set',
    54  	time_acked bigint DEFAULT NULL COMMENT 'the time the message was acked in epoch nanoseconds. Must be null if time_next is set',
    55  
    56  	# add as many custom fields here as required
    57  	# optional - these are suggestions
    58  	tenant_id bigint,
    59  	message json,
    60  
    61  	# required indexes
    62  	primary key(id),
    63  	index next_idx(time_next),
    64  	index poller_idx(time_acked, priority, time_next desc)
    65  
    66  	# add any secondary indexes or foreign keys - no restrictions
    67  ) comment 'vitess_message,vt_ack_wait=1,vt_purge_after=3,vt_batch_size=2,vt_cache_size=10,vt_poller_interval=1'
    68  `
    69  
    70  func TestMessage(t *testing.T) {
    71  	ctx := context.Background()
    72  
    73  	vtParams := mysql.ConnParams{
    74  		Host: "localhost",
    75  		Port: clusterInstance.VtgateMySQLPort,
    76  	}
    77  	conn, err := mysql.Connect(ctx, &vtParams)
    78  	require.NoError(t, err)
    79  	defer conn.Close()
    80  
    81  	streamConn, err := mysql.Connect(ctx, &vtParams)
    82  	require.NoError(t, err)
    83  	defer streamConn.Close()
    84  
    85  	utils.Exec(t, conn, fmt.Sprintf("use %s", lookupKeyspace))
    86  	utils.Exec(t, conn, createMessage)
    87  	clusterInstance.VtctlProcess.ExecuteCommand(fmt.Sprintf("ReloadSchemaKeyspace %s", lookupKeyspace))
    88  
    89  	defer utils.Exec(t, conn, "drop table vitess_message")
    90  
    91  	utils.Exec(t, streamConn, "set workload = 'olap'")
    92  	err = streamConn.ExecuteStreamFetch("stream * from vitess_message")
    93  	require.NoError(t, err)
    94  
    95  	wantFields := []*querypb.Field{{
    96  		Name: "id",
    97  		Type: sqltypes.Int64,
    98  	}, {
    99  		Name: "tenant_id",
   100  		Type: sqltypes.Int64,
   101  	}, {
   102  		Name: "message",
   103  		Type: sqltypes.TypeJSON,
   104  	}}
   105  	gotFields, err := streamConn.Fields()
   106  	for i, field := range gotFields {
   107  		// Remove other artifacts.
   108  		gotFields[i] = &querypb.Field{
   109  			Name: field.Name,
   110  			Type: field.Type,
   111  		}
   112  	}
   113  	require.NoError(t, err)
   114  	cmp.MustMatch(t, wantFields, gotFields)
   115  
   116  	utils.Exec(t, conn, fmt.Sprintf("insert into vitess_message(id, tenant_id, message) values(1, 1, '%s')", testMessage))
   117  
   118  	// account for jitter in timings, maxJitter uses the current hardcoded value for jitter in message_manager.go
   119  	jitter := int64(0)
   120  	maxJitter := int64(1.4 * 1e9)
   121  
   122  	// Consume first message.
   123  	start := time.Now().UnixNano()
   124  	got, err := streamConn.FetchNext(nil)
   125  	require.NoError(t, err)
   126  
   127  	want := []sqltypes.Value{
   128  		sqltypes.NewInt64(1),
   129  		sqltypes.NewInt64(1),
   130  		sqltypes.TestValue(sqltypes.TypeJSON, testMessage),
   131  	}
   132  	cmp.MustMatch(t, want, got)
   133  
   134  	qr := utils.Exec(t, conn, "select time_next, epoch from vitess_message where id = 1")
   135  	next, epoch := getTimeEpoch(qr)
   136  	jitter += epoch * maxJitter
   137  	// epoch could be 0 or 1, depending on how fast the row is updated
   138  	switch epoch {
   139  	case 0:
   140  		if !(start-1e9 < next && next < (start+jitter)) {
   141  			t.Errorf("next: %d. must be within 1s of start: %d", next/1e9, (start+jitter)/1e9)
   142  		}
   143  	case 1:
   144  		if !(start < next && next < (start+jitter)+3e9) {
   145  			t.Errorf("next: %d. must be about 1s after start: %d", next/1e9, (start+jitter)/1e9)
   146  		}
   147  	default:
   148  		t.Errorf("epoch: %d, must be 0 or 1", epoch)
   149  	}
   150  
   151  	// Consume the resend.
   152  	_, err = streamConn.FetchNext(nil)
   153  	require.NoError(t, err)
   154  	qr = utils.Exec(t, conn, "select time_next, epoch from vitess_message where id = 1")
   155  	next, epoch = getTimeEpoch(qr)
   156  	jitter += epoch * maxJitter
   157  	// epoch could be 1 or 2, depending on how fast the row is updated
   158  	switch epoch {
   159  	case 1:
   160  		if !(start < next && next < (start+jitter)+3e9) {
   161  			t.Errorf("next: %d. must be about 1s after start: %d", next/1e9, (start+jitter)/1e9)
   162  		}
   163  	case 2:
   164  		if !(start+2e9 < next && next < (start+jitter)+6e9) {
   165  			t.Errorf("next: %d. must be about 3s after start: %d", next/1e9, (start+jitter)/1e9)
   166  		}
   167  	default:
   168  		t.Errorf("epoch: %d, must be 1 or 2", epoch)
   169  	}
   170  
   171  	// Ack the message.
   172  	qr = utils.Exec(t, conn, "update vitess_message set time_acked = 123, time_next = null where id = 1 and time_acked is null")
   173  	assert.Equal(t, uint64(1), qr.RowsAffected)
   174  
   175  	// Within 3+1 seconds, the row should be deleted.
   176  	time.Sleep(4 * time.Second)
   177  	qr = utils.Exec(t, conn, "select time_acked, epoch from vitess_message where id = 1")
   178  	assert.Equal(t, 0, len(qr.Rows))
   179  }
   180  
   181  var createThreeColMessage = `
   182  create table vitess_message3(
   183  	# required columns
   184  	id bigint NOT NULL COMMENT 'often an event id, can also be auto-increment or a sequence',
   185  	priority tinyint NOT NULL DEFAULT '50' COMMENT 'lower number priorities process first',
   186  	epoch bigint NOT NULL DEFAULT '0' COMMENT 'Vitess increments this each time it sends a message, and is used for incremental backoff doubling',
   187  	time_next bigint DEFAULT 0 COMMENT 'the earliest time the message will be sent in epoch nanoseconds. Must be null if time_acked is set',
   188  	time_acked bigint DEFAULT NULL COMMENT 'the time the message was acked in epoch nanoseconds. Must be null if time_next is set',
   189  
   190  	# add as many custom fields here as required
   191  	# optional - these are suggestions
   192  	tenant_id bigint,
   193  	message json,
   194  
   195  	# custom to this test
   196  	msg1 varchar(128),
   197  	msg2 bigint,
   198  
   199  	# required indexes
   200  	primary key(id),
   201  	index next_idx(time_next),
   202  	index poller_idx(time_acked, priority, time_next desc)
   203  
   204  	# add any secondary indexes or foreign keys - no restrictions
   205  ) comment 'vitess_message,vt_ack_wait=1,vt_purge_after=3,vt_batch_size=2,vt_cache_size=10,vt_poller_interval=1'
   206  `
   207  
   208  func TestThreeColMessage(t *testing.T) {
   209  	ctx := context.Background()
   210  
   211  	vtParams := mysql.ConnParams{
   212  		Host: "localhost",
   213  		Port: clusterInstance.VtgateMySQLPort,
   214  	}
   215  	conn, err := mysql.Connect(ctx, &vtParams)
   216  	require.NoError(t, err)
   217  	defer conn.Close()
   218  
   219  	streamConn, err := mysql.Connect(ctx, &vtParams)
   220  	require.NoError(t, err)
   221  	defer streamConn.Close()
   222  
   223  	utils.Exec(t, conn, fmt.Sprintf("use %s", lookupKeyspace))
   224  	utils.Exec(t, conn, createThreeColMessage)
   225  	defer utils.Exec(t, conn, "drop table vitess_message3")
   226  
   227  	utils.Exec(t, streamConn, "set workload = 'olap'")
   228  	err = streamConn.ExecuteStreamFetch("stream * from vitess_message3")
   229  	require.NoError(t, err)
   230  
   231  	wantFields := []*querypb.Field{{
   232  		Name: "id",
   233  		Type: sqltypes.Int64,
   234  	}, {
   235  		Name: "tenant_id",
   236  		Type: sqltypes.Int64,
   237  	}, {
   238  		Name: "message",
   239  		Type: sqltypes.TypeJSON,
   240  	}, {
   241  		Name: "msg1",
   242  		Type: sqltypes.VarChar,
   243  	}, {
   244  		Name: "msg2",
   245  		Type: sqltypes.Int64,
   246  	}}
   247  	gotFields, err := streamConn.Fields()
   248  	for i, field := range gotFields {
   249  		// Remove other artifacts.
   250  		gotFields[i] = &querypb.Field{
   251  			Name: field.Name,
   252  			Type: field.Type,
   253  		}
   254  	}
   255  	require.NoError(t, err)
   256  	cmp.MustMatch(t, wantFields, gotFields)
   257  
   258  	utils.Exec(t, conn, fmt.Sprintf("insert into vitess_message3(id, tenant_id, message, msg1, msg2) values(1, 3, '%s', 'hello world', 3)", testMessage))
   259  
   260  	got, err := streamConn.FetchNext(nil)
   261  	require.NoError(t, err)
   262  	want := []sqltypes.Value{
   263  		sqltypes.NewInt64(1),
   264  		sqltypes.NewInt64(3),
   265  		sqltypes.TestValue(sqltypes.TypeJSON, testMessage),
   266  		sqltypes.NewVarChar("hello world"),
   267  		sqltypes.NewInt64(3),
   268  	}
   269  	cmp.MustMatch(t, want, got)
   270  
   271  	// Verify Ack.
   272  	qr := utils.Exec(t, conn, "update vitess_message3 set time_acked = 123, time_next = null where id = 1 and time_acked is null")
   273  	assert.Equal(t, uint64(1), qr.RowsAffected)
   274  }
   275  
   276  var createSpecificStreamingColsMessage = `create table vitess_message4(
   277  	# required columns
   278  	id bigint NOT NULL COMMENT 'often an event id, can also be auto-increment or a sequence',
   279  	priority tinyint NOT NULL DEFAULT '50' COMMENT 'lower number priorities process first',
   280  	epoch bigint NOT NULL DEFAULT '0' COMMENT 'Vitess increments this each time it sends a message, and is used for incremental backoff doubling',
   281  	time_next bigint DEFAULT 0 COMMENT 'the earliest time the message will be sent in epoch nanoseconds. Must be null if time_acked is set',
   282  	time_acked bigint DEFAULT NULL COMMENT 'the time the message was acked in epoch nanoseconds. Must be null if time_next is set',
   283  
   284  	# add as many custom fields here as required
   285  	# optional - these are suggestions
   286  	tenant_id bigint,
   287  	message json,
   288  
   289  	# custom to this test
   290  	msg1 varchar(128),
   291  	msg2 bigint,
   292  
   293  	# required indexes
   294  	primary key(id),
   295  	index next_idx(time_next),
   296  	index poller_idx(time_acked, priority, time_next desc)
   297  
   298  	# add any secondary indexes or foreign keys - no restrictions
   299  ) comment 'vitess_message,vt_message_cols=id|msg1,vt_ack_wait=1,vt_purge_after=3,vt_batch_size=2,vt_cache_size=10,vt_poller_interval=1'`
   300  
   301  func TestSpecificStreamingColsMessage(t *testing.T) {
   302  	ctx := context.Background()
   303  
   304  	vtParams := mysql.ConnParams{
   305  		Host: "localhost",
   306  		Port: clusterInstance.VtgateMySQLPort,
   307  	}
   308  	conn, err := mysql.Connect(ctx, &vtParams)
   309  	require.NoError(t, err)
   310  	defer conn.Close()
   311  
   312  	streamConn, err := mysql.Connect(ctx, &vtParams)
   313  	require.NoError(t, err)
   314  	defer streamConn.Close()
   315  
   316  	utils.Exec(t, conn, fmt.Sprintf("use %s", lookupKeyspace))
   317  	utils.Exec(t, conn, createSpecificStreamingColsMessage)
   318  	defer utils.Exec(t, conn, "drop table vitess_message4")
   319  
   320  	utils.Exec(t, streamConn, "set workload = 'olap'")
   321  	err = streamConn.ExecuteStreamFetch("stream * from vitess_message4")
   322  	require.NoError(t, err)
   323  
   324  	wantFields := []*querypb.Field{{
   325  		Name: "id",
   326  		Type: sqltypes.Int64,
   327  	}, {
   328  		Name: "msg1",
   329  		Type: sqltypes.VarChar,
   330  	}}
   331  	gotFields, err := streamConn.Fields()
   332  	for i, field := range gotFields {
   333  		// Remove other artifacts.
   334  		gotFields[i] = &querypb.Field{
   335  			Name: field.Name,
   336  			Type: field.Type,
   337  		}
   338  	}
   339  	require.NoError(t, err)
   340  	cmp.MustMatch(t, wantFields, gotFields)
   341  
   342  	utils.Exec(t, conn, "insert into vitess_message4(id, msg1, msg2) values(1, 'hello world', 3)")
   343  
   344  	got, err := streamConn.FetchNext(nil)
   345  	require.NoError(t, err)
   346  	want := []sqltypes.Value{
   347  		sqltypes.NewInt64(1),
   348  		sqltypes.NewVarChar("hello world"),
   349  	}
   350  	cmp.MustMatch(t, want, got)
   351  
   352  	// Verify Ack.
   353  	qr := utils.Exec(t, conn, "update vitess_message4 set time_acked = 123, time_next = null where id = 1 and time_acked is null")
   354  	assert.Equal(t, uint64(1), qr.RowsAffected)
   355  }
   356  
   357  func getTimeEpoch(qr *sqltypes.Result) (int64, int64) {
   358  	if len(qr.Rows) != 1 {
   359  		return 0, 0
   360  	}
   361  	t, _ := evalengine.ToInt64(qr.Rows[0][0])
   362  	e, _ := evalengine.ToInt64(qr.Rows[0][1])
   363  	return t, e
   364  }
   365  
   366  func TestSharded(t *testing.T) {
   367  	// validate the messaging for sharded keyspace(user)
   368  	testMessaging(t, "sharded_message", userKeyspace)
   369  }
   370  
   371  func TestUnsharded(t *testing.T) {
   372  	// validate messaging for unsharded keyspace(lookup)
   373  	testMessaging(t, "unsharded_message", lookupKeyspace)
   374  }
   375  
   376  // TestReparenting checks the client connection count after reparenting.
   377  func TestReparenting(t *testing.T) {
   378  	defer cluster.PanicHandler(t)
   379  	name := "sharded_message"
   380  
   381  	ctx := context.Background()
   382  	// start grpc connection with vtgate and validate client
   383  	// connection counts in tablets
   384  	stream, err := VtgateGrpcConn(ctx, clusterInstance)
   385  	require.Nil(t, err)
   386  	defer stream.Close()
   387  	_, err = stream.MessageStream(userKeyspace, "", nil, name)
   388  	require.Nil(t, err)
   389  
   390  	assertClientCount(t, 1, shard0Primary)
   391  	assertClientCount(t, 0, shard0Replica)
   392  	assertClientCount(t, 1, shard1Primary)
   393  
   394  	// do planned reparenting, make one replica as primary
   395  	// and validate client connection count in correspond tablets
   396  	clusterInstance.VtctlclientProcess.ExecuteCommandWithOutput(
   397  		"PlannedReparentShard", "--",
   398  		"--keyspace_shard", userKeyspace+"/-80",
   399  		"--new_primary", shard0Replica.Alias)
   400  	// validate topology
   401  	err = clusterInstance.VtctlclientProcess.ExecuteCommand("Validate")
   402  	require.Nil(t, err)
   403  
   404  	// Verify connection has migrated.
   405  	// The wait must be at least 6s which is how long vtgate will
   406  	// wait before retrying: that is 30s/5 where 30s is the default
   407  	// message_stream_grace_period.
   408  	time.Sleep(10 * time.Second)
   409  	assertClientCount(t, 0, shard0Primary)
   410  	assertClientCount(t, 1, shard0Replica)
   411  	assertClientCount(t, 1, shard1Primary)
   412  	session := stream.Session("@primary", nil)
   413  	msg3 := fmt.Sprintf(testShardedMessagef, 3)
   414  	cluster.ExecuteQueriesUsingVtgate(t, session, fmt.Sprintf("insert into sharded_message (id, tenant_id, message) values (3,3,'%s')", msg3))
   415  
   416  	// validate that we have received inserted message
   417  	stream.Next()
   418  
   419  	// make old primary again as new primary
   420  	clusterInstance.VtctlclientProcess.ExecuteCommandWithOutput(
   421  		"PlannedReparentShard", "--",
   422  		"--keyspace_shard", userKeyspace+"/-80",
   423  		"--new_primary", shard0Primary.Alias)
   424  	// validate topology
   425  	err = clusterInstance.VtctlclientProcess.ExecuteCommand("Validate")
   426  	require.Nil(t, err)
   427  	time.Sleep(10 * time.Second)
   428  	assertClientCount(t, 1, shard0Primary)
   429  	assertClientCount(t, 0, shard0Replica)
   430  	assertClientCount(t, 1, shard1Primary)
   431  
   432  	_, err = session.Execute(context.Background(), "update "+name+" set time_acked = 1, time_next = null where id in (3) and time_acked is null", nil)
   433  	require.Nil(t, err)
   434  }
   435  
   436  // TestConnection validate the connection count and message streaming.
   437  func TestConnection(t *testing.T) {
   438  	defer cluster.PanicHandler(t)
   439  
   440  	name := "sharded_message"
   441  
   442  	// 1 sec sleep added to avoid invalid connection count
   443  	time.Sleep(time.Second)
   444  
   445  	// create two grpc connection with vtgate and verify
   446  	// client connection count in vttablet of the primary
   447  	assertClientCount(t, 0, shard0Primary)
   448  	assertClientCount(t, 0, shard1Primary)
   449  
   450  	ctx := context.Background()
   451  	// first connection with vtgate
   452  	stream, err := VtgateGrpcConn(ctx, clusterInstance)
   453  	require.Nil(t, err)
   454  	_, err = stream.MessageStream(userKeyspace, "", nil, name)
   455  	require.Nil(t, err)
   456  	// validate client count of vttablet
   457  	time.Sleep(time.Second)
   458  	assertClientCount(t, 1, shard0Primary)
   459  	assertClientCount(t, 1, shard1Primary)
   460  	// second connection with vtgate, secont connection
   461  	// will only be used for client connection counts
   462  	stream1, err := VtgateGrpcConn(ctx, clusterInstance)
   463  	require.Nil(t, err)
   464  	_, err = stream1.MessageStream(userKeyspace, "", nil, name)
   465  	require.Nil(t, err)
   466  	// validate client count of vttablet
   467  	time.Sleep(time.Second)
   468  	assertClientCount(t, 2, shard0Primary)
   469  	assertClientCount(t, 2, shard1Primary)
   470  
   471  	// insert data in primary and validate that we receive this
   472  	// in message stream
   473  	session := stream.Session("@primary", nil)
   474  	// insert data in primary
   475  	msg2 := fmt.Sprintf(testShardedMessagef, 2)
   476  	msg5 := fmt.Sprintf(testShardedMessagef, 5)
   477  	cluster.ExecuteQueriesUsingVtgate(t, session, fmt.Sprintf("insert into sharded_message (id, tenant_id, message) values (2,2,'%s')", msg2))
   478  	cluster.ExecuteQueriesUsingVtgate(t, session, fmt.Sprintf("insert into sharded_message (id, tenant_id, message) values (5,5,'%s')", msg5))
   479  	// validate in msg stream
   480  	_, err = stream.Next()
   481  	require.Nil(t, err)
   482  	_, err = stream.Next()
   483  	require.Nil(t, err)
   484  
   485  	_, err = session.Execute(context.Background(), "update "+name+" set time_acked = 1, time_next = null where id in (2, 5) and time_acked is null", nil)
   486  	require.Nil(t, err)
   487  	// After closing one stream, ensure vttablets have dropped it.
   488  	stream.Close()
   489  	time.Sleep(time.Second)
   490  	assertClientCount(t, 1, shard0Primary)
   491  	assertClientCount(t, 1, shard1Primary)
   492  
   493  	stream1.Close()
   494  }
   495  
   496  func testMessaging(t *testing.T, name, ks string) {
   497  	defer cluster.PanicHandler(t)
   498  	ctx := context.Background()
   499  	stream, err := VtgateGrpcConn(ctx, clusterInstance)
   500  	require.Nil(t, err)
   501  	defer stream.Close()
   502  
   503  	session := stream.Session("@primary", nil)
   504  	msg4 := fmt.Sprintf(testShardedMessagef, 4)
   505  	msg1 := fmt.Sprintf(testShardedMessagef, 1)
   506  	cluster.ExecuteQueriesUsingVtgate(t, session, fmt.Sprintf("insert into "+name+" (id, tenant_id, message) values (4,4,'%s')", msg4))
   507  	cluster.ExecuteQueriesUsingVtgate(t, session, fmt.Sprintf("insert into "+name+" (id, tenant_id, message) values (1,1,'%s')", msg1))
   508  
   509  	// validate fields
   510  	res, err := stream.MessageStream(ks, "", nil, name)
   511  	require.Nil(t, err)
   512  	require.Equal(t, 3, len(res.Fields))
   513  	validateField(t, res.Fields[0], "id", query.Type_INT64)
   514  	validateField(t, res.Fields[1], "tenant_id", query.Type_INT64)
   515  	validateField(t, res.Fields[2], "message", query.Type_JSON)
   516  
   517  	// validate recieved msgs
   518  	resMap := make(map[string]string)
   519  	res, err = stream.Next()
   520  	require.Nil(t, err)
   521  	for _, row := range res.Rows {
   522  		resMap[row[0].ToString()] = row[1].ToString()
   523  	}
   524  
   525  	if name == "sharded_message" {
   526  		res, err = stream.Next()
   527  		require.Nil(t, err)
   528  		for _, row := range res.Rows {
   529  			resMap[row[0].ToString()] = row[1].ToString()
   530  		}
   531  	}
   532  
   533  	assert.Equal(t, "1", resMap["1"])
   534  	assert.Equal(t, "4", resMap["4"])
   535  
   536  	resMap = make(map[string]string)
   537  	stream.ClearMem()
   538  	// validate message ack with id 4
   539  	qr, err := session.Execute(context.Background(), "update "+name+" set time_acked = 1, time_next = null where id in (4) and time_acked is null", nil)
   540  	require.Nil(t, err)
   541  	assert.Equal(t, uint64(1), qr.RowsAffected)
   542  
   543  	for res, err = stream.Next(); err == nil; res, err = stream.Next() {
   544  		for _, row := range res.Rows {
   545  			resMap[row[0].ToString()] = row[1].ToString()
   546  		}
   547  	}
   548  
   549  	assert.Equal(t, "1", resMap["1"])
   550  
   551  	// validate message ack with 1 and 4, only 1 should be ack
   552  	qr, err = session.Execute(context.Background(), "update "+name+" set time_acked = 1, time_next = null where id in (1, 4) and time_acked is null", nil)
   553  	require.Nil(t, err)
   554  	assert.Equal(t, uint64(1), qr.RowsAffected)
   555  }
   556  
   557  func validateField(t *testing.T, field *query.Field, name string, _type query.Type) {
   558  	assert.Equal(t, name, field.Name)
   559  	assert.Equal(t, _type, field.Type)
   560  }
   561  
   562  // MsgStream handles all meta required for grpc connection with vtgate.
   563  type VTGateStream struct {
   564  	ctx      context.Context
   565  	host     string
   566  	respChan chan *sqltypes.Result
   567  	mem      *sqltypes.Result
   568  	*vtgateconn.VTGateConn
   569  }
   570  
   571  // VtgateGrpcConn create new msg stream for grpc connection with vtgate.
   572  func VtgateGrpcConn(ctx context.Context, cluster *cluster.LocalProcessCluster) (*VTGateStream, error) {
   573  	stream := new(VTGateStream)
   574  	stream.ctx = ctx
   575  	stream.host = fmt.Sprintf("%s:%d", cluster.Hostname, cluster.VtgateProcess.GrpcPort)
   576  	conn, err := vtgateconn.Dial(ctx, stream.host)
   577  	// init components
   578  	stream.respChan = make(chan *sqltypes.Result)
   579  	stream.VTGateConn = conn
   580  
   581  	return stream, err
   582  }
   583  
   584  // MessageStream strarts the stream for the corresponding connection.
   585  func (stream *VTGateStream) MessageStream(ks, shard string, keyRange *topodata.KeyRange, name string) (*sqltypes.Result, error) {
   586  	// start message stream which send received message to the respChan
   587  	session := stream.Session("@primary", nil)
   588  	resultStream, err := session.StreamExecute(stream.ctx, fmt.Sprintf("stream * from %s", name), nil)
   589  	if err != nil {
   590  		return nil, err
   591  	}
   592  	qr, err := resultStream.Recv()
   593  	if err != nil {
   594  		return nil, err
   595  	}
   596  	go func() {
   597  		for {
   598  			qr, err := resultStream.Recv()
   599  			if err != nil {
   600  				log.Infof("Message stream ended: %v", err)
   601  				return
   602  			}
   603  
   604  			if stream.mem != nil && stream.mem.Equal(qr) {
   605  				continue
   606  			}
   607  
   608  			stream.mem = qr
   609  			stream.respChan <- qr
   610  		}
   611  	}()
   612  	return qr, nil
   613  }
   614  
   615  // ClearMem cleares the last result stored.
   616  func (stream *VTGateStream) ClearMem() {
   617  	stream.mem = nil
   618  }
   619  
   620  // Next reads the new msg available in stream.
   621  func (stream *VTGateStream) Next() (*sqltypes.Result, error) {
   622  	timer := time.NewTimer(10 * time.Second)
   623  	select {
   624  	case s := <-stream.respChan:
   625  		return s, nil
   626  	case <-timer.C:
   627  		return nil, fmt.Errorf("time limit exceeded")
   628  	}
   629  }
   630  
   631  // assertClientCount read connected client count from the vttablet debug vars.
   632  func assertClientCount(t *testing.T, expected int, vttablet *cluster.Vttablet) {
   633  	var vars struct {
   634  		Messages map[string]int
   635  	}
   636  
   637  	parseDebugVars(t, &vars, vttablet)
   638  
   639  	got := vars.Messages["sharded_message.ClientCount"]
   640  	if got != expected {
   641  		t.Fatalf("wrong number of clients: got %d, expected %d. messages:\n%#v", got, expected, vars.Messages)
   642  	}
   643  }
   644  
   645  func parseDebugVars(t *testing.T, output interface{}, vttablet *cluster.Vttablet) {
   646  	debugVarURL := fmt.Sprintf("http://%s:%d/debug/vars", vttablet.VttabletProcess.TabletHostname, vttablet.HTTPPort)
   647  	resp, err := http.Get(debugVarURL)
   648  	if err != nil {
   649  		t.Fatalf("failed to fetch %q: %v", debugVarURL, err)
   650  	}
   651  	defer resp.Body.Close()
   652  
   653  	respByte, err := io.ReadAll(resp.Body)
   654  	if err != nil {
   655  		t.Fatalf("failed to read body %q: %v", debugVarURL, err)
   656  	}
   657  
   658  	if resp.StatusCode != 200 {
   659  		t.Fatalf("status code %d while fetching %q:\n%s", resp.StatusCode, debugVarURL, respByte)
   660  	}
   661  
   662  	if err := json.Unmarshal(respByte, output); err != nil {
   663  		t.Fatalf("failed to unmarshal JSON from %q: %v", debugVarURL, err)
   664  	}
   665  }