github.com/m3db/m3@v1.5.0/src/dbnode/client/session_test.go (about)

     1  // Copyright (c) 2016 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package client
    22  
    23  import (
    24  	"context"
    25  	"errors"
    26  	"fmt"
    27  	"strings"
    28  	"sync"
    29  	"sync/atomic"
    30  	"testing"
    31  	"time"
    32  
    33  	"github.com/m3db/m3/src/cluster/shard"
    34  	"github.com/m3db/m3/src/dbnode/encoding"
    35  	"github.com/m3db/m3/src/dbnode/generated/thrift/rpc"
    36  	"github.com/m3db/m3/src/dbnode/sharding"
    37  	"github.com/m3db/m3/src/dbnode/storage/index"
    38  	"github.com/m3db/m3/src/dbnode/topology"
    39  	"github.com/m3db/m3/src/dbnode/x/xpool"
    40  	"github.com/m3db/m3/src/m3ninx/idx"
    41  	xerror "github.com/m3db/m3/src/x/errors"
    42  	"github.com/m3db/m3/src/x/ident"
    43  	xretry "github.com/m3db/m3/src/x/retry"
    44  	"github.com/m3db/m3/src/x/serialize"
    45  	xtest "github.com/m3db/m3/src/x/test"
    46  
    47  	"github.com/golang/mock/gomock"
    48  	"github.com/stretchr/testify/assert"
    49  	"github.com/stretchr/testify/require"
    50  )
    51  
    52  const (
    53  	sessionTestReplicas = 3
    54  	sessionTestShards   = 3
    55  )
    56  
    57  type outcome int
    58  
    59  const (
    60  	outcomeSuccess outcome = iota
    61  	outcomeFail
    62  )
    63  
    64  type testEnqueueFn func(idx int, op op)
    65  
    66  var (
    67  	// NB: allocating once to speedup tests.
    68  	_testSessionOpts = NewOptions().
    69  		SetCheckedBytesWrapperPoolSize(1).
    70  		SetFetchBatchOpPoolSize(1).
    71  		SetHostQueueOpsArrayPoolSize(1).
    72  		SetTagEncoderPoolSize(1).
    73  		SetWriteOpPoolSize(1).
    74  		SetWriteTaggedOpPoolSize(1).
    75  		SetSeriesIteratorPoolSize(1)
    76  )
    77  
    78  func testContext() context.Context {
    79  	// nolint: govet
    80  	ctx, _ := context.WithTimeout(context.Background(), time.Minute) //nolint
    81  	return ctx
    82  }
    83  
    84  func newSessionTestOptions() Options {
    85  	return applySessionTestOptions(_testSessionOpts)
    86  }
    87  
    88  func sessionTestShardSet() sharding.ShardSet {
    89  	var ids []uint32
    90  	for i := uint32(0); i < uint32(sessionTestShards); i++ {
    91  		ids = append(ids, i)
    92  	}
    93  
    94  	shards := sharding.NewShards(ids, shard.Available)
    95  	hashFn := func(id ident.ID) uint32 { return 0 }
    96  	shardSet, _ := sharding.NewShardSet(shards, hashFn)
    97  	return shardSet
    98  }
    99  
   100  func testHostName(i int) string { return fmt.Sprintf("testhost%d", i) }
   101  
   102  func sessionTestHostAndShards(
   103  	shardSet sharding.ShardSet,
   104  ) []topology.HostShardSet {
   105  	var hosts []topology.Host
   106  	for i := 0; i < sessionTestReplicas; i++ {
   107  		id := testHostName(i)
   108  		host := topology.NewHost(id, fmt.Sprintf("%s:9000", id))
   109  		hosts = append(hosts, host)
   110  	}
   111  
   112  	var hostShardSets []topology.HostShardSet
   113  	for _, host := range hosts {
   114  		hostShardSet := topology.NewHostShardSet(host, shardSet)
   115  		hostShardSets = append(hostShardSets, hostShardSet)
   116  	}
   117  	return hostShardSets
   118  }
   119  
   120  func applySessionTestOptions(opts Options) Options {
   121  	shardSet := sessionTestShardSet()
   122  	return opts.
   123  		// Some of the test mocks expect things to only happen once, so disable retries
   124  		// for the unit tests.
   125  		SetWriteRetrier(xretry.NewRetrier(xretry.NewOptions().SetMaxRetries(0))).
   126  		SetFetchRetrier(xretry.NewRetrier(xretry.NewOptions().SetMaxRetries(0))).
   127  		SetSeriesIteratorPoolSize(0).
   128  		SetWriteOpPoolSize(0).
   129  		SetWriteTaggedOpPoolSize(0).
   130  		SetFetchBatchOpPoolSize(0).
   131  		SetTopologyInitializer(topology.NewStaticInitializer(
   132  			topology.NewStaticOptions().
   133  				SetReplicas(sessionTestReplicas).
   134  				SetShardSet(shardSet).
   135  				SetHostShardSets(sessionTestHostAndShards(shardSet))))
   136  }
   137  
   138  func newTestHostQueue(opts Options) *queue {
   139  	hq, err := newHostQueue(h, hostQueueOpts{
   140  		writeBatchRawRequestPool:                     testWriteBatchRawPool,
   141  		writeBatchRawV2RequestPool:                   testWriteBatchRawV2Pool,
   142  		writeBatchRawRequestElementArrayPool:         testWriteArrayPool,
   143  		writeBatchRawV2RequestElementArrayPool:       testWriteV2ArrayPool,
   144  		writeTaggedBatchRawRequestPool:               testWriteTaggedBatchRawPool,
   145  		writeTaggedBatchRawV2RequestPool:             testWriteTaggedBatchRawV2Pool,
   146  		writeTaggedBatchRawRequestElementArrayPool:   testWriteTaggedArrayPool,
   147  		writeTaggedBatchRawV2RequestElementArrayPool: testWriteTaggedV2ArrayPool,
   148  		fetchBatchRawV2RequestPool:                   testFetchBatchRawV2Pool,
   149  		fetchBatchRawV2RequestElementArrayPool:       testFetchBatchRawV2ArrayPool,
   150  		opts:                                         opts,
   151  	})
   152  	if err != nil {
   153  		panic(err)
   154  	}
   155  	return hq.(*queue)
   156  }
   157  
   158  func TestSessionCreationFailure(t *testing.T) {
   159  	topoOpts := topology.NewDynamicOptions()
   160  	topoInit := topology.NewDynamicInitializer(topoOpts)
   161  	opt := newSessionTestOptions().SetTopologyInitializer(topoInit)
   162  	_, err := newSession(opt)
   163  	assert.Error(t, err)
   164  }
   165  
   166  func TestSessionShardID(t *testing.T) {
   167  	ctrl := gomock.NewController(t)
   168  	defer ctrl.Finish()
   169  
   170  	opts := newSessionTestOptions()
   171  	s, err := newSession(opts)
   172  	assert.NoError(t, err)
   173  
   174  	_, err = s.ShardID(ident.StringID("foo"))
   175  	assert.Error(t, err)
   176  	assert.Equal(t, ErrSessionStatusNotOpen, err)
   177  
   178  	mockHostQueues(ctrl, s.(*session), sessionTestReplicas, nil)
   179  
   180  	require.NoError(t, s.Open())
   181  
   182  	// The shard set we create in newSessionTestOptions always hashes to uint32
   183  	shard, err := s.ShardID(ident.StringID("foo"))
   184  	require.NoError(t, err)
   185  	assert.Equal(t, uint32(0), shard)
   186  
   187  	assert.NoError(t, s.Close())
   188  }
   189  
   190  func TestSessionClusterConnectConsistencyLevelAll(t *testing.T) {
   191  	ctrl := gomock.NewController(t)
   192  	defer ctrl.Finish()
   193  
   194  	level := topology.ConnectConsistencyLevelAll
   195  	testSessionClusterConnectConsistencyLevel(t, ctrl, level, 0, outcomeSuccess)
   196  	for i := 1; i <= 3; i++ {
   197  		testSessionClusterConnectConsistencyLevel(t, ctrl, level, i, outcomeFail)
   198  	}
   199  }
   200  
   201  func TestSessionClusterConnectConsistencyLevelMajority(t *testing.T) {
   202  	ctrl := gomock.NewController(t)
   203  	defer ctrl.Finish()
   204  
   205  	level := topology.ConnectConsistencyLevelMajority
   206  	for i := 0; i <= 1; i++ {
   207  		testSessionClusterConnectConsistencyLevel(t, ctrl, level, i, outcomeSuccess)
   208  	}
   209  	for i := 2; i <= 3; i++ {
   210  		testSessionClusterConnectConsistencyLevel(t, ctrl, level, i, outcomeFail)
   211  	}
   212  }
   213  
   214  func TestSessionClusterConnectConsistencyLevelOne(t *testing.T) {
   215  	ctrl := gomock.NewController(t)
   216  	defer ctrl.Finish()
   217  
   218  	level := topology.ConnectConsistencyLevelOne
   219  	for i := 0; i <= 2; i++ {
   220  		testSessionClusterConnectConsistencyLevel(t, ctrl, level, i, outcomeSuccess)
   221  	}
   222  	testSessionClusterConnectConsistencyLevel(t, ctrl, level, 3, outcomeFail)
   223  }
   224  
   225  func TestSessionClusterConnectConsistencyLevelNone(t *testing.T) {
   226  	ctrl := gomock.NewController(t)
   227  	defer ctrl.Finish()
   228  
   229  	level := topology.ConnectConsistencyLevelNone
   230  	for i := 0; i <= 3; i++ {
   231  		testSessionClusterConnectConsistencyLevel(t, ctrl, level, i, outcomeSuccess)
   232  	}
   233  }
   234  
   235  func TestIteratorPools(t *testing.T) {
   236  	s := session{}
   237  	itPool, err := s.IteratorPools()
   238  
   239  	assert.EqualError(t, err, ErrSessionStatusNotOpen.Error())
   240  	assert.Nil(t, itPool)
   241  
   242  	multiReaderIteratorArray := encoding.NewMultiReaderIteratorArrayPool(nil)
   243  	multiReaderIteratorPool := encoding.NewMultiReaderIteratorPool(nil)
   244  	seriesIteratorPool := encoding.NewSeriesIteratorPool(nil)
   245  	checkedBytesWrapperPool := xpool.NewCheckedBytesWrapperPool(nil)
   246  	idPool := ident.NewPool(nil, ident.PoolOptions{})
   247  	encoderPool := serialize.NewTagEncoderPool(nil, nil)
   248  	decoderPool := serialize.NewTagDecoderPool(nil, nil)
   249  
   250  	s.pools = sessionPools{
   251  		multiReaderIteratorArray: multiReaderIteratorArray,
   252  		multiReaderIterator:      multiReaderIteratorPool,
   253  		seriesIterator:           seriesIteratorPool,
   254  		checkedBytesWrapper:      checkedBytesWrapperPool,
   255  		id:                       idPool,
   256  		tagEncoder:               encoderPool,
   257  		tagDecoder:               decoderPool,
   258  	}
   259  
   260  	// Error expected if state is not open
   261  	itPool, err = s.IteratorPools()
   262  	assert.EqualError(t, err, ErrSessionStatusNotOpen.Error())
   263  	assert.Nil(t, itPool)
   264  
   265  	s.state.status = statusOpen
   266  
   267  	itPool, err = s.IteratorPools()
   268  	require.NoError(t, err)
   269  	assert.Equal(t, multiReaderIteratorArray, itPool.MultiReaderIteratorArray())
   270  	assert.Equal(t, multiReaderIteratorPool, itPool.MultiReaderIterator())
   271  	assert.Equal(t, seriesIteratorPool, itPool.SeriesIterator())
   272  	assert.Equal(t, checkedBytesWrapperPool, itPool.CheckedBytesWrapper())
   273  	assert.Equal(t, encoderPool, itPool.TagEncoder())
   274  	assert.Equal(t, decoderPool, itPool.TagDecoder())
   275  	assert.Equal(t, idPool, itPool.ID())
   276  }
   277  
   278  //nolint:dupl
   279  func TestSeriesLimit_FetchTagged(t *testing.T) {
   280  	ctrl := gomock.NewController(t)
   281  	defer ctrl.Finish()
   282  
   283  	// mock the host queue to return a result with a single series, this results in 3 series total, one per shard.
   284  	sess := setupMultipleInstanceCluster(t, ctrl, func(op op, host topology.Host) {
   285  		fOp := op.(*fetchTaggedOp)
   286  		assert.Equal(t, int64(2), *fOp.request.SeriesLimit)
   287  		shardID := strings.Split(host.ID(), "-")[2]
   288  		op.CompletionFn()(fetchTaggedResultAccumulatorOpts{
   289  			host: host,
   290  			response: &rpc.FetchTaggedResult_{
   291  				Exhaustive: true,
   292  				Elements: []*rpc.FetchTaggedIDResult_{
   293  					{
   294  						// use shard id for the metric id so it's stable across replicas.
   295  						ID: []byte(shardID),
   296  					},
   297  				},
   298  			},
   299  		}, nil)
   300  	})
   301  
   302  	iters, meta, err := sess.fetchTaggedAttempt(context.TODO(), ident.StringID("ns"),
   303  		index.Query{Query: idx.NewAllQuery()},
   304  		index.QueryOptions{
   305  			// set to 6 so we can test the instance series limit is 2 (6 /3 instances per replica * InstanceMultiple)
   306  			SeriesLimit:      6,
   307  			InstanceMultiple: 1,
   308  		})
   309  	require.NoError(t, err)
   310  	require.NotNil(t, iters)
   311  	// expect a series per shard.
   312  	require.Equal(t, 3, iters.Len())
   313  	require.True(t, meta.Exhaustive)
   314  	require.NoError(t, sess.Close())
   315  }
   316  
   317  //nolint:dupl
   318  func TestSeriesLimit_FetchTaggedIDs(t *testing.T) {
   319  	ctrl := gomock.NewController(t)
   320  	defer ctrl.Finish()
   321  
   322  	// mock the host queue to return a result with a single series, this results in 3 series total, one per shard.
   323  	sess := setupMultipleInstanceCluster(t, ctrl, func(op op, host topology.Host) {
   324  		fOp := op.(*fetchTaggedOp)
   325  		assert.Equal(t, int64(2), *fOp.request.SeriesLimit)
   326  		shardID := strings.Split(host.ID(), "-")[2]
   327  		op.CompletionFn()(fetchTaggedResultAccumulatorOpts{
   328  			host: host,
   329  			response: &rpc.FetchTaggedResult_{
   330  				Exhaustive: true,
   331  				Elements: []*rpc.FetchTaggedIDResult_{
   332  					{
   333  						// use shard id for the metric id so it's stable across replicas.
   334  						ID: []byte(shardID),
   335  					},
   336  				},
   337  			},
   338  		}, nil)
   339  	})
   340  
   341  	iter, meta, err := sess.fetchTaggedIDsAttempt(context.TODO(), ident.StringID("ns"),
   342  		index.Query{Query: idx.NewAllQuery()},
   343  		index.QueryOptions{
   344  			// set to 6 so we can test the instance series limit is 2 (6 /3 instances per replica * InstanceMultiple)
   345  			SeriesLimit:      6,
   346  			InstanceMultiple: 1,
   347  		})
   348  	require.NoError(t, err)
   349  	require.NotNil(t, iter)
   350  	// expect a series per shard.
   351  	require.Equal(t, 3, iter.Remaining())
   352  	require.True(t, meta.Exhaustive)
   353  	require.NoError(t, sess.Close())
   354  }
   355  
   356  //nolint:dupl
   357  func TestSeriesLimit_Aggregate(t *testing.T) {
   358  	ctrl := gomock.NewController(t)
   359  	defer ctrl.Finish()
   360  
   361  	// mock the host queue to return a result with a single series, this results in 3 series total, one per shard.
   362  	sess := setupMultipleInstanceCluster(t, ctrl, func(op op, host topology.Host) {
   363  		aOp := op.(*aggregateOp)
   364  		assert.Equal(t, int64(2), *aOp.request.SeriesLimit)
   365  		shardID := strings.Split(host.ID(), "-")[2]
   366  		op.CompletionFn()(aggregateResultAccumulatorOpts{
   367  			host: host,
   368  			response: &rpc.AggregateQueryRawResult_{
   369  				Exhaustive: true,
   370  				Results: []*rpc.AggregateQueryRawResultTagNameElement{
   371  					{
   372  						// use shard id for the tag value so it's stable across replicas.
   373  						TagName: []byte(shardID),
   374  						TagValues: []*rpc.AggregateQueryRawResultTagValueElement{
   375  							{
   376  								TagValue: []byte("value"),
   377  							},
   378  						},
   379  					},
   380  				},
   381  			},
   382  		}, nil)
   383  	})
   384  	iter, meta, err := sess.aggregateAttempt(context.TODO(), ident.StringID("ns"),
   385  		index.Query{Query: idx.NewAllQuery()},
   386  		index.AggregationOptions{
   387  			QueryOptions: index.QueryOptions{
   388  				// set to 6 so we can test the instance series limit is 2 (6 /3 instances per replica * InstanceMultiple)
   389  				SeriesLimit:      6,
   390  				InstanceMultiple: 1,
   391  			},
   392  		})
   393  	require.NoError(t, err)
   394  	require.NotNil(t, iter)
   395  	require.Equal(t, 3, iter.Remaining())
   396  	require.True(t, meta.Exhaustive)
   397  	require.NoError(t, sess.Close())
   398  }
   399  
   400  func TestIterationStrategy_FetchTagged(t *testing.T) {
   401  	ctrl := gomock.NewController(t)
   402  	defer ctrl.Finish()
   403  
   404  	// mock the host queue to return a result with a single series, this results in 3 series total, one per shard.
   405  	sess := setupMultipleInstanceCluster(t, ctrl, func(op op, host topology.Host) {
   406  		fOp := op.(*fetchTaggedOp)
   407  		assert.Equal(t, int64(2), *fOp.request.SeriesLimit)
   408  		shardID := strings.Split(host.ID(), "-")[2]
   409  		op.CompletionFn()(fetchTaggedResultAccumulatorOpts{
   410  			host: host,
   411  			response: &rpc.FetchTaggedResult_{
   412  				Exhaustive: true,
   413  				Elements: []*rpc.FetchTaggedIDResult_{
   414  					{
   415  						// use shard id for the metric id so it's stable across replicas.
   416  						ID: []byte(shardID),
   417  					},
   418  				},
   419  			},
   420  		}, nil)
   421  	})
   422  
   423  	stategy := encoding.IterateHighestFrequencyValue
   424  	iters, meta, err := sess.fetchTaggedAttempt(context.TODO(), ident.StringID("ns"),
   425  		index.Query{Query: idx.NewAllQuery()},
   426  		index.QueryOptions{
   427  			// set to 6 so we can test the instance series limit is 2 (6 /3 instances per replica * InstanceMultiple)
   428  			SeriesLimit:                   6,
   429  			InstanceMultiple:              1,
   430  			IterateEqualTimestampStrategy: &stategy,
   431  		})
   432  	require.NoError(t, err)
   433  	require.NotNil(t, iters)
   434  
   435  	// expect a series per shard.
   436  	require.Equal(t, 3, iters.Len())
   437  
   438  	// Confirm propagated strategy.
   439  	for _, i := range iters.Iters() {
   440  		require.Equal(t, stategy, i.IterateEqualTimestampStrategy())
   441  	}
   442  
   443  	require.True(t, meta.Exhaustive)
   444  	require.NoError(t, sess.Close())
   445  }
   446  
   447  func TestSessionClusterConnectConsistencyLevelAny(t *testing.T) {
   448  	ctrl := gomock.NewController(t)
   449  	defer ctrl.Finish()
   450  
   451  	level := topology.ConnectConsistencyLevelAny
   452  	for i := 0; i <= 3; i++ {
   453  		testSessionClusterConnectConsistencyLevel(t, ctrl, level, i, outcomeSuccess)
   454  	}
   455  }
   456  
   457  func TestDedicatedConnection(t *testing.T) {
   458  	ctrl := xtest.NewController(t)
   459  	defer ctrl.Finish()
   460  
   461  	var (
   462  		shardID = uint32(32)
   463  
   464  		topoMap = topology.NewMockMap(ctrl)
   465  
   466  		local   = mockHost(ctrl, "h0", "local")
   467  		remote1 = mockHost(ctrl, "h1", "remote1")
   468  		remote2 = mockHost(ctrl, "h2", "remote2")
   469  
   470  		availableShard    = shard.NewShard(shardID).SetState(shard.Available)
   471  		initializingShard = shard.NewShard(shardID).SetState(shard.Initializing)
   472  	)
   473  
   474  	topoMap.EXPECT().RouteShardForEach(shardID, gomock.Any()).DoAndReturn(
   475  		func(shardID uint32, callback func(int, shard.Shard, topology.Host)) error {
   476  			callback(0, availableShard, local)
   477  			callback(1, initializingShard, remote1)
   478  			callback(2, availableShard, remote2)
   479  			return nil
   480  		}).Times(4)
   481  
   482  	s := session{origin: local}
   483  	s.opts = NewOptions().SetNewConnectionFn(noopNewConnection)
   484  	s.healthCheckNewConnFn = testHealthCheck(nil, false)
   485  	s.state.status = statusOpen
   486  	s.state.topoMap = topoMap
   487  
   488  	_, ch, err := s.DedicatedConnection(shardID, DedicatedConnectionOptions{})
   489  	require.NoError(t, err)
   490  	assert.Equal(t, "remote1", asNoopPooledChannel(ch).address)
   491  
   492  	_, ch2, err := s.DedicatedConnection(shardID, DedicatedConnectionOptions{ShardStateFilter: shard.Available})
   493  	require.NoError(t, err)
   494  	assert.Equal(t, "remote2", asNoopPooledChannel(ch2).address)
   495  
   496  	s.healthCheckNewConnFn = testHealthCheck(nil, true)
   497  	_, ch3, err := s.DedicatedConnection(shardID, DedicatedConnectionOptions{BootstrappedNodesOnly: true})
   498  	require.NoError(t, err)
   499  	assert.Equal(t, "remote1", asNoopPooledChannel(ch3).address)
   500  
   501  	healthErr := errors.New("unhealthy")
   502  	s.healthCheckNewConnFn = testHealthCheck(healthErr, false)
   503  
   504  	var channels []*noopPooledChannel
   505  	s.opts = NewOptions().SetNewConnectionFn(func(_ string, _ string, _ Options) (Channel, rpc.TChanNode, error) {
   506  		c := &noopPooledChannel{"test", 0}
   507  		channels = append(channels, c)
   508  		return c, nil, nil
   509  	})
   510  	_, _, err = s.DedicatedConnection(shardID, DedicatedConnectionOptions{})
   511  	require.NotNil(t, err)
   512  	multiErr, ok := err.(xerror.MultiError) // nolint: errorlint
   513  	assert.True(t, ok, "expecting MultiError")
   514  	assert.True(t, multiErr.Contains(healthErr))
   515  	// 2 because of 2 remote hosts failing health check
   516  	assert.Len(t, channels, 2)
   517  	assert.Equal(t, 1, channels[0].CloseCount())
   518  	assert.Equal(t, 1, channels[1].CloseCount())
   519  }
   520  
   521  func testSessionClusterConnectConsistencyLevel(
   522  	t *testing.T,
   523  	ctrl *gomock.Controller,
   524  	level topology.ConnectConsistencyLevel,
   525  	failures int,
   526  	expected outcome,
   527  ) {
   528  	opts := newSessionTestOptions()
   529  	opts = opts.SetClusterConnectTimeout(10 * clusterConnectWaitInterval)
   530  	opts = opts.SetClusterConnectConsistencyLevel(level)
   531  	s, err := newSession(opts)
   532  	assert.NoError(t, err)
   533  	session := s.(*session)
   534  
   535  	var failingConns int32
   536  	session.newHostQueueFn = func(
   537  		host topology.Host,
   538  		opts hostQueueOpts,
   539  	) (hostQueue, error) {
   540  		hostQueue := NewMockhostQueue(ctrl)
   541  		hostQueue.EXPECT().Open().Times(1)
   542  		hostQueue.EXPECT().Host().Return(host).AnyTimes()
   543  		if atomic.AddInt32(&failingConns, 1) <= int32(failures) {
   544  			hostQueue.EXPECT().ConnectionCount().Return(0).AnyTimes()
   545  		} else {
   546  			min := opts.opts.MinConnectionCount()
   547  			hostQueue.EXPECT().ConnectionCount().Return(min).AnyTimes()
   548  		}
   549  		hostQueue.EXPECT().Close().AnyTimes()
   550  		return hostQueue, nil
   551  	}
   552  
   553  	err = session.Open()
   554  	switch expected {
   555  	case outcomeSuccess:
   556  		assert.NoError(t, err)
   557  	case outcomeFail:
   558  		assert.Error(t, err)
   559  		assert.Equal(t, ErrClusterConnectTimeout, err)
   560  	}
   561  }
   562  
   563  // setupMultipleInstanceCluster sets up a db cluster with 3 shards and 3 replicas. The 3 shards are distributed across
   564  // 9 hosts, so each host has 1 replica of 1 shard.
   565  // the function passed is executed when an operation is enqueued. the provided fn is dispatched in a separate goroutine
   566  // to simulate the queue processing. this also allows the function to access the state locks.
   567  func setupMultipleInstanceCluster(t *testing.T, ctrl *gomock.Controller, fn func(op op, host topology.Host)) *session {
   568  	opts := newSessionTestOptions()
   569  	shardSet := sessionTestShardSet()
   570  	var hostShardSets []topology.HostShardSet
   571  	// setup 9 hosts so there are 3 instances per replica. Each instance has a single shard.
   572  	for i := 0; i < sessionTestReplicas; i++ {
   573  		for j := 0; j < sessionTestShards; j++ {
   574  			id := fmt.Sprintf("testhost-%d-%d", i, j)
   575  			host := topology.NewHost(id, fmt.Sprintf("%s:9000", id))
   576  			hostShard, _ := sharding.NewShardSet([]shard.Shard{shardSet.All()[j]}, shardSet.HashFn())
   577  			hostShardSet := topology.NewHostShardSet(host, hostShard)
   578  			hostShardSets = append(hostShardSets, hostShardSet)
   579  		}
   580  	}
   581  
   582  	opts = opts.SetTopologyInitializer(topology.NewStaticInitializer(
   583  		topology.NewStaticOptions().
   584  			SetReplicas(sessionTestReplicas).
   585  			SetShardSet(shardSet).
   586  			SetHostShardSets(hostShardSets)))
   587  	s, err := newSession(opts)
   588  	assert.NoError(t, err)
   589  	sess := s.(*session)
   590  
   591  	sess.newHostQueueFn = func(host topology.Host, hostQueueOpts hostQueueOpts) (hostQueue, error) {
   592  		q := NewMockhostQueue(ctrl)
   593  		q.EXPECT().Open()
   594  		q.EXPECT().ConnectionCount().Return(hostQueueOpts.opts.MinConnectionCount()).AnyTimes()
   595  		q.EXPECT().Host().Return(host).AnyTimes()
   596  		q.EXPECT().Enqueue(gomock.Any()).Do(func(op op) error {
   597  			go func() {
   598  				fn(op, host)
   599  			}()
   600  			return nil
   601  		}).Return(nil)
   602  		q.EXPECT().Close()
   603  		return q, nil
   604  	}
   605  
   606  	require.NoError(t, sess.Open())
   607  	return sess
   608  }
   609  
   610  func mockHostQueues(
   611  	ctrl *gomock.Controller,
   612  	s *session,
   613  	replicas int,
   614  	enqueueFns []testEnqueueFn,
   615  ) *sync.WaitGroup {
   616  	var enqueueWg sync.WaitGroup
   617  	enqueueWg.Add(replicas)
   618  	idx := 0
   619  	s.newHostQueueFn = func(
   620  		host topology.Host,
   621  		opts hostQueueOpts,
   622  	) (hostQueue, error) {
   623  		// Make a copy of the enqueue fns for each host
   624  		hostEnqueueFns := make([]testEnqueueFn, len(enqueueFns))
   625  		copy(hostEnqueueFns, enqueueFns)
   626  
   627  		enqueuedIdx := idx
   628  		hostQueue := NewMockhostQueue(ctrl)
   629  		hostQueue.EXPECT().Open()
   630  		hostQueue.EXPECT().Host().Return(host).AnyTimes()
   631  		// Take two attempts to establish min connection count
   632  		hostQueue.EXPECT().ConnectionCount().Return(0).Times(sessionTestShards)
   633  		hostQueue.EXPECT().ConnectionCount().Return(opts.opts.MinConnectionCount()).Times(sessionTestShards)
   634  		var expectNextEnqueueFn func(fns []testEnqueueFn)
   635  		expectNextEnqueueFn = func(fns []testEnqueueFn) {
   636  			fn := fns[0]
   637  			fns = fns[1:]
   638  			hostQueue.EXPECT().Enqueue(gomock.Any()).Do(func(op op) error {
   639  				fn(enqueuedIdx, op)
   640  				if len(fns) > 0 {
   641  					expectNextEnqueueFn(fns)
   642  				} else {
   643  					enqueueWg.Done()
   644  				}
   645  				return nil
   646  			}).Return(nil)
   647  		}
   648  		if len(hostEnqueueFns) > 0 {
   649  			expectNextEnqueueFn(hostEnqueueFns)
   650  		}
   651  		hostQueue.EXPECT().Close()
   652  		idx++
   653  		return hostQueue, nil
   654  	}
   655  	return &enqueueWg
   656  }
   657  
   658  func mockHost(ctrl *gomock.Controller, id, address string) topology.Host {
   659  	host := topology.NewMockHost(ctrl)
   660  	host.EXPECT().ID().Return(id).AnyTimes()
   661  	host.EXPECT().Address().Return(address).AnyTimes()
   662  	return host
   663  }
   664  
   665  func testHealthCheck(err error, bootstrappedNodesOnly bool) func(rpc.TChanNode, Options, bool) error {
   666  	return func(client rpc.TChanNode, opts Options, checkBootstrapped bool) error {
   667  		if checkBootstrapped != bootstrappedNodesOnly {
   668  			return fmt.Errorf("checkBootstrapped value (%t) != expected (%t)",
   669  				checkBootstrapped, bootstrappedNodesOnly)
   670  		}
   671  		return err
   672  	}
   673  }
   674  
   675  func noopNewConnection(
   676  	_ string,
   677  	addr string,
   678  	_ Options,
   679  ) (Channel, rpc.TChanNode, error) {
   680  	return &noopPooledChannel{addr, 0}, nil, nil
   681  }