github.com/m3db/m3@v1.5.1-0.20231129193456-75a402aa583b/src/dbnode/client/session_test.go (about) 1 // Copyright (c) 2016 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package client 22 23 import ( 24 "context" 25 "errors" 26 "fmt" 27 "strings" 28 "sync" 29 "sync/atomic" 30 "testing" 31 "time" 32 33 "github.com/m3db/m3/src/cluster/shard" 34 "github.com/m3db/m3/src/dbnode/encoding" 35 "github.com/m3db/m3/src/dbnode/generated/thrift/rpc" 36 "github.com/m3db/m3/src/dbnode/sharding" 37 "github.com/m3db/m3/src/dbnode/storage/index" 38 "github.com/m3db/m3/src/dbnode/topology" 39 "github.com/m3db/m3/src/dbnode/x/xpool" 40 "github.com/m3db/m3/src/m3ninx/idx" 41 xerror "github.com/m3db/m3/src/x/errors" 42 "github.com/m3db/m3/src/x/ident" 43 xretry "github.com/m3db/m3/src/x/retry" 44 "github.com/m3db/m3/src/x/sampler" 45 "github.com/m3db/m3/src/x/serialize" 46 xtest "github.com/m3db/m3/src/x/test" 47 48 "github.com/golang/mock/gomock" 49 "github.com/stretchr/testify/assert" 50 "github.com/stretchr/testify/require" 51 ) 52 53 const ( 54 sessionTestReplicas = 3 55 sessionTestShards = 3 56 ) 57 58 type outcome int 59 60 const ( 61 outcomeSuccess outcome = iota 62 outcomeFail 63 ) 64 65 type testEnqueueFn func(idx int, op op) 66 67 // NB: allocating once to speedup tests. 68 var _testSessionOpts = NewOptions(). 69 SetCheckedBytesWrapperPoolSize(1). 70 SetFetchBatchOpPoolSize(1). 71 SetHostQueueOpsArrayPoolSize(1). 72 SetTagEncoderPoolSize(1). 73 SetWriteOpPoolSize(1). 74 SetWriteTaggedOpPoolSize(1). 75 SetSeriesIteratorPoolSize(1). 76 // Set 100% sample rate to test the code path that logs errors. 77 SetLogErrorSampleRate(sampler.Rate(1)). 78 SetLogHostFetchErrorSampleRate(sampler.Rate(1)). 79 SetLogHostWriteErrorSampleRate(sampler.Rate(1)) 80 81 func testContext() context.Context { 82 // nolint: govet 83 ctx, _ := context.WithTimeout(context.Background(), time.Minute) //nolint 84 return ctx 85 } 86 87 func newSessionTestOptions() Options { 88 return applySessionTestOptions(_testSessionOpts) 89 } 90 91 func sessionTestShardSet() sharding.ShardSet { 92 var ids []uint32 93 for i := uint32(0); i < uint32(sessionTestShards); i++ { 94 ids = append(ids, i) 95 } 96 97 shards := sharding.NewShards(ids, shard.Available) 98 hashFn := func(id ident.ID) uint32 { return 0 } 99 shardSet, _ := sharding.NewShardSet(shards, hashFn) 100 return shardSet 101 } 102 103 func testHostName(i int) string { return fmt.Sprintf("testhost%d", i) } 104 105 func sessionTestHostAndShards( 106 shardSet sharding.ShardSet, 107 ) []topology.HostShardSet { 108 var hosts []topology.Host 109 for i := 0; i < sessionTestReplicas; i++ { 110 id := testHostName(i) 111 host := topology.NewHost(id, fmt.Sprintf("%s:9000", id)) 112 hosts = append(hosts, host) 113 } 114 115 var hostShardSets []topology.HostShardSet 116 for _, host := range hosts { 117 hostShardSet := topology.NewHostShardSet(host, shardSet) 118 hostShardSets = append(hostShardSets, hostShardSet) 119 } 120 return hostShardSets 121 } 122 123 func applySessionTestOptions(opts Options) Options { 124 shardSet := sessionTestShardSet() 125 return opts. 126 // Some of the test mocks expect things to only happen once, so disable retries 127 // for the unit tests. 128 SetWriteRetrier(xretry.NewRetrier(xretry.NewOptions().SetMaxRetries(0))). 129 SetFetchRetrier(xretry.NewRetrier(xretry.NewOptions().SetMaxRetries(0))). 130 SetSeriesIteratorPoolSize(0). 131 SetWriteOpPoolSize(0). 132 SetWriteTaggedOpPoolSize(0). 133 SetFetchBatchOpPoolSize(0). 134 SetTopologyInitializer(topology.NewStaticInitializer( 135 topology.NewStaticOptions(). 136 SetReplicas(sessionTestReplicas). 137 SetShardSet(shardSet). 138 SetHostShardSets(sessionTestHostAndShards(shardSet)))) 139 } 140 141 func newTestHostQueue(opts Options) *queue { 142 hq, err := newHostQueue(h, hostQueueOpts{ 143 writeBatchRawRequestPool: testWriteBatchRawPool, 144 writeBatchRawV2RequestPool: testWriteBatchRawV2Pool, 145 writeBatchRawRequestElementArrayPool: testWriteArrayPool, 146 writeBatchRawV2RequestElementArrayPool: testWriteV2ArrayPool, 147 writeTaggedBatchRawRequestPool: testWriteTaggedBatchRawPool, 148 writeTaggedBatchRawV2RequestPool: testWriteTaggedBatchRawV2Pool, 149 writeTaggedBatchRawRequestElementArrayPool: testWriteTaggedArrayPool, 150 writeTaggedBatchRawV2RequestElementArrayPool: testWriteTaggedV2ArrayPool, 151 fetchBatchRawV2RequestPool: testFetchBatchRawV2Pool, 152 fetchBatchRawV2RequestElementArrayPool: testFetchBatchRawV2ArrayPool, 153 opts: opts, 154 }) 155 if err != nil { 156 panic(err) 157 } 158 return hq.(*queue) 159 } 160 161 func TestSessionCreationFailure(t *testing.T) { 162 topoOpts := topology.NewDynamicOptions() 163 topoInit := topology.NewDynamicInitializer(topoOpts) 164 opt := newSessionTestOptions().SetTopologyInitializer(topoInit) 165 _, err := newSession(opt) 166 assert.Error(t, err) 167 } 168 169 func TestSessionShardID(t *testing.T) { 170 ctrl := gomock.NewController(t) 171 defer ctrl.Finish() 172 173 opts := newSessionTestOptions() 174 s, err := newSession(opts) 175 assert.NoError(t, err) 176 177 _, err = s.ShardID(ident.StringID("foo")) 178 assert.Error(t, err) 179 assert.Equal(t, ErrSessionStatusNotOpen, err) 180 181 mockHostQueues(ctrl, s.(*session), sessionTestReplicas, nil) 182 183 require.NoError(t, s.Open()) 184 185 // The shard set we create in newSessionTestOptions always hashes to uint32 186 shard, err := s.ShardID(ident.StringID("foo")) 187 require.NoError(t, err) 188 assert.Equal(t, uint32(0), shard) 189 190 assert.NoError(t, s.Close()) 191 } 192 193 func TestSessionClusterConnectConsistencyLevelAll(t *testing.T) { 194 ctrl := gomock.NewController(t) 195 defer ctrl.Finish() 196 197 level := topology.ConnectConsistencyLevelAll 198 testSessionClusterConnectConsistencyLevel(t, ctrl, level, 0, outcomeSuccess) 199 for i := 1; i <= 3; i++ { 200 testSessionClusterConnectConsistencyLevel(t, ctrl, level, i, outcomeFail) 201 } 202 } 203 204 func TestSessionClusterConnectConsistencyLevelMajority(t *testing.T) { 205 ctrl := gomock.NewController(t) 206 defer ctrl.Finish() 207 208 level := topology.ConnectConsistencyLevelMajority 209 for i := 0; i <= 1; i++ { 210 testSessionClusterConnectConsistencyLevel(t, ctrl, level, i, outcomeSuccess) 211 } 212 for i := 2; i <= 3; i++ { 213 testSessionClusterConnectConsistencyLevel(t, ctrl, level, i, outcomeFail) 214 } 215 } 216 217 func TestSessionClusterConnectConsistencyLevelOne(t *testing.T) { 218 ctrl := gomock.NewController(t) 219 defer ctrl.Finish() 220 221 level := topology.ConnectConsistencyLevelOne 222 for i := 0; i <= 2; i++ { 223 testSessionClusterConnectConsistencyLevel(t, ctrl, level, i, outcomeSuccess) 224 } 225 testSessionClusterConnectConsistencyLevel(t, ctrl, level, 3, outcomeFail) 226 } 227 228 func TestSessionClusterConnectConsistencyLevelNone(t *testing.T) { 229 ctrl := gomock.NewController(t) 230 defer ctrl.Finish() 231 232 level := topology.ConnectConsistencyLevelNone 233 for i := 0; i <= 3; i++ { 234 testSessionClusterConnectConsistencyLevel(t, ctrl, level, i, outcomeSuccess) 235 } 236 } 237 238 func TestIteratorPools(t *testing.T) { 239 s := session{} 240 itPool, err := s.IteratorPools() 241 242 assert.EqualError(t, err, ErrSessionStatusNotOpen.Error()) 243 assert.Nil(t, itPool) 244 245 multiReaderIteratorArray := encoding.NewMultiReaderIteratorArrayPool(nil) 246 multiReaderIteratorPool := encoding.NewMultiReaderIteratorPool(nil) 247 seriesIteratorPool := encoding.NewSeriesIteratorPool(nil) 248 checkedBytesWrapperPool := xpool.NewCheckedBytesWrapperPool(nil) 249 idPool := ident.NewPool(nil, ident.PoolOptions{}) 250 encoderPool := serialize.NewTagEncoderPool(nil, nil) 251 decoderPool := serialize.NewTagDecoderPool(nil, nil) 252 253 s.pools = sessionPools{ 254 multiReaderIteratorArray: multiReaderIteratorArray, 255 multiReaderIterator: multiReaderIteratorPool, 256 seriesIterator: seriesIteratorPool, 257 checkedBytesWrapper: checkedBytesWrapperPool, 258 id: idPool, 259 tagEncoder: encoderPool, 260 tagDecoder: decoderPool, 261 } 262 263 // Error expected if state is not open 264 itPool, err = s.IteratorPools() 265 assert.EqualError(t, err, ErrSessionStatusNotOpen.Error()) 266 assert.Nil(t, itPool) 267 268 s.state.status = statusOpen 269 270 itPool, err = s.IteratorPools() 271 require.NoError(t, err) 272 assert.Equal(t, multiReaderIteratorArray, itPool.MultiReaderIteratorArray()) 273 assert.Equal(t, multiReaderIteratorPool, itPool.MultiReaderIterator()) 274 assert.Equal(t, seriesIteratorPool, itPool.SeriesIterator()) 275 assert.Equal(t, checkedBytesWrapperPool, itPool.CheckedBytesWrapper()) 276 assert.Equal(t, encoderPool, itPool.TagEncoder()) 277 assert.Equal(t, decoderPool, itPool.TagDecoder()) 278 assert.Equal(t, idPool, itPool.ID()) 279 } 280 281 //nolint:dupl 282 func TestSeriesLimit_FetchTagged(t *testing.T) { 283 ctrl := gomock.NewController(t) 284 defer ctrl.Finish() 285 286 // mock the host queue to return a result with a single series, this results in 3 series total, one per shard. 287 sess := setupMultipleInstanceCluster(t, ctrl, func(op op, host topology.Host) { 288 fOp := op.(*fetchTaggedOp) 289 assert.Equal(t, int64(2), *fOp.request.SeriesLimit) 290 shardID := strings.Split(host.ID(), "-")[2] 291 op.CompletionFn()(fetchTaggedResultAccumulatorOpts{ 292 host: host, 293 response: &rpc.FetchTaggedResult_{ 294 Exhaustive: true, 295 Elements: []*rpc.FetchTaggedIDResult_{ 296 { 297 // use shard id for the metric id so it's stable across replicas. 298 ID: []byte(shardID), 299 }, 300 }, 301 }, 302 }, nil) 303 }) 304 305 iters, meta, err := sess.fetchTaggedAttempt(context.TODO(), ident.StringID("ns"), 306 index.Query{Query: idx.NewAllQuery()}, 307 index.QueryOptions{ 308 // set to 6 so we can test the instance series limit is 2 (6 /3 instances per replica * InstanceMultiple) 309 SeriesLimit: 6, 310 InstanceMultiple: 1, 311 }) 312 require.NoError(t, err) 313 require.NotNil(t, iters) 314 // expect a series per shard. 315 require.Equal(t, 3, iters.Len()) 316 require.True(t, meta.Exhaustive) 317 require.NoError(t, sess.Close()) 318 } 319 320 //nolint:dupl 321 func TestSeriesLimit_FetchTaggedIDs(t *testing.T) { 322 ctrl := gomock.NewController(t) 323 defer ctrl.Finish() 324 325 // mock the host queue to return a result with a single series, this results in 3 series total, one per shard. 326 sess := setupMultipleInstanceCluster(t, ctrl, func(op op, host topology.Host) { 327 fOp := op.(*fetchTaggedOp) 328 assert.Equal(t, int64(2), *fOp.request.SeriesLimit) 329 shardID := strings.Split(host.ID(), "-")[2] 330 op.CompletionFn()(fetchTaggedResultAccumulatorOpts{ 331 host: host, 332 response: &rpc.FetchTaggedResult_{ 333 Exhaustive: true, 334 Elements: []*rpc.FetchTaggedIDResult_{ 335 { 336 // use shard id for the metric id so it's stable across replicas. 337 ID: []byte(shardID), 338 }, 339 }, 340 }, 341 }, nil) 342 }) 343 344 iter, meta, err := sess.fetchTaggedIDsAttempt(context.TODO(), ident.StringID("ns"), 345 index.Query{Query: idx.NewAllQuery()}, 346 index.QueryOptions{ 347 // set to 6 so we can test the instance series limit is 2 (6 /3 instances per replica * InstanceMultiple) 348 SeriesLimit: 6, 349 InstanceMultiple: 1, 350 }) 351 require.NoError(t, err) 352 require.NotNil(t, iter) 353 // expect a series per shard. 354 require.Equal(t, 3, iter.Remaining()) 355 require.True(t, meta.Exhaustive) 356 require.NoError(t, sess.Close()) 357 } 358 359 //nolint:dupl 360 func TestSeriesLimit_Aggregate(t *testing.T) { 361 ctrl := gomock.NewController(t) 362 defer ctrl.Finish() 363 364 // mock the host queue to return a result with a single series, this results in 3 series total, one per shard. 365 sess := setupMultipleInstanceCluster(t, ctrl, func(op op, host topology.Host) { 366 aOp := op.(*aggregateOp) 367 assert.Equal(t, int64(2), *aOp.request.SeriesLimit) 368 shardID := strings.Split(host.ID(), "-")[2] 369 op.CompletionFn()(aggregateResultAccumulatorOpts{ 370 host: host, 371 response: &rpc.AggregateQueryRawResult_{ 372 Exhaustive: true, 373 Results: []*rpc.AggregateQueryRawResultTagNameElement{ 374 { 375 // use shard id for the tag value so it's stable across replicas. 376 TagName: []byte(shardID), 377 TagValues: []*rpc.AggregateQueryRawResultTagValueElement{ 378 { 379 TagValue: []byte("value"), 380 }, 381 }, 382 }, 383 }, 384 }, 385 }, nil) 386 }) 387 iter, meta, err := sess.aggregateAttempt(context.TODO(), ident.StringID("ns"), 388 index.Query{Query: idx.NewAllQuery()}, 389 index.AggregationOptions{ 390 QueryOptions: index.QueryOptions{ 391 // set to 6 so we can test the instance series limit is 2 (6 /3 instances per replica * InstanceMultiple) 392 SeriesLimit: 6, 393 InstanceMultiple: 1, 394 }, 395 }) 396 require.NoError(t, err) 397 require.NotNil(t, iter) 398 require.Equal(t, 3, iter.Remaining()) 399 require.True(t, meta.Exhaustive) 400 require.NoError(t, sess.Close()) 401 } 402 403 func TestIterationStrategy_FetchTagged(t *testing.T) { 404 ctrl := gomock.NewController(t) 405 defer ctrl.Finish() 406 407 // mock the host queue to return a result with a single series, this results in 3 series total, one per shard. 408 sess := setupMultipleInstanceCluster(t, ctrl, func(op op, host topology.Host) { 409 fOp := op.(*fetchTaggedOp) 410 assert.Equal(t, int64(2), *fOp.request.SeriesLimit) 411 shardID := strings.Split(host.ID(), "-")[2] 412 op.CompletionFn()(fetchTaggedResultAccumulatorOpts{ 413 host: host, 414 response: &rpc.FetchTaggedResult_{ 415 Exhaustive: true, 416 Elements: []*rpc.FetchTaggedIDResult_{ 417 { 418 // use shard id for the metric id so it's stable across replicas. 419 ID: []byte(shardID), 420 }, 421 }, 422 }, 423 }, nil) 424 }) 425 426 stategy := encoding.IterateHighestFrequencyValue 427 iters, meta, err := sess.fetchTaggedAttempt(context.TODO(), ident.StringID("ns"), 428 index.Query{Query: idx.NewAllQuery()}, 429 index.QueryOptions{ 430 // set to 6 so we can test the instance series limit is 2 (6 /3 instances per replica * InstanceMultiple) 431 SeriesLimit: 6, 432 InstanceMultiple: 1, 433 IterateEqualTimestampStrategy: &stategy, 434 }) 435 require.NoError(t, err) 436 require.NotNil(t, iters) 437 438 // expect a series per shard. 439 require.Equal(t, 3, iters.Len()) 440 441 // Confirm propagated strategy. 442 for _, i := range iters.Iters() { 443 require.Equal(t, stategy, i.IterateEqualTimestampStrategy()) 444 } 445 446 require.True(t, meta.Exhaustive) 447 require.NoError(t, sess.Close()) 448 } 449 450 func TestSessionClusterConnectConsistencyLevelAny(t *testing.T) { 451 ctrl := gomock.NewController(t) 452 defer ctrl.Finish() 453 454 level := topology.ConnectConsistencyLevelAny 455 for i := 0; i <= 3; i++ { 456 testSessionClusterConnectConsistencyLevel(t, ctrl, level, i, outcomeSuccess) 457 } 458 } 459 460 func TestDedicatedConnection(t *testing.T) { 461 ctrl := xtest.NewController(t) 462 defer ctrl.Finish() 463 464 var ( 465 shardID = uint32(32) 466 467 topoMap = topology.NewMockMap(ctrl) 468 469 local = mockHost(ctrl, "h0", "local") 470 remote1 = mockHost(ctrl, "h1", "remote1") 471 remote2 = mockHost(ctrl, "h2", "remote2") 472 473 availableShard = shard.NewShard(shardID).SetState(shard.Available) 474 initializingShard = shard.NewShard(shardID).SetState(shard.Initializing) 475 ) 476 477 topoMap.EXPECT().RouteShardForEach(shardID, gomock.Any()).DoAndReturn( 478 func(shardID uint32, callback func(int, shard.Shard, topology.Host)) error { 479 callback(0, availableShard, local) 480 callback(1, initializingShard, remote1) 481 callback(2, availableShard, remote2) 482 return nil 483 }).Times(4) 484 485 s := session{origin: local} 486 s.opts = NewOptions().SetNewConnectionFn(noopNewConnection) 487 s.healthCheckNewConnFn = testHealthCheck(nil, false) 488 s.state.status = statusOpen 489 s.state.topoMap = topoMap 490 491 _, ch, err := s.DedicatedConnection(shardID, DedicatedConnectionOptions{}) 492 require.NoError(t, err) 493 assert.Equal(t, "remote1", asNoopPooledChannel(ch).address) 494 495 _, ch2, err := s.DedicatedConnection(shardID, DedicatedConnectionOptions{ShardStateFilter: shard.Available}) 496 require.NoError(t, err) 497 assert.Equal(t, "remote2", asNoopPooledChannel(ch2).address) 498 499 s.healthCheckNewConnFn = testHealthCheck(nil, true) 500 _, ch3, err := s.DedicatedConnection(shardID, DedicatedConnectionOptions{BootstrappedNodesOnly: true}) 501 require.NoError(t, err) 502 assert.Equal(t, "remote1", asNoopPooledChannel(ch3).address) 503 504 healthErr := errors.New("unhealthy") 505 s.healthCheckNewConnFn = testHealthCheck(healthErr, false) 506 507 var channels []*noopPooledChannel 508 s.opts = NewOptions().SetNewConnectionFn(func(_ string, _ string, _ Options) (Channel, rpc.TChanNode, error) { 509 c := &noopPooledChannel{"test", 0} 510 channels = append(channels, c) 511 return c, nil, nil 512 }) 513 _, _, err = s.DedicatedConnection(shardID, DedicatedConnectionOptions{}) 514 require.NotNil(t, err) 515 multiErr, ok := err.(xerror.MultiError) // nolint: errorlint 516 assert.True(t, ok, "expecting MultiError") 517 assert.True(t, multiErr.Contains(healthErr)) 518 // 2 because of 2 remote hosts failing health check 519 assert.Len(t, channels, 2) 520 assert.Equal(t, 1, channels[0].CloseCount()) 521 assert.Equal(t, 1, channels[1].CloseCount()) 522 } 523 524 func testSessionClusterConnectConsistencyLevel( 525 t *testing.T, 526 ctrl *gomock.Controller, 527 level topology.ConnectConsistencyLevel, 528 failures int, 529 expected outcome, 530 ) { 531 opts := newSessionTestOptions() 532 opts = opts.SetClusterConnectTimeout(10 * clusterConnectWaitInterval) 533 opts = opts.SetClusterConnectConsistencyLevel(level) 534 s, err := newSession(opts) 535 assert.NoError(t, err) 536 session := s.(*session) 537 538 var failingConns int32 539 session.newHostQueueFn = func( 540 host topology.Host, 541 opts hostQueueOpts, 542 ) (hostQueue, error) { 543 hostQueue := NewMockhostQueue(ctrl) 544 hostQueue.EXPECT().Open().Times(1) 545 hostQueue.EXPECT().Host().Return(host).AnyTimes() 546 if atomic.AddInt32(&failingConns, 1) <= int32(failures) { 547 hostQueue.EXPECT().ConnectionCount().Return(0).AnyTimes() 548 } else { 549 min := opts.opts.MinConnectionCount() 550 hostQueue.EXPECT().ConnectionCount().Return(min).AnyTimes() 551 } 552 hostQueue.EXPECT().Close().AnyTimes() 553 return hostQueue, nil 554 } 555 556 err = session.Open() 557 switch expected { 558 case outcomeSuccess: 559 assert.NoError(t, err) 560 case outcomeFail: 561 assert.Error(t, err) 562 assert.Equal(t, ErrClusterConnectTimeout, err) 563 } 564 } 565 566 // setupMultipleInstanceCluster sets up a db cluster with 3 shards and 3 replicas. The 3 shards are distributed across 567 // 9 hosts, so each host has 1 replica of 1 shard. 568 // the function passed is executed when an operation is enqueued. the provided fn is dispatched in a separate goroutine 569 // to simulate the queue processing. this also allows the function to access the state locks. 570 func setupMultipleInstanceCluster(t *testing.T, ctrl *gomock.Controller, fn func(op op, host topology.Host)) *session { 571 opts := newSessionTestOptions() 572 shardSet := sessionTestShardSet() 573 var hostShardSets []topology.HostShardSet 574 // setup 9 hosts so there are 3 instances per replica. Each instance has a single shard. 575 for i := 0; i < sessionTestReplicas; i++ { 576 for j := 0; j < sessionTestShards; j++ { 577 id := fmt.Sprintf("testhost-%d-%d", i, j) 578 host := topology.NewHost(id, fmt.Sprintf("%s:9000", id)) 579 hostShard, _ := sharding.NewShardSet([]shard.Shard{shardSet.All()[j]}, shardSet.HashFn()) 580 hostShardSet := topology.NewHostShardSet(host, hostShard) 581 hostShardSets = append(hostShardSets, hostShardSet) 582 } 583 } 584 585 opts = opts.SetTopologyInitializer(topology.NewStaticInitializer( 586 topology.NewStaticOptions(). 587 SetReplicas(sessionTestReplicas). 588 SetShardSet(shardSet). 589 SetHostShardSets(hostShardSets))) 590 s, err := newSession(opts) 591 assert.NoError(t, err) 592 sess := s.(*session) 593 594 sess.newHostQueueFn = func(host topology.Host, hostQueueOpts hostQueueOpts) (hostQueue, error) { 595 q := NewMockhostQueue(ctrl) 596 q.EXPECT().Open() 597 q.EXPECT().ConnectionCount().Return(hostQueueOpts.opts.MinConnectionCount()).AnyTimes() 598 q.EXPECT().Host().Return(host).AnyTimes() 599 q.EXPECT().Enqueue(gomock.Any()).Do(func(op op) error { 600 go func() { 601 fn(op, host) 602 }() 603 return nil 604 }).Return(nil) 605 q.EXPECT().Close() 606 return q, nil 607 } 608 609 require.NoError(t, sess.Open()) 610 return sess 611 } 612 613 func mockHostQueues( 614 ctrl *gomock.Controller, 615 s *session, 616 replicas int, 617 enqueueFns []testEnqueueFn, 618 ) *sync.WaitGroup { 619 var enqueueWg sync.WaitGroup 620 enqueueWg.Add(replicas) 621 idx := 0 622 s.newHostQueueFn = func( 623 host topology.Host, 624 opts hostQueueOpts, 625 ) (hostQueue, error) { 626 // Make a copy of the enqueue fns for each host 627 hostEnqueueFns := make([]testEnqueueFn, len(enqueueFns)) 628 copy(hostEnqueueFns, enqueueFns) 629 630 enqueuedIdx := idx 631 hostQueue := NewMockhostQueue(ctrl) 632 hostQueue.EXPECT().Open() 633 hostQueue.EXPECT().Host().Return(host).AnyTimes() 634 // Take two attempts to establish min connection count 635 hostQueue.EXPECT().ConnectionCount().Return(0).Times(sessionTestShards) 636 hostQueue.EXPECT().ConnectionCount().Return(opts.opts.MinConnectionCount()).Times(sessionTestShards) 637 var expectNextEnqueueFn func(fns []testEnqueueFn) 638 expectNextEnqueueFn = func(fns []testEnqueueFn) { 639 fn := fns[0] 640 fns = fns[1:] 641 hostQueue.EXPECT().Enqueue(gomock.Any()).Do(func(op op) error { 642 fn(enqueuedIdx, op) 643 if len(fns) > 0 { 644 expectNextEnqueueFn(fns) 645 } else { 646 enqueueWg.Done() 647 } 648 return nil 649 }).Return(nil) 650 } 651 if len(hostEnqueueFns) > 0 { 652 expectNextEnqueueFn(hostEnqueueFns) 653 } 654 hostQueue.EXPECT().Close() 655 idx++ 656 return hostQueue, nil 657 } 658 return &enqueueWg 659 } 660 661 func mockHost(ctrl *gomock.Controller, id, address string) topology.Host { 662 host := topology.NewMockHost(ctrl) 663 host.EXPECT().ID().Return(id).AnyTimes() 664 host.EXPECT().Address().Return(address).AnyTimes() 665 return host 666 } 667 668 func testHealthCheck(err error, bootstrappedNodesOnly bool) func(rpc.TChanNode, Options, bool) error { 669 return func(client rpc.TChanNode, opts Options, checkBootstrapped bool) error { 670 if checkBootstrapped != bootstrappedNodesOnly { 671 return fmt.Errorf("checkBootstrapped value (%t) != expected (%t)", 672 checkBootstrapped, bootstrappedNodesOnly) 673 } 674 return err 675 } 676 } 677 678 func noopNewConnection( 679 _ string, 680 addr string, 681 _ Options, 682 ) (Channel, rpc.TChanNode, error) { 683 return &noopPooledChannel{addr, 0}, nil, nil 684 }