github.com/m3db/m3@v1.5.1-0.20231129193456-75a402aa583b/src/dbnode/client/session_test.go (about)

     1  // Copyright (c) 2016 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package client
    22  
    23  import (
    24  	"context"
    25  	"errors"
    26  	"fmt"
    27  	"strings"
    28  	"sync"
    29  	"sync/atomic"
    30  	"testing"
    31  	"time"
    32  
    33  	"github.com/m3db/m3/src/cluster/shard"
    34  	"github.com/m3db/m3/src/dbnode/encoding"
    35  	"github.com/m3db/m3/src/dbnode/generated/thrift/rpc"
    36  	"github.com/m3db/m3/src/dbnode/sharding"
    37  	"github.com/m3db/m3/src/dbnode/storage/index"
    38  	"github.com/m3db/m3/src/dbnode/topology"
    39  	"github.com/m3db/m3/src/dbnode/x/xpool"
    40  	"github.com/m3db/m3/src/m3ninx/idx"
    41  	xerror "github.com/m3db/m3/src/x/errors"
    42  	"github.com/m3db/m3/src/x/ident"
    43  	xretry "github.com/m3db/m3/src/x/retry"
    44  	"github.com/m3db/m3/src/x/sampler"
    45  	"github.com/m3db/m3/src/x/serialize"
    46  	xtest "github.com/m3db/m3/src/x/test"
    47  
    48  	"github.com/golang/mock/gomock"
    49  	"github.com/stretchr/testify/assert"
    50  	"github.com/stretchr/testify/require"
    51  )
    52  
    53  const (
    54  	sessionTestReplicas = 3
    55  	sessionTestShards   = 3
    56  )
    57  
    58  type outcome int
    59  
    60  const (
    61  	outcomeSuccess outcome = iota
    62  	outcomeFail
    63  )
    64  
    65  type testEnqueueFn func(idx int, op op)
    66  
    67  // NB: allocating once to speedup tests.
    68  var _testSessionOpts = NewOptions().
    69  	SetCheckedBytesWrapperPoolSize(1).
    70  	SetFetchBatchOpPoolSize(1).
    71  	SetHostQueueOpsArrayPoolSize(1).
    72  	SetTagEncoderPoolSize(1).
    73  	SetWriteOpPoolSize(1).
    74  	SetWriteTaggedOpPoolSize(1).
    75  	SetSeriesIteratorPoolSize(1).
    76  	// Set 100% sample rate to test the code path that logs errors.
    77  	SetLogErrorSampleRate(sampler.Rate(1)).
    78  	SetLogHostFetchErrorSampleRate(sampler.Rate(1)).
    79  	SetLogHostWriteErrorSampleRate(sampler.Rate(1))
    80  
    81  func testContext() context.Context {
    82  	// nolint: govet
    83  	ctx, _ := context.WithTimeout(context.Background(), time.Minute) //nolint
    84  	return ctx
    85  }
    86  
    87  func newSessionTestOptions() Options {
    88  	return applySessionTestOptions(_testSessionOpts)
    89  }
    90  
    91  func sessionTestShardSet() sharding.ShardSet {
    92  	var ids []uint32
    93  	for i := uint32(0); i < uint32(sessionTestShards); i++ {
    94  		ids = append(ids, i)
    95  	}
    96  
    97  	shards := sharding.NewShards(ids, shard.Available)
    98  	hashFn := func(id ident.ID) uint32 { return 0 }
    99  	shardSet, _ := sharding.NewShardSet(shards, hashFn)
   100  	return shardSet
   101  }
   102  
   103  func testHostName(i int) string { return fmt.Sprintf("testhost%d", i) }
   104  
   105  func sessionTestHostAndShards(
   106  	shardSet sharding.ShardSet,
   107  ) []topology.HostShardSet {
   108  	var hosts []topology.Host
   109  	for i := 0; i < sessionTestReplicas; i++ {
   110  		id := testHostName(i)
   111  		host := topology.NewHost(id, fmt.Sprintf("%s:9000", id))
   112  		hosts = append(hosts, host)
   113  	}
   114  
   115  	var hostShardSets []topology.HostShardSet
   116  	for _, host := range hosts {
   117  		hostShardSet := topology.NewHostShardSet(host, shardSet)
   118  		hostShardSets = append(hostShardSets, hostShardSet)
   119  	}
   120  	return hostShardSets
   121  }
   122  
   123  func applySessionTestOptions(opts Options) Options {
   124  	shardSet := sessionTestShardSet()
   125  	return opts.
   126  		// Some of the test mocks expect things to only happen once, so disable retries
   127  		// for the unit tests.
   128  		SetWriteRetrier(xretry.NewRetrier(xretry.NewOptions().SetMaxRetries(0))).
   129  		SetFetchRetrier(xretry.NewRetrier(xretry.NewOptions().SetMaxRetries(0))).
   130  		SetSeriesIteratorPoolSize(0).
   131  		SetWriteOpPoolSize(0).
   132  		SetWriteTaggedOpPoolSize(0).
   133  		SetFetchBatchOpPoolSize(0).
   134  		SetTopologyInitializer(topology.NewStaticInitializer(
   135  			topology.NewStaticOptions().
   136  				SetReplicas(sessionTestReplicas).
   137  				SetShardSet(shardSet).
   138  				SetHostShardSets(sessionTestHostAndShards(shardSet))))
   139  }
   140  
   141  func newTestHostQueue(opts Options) *queue {
   142  	hq, err := newHostQueue(h, hostQueueOpts{
   143  		writeBatchRawRequestPool:                     testWriteBatchRawPool,
   144  		writeBatchRawV2RequestPool:                   testWriteBatchRawV2Pool,
   145  		writeBatchRawRequestElementArrayPool:         testWriteArrayPool,
   146  		writeBatchRawV2RequestElementArrayPool:       testWriteV2ArrayPool,
   147  		writeTaggedBatchRawRequestPool:               testWriteTaggedBatchRawPool,
   148  		writeTaggedBatchRawV2RequestPool:             testWriteTaggedBatchRawV2Pool,
   149  		writeTaggedBatchRawRequestElementArrayPool:   testWriteTaggedArrayPool,
   150  		writeTaggedBatchRawV2RequestElementArrayPool: testWriteTaggedV2ArrayPool,
   151  		fetchBatchRawV2RequestPool:                   testFetchBatchRawV2Pool,
   152  		fetchBatchRawV2RequestElementArrayPool:       testFetchBatchRawV2ArrayPool,
   153  		opts:                                         opts,
   154  	})
   155  	if err != nil {
   156  		panic(err)
   157  	}
   158  	return hq.(*queue)
   159  }
   160  
   161  func TestSessionCreationFailure(t *testing.T) {
   162  	topoOpts := topology.NewDynamicOptions()
   163  	topoInit := topology.NewDynamicInitializer(topoOpts)
   164  	opt := newSessionTestOptions().SetTopologyInitializer(topoInit)
   165  	_, err := newSession(opt)
   166  	assert.Error(t, err)
   167  }
   168  
   169  func TestSessionShardID(t *testing.T) {
   170  	ctrl := gomock.NewController(t)
   171  	defer ctrl.Finish()
   172  
   173  	opts := newSessionTestOptions()
   174  	s, err := newSession(opts)
   175  	assert.NoError(t, err)
   176  
   177  	_, err = s.ShardID(ident.StringID("foo"))
   178  	assert.Error(t, err)
   179  	assert.Equal(t, ErrSessionStatusNotOpen, err)
   180  
   181  	mockHostQueues(ctrl, s.(*session), sessionTestReplicas, nil)
   182  
   183  	require.NoError(t, s.Open())
   184  
   185  	// The shard set we create in newSessionTestOptions always hashes to uint32
   186  	shard, err := s.ShardID(ident.StringID("foo"))
   187  	require.NoError(t, err)
   188  	assert.Equal(t, uint32(0), shard)
   189  
   190  	assert.NoError(t, s.Close())
   191  }
   192  
   193  func TestSessionClusterConnectConsistencyLevelAll(t *testing.T) {
   194  	ctrl := gomock.NewController(t)
   195  	defer ctrl.Finish()
   196  
   197  	level := topology.ConnectConsistencyLevelAll
   198  	testSessionClusterConnectConsistencyLevel(t, ctrl, level, 0, outcomeSuccess)
   199  	for i := 1; i <= 3; i++ {
   200  		testSessionClusterConnectConsistencyLevel(t, ctrl, level, i, outcomeFail)
   201  	}
   202  }
   203  
   204  func TestSessionClusterConnectConsistencyLevelMajority(t *testing.T) {
   205  	ctrl := gomock.NewController(t)
   206  	defer ctrl.Finish()
   207  
   208  	level := topology.ConnectConsistencyLevelMajority
   209  	for i := 0; i <= 1; i++ {
   210  		testSessionClusterConnectConsistencyLevel(t, ctrl, level, i, outcomeSuccess)
   211  	}
   212  	for i := 2; i <= 3; i++ {
   213  		testSessionClusterConnectConsistencyLevel(t, ctrl, level, i, outcomeFail)
   214  	}
   215  }
   216  
   217  func TestSessionClusterConnectConsistencyLevelOne(t *testing.T) {
   218  	ctrl := gomock.NewController(t)
   219  	defer ctrl.Finish()
   220  
   221  	level := topology.ConnectConsistencyLevelOne
   222  	for i := 0; i <= 2; i++ {
   223  		testSessionClusterConnectConsistencyLevel(t, ctrl, level, i, outcomeSuccess)
   224  	}
   225  	testSessionClusterConnectConsistencyLevel(t, ctrl, level, 3, outcomeFail)
   226  }
   227  
   228  func TestSessionClusterConnectConsistencyLevelNone(t *testing.T) {
   229  	ctrl := gomock.NewController(t)
   230  	defer ctrl.Finish()
   231  
   232  	level := topology.ConnectConsistencyLevelNone
   233  	for i := 0; i <= 3; i++ {
   234  		testSessionClusterConnectConsistencyLevel(t, ctrl, level, i, outcomeSuccess)
   235  	}
   236  }
   237  
   238  func TestIteratorPools(t *testing.T) {
   239  	s := session{}
   240  	itPool, err := s.IteratorPools()
   241  
   242  	assert.EqualError(t, err, ErrSessionStatusNotOpen.Error())
   243  	assert.Nil(t, itPool)
   244  
   245  	multiReaderIteratorArray := encoding.NewMultiReaderIteratorArrayPool(nil)
   246  	multiReaderIteratorPool := encoding.NewMultiReaderIteratorPool(nil)
   247  	seriesIteratorPool := encoding.NewSeriesIteratorPool(nil)
   248  	checkedBytesWrapperPool := xpool.NewCheckedBytesWrapperPool(nil)
   249  	idPool := ident.NewPool(nil, ident.PoolOptions{})
   250  	encoderPool := serialize.NewTagEncoderPool(nil, nil)
   251  	decoderPool := serialize.NewTagDecoderPool(nil, nil)
   252  
   253  	s.pools = sessionPools{
   254  		multiReaderIteratorArray: multiReaderIteratorArray,
   255  		multiReaderIterator:      multiReaderIteratorPool,
   256  		seriesIterator:           seriesIteratorPool,
   257  		checkedBytesWrapper:      checkedBytesWrapperPool,
   258  		id:                       idPool,
   259  		tagEncoder:               encoderPool,
   260  		tagDecoder:               decoderPool,
   261  	}
   262  
   263  	// Error expected if state is not open
   264  	itPool, err = s.IteratorPools()
   265  	assert.EqualError(t, err, ErrSessionStatusNotOpen.Error())
   266  	assert.Nil(t, itPool)
   267  
   268  	s.state.status = statusOpen
   269  
   270  	itPool, err = s.IteratorPools()
   271  	require.NoError(t, err)
   272  	assert.Equal(t, multiReaderIteratorArray, itPool.MultiReaderIteratorArray())
   273  	assert.Equal(t, multiReaderIteratorPool, itPool.MultiReaderIterator())
   274  	assert.Equal(t, seriesIteratorPool, itPool.SeriesIterator())
   275  	assert.Equal(t, checkedBytesWrapperPool, itPool.CheckedBytesWrapper())
   276  	assert.Equal(t, encoderPool, itPool.TagEncoder())
   277  	assert.Equal(t, decoderPool, itPool.TagDecoder())
   278  	assert.Equal(t, idPool, itPool.ID())
   279  }
   280  
   281  //nolint:dupl
   282  func TestSeriesLimit_FetchTagged(t *testing.T) {
   283  	ctrl := gomock.NewController(t)
   284  	defer ctrl.Finish()
   285  
   286  	// mock the host queue to return a result with a single series, this results in 3 series total, one per shard.
   287  	sess := setupMultipleInstanceCluster(t, ctrl, func(op op, host topology.Host) {
   288  		fOp := op.(*fetchTaggedOp)
   289  		assert.Equal(t, int64(2), *fOp.request.SeriesLimit)
   290  		shardID := strings.Split(host.ID(), "-")[2]
   291  		op.CompletionFn()(fetchTaggedResultAccumulatorOpts{
   292  			host: host,
   293  			response: &rpc.FetchTaggedResult_{
   294  				Exhaustive: true,
   295  				Elements: []*rpc.FetchTaggedIDResult_{
   296  					{
   297  						// use shard id for the metric id so it's stable across replicas.
   298  						ID: []byte(shardID),
   299  					},
   300  				},
   301  			},
   302  		}, nil)
   303  	})
   304  
   305  	iters, meta, err := sess.fetchTaggedAttempt(context.TODO(), ident.StringID("ns"),
   306  		index.Query{Query: idx.NewAllQuery()},
   307  		index.QueryOptions{
   308  			// set to 6 so we can test the instance series limit is 2 (6 /3 instances per replica * InstanceMultiple)
   309  			SeriesLimit:      6,
   310  			InstanceMultiple: 1,
   311  		})
   312  	require.NoError(t, err)
   313  	require.NotNil(t, iters)
   314  	// expect a series per shard.
   315  	require.Equal(t, 3, iters.Len())
   316  	require.True(t, meta.Exhaustive)
   317  	require.NoError(t, sess.Close())
   318  }
   319  
   320  //nolint:dupl
   321  func TestSeriesLimit_FetchTaggedIDs(t *testing.T) {
   322  	ctrl := gomock.NewController(t)
   323  	defer ctrl.Finish()
   324  
   325  	// mock the host queue to return a result with a single series, this results in 3 series total, one per shard.
   326  	sess := setupMultipleInstanceCluster(t, ctrl, func(op op, host topology.Host) {
   327  		fOp := op.(*fetchTaggedOp)
   328  		assert.Equal(t, int64(2), *fOp.request.SeriesLimit)
   329  		shardID := strings.Split(host.ID(), "-")[2]
   330  		op.CompletionFn()(fetchTaggedResultAccumulatorOpts{
   331  			host: host,
   332  			response: &rpc.FetchTaggedResult_{
   333  				Exhaustive: true,
   334  				Elements: []*rpc.FetchTaggedIDResult_{
   335  					{
   336  						// use shard id for the metric id so it's stable across replicas.
   337  						ID: []byte(shardID),
   338  					},
   339  				},
   340  			},
   341  		}, nil)
   342  	})
   343  
   344  	iter, meta, err := sess.fetchTaggedIDsAttempt(context.TODO(), ident.StringID("ns"),
   345  		index.Query{Query: idx.NewAllQuery()},
   346  		index.QueryOptions{
   347  			// set to 6 so we can test the instance series limit is 2 (6 /3 instances per replica * InstanceMultiple)
   348  			SeriesLimit:      6,
   349  			InstanceMultiple: 1,
   350  		})
   351  	require.NoError(t, err)
   352  	require.NotNil(t, iter)
   353  	// expect a series per shard.
   354  	require.Equal(t, 3, iter.Remaining())
   355  	require.True(t, meta.Exhaustive)
   356  	require.NoError(t, sess.Close())
   357  }
   358  
   359  //nolint:dupl
   360  func TestSeriesLimit_Aggregate(t *testing.T) {
   361  	ctrl := gomock.NewController(t)
   362  	defer ctrl.Finish()
   363  
   364  	// mock the host queue to return a result with a single series, this results in 3 series total, one per shard.
   365  	sess := setupMultipleInstanceCluster(t, ctrl, func(op op, host topology.Host) {
   366  		aOp := op.(*aggregateOp)
   367  		assert.Equal(t, int64(2), *aOp.request.SeriesLimit)
   368  		shardID := strings.Split(host.ID(), "-")[2]
   369  		op.CompletionFn()(aggregateResultAccumulatorOpts{
   370  			host: host,
   371  			response: &rpc.AggregateQueryRawResult_{
   372  				Exhaustive: true,
   373  				Results: []*rpc.AggregateQueryRawResultTagNameElement{
   374  					{
   375  						// use shard id for the tag value so it's stable across replicas.
   376  						TagName: []byte(shardID),
   377  						TagValues: []*rpc.AggregateQueryRawResultTagValueElement{
   378  							{
   379  								TagValue: []byte("value"),
   380  							},
   381  						},
   382  					},
   383  				},
   384  			},
   385  		}, nil)
   386  	})
   387  	iter, meta, err := sess.aggregateAttempt(context.TODO(), ident.StringID("ns"),
   388  		index.Query{Query: idx.NewAllQuery()},
   389  		index.AggregationOptions{
   390  			QueryOptions: index.QueryOptions{
   391  				// set to 6 so we can test the instance series limit is 2 (6 /3 instances per replica * InstanceMultiple)
   392  				SeriesLimit:      6,
   393  				InstanceMultiple: 1,
   394  			},
   395  		})
   396  	require.NoError(t, err)
   397  	require.NotNil(t, iter)
   398  	require.Equal(t, 3, iter.Remaining())
   399  	require.True(t, meta.Exhaustive)
   400  	require.NoError(t, sess.Close())
   401  }
   402  
   403  func TestIterationStrategy_FetchTagged(t *testing.T) {
   404  	ctrl := gomock.NewController(t)
   405  	defer ctrl.Finish()
   406  
   407  	// mock the host queue to return a result with a single series, this results in 3 series total, one per shard.
   408  	sess := setupMultipleInstanceCluster(t, ctrl, func(op op, host topology.Host) {
   409  		fOp := op.(*fetchTaggedOp)
   410  		assert.Equal(t, int64(2), *fOp.request.SeriesLimit)
   411  		shardID := strings.Split(host.ID(), "-")[2]
   412  		op.CompletionFn()(fetchTaggedResultAccumulatorOpts{
   413  			host: host,
   414  			response: &rpc.FetchTaggedResult_{
   415  				Exhaustive: true,
   416  				Elements: []*rpc.FetchTaggedIDResult_{
   417  					{
   418  						// use shard id for the metric id so it's stable across replicas.
   419  						ID: []byte(shardID),
   420  					},
   421  				},
   422  			},
   423  		}, nil)
   424  	})
   425  
   426  	stategy := encoding.IterateHighestFrequencyValue
   427  	iters, meta, err := sess.fetchTaggedAttempt(context.TODO(), ident.StringID("ns"),
   428  		index.Query{Query: idx.NewAllQuery()},
   429  		index.QueryOptions{
   430  			// set to 6 so we can test the instance series limit is 2 (6 /3 instances per replica * InstanceMultiple)
   431  			SeriesLimit:                   6,
   432  			InstanceMultiple:              1,
   433  			IterateEqualTimestampStrategy: &stategy,
   434  		})
   435  	require.NoError(t, err)
   436  	require.NotNil(t, iters)
   437  
   438  	// expect a series per shard.
   439  	require.Equal(t, 3, iters.Len())
   440  
   441  	// Confirm propagated strategy.
   442  	for _, i := range iters.Iters() {
   443  		require.Equal(t, stategy, i.IterateEqualTimestampStrategy())
   444  	}
   445  
   446  	require.True(t, meta.Exhaustive)
   447  	require.NoError(t, sess.Close())
   448  }
   449  
   450  func TestSessionClusterConnectConsistencyLevelAny(t *testing.T) {
   451  	ctrl := gomock.NewController(t)
   452  	defer ctrl.Finish()
   453  
   454  	level := topology.ConnectConsistencyLevelAny
   455  	for i := 0; i <= 3; i++ {
   456  		testSessionClusterConnectConsistencyLevel(t, ctrl, level, i, outcomeSuccess)
   457  	}
   458  }
   459  
   460  func TestDedicatedConnection(t *testing.T) {
   461  	ctrl := xtest.NewController(t)
   462  	defer ctrl.Finish()
   463  
   464  	var (
   465  		shardID = uint32(32)
   466  
   467  		topoMap = topology.NewMockMap(ctrl)
   468  
   469  		local   = mockHost(ctrl, "h0", "local")
   470  		remote1 = mockHost(ctrl, "h1", "remote1")
   471  		remote2 = mockHost(ctrl, "h2", "remote2")
   472  
   473  		availableShard    = shard.NewShard(shardID).SetState(shard.Available)
   474  		initializingShard = shard.NewShard(shardID).SetState(shard.Initializing)
   475  	)
   476  
   477  	topoMap.EXPECT().RouteShardForEach(shardID, gomock.Any()).DoAndReturn(
   478  		func(shardID uint32, callback func(int, shard.Shard, topology.Host)) error {
   479  			callback(0, availableShard, local)
   480  			callback(1, initializingShard, remote1)
   481  			callback(2, availableShard, remote2)
   482  			return nil
   483  		}).Times(4)
   484  
   485  	s := session{origin: local}
   486  	s.opts = NewOptions().SetNewConnectionFn(noopNewConnection)
   487  	s.healthCheckNewConnFn = testHealthCheck(nil, false)
   488  	s.state.status = statusOpen
   489  	s.state.topoMap = topoMap
   490  
   491  	_, ch, err := s.DedicatedConnection(shardID, DedicatedConnectionOptions{})
   492  	require.NoError(t, err)
   493  	assert.Equal(t, "remote1", asNoopPooledChannel(ch).address)
   494  
   495  	_, ch2, err := s.DedicatedConnection(shardID, DedicatedConnectionOptions{ShardStateFilter: shard.Available})
   496  	require.NoError(t, err)
   497  	assert.Equal(t, "remote2", asNoopPooledChannel(ch2).address)
   498  
   499  	s.healthCheckNewConnFn = testHealthCheck(nil, true)
   500  	_, ch3, err := s.DedicatedConnection(shardID, DedicatedConnectionOptions{BootstrappedNodesOnly: true})
   501  	require.NoError(t, err)
   502  	assert.Equal(t, "remote1", asNoopPooledChannel(ch3).address)
   503  
   504  	healthErr := errors.New("unhealthy")
   505  	s.healthCheckNewConnFn = testHealthCheck(healthErr, false)
   506  
   507  	var channels []*noopPooledChannel
   508  	s.opts = NewOptions().SetNewConnectionFn(func(_ string, _ string, _ Options) (Channel, rpc.TChanNode, error) {
   509  		c := &noopPooledChannel{"test", 0}
   510  		channels = append(channels, c)
   511  		return c, nil, nil
   512  	})
   513  	_, _, err = s.DedicatedConnection(shardID, DedicatedConnectionOptions{})
   514  	require.NotNil(t, err)
   515  	multiErr, ok := err.(xerror.MultiError) // nolint: errorlint
   516  	assert.True(t, ok, "expecting MultiError")
   517  	assert.True(t, multiErr.Contains(healthErr))
   518  	// 2 because of 2 remote hosts failing health check
   519  	assert.Len(t, channels, 2)
   520  	assert.Equal(t, 1, channels[0].CloseCount())
   521  	assert.Equal(t, 1, channels[1].CloseCount())
   522  }
   523  
   524  func testSessionClusterConnectConsistencyLevel(
   525  	t *testing.T,
   526  	ctrl *gomock.Controller,
   527  	level topology.ConnectConsistencyLevel,
   528  	failures int,
   529  	expected outcome,
   530  ) {
   531  	opts := newSessionTestOptions()
   532  	opts = opts.SetClusterConnectTimeout(10 * clusterConnectWaitInterval)
   533  	opts = opts.SetClusterConnectConsistencyLevel(level)
   534  	s, err := newSession(opts)
   535  	assert.NoError(t, err)
   536  	session := s.(*session)
   537  
   538  	var failingConns int32
   539  	session.newHostQueueFn = func(
   540  		host topology.Host,
   541  		opts hostQueueOpts,
   542  	) (hostQueue, error) {
   543  		hostQueue := NewMockhostQueue(ctrl)
   544  		hostQueue.EXPECT().Open().Times(1)
   545  		hostQueue.EXPECT().Host().Return(host).AnyTimes()
   546  		if atomic.AddInt32(&failingConns, 1) <= int32(failures) {
   547  			hostQueue.EXPECT().ConnectionCount().Return(0).AnyTimes()
   548  		} else {
   549  			min := opts.opts.MinConnectionCount()
   550  			hostQueue.EXPECT().ConnectionCount().Return(min).AnyTimes()
   551  		}
   552  		hostQueue.EXPECT().Close().AnyTimes()
   553  		return hostQueue, nil
   554  	}
   555  
   556  	err = session.Open()
   557  	switch expected {
   558  	case outcomeSuccess:
   559  		assert.NoError(t, err)
   560  	case outcomeFail:
   561  		assert.Error(t, err)
   562  		assert.Equal(t, ErrClusterConnectTimeout, err)
   563  	}
   564  }
   565  
   566  // setupMultipleInstanceCluster sets up a db cluster with 3 shards and 3 replicas. The 3 shards are distributed across
   567  // 9 hosts, so each host has 1 replica of 1 shard.
   568  // the function passed is executed when an operation is enqueued. the provided fn is dispatched in a separate goroutine
   569  // to simulate the queue processing. this also allows the function to access the state locks.
   570  func setupMultipleInstanceCluster(t *testing.T, ctrl *gomock.Controller, fn func(op op, host topology.Host)) *session {
   571  	opts := newSessionTestOptions()
   572  	shardSet := sessionTestShardSet()
   573  	var hostShardSets []topology.HostShardSet
   574  	// setup 9 hosts so there are 3 instances per replica. Each instance has a single shard.
   575  	for i := 0; i < sessionTestReplicas; i++ {
   576  		for j := 0; j < sessionTestShards; j++ {
   577  			id := fmt.Sprintf("testhost-%d-%d", i, j)
   578  			host := topology.NewHost(id, fmt.Sprintf("%s:9000", id))
   579  			hostShard, _ := sharding.NewShardSet([]shard.Shard{shardSet.All()[j]}, shardSet.HashFn())
   580  			hostShardSet := topology.NewHostShardSet(host, hostShard)
   581  			hostShardSets = append(hostShardSets, hostShardSet)
   582  		}
   583  	}
   584  
   585  	opts = opts.SetTopologyInitializer(topology.NewStaticInitializer(
   586  		topology.NewStaticOptions().
   587  			SetReplicas(sessionTestReplicas).
   588  			SetShardSet(shardSet).
   589  			SetHostShardSets(hostShardSets)))
   590  	s, err := newSession(opts)
   591  	assert.NoError(t, err)
   592  	sess := s.(*session)
   593  
   594  	sess.newHostQueueFn = func(host topology.Host, hostQueueOpts hostQueueOpts) (hostQueue, error) {
   595  		q := NewMockhostQueue(ctrl)
   596  		q.EXPECT().Open()
   597  		q.EXPECT().ConnectionCount().Return(hostQueueOpts.opts.MinConnectionCount()).AnyTimes()
   598  		q.EXPECT().Host().Return(host).AnyTimes()
   599  		q.EXPECT().Enqueue(gomock.Any()).Do(func(op op) error {
   600  			go func() {
   601  				fn(op, host)
   602  			}()
   603  			return nil
   604  		}).Return(nil)
   605  		q.EXPECT().Close()
   606  		return q, nil
   607  	}
   608  
   609  	require.NoError(t, sess.Open())
   610  	return sess
   611  }
   612  
   613  func mockHostQueues(
   614  	ctrl *gomock.Controller,
   615  	s *session,
   616  	replicas int,
   617  	enqueueFns []testEnqueueFn,
   618  ) *sync.WaitGroup {
   619  	var enqueueWg sync.WaitGroup
   620  	enqueueWg.Add(replicas)
   621  	idx := 0
   622  	s.newHostQueueFn = func(
   623  		host topology.Host,
   624  		opts hostQueueOpts,
   625  	) (hostQueue, error) {
   626  		// Make a copy of the enqueue fns for each host
   627  		hostEnqueueFns := make([]testEnqueueFn, len(enqueueFns))
   628  		copy(hostEnqueueFns, enqueueFns)
   629  
   630  		enqueuedIdx := idx
   631  		hostQueue := NewMockhostQueue(ctrl)
   632  		hostQueue.EXPECT().Open()
   633  		hostQueue.EXPECT().Host().Return(host).AnyTimes()
   634  		// Take two attempts to establish min connection count
   635  		hostQueue.EXPECT().ConnectionCount().Return(0).Times(sessionTestShards)
   636  		hostQueue.EXPECT().ConnectionCount().Return(opts.opts.MinConnectionCount()).Times(sessionTestShards)
   637  		var expectNextEnqueueFn func(fns []testEnqueueFn)
   638  		expectNextEnqueueFn = func(fns []testEnqueueFn) {
   639  			fn := fns[0]
   640  			fns = fns[1:]
   641  			hostQueue.EXPECT().Enqueue(gomock.Any()).Do(func(op op) error {
   642  				fn(enqueuedIdx, op)
   643  				if len(fns) > 0 {
   644  					expectNextEnqueueFn(fns)
   645  				} else {
   646  					enqueueWg.Done()
   647  				}
   648  				return nil
   649  			}).Return(nil)
   650  		}
   651  		if len(hostEnqueueFns) > 0 {
   652  			expectNextEnqueueFn(hostEnqueueFns)
   653  		}
   654  		hostQueue.EXPECT().Close()
   655  		idx++
   656  		return hostQueue, nil
   657  	}
   658  	return &enqueueWg
   659  }
   660  
   661  func mockHost(ctrl *gomock.Controller, id, address string) topology.Host {
   662  	host := topology.NewMockHost(ctrl)
   663  	host.EXPECT().ID().Return(id).AnyTimes()
   664  	host.EXPECT().Address().Return(address).AnyTimes()
   665  	return host
   666  }
   667  
   668  func testHealthCheck(err error, bootstrappedNodesOnly bool) func(rpc.TChanNode, Options, bool) error {
   669  	return func(client rpc.TChanNode, opts Options, checkBootstrapped bool) error {
   670  		if checkBootstrapped != bootstrappedNodesOnly {
   671  			return fmt.Errorf("checkBootstrapped value (%t) != expected (%t)",
   672  				checkBootstrapped, bootstrappedNodesOnly)
   673  		}
   674  		return err
   675  	}
   676  }
   677  
   678  func noopNewConnection(
   679  	_ string,
   680  	addr string,
   681  	_ Options,
   682  ) (Channel, rpc.TChanNode, error) {
   683  	return &noopPooledChannel{addr, 0}, nil, nil
   684  }