github.com/m3db/m3@v1.5.0/src/dbnode/client/session_test.go (about) 1 // Copyright (c) 2016 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package client 22 23 import ( 24 "context" 25 "errors" 26 "fmt" 27 "strings" 28 "sync" 29 "sync/atomic" 30 "testing" 31 "time" 32 33 "github.com/m3db/m3/src/cluster/shard" 34 "github.com/m3db/m3/src/dbnode/encoding" 35 "github.com/m3db/m3/src/dbnode/generated/thrift/rpc" 36 "github.com/m3db/m3/src/dbnode/sharding" 37 "github.com/m3db/m3/src/dbnode/storage/index" 38 "github.com/m3db/m3/src/dbnode/topology" 39 "github.com/m3db/m3/src/dbnode/x/xpool" 40 "github.com/m3db/m3/src/m3ninx/idx" 41 xerror "github.com/m3db/m3/src/x/errors" 42 "github.com/m3db/m3/src/x/ident" 43 xretry "github.com/m3db/m3/src/x/retry" 44 "github.com/m3db/m3/src/x/serialize" 45 xtest "github.com/m3db/m3/src/x/test" 46 47 "github.com/golang/mock/gomock" 48 "github.com/stretchr/testify/assert" 49 "github.com/stretchr/testify/require" 50 ) 51 52 const ( 53 sessionTestReplicas = 3 54 sessionTestShards = 3 55 ) 56 57 type outcome int 58 59 const ( 60 outcomeSuccess outcome = iota 61 outcomeFail 62 ) 63 64 type testEnqueueFn func(idx int, op op) 65 66 var ( 67 // NB: allocating once to speedup tests. 68 _testSessionOpts = NewOptions(). 69 SetCheckedBytesWrapperPoolSize(1). 70 SetFetchBatchOpPoolSize(1). 71 SetHostQueueOpsArrayPoolSize(1). 72 SetTagEncoderPoolSize(1). 73 SetWriteOpPoolSize(1). 74 SetWriteTaggedOpPoolSize(1). 75 SetSeriesIteratorPoolSize(1) 76 ) 77 78 func testContext() context.Context { 79 // nolint: govet 80 ctx, _ := context.WithTimeout(context.Background(), time.Minute) //nolint 81 return ctx 82 } 83 84 func newSessionTestOptions() Options { 85 return applySessionTestOptions(_testSessionOpts) 86 } 87 88 func sessionTestShardSet() sharding.ShardSet { 89 var ids []uint32 90 for i := uint32(0); i < uint32(sessionTestShards); i++ { 91 ids = append(ids, i) 92 } 93 94 shards := sharding.NewShards(ids, shard.Available) 95 hashFn := func(id ident.ID) uint32 { return 0 } 96 shardSet, _ := sharding.NewShardSet(shards, hashFn) 97 return shardSet 98 } 99 100 func testHostName(i int) string { return fmt.Sprintf("testhost%d", i) } 101 102 func sessionTestHostAndShards( 103 shardSet sharding.ShardSet, 104 ) []topology.HostShardSet { 105 var hosts []topology.Host 106 for i := 0; i < sessionTestReplicas; i++ { 107 id := testHostName(i) 108 host := topology.NewHost(id, fmt.Sprintf("%s:9000", id)) 109 hosts = append(hosts, host) 110 } 111 112 var hostShardSets []topology.HostShardSet 113 for _, host := range hosts { 114 hostShardSet := topology.NewHostShardSet(host, shardSet) 115 hostShardSets = append(hostShardSets, hostShardSet) 116 } 117 return hostShardSets 118 } 119 120 func applySessionTestOptions(opts Options) Options { 121 shardSet := sessionTestShardSet() 122 return opts. 123 // Some of the test mocks expect things to only happen once, so disable retries 124 // for the unit tests. 125 SetWriteRetrier(xretry.NewRetrier(xretry.NewOptions().SetMaxRetries(0))). 126 SetFetchRetrier(xretry.NewRetrier(xretry.NewOptions().SetMaxRetries(0))). 127 SetSeriesIteratorPoolSize(0). 128 SetWriteOpPoolSize(0). 129 SetWriteTaggedOpPoolSize(0). 130 SetFetchBatchOpPoolSize(0). 131 SetTopologyInitializer(topology.NewStaticInitializer( 132 topology.NewStaticOptions(). 133 SetReplicas(sessionTestReplicas). 134 SetShardSet(shardSet). 135 SetHostShardSets(sessionTestHostAndShards(shardSet)))) 136 } 137 138 func newTestHostQueue(opts Options) *queue { 139 hq, err := newHostQueue(h, hostQueueOpts{ 140 writeBatchRawRequestPool: testWriteBatchRawPool, 141 writeBatchRawV2RequestPool: testWriteBatchRawV2Pool, 142 writeBatchRawRequestElementArrayPool: testWriteArrayPool, 143 writeBatchRawV2RequestElementArrayPool: testWriteV2ArrayPool, 144 writeTaggedBatchRawRequestPool: testWriteTaggedBatchRawPool, 145 writeTaggedBatchRawV2RequestPool: testWriteTaggedBatchRawV2Pool, 146 writeTaggedBatchRawRequestElementArrayPool: testWriteTaggedArrayPool, 147 writeTaggedBatchRawV2RequestElementArrayPool: testWriteTaggedV2ArrayPool, 148 fetchBatchRawV2RequestPool: testFetchBatchRawV2Pool, 149 fetchBatchRawV2RequestElementArrayPool: testFetchBatchRawV2ArrayPool, 150 opts: opts, 151 }) 152 if err != nil { 153 panic(err) 154 } 155 return hq.(*queue) 156 } 157 158 func TestSessionCreationFailure(t *testing.T) { 159 topoOpts := topology.NewDynamicOptions() 160 topoInit := topology.NewDynamicInitializer(topoOpts) 161 opt := newSessionTestOptions().SetTopologyInitializer(topoInit) 162 _, err := newSession(opt) 163 assert.Error(t, err) 164 } 165 166 func TestSessionShardID(t *testing.T) { 167 ctrl := gomock.NewController(t) 168 defer ctrl.Finish() 169 170 opts := newSessionTestOptions() 171 s, err := newSession(opts) 172 assert.NoError(t, err) 173 174 _, err = s.ShardID(ident.StringID("foo")) 175 assert.Error(t, err) 176 assert.Equal(t, ErrSessionStatusNotOpen, err) 177 178 mockHostQueues(ctrl, s.(*session), sessionTestReplicas, nil) 179 180 require.NoError(t, s.Open()) 181 182 // The shard set we create in newSessionTestOptions always hashes to uint32 183 shard, err := s.ShardID(ident.StringID("foo")) 184 require.NoError(t, err) 185 assert.Equal(t, uint32(0), shard) 186 187 assert.NoError(t, s.Close()) 188 } 189 190 func TestSessionClusterConnectConsistencyLevelAll(t *testing.T) { 191 ctrl := gomock.NewController(t) 192 defer ctrl.Finish() 193 194 level := topology.ConnectConsistencyLevelAll 195 testSessionClusterConnectConsistencyLevel(t, ctrl, level, 0, outcomeSuccess) 196 for i := 1; i <= 3; i++ { 197 testSessionClusterConnectConsistencyLevel(t, ctrl, level, i, outcomeFail) 198 } 199 } 200 201 func TestSessionClusterConnectConsistencyLevelMajority(t *testing.T) { 202 ctrl := gomock.NewController(t) 203 defer ctrl.Finish() 204 205 level := topology.ConnectConsistencyLevelMajority 206 for i := 0; i <= 1; i++ { 207 testSessionClusterConnectConsistencyLevel(t, ctrl, level, i, outcomeSuccess) 208 } 209 for i := 2; i <= 3; i++ { 210 testSessionClusterConnectConsistencyLevel(t, ctrl, level, i, outcomeFail) 211 } 212 } 213 214 func TestSessionClusterConnectConsistencyLevelOne(t *testing.T) { 215 ctrl := gomock.NewController(t) 216 defer ctrl.Finish() 217 218 level := topology.ConnectConsistencyLevelOne 219 for i := 0; i <= 2; i++ { 220 testSessionClusterConnectConsistencyLevel(t, ctrl, level, i, outcomeSuccess) 221 } 222 testSessionClusterConnectConsistencyLevel(t, ctrl, level, 3, outcomeFail) 223 } 224 225 func TestSessionClusterConnectConsistencyLevelNone(t *testing.T) { 226 ctrl := gomock.NewController(t) 227 defer ctrl.Finish() 228 229 level := topology.ConnectConsistencyLevelNone 230 for i := 0; i <= 3; i++ { 231 testSessionClusterConnectConsistencyLevel(t, ctrl, level, i, outcomeSuccess) 232 } 233 } 234 235 func TestIteratorPools(t *testing.T) { 236 s := session{} 237 itPool, err := s.IteratorPools() 238 239 assert.EqualError(t, err, ErrSessionStatusNotOpen.Error()) 240 assert.Nil(t, itPool) 241 242 multiReaderIteratorArray := encoding.NewMultiReaderIteratorArrayPool(nil) 243 multiReaderIteratorPool := encoding.NewMultiReaderIteratorPool(nil) 244 seriesIteratorPool := encoding.NewSeriesIteratorPool(nil) 245 checkedBytesWrapperPool := xpool.NewCheckedBytesWrapperPool(nil) 246 idPool := ident.NewPool(nil, ident.PoolOptions{}) 247 encoderPool := serialize.NewTagEncoderPool(nil, nil) 248 decoderPool := serialize.NewTagDecoderPool(nil, nil) 249 250 s.pools = sessionPools{ 251 multiReaderIteratorArray: multiReaderIteratorArray, 252 multiReaderIterator: multiReaderIteratorPool, 253 seriesIterator: seriesIteratorPool, 254 checkedBytesWrapper: checkedBytesWrapperPool, 255 id: idPool, 256 tagEncoder: encoderPool, 257 tagDecoder: decoderPool, 258 } 259 260 // Error expected if state is not open 261 itPool, err = s.IteratorPools() 262 assert.EqualError(t, err, ErrSessionStatusNotOpen.Error()) 263 assert.Nil(t, itPool) 264 265 s.state.status = statusOpen 266 267 itPool, err = s.IteratorPools() 268 require.NoError(t, err) 269 assert.Equal(t, multiReaderIteratorArray, itPool.MultiReaderIteratorArray()) 270 assert.Equal(t, multiReaderIteratorPool, itPool.MultiReaderIterator()) 271 assert.Equal(t, seriesIteratorPool, itPool.SeriesIterator()) 272 assert.Equal(t, checkedBytesWrapperPool, itPool.CheckedBytesWrapper()) 273 assert.Equal(t, encoderPool, itPool.TagEncoder()) 274 assert.Equal(t, decoderPool, itPool.TagDecoder()) 275 assert.Equal(t, idPool, itPool.ID()) 276 } 277 278 //nolint:dupl 279 func TestSeriesLimit_FetchTagged(t *testing.T) { 280 ctrl := gomock.NewController(t) 281 defer ctrl.Finish() 282 283 // mock the host queue to return a result with a single series, this results in 3 series total, one per shard. 284 sess := setupMultipleInstanceCluster(t, ctrl, func(op op, host topology.Host) { 285 fOp := op.(*fetchTaggedOp) 286 assert.Equal(t, int64(2), *fOp.request.SeriesLimit) 287 shardID := strings.Split(host.ID(), "-")[2] 288 op.CompletionFn()(fetchTaggedResultAccumulatorOpts{ 289 host: host, 290 response: &rpc.FetchTaggedResult_{ 291 Exhaustive: true, 292 Elements: []*rpc.FetchTaggedIDResult_{ 293 { 294 // use shard id for the metric id so it's stable across replicas. 295 ID: []byte(shardID), 296 }, 297 }, 298 }, 299 }, nil) 300 }) 301 302 iters, meta, err := sess.fetchTaggedAttempt(context.TODO(), ident.StringID("ns"), 303 index.Query{Query: idx.NewAllQuery()}, 304 index.QueryOptions{ 305 // set to 6 so we can test the instance series limit is 2 (6 /3 instances per replica * InstanceMultiple) 306 SeriesLimit: 6, 307 InstanceMultiple: 1, 308 }) 309 require.NoError(t, err) 310 require.NotNil(t, iters) 311 // expect a series per shard. 312 require.Equal(t, 3, iters.Len()) 313 require.True(t, meta.Exhaustive) 314 require.NoError(t, sess.Close()) 315 } 316 317 //nolint:dupl 318 func TestSeriesLimit_FetchTaggedIDs(t *testing.T) { 319 ctrl := gomock.NewController(t) 320 defer ctrl.Finish() 321 322 // mock the host queue to return a result with a single series, this results in 3 series total, one per shard. 323 sess := setupMultipleInstanceCluster(t, ctrl, func(op op, host topology.Host) { 324 fOp := op.(*fetchTaggedOp) 325 assert.Equal(t, int64(2), *fOp.request.SeriesLimit) 326 shardID := strings.Split(host.ID(), "-")[2] 327 op.CompletionFn()(fetchTaggedResultAccumulatorOpts{ 328 host: host, 329 response: &rpc.FetchTaggedResult_{ 330 Exhaustive: true, 331 Elements: []*rpc.FetchTaggedIDResult_{ 332 { 333 // use shard id for the metric id so it's stable across replicas. 334 ID: []byte(shardID), 335 }, 336 }, 337 }, 338 }, nil) 339 }) 340 341 iter, meta, err := sess.fetchTaggedIDsAttempt(context.TODO(), ident.StringID("ns"), 342 index.Query{Query: idx.NewAllQuery()}, 343 index.QueryOptions{ 344 // set to 6 so we can test the instance series limit is 2 (6 /3 instances per replica * InstanceMultiple) 345 SeriesLimit: 6, 346 InstanceMultiple: 1, 347 }) 348 require.NoError(t, err) 349 require.NotNil(t, iter) 350 // expect a series per shard. 351 require.Equal(t, 3, iter.Remaining()) 352 require.True(t, meta.Exhaustive) 353 require.NoError(t, sess.Close()) 354 } 355 356 //nolint:dupl 357 func TestSeriesLimit_Aggregate(t *testing.T) { 358 ctrl := gomock.NewController(t) 359 defer ctrl.Finish() 360 361 // mock the host queue to return a result with a single series, this results in 3 series total, one per shard. 362 sess := setupMultipleInstanceCluster(t, ctrl, func(op op, host topology.Host) { 363 aOp := op.(*aggregateOp) 364 assert.Equal(t, int64(2), *aOp.request.SeriesLimit) 365 shardID := strings.Split(host.ID(), "-")[2] 366 op.CompletionFn()(aggregateResultAccumulatorOpts{ 367 host: host, 368 response: &rpc.AggregateQueryRawResult_{ 369 Exhaustive: true, 370 Results: []*rpc.AggregateQueryRawResultTagNameElement{ 371 { 372 // use shard id for the tag value so it's stable across replicas. 373 TagName: []byte(shardID), 374 TagValues: []*rpc.AggregateQueryRawResultTagValueElement{ 375 { 376 TagValue: []byte("value"), 377 }, 378 }, 379 }, 380 }, 381 }, 382 }, nil) 383 }) 384 iter, meta, err := sess.aggregateAttempt(context.TODO(), ident.StringID("ns"), 385 index.Query{Query: idx.NewAllQuery()}, 386 index.AggregationOptions{ 387 QueryOptions: index.QueryOptions{ 388 // set to 6 so we can test the instance series limit is 2 (6 /3 instances per replica * InstanceMultiple) 389 SeriesLimit: 6, 390 InstanceMultiple: 1, 391 }, 392 }) 393 require.NoError(t, err) 394 require.NotNil(t, iter) 395 require.Equal(t, 3, iter.Remaining()) 396 require.True(t, meta.Exhaustive) 397 require.NoError(t, sess.Close()) 398 } 399 400 func TestIterationStrategy_FetchTagged(t *testing.T) { 401 ctrl := gomock.NewController(t) 402 defer ctrl.Finish() 403 404 // mock the host queue to return a result with a single series, this results in 3 series total, one per shard. 405 sess := setupMultipleInstanceCluster(t, ctrl, func(op op, host topology.Host) { 406 fOp := op.(*fetchTaggedOp) 407 assert.Equal(t, int64(2), *fOp.request.SeriesLimit) 408 shardID := strings.Split(host.ID(), "-")[2] 409 op.CompletionFn()(fetchTaggedResultAccumulatorOpts{ 410 host: host, 411 response: &rpc.FetchTaggedResult_{ 412 Exhaustive: true, 413 Elements: []*rpc.FetchTaggedIDResult_{ 414 { 415 // use shard id for the metric id so it's stable across replicas. 416 ID: []byte(shardID), 417 }, 418 }, 419 }, 420 }, nil) 421 }) 422 423 stategy := encoding.IterateHighestFrequencyValue 424 iters, meta, err := sess.fetchTaggedAttempt(context.TODO(), ident.StringID("ns"), 425 index.Query{Query: idx.NewAllQuery()}, 426 index.QueryOptions{ 427 // set to 6 so we can test the instance series limit is 2 (6 /3 instances per replica * InstanceMultiple) 428 SeriesLimit: 6, 429 InstanceMultiple: 1, 430 IterateEqualTimestampStrategy: &stategy, 431 }) 432 require.NoError(t, err) 433 require.NotNil(t, iters) 434 435 // expect a series per shard. 436 require.Equal(t, 3, iters.Len()) 437 438 // Confirm propagated strategy. 439 for _, i := range iters.Iters() { 440 require.Equal(t, stategy, i.IterateEqualTimestampStrategy()) 441 } 442 443 require.True(t, meta.Exhaustive) 444 require.NoError(t, sess.Close()) 445 } 446 447 func TestSessionClusterConnectConsistencyLevelAny(t *testing.T) { 448 ctrl := gomock.NewController(t) 449 defer ctrl.Finish() 450 451 level := topology.ConnectConsistencyLevelAny 452 for i := 0; i <= 3; i++ { 453 testSessionClusterConnectConsistencyLevel(t, ctrl, level, i, outcomeSuccess) 454 } 455 } 456 457 func TestDedicatedConnection(t *testing.T) { 458 ctrl := xtest.NewController(t) 459 defer ctrl.Finish() 460 461 var ( 462 shardID = uint32(32) 463 464 topoMap = topology.NewMockMap(ctrl) 465 466 local = mockHost(ctrl, "h0", "local") 467 remote1 = mockHost(ctrl, "h1", "remote1") 468 remote2 = mockHost(ctrl, "h2", "remote2") 469 470 availableShard = shard.NewShard(shardID).SetState(shard.Available) 471 initializingShard = shard.NewShard(shardID).SetState(shard.Initializing) 472 ) 473 474 topoMap.EXPECT().RouteShardForEach(shardID, gomock.Any()).DoAndReturn( 475 func(shardID uint32, callback func(int, shard.Shard, topology.Host)) error { 476 callback(0, availableShard, local) 477 callback(1, initializingShard, remote1) 478 callback(2, availableShard, remote2) 479 return nil 480 }).Times(4) 481 482 s := session{origin: local} 483 s.opts = NewOptions().SetNewConnectionFn(noopNewConnection) 484 s.healthCheckNewConnFn = testHealthCheck(nil, false) 485 s.state.status = statusOpen 486 s.state.topoMap = topoMap 487 488 _, ch, err := s.DedicatedConnection(shardID, DedicatedConnectionOptions{}) 489 require.NoError(t, err) 490 assert.Equal(t, "remote1", asNoopPooledChannel(ch).address) 491 492 _, ch2, err := s.DedicatedConnection(shardID, DedicatedConnectionOptions{ShardStateFilter: shard.Available}) 493 require.NoError(t, err) 494 assert.Equal(t, "remote2", asNoopPooledChannel(ch2).address) 495 496 s.healthCheckNewConnFn = testHealthCheck(nil, true) 497 _, ch3, err := s.DedicatedConnection(shardID, DedicatedConnectionOptions{BootstrappedNodesOnly: true}) 498 require.NoError(t, err) 499 assert.Equal(t, "remote1", asNoopPooledChannel(ch3).address) 500 501 healthErr := errors.New("unhealthy") 502 s.healthCheckNewConnFn = testHealthCheck(healthErr, false) 503 504 var channels []*noopPooledChannel 505 s.opts = NewOptions().SetNewConnectionFn(func(_ string, _ string, _ Options) (Channel, rpc.TChanNode, error) { 506 c := &noopPooledChannel{"test", 0} 507 channels = append(channels, c) 508 return c, nil, nil 509 }) 510 _, _, err = s.DedicatedConnection(shardID, DedicatedConnectionOptions{}) 511 require.NotNil(t, err) 512 multiErr, ok := err.(xerror.MultiError) // nolint: errorlint 513 assert.True(t, ok, "expecting MultiError") 514 assert.True(t, multiErr.Contains(healthErr)) 515 // 2 because of 2 remote hosts failing health check 516 assert.Len(t, channels, 2) 517 assert.Equal(t, 1, channels[0].CloseCount()) 518 assert.Equal(t, 1, channels[1].CloseCount()) 519 } 520 521 func testSessionClusterConnectConsistencyLevel( 522 t *testing.T, 523 ctrl *gomock.Controller, 524 level topology.ConnectConsistencyLevel, 525 failures int, 526 expected outcome, 527 ) { 528 opts := newSessionTestOptions() 529 opts = opts.SetClusterConnectTimeout(10 * clusterConnectWaitInterval) 530 opts = opts.SetClusterConnectConsistencyLevel(level) 531 s, err := newSession(opts) 532 assert.NoError(t, err) 533 session := s.(*session) 534 535 var failingConns int32 536 session.newHostQueueFn = func( 537 host topology.Host, 538 opts hostQueueOpts, 539 ) (hostQueue, error) { 540 hostQueue := NewMockhostQueue(ctrl) 541 hostQueue.EXPECT().Open().Times(1) 542 hostQueue.EXPECT().Host().Return(host).AnyTimes() 543 if atomic.AddInt32(&failingConns, 1) <= int32(failures) { 544 hostQueue.EXPECT().ConnectionCount().Return(0).AnyTimes() 545 } else { 546 min := opts.opts.MinConnectionCount() 547 hostQueue.EXPECT().ConnectionCount().Return(min).AnyTimes() 548 } 549 hostQueue.EXPECT().Close().AnyTimes() 550 return hostQueue, nil 551 } 552 553 err = session.Open() 554 switch expected { 555 case outcomeSuccess: 556 assert.NoError(t, err) 557 case outcomeFail: 558 assert.Error(t, err) 559 assert.Equal(t, ErrClusterConnectTimeout, err) 560 } 561 } 562 563 // setupMultipleInstanceCluster sets up a db cluster with 3 shards and 3 replicas. The 3 shards are distributed across 564 // 9 hosts, so each host has 1 replica of 1 shard. 565 // the function passed is executed when an operation is enqueued. the provided fn is dispatched in a separate goroutine 566 // to simulate the queue processing. this also allows the function to access the state locks. 567 func setupMultipleInstanceCluster(t *testing.T, ctrl *gomock.Controller, fn func(op op, host topology.Host)) *session { 568 opts := newSessionTestOptions() 569 shardSet := sessionTestShardSet() 570 var hostShardSets []topology.HostShardSet 571 // setup 9 hosts so there are 3 instances per replica. Each instance has a single shard. 572 for i := 0; i < sessionTestReplicas; i++ { 573 for j := 0; j < sessionTestShards; j++ { 574 id := fmt.Sprintf("testhost-%d-%d", i, j) 575 host := topology.NewHost(id, fmt.Sprintf("%s:9000", id)) 576 hostShard, _ := sharding.NewShardSet([]shard.Shard{shardSet.All()[j]}, shardSet.HashFn()) 577 hostShardSet := topology.NewHostShardSet(host, hostShard) 578 hostShardSets = append(hostShardSets, hostShardSet) 579 } 580 } 581 582 opts = opts.SetTopologyInitializer(topology.NewStaticInitializer( 583 topology.NewStaticOptions(). 584 SetReplicas(sessionTestReplicas). 585 SetShardSet(shardSet). 586 SetHostShardSets(hostShardSets))) 587 s, err := newSession(opts) 588 assert.NoError(t, err) 589 sess := s.(*session) 590 591 sess.newHostQueueFn = func(host topology.Host, hostQueueOpts hostQueueOpts) (hostQueue, error) { 592 q := NewMockhostQueue(ctrl) 593 q.EXPECT().Open() 594 q.EXPECT().ConnectionCount().Return(hostQueueOpts.opts.MinConnectionCount()).AnyTimes() 595 q.EXPECT().Host().Return(host).AnyTimes() 596 q.EXPECT().Enqueue(gomock.Any()).Do(func(op op) error { 597 go func() { 598 fn(op, host) 599 }() 600 return nil 601 }).Return(nil) 602 q.EXPECT().Close() 603 return q, nil 604 } 605 606 require.NoError(t, sess.Open()) 607 return sess 608 } 609 610 func mockHostQueues( 611 ctrl *gomock.Controller, 612 s *session, 613 replicas int, 614 enqueueFns []testEnqueueFn, 615 ) *sync.WaitGroup { 616 var enqueueWg sync.WaitGroup 617 enqueueWg.Add(replicas) 618 idx := 0 619 s.newHostQueueFn = func( 620 host topology.Host, 621 opts hostQueueOpts, 622 ) (hostQueue, error) { 623 // Make a copy of the enqueue fns for each host 624 hostEnqueueFns := make([]testEnqueueFn, len(enqueueFns)) 625 copy(hostEnqueueFns, enqueueFns) 626 627 enqueuedIdx := idx 628 hostQueue := NewMockhostQueue(ctrl) 629 hostQueue.EXPECT().Open() 630 hostQueue.EXPECT().Host().Return(host).AnyTimes() 631 // Take two attempts to establish min connection count 632 hostQueue.EXPECT().ConnectionCount().Return(0).Times(sessionTestShards) 633 hostQueue.EXPECT().ConnectionCount().Return(opts.opts.MinConnectionCount()).Times(sessionTestShards) 634 var expectNextEnqueueFn func(fns []testEnqueueFn) 635 expectNextEnqueueFn = func(fns []testEnqueueFn) { 636 fn := fns[0] 637 fns = fns[1:] 638 hostQueue.EXPECT().Enqueue(gomock.Any()).Do(func(op op) error { 639 fn(enqueuedIdx, op) 640 if len(fns) > 0 { 641 expectNextEnqueueFn(fns) 642 } else { 643 enqueueWg.Done() 644 } 645 return nil 646 }).Return(nil) 647 } 648 if len(hostEnqueueFns) > 0 { 649 expectNextEnqueueFn(hostEnqueueFns) 650 } 651 hostQueue.EXPECT().Close() 652 idx++ 653 return hostQueue, nil 654 } 655 return &enqueueWg 656 } 657 658 func mockHost(ctrl *gomock.Controller, id, address string) topology.Host { 659 host := topology.NewMockHost(ctrl) 660 host.EXPECT().ID().Return(id).AnyTimes() 661 host.EXPECT().Address().Return(address).AnyTimes() 662 return host 663 } 664 665 func testHealthCheck(err error, bootstrappedNodesOnly bool) func(rpc.TChanNode, Options, bool) error { 666 return func(client rpc.TChanNode, opts Options, checkBootstrapped bool) error { 667 if checkBootstrapped != bootstrappedNodesOnly { 668 return fmt.Errorf("checkBootstrapped value (%t) != expected (%t)", 669 checkBootstrapped, bootstrappedNodesOnly) 670 } 671 return err 672 } 673 } 674 675 func noopNewConnection( 676 _ string, 677 addr string, 678 _ Options, 679 ) (Channel, rpc.TChanNode, error) { 680 return &noopPooledChannel{addr, 0}, nil, nil 681 }