github.com/kaisenlinux/docker.io@v0.0.0-20230510090727-ea55db55fac7/swarmkit/manager/state/raft/testutils/testutils.go (about) 1 package testutils 2 3 import ( 4 "context" 5 "io/ioutil" 6 "net" 7 "os" 8 "reflect" 9 "sync" 10 "testing" 11 "time" 12 13 "google.golang.org/grpc" 14 15 "code.cloudfoundry.org/clock/fakeclock" 16 etcdraft "github.com/coreos/etcd/raft" 17 "github.com/coreos/etcd/raft/raftpb" 18 "github.com/docker/swarmkit/api" 19 "github.com/docker/swarmkit/ca" 20 cautils "github.com/docker/swarmkit/ca/testutils" 21 "github.com/docker/swarmkit/identity" 22 "github.com/docker/swarmkit/manager/health" 23 "github.com/docker/swarmkit/manager/state/raft" 24 "github.com/docker/swarmkit/manager/state/store" 25 "github.com/docker/swarmkit/testutils" 26 "github.com/pkg/errors" 27 "github.com/stretchr/testify/assert" 28 "github.com/stretchr/testify/require" 29 ) 30 31 // TestNode represents a raft test node 32 type TestNode struct { 33 *raft.Node 34 Server *grpc.Server 35 Listener *WrappedListener 36 SecurityConfig *ca.SecurityConfig 37 Address string 38 StateDir string 39 cancel context.CancelFunc 40 KeyRotator *SimpleKeyRotator 41 } 42 43 // Leader is wrapper around real Leader method to suppress error. 44 // TODO: tests should use Leader method directly. 45 func (n *TestNode) Leader() uint64 { 46 id, _ := n.Node.Leader() 47 return id 48 } 49 50 // AdvanceTicks advances the raft state machine fake clock 51 func AdvanceTicks(clockSource *fakeclock.FakeClock, ticks int) { 52 // A FakeClock timer won't fire multiple times if time is advanced 53 // more than its interval. 54 for i := 0; i != ticks; i++ { 55 clockSource.Increment(time.Second) 56 } 57 } 58 59 // WaitForCluster waits until leader will be one of specified nodes 60 func WaitForCluster(t *testing.T, clockSource *fakeclock.FakeClock, nodes map[uint64]*TestNode) { 61 err := testutils.PollFunc(clockSource, func() error { 62 var prev *etcdraft.Status 63 var leadNode *TestNode 64 nodeLoop: 65 for _, n := range nodes { 66 if prev == nil { 67 prev = new(etcdraft.Status) 68 *prev = n.Status() 69 for _, n2 := range nodes { 70 if n2.Config.ID == prev.Lead { 71 leadNode = n2 72 continue nodeLoop 73 } 74 } 75 return errors.New("did not find a ready leader in member list") 76 } 77 cur := n.Status() 78 79 for _, n2 := range nodes { 80 if n2.Config.ID == cur.Lead { 81 if cur.Lead != prev.Lead || cur.Term != prev.Term || cur.Applied != prev.Applied { 82 return errors.New("state does not match on all nodes") 83 } 84 continue nodeLoop 85 } 86 } 87 return errors.New("did not find leader in member list") 88 } 89 // Don't raise error just because test machine is running slowly 90 for i := 0; i < 5; i++ { 91 if leadNode.ReadyForProposals() { 92 return nil 93 } 94 time.Sleep(2 * time.Second) 95 } 96 return errors.New("leader is not ready") 97 }) 98 require.NoError(t, err) 99 } 100 101 // WaitForPeerNumber waits until peers in cluster converge to specified number 102 func WaitForPeerNumber(t *testing.T, clockSource *fakeclock.FakeClock, nodes map[uint64]*TestNode, count int) { 103 assert.NoError(t, testutils.PollFunc(clockSource, func() error { 104 for _, n := range nodes { 105 if len(n.GetMemberlist()) != count { 106 return errors.New("unexpected number of members") 107 } 108 } 109 return nil 110 })) 111 } 112 113 // WrappedListener disables the Close method to make it possible to reuse a 114 // socket. close must be called to release the socket. 115 type WrappedListener struct { 116 net.Listener 117 acceptConn chan net.Conn 118 acceptErr chan error 119 closed chan struct{} 120 } 121 122 // NewWrappedListener creates a new wrapped listener to register the raft server 123 func NewWrappedListener(l net.Listener) *WrappedListener { 124 wrappedListener := WrappedListener{ 125 Listener: l, 126 acceptConn: make(chan net.Conn, 10), 127 acceptErr: make(chan error, 1), 128 closed: make(chan struct{}, 10), // grpc closes multiple times 129 } 130 // Accept connections 131 go func() { 132 for { 133 conn, err := l.Accept() 134 if err != nil { 135 wrappedListener.acceptErr <- err 136 return 137 } 138 wrappedListener.acceptConn <- conn 139 } 140 }() 141 142 return &wrappedListener 143 } 144 145 // Accept accepts new connections on a wrapped listener 146 func (l *WrappedListener) Accept() (net.Conn, error) { 147 // closure must take precedence over taking a connection 148 // from the channel 149 select { 150 case <-l.closed: 151 return nil, errors.New("listener closed") 152 default: 153 } 154 155 select { 156 case conn := <-l.acceptConn: 157 return conn, nil 158 case err := <-l.acceptErr: 159 return nil, err 160 case <-l.closed: 161 return nil, errors.New("listener closed") 162 } 163 } 164 165 // Close notifies that the listener can't accept any more connections 166 func (l *WrappedListener) Close() error { 167 l.closed <- struct{}{} 168 return nil 169 } 170 171 // CloseListener closes the underlying listener 172 func (l *WrappedListener) CloseListener() error { 173 return l.Listener.Close() 174 } 175 176 // RecycleWrappedListener creates a new wrappedListener that uses the same 177 // listening socket as the supplied wrappedListener. 178 func RecycleWrappedListener(old *WrappedListener) *WrappedListener { 179 return &WrappedListener{ 180 Listener: old.Listener, 181 acceptConn: old.acceptConn, 182 acceptErr: old.acceptErr, 183 closed: make(chan struct{}, 10), // grpc closes multiple times 184 } 185 } 186 187 // SimpleKeyRotator does some DEK rotation 188 type SimpleKeyRotator struct { 189 mu sync.Mutex 190 rotateCh chan struct{} 191 updateFunc func() error 192 overrideNeedRotate *bool 193 raft.EncryptionKeys 194 } 195 196 // GetKeys returns the current set of keys 197 func (s *SimpleKeyRotator) GetKeys() raft.EncryptionKeys { 198 s.mu.Lock() 199 defer s.mu.Unlock() 200 return s.EncryptionKeys 201 } 202 203 // NeedsRotation returns whether we need to rotate 204 func (s *SimpleKeyRotator) NeedsRotation() bool { 205 s.mu.Lock() 206 defer s.mu.Unlock() 207 if s.overrideNeedRotate != nil { 208 return *s.overrideNeedRotate 209 } 210 return s.EncryptionKeys.PendingDEK != nil 211 } 212 213 // UpdateKeys updates the current encryption keys 214 func (s *SimpleKeyRotator) UpdateKeys(newKeys raft.EncryptionKeys) error { 215 s.mu.Lock() 216 defer s.mu.Unlock() 217 if s.updateFunc != nil { 218 return s.updateFunc() 219 } 220 s.EncryptionKeys = newKeys 221 return nil 222 } 223 224 // RotationNotify returns the rotation notification channel 225 func (s *SimpleKeyRotator) RotationNotify() chan struct{} { 226 return s.rotateCh 227 } 228 229 // QueuePendingKey lets us rotate the key 230 func (s *SimpleKeyRotator) QueuePendingKey(key []byte) { 231 s.mu.Lock() 232 defer s.mu.Unlock() 233 s.EncryptionKeys.PendingDEK = key 234 } 235 236 // SetUpdateFunc enables you to inject an error when updating keys 237 func (s *SimpleKeyRotator) SetUpdateFunc(updateFunc func() error) { 238 s.mu.Lock() 239 defer s.mu.Unlock() 240 s.updateFunc = updateFunc 241 } 242 243 // SetNeedsRotation enables you to inject a value for NeedsRotation 244 func (s *SimpleKeyRotator) SetNeedsRotation(override *bool) { 245 s.mu.Lock() 246 defer s.mu.Unlock() 247 s.overrideNeedRotate = override 248 } 249 250 // NewSimpleKeyRotator returns a basic EncryptionKeyRotator 251 func NewSimpleKeyRotator(keys raft.EncryptionKeys) *SimpleKeyRotator { 252 return &SimpleKeyRotator{ 253 rotateCh: make(chan struct{}), 254 EncryptionKeys: keys, 255 } 256 } 257 258 var _ raft.EncryptionKeyRotator = NewSimpleKeyRotator(raft.EncryptionKeys{}) 259 260 // NewNode creates a new raft node to use for tests 261 func NewNode(t *testing.T, clockSource *fakeclock.FakeClock, tc *cautils.TestCA, opts ...raft.NodeOptions) *TestNode { 262 l, err := net.Listen("tcp", "127.0.0.1:0") 263 require.NoError(t, err, "can't bind to raft service port") 264 wrappedListener := NewWrappedListener(l) 265 266 securityConfig, err := tc.NewNodeConfig(ca.ManagerRole) 267 require.NoError(t, err) 268 269 serverOpts := []grpc.ServerOption{grpc.Creds(securityConfig.ServerTLSCreds)} 270 s := grpc.NewServer(serverOpts...) 271 272 cfg := raft.DefaultNodeConfig() 273 274 stateDir, err := ioutil.TempDir("", t.Name()) 275 require.NoError(t, err, "can't create temporary state directory") 276 277 keyRotator := NewSimpleKeyRotator(raft.EncryptionKeys{CurrentDEK: []byte("current")}) 278 newNodeOpts := raft.NodeOptions{ 279 ID: securityConfig.ClientTLSCreds.NodeID(), 280 Addr: l.Addr().String(), 281 Config: cfg, 282 StateDir: stateDir, 283 ClockSource: clockSource, 284 TLSCredentials: securityConfig.ClientTLSCreds, 285 KeyRotator: keyRotator, 286 } 287 288 if len(opts) > 1 { 289 panic("more than one optional argument provided") 290 } 291 if len(opts) == 1 { 292 newNodeOpts.JoinAddr = opts[0].JoinAddr 293 if opts[0].Addr != "" { 294 newNodeOpts.Addr = opts[0].Addr 295 } 296 newNodeOpts.DisableStackDump = opts[0].DisableStackDump 297 } 298 299 n := raft.NewNode(newNodeOpts) 300 301 healthServer := health.NewHealthServer() 302 api.RegisterHealthServer(s, healthServer) 303 raft.Register(s, n) 304 305 go s.Serve(wrappedListener) 306 307 healthServer.SetServingStatus("Raft", api.HealthCheckResponse_SERVING) 308 309 return &TestNode{ 310 Node: n, 311 Listener: wrappedListener, 312 SecurityConfig: securityConfig, 313 Address: newNodeOpts.Addr, 314 StateDir: newNodeOpts.StateDir, 315 Server: s, 316 KeyRotator: keyRotator, 317 } 318 } 319 320 // NewInitNode creates a new raft node initiating the cluster 321 // for other members to join 322 func NewInitNode(t *testing.T, tc *cautils.TestCA, raftConfig *api.RaftConfig, opts ...raft.NodeOptions) (*TestNode, *fakeclock.FakeClock) { 323 clockSource := fakeclock.NewFakeClock(time.Now()) 324 n := NewNode(t, clockSource, tc, opts...) 325 ctx, cancel := context.WithCancel(context.Background()) 326 n.cancel = cancel 327 328 err := n.Node.JoinAndStart(ctx) 329 require.NoError(t, err, "can't join cluster") 330 331 leadershipCh, cancel := n.SubscribeLeadership() 332 defer cancel() 333 334 go n.Run(ctx) 335 336 // Wait for the node to become the leader. 337 <-leadershipCh 338 339 if raftConfig != nil { 340 assert.NoError(t, n.MemoryStore().Update(func(tx store.Tx) error { 341 return store.CreateCluster(tx, &api.Cluster{ 342 ID: identity.NewID(), 343 Spec: api.ClusterSpec{ 344 Annotations: api.Annotations{ 345 Name: store.DefaultClusterName, 346 }, 347 Raft: *raftConfig, 348 }, 349 }) 350 })) 351 } 352 353 return n, clockSource 354 } 355 356 // NewJoinNode creates a new raft node joining an existing cluster 357 func NewJoinNode(t *testing.T, clockSource *fakeclock.FakeClock, join string, tc *cautils.TestCA, opts ...raft.NodeOptions) *TestNode { 358 var derivedOpts raft.NodeOptions 359 if len(opts) == 1 { 360 derivedOpts = opts[0] 361 } 362 derivedOpts.JoinAddr = join 363 n := NewNode(t, clockSource, tc, derivedOpts) 364 365 ctx, cancel := context.WithCancel(context.Background()) 366 n.cancel = cancel 367 err := n.Node.JoinAndStart(ctx) 368 require.NoError(t, err, "can't join cluster") 369 370 go n.Run(ctx) 371 372 return n 373 } 374 375 // CopyNode returns a copy of a node 376 func CopyNode(t *testing.T, clockSource *fakeclock.FakeClock, oldNode *TestNode, forceNewCluster bool, kr *SimpleKeyRotator) (*TestNode, context.Context) { 377 wrappedListener := RecycleWrappedListener(oldNode.Listener) 378 securityConfig := oldNode.SecurityConfig 379 serverOpts := []grpc.ServerOption{grpc.Creds(securityConfig.ServerTLSCreds)} 380 s := grpc.NewServer(serverOpts...) 381 382 cfg := raft.DefaultNodeConfig() 383 384 if kr == nil { 385 kr = oldNode.KeyRotator 386 } 387 388 newNodeOpts := raft.NodeOptions{ 389 ID: securityConfig.ClientTLSCreds.NodeID(), 390 Addr: oldNode.Address, 391 Config: cfg, 392 StateDir: oldNode.StateDir, 393 ForceNewCluster: forceNewCluster, 394 ClockSource: clockSource, 395 SendTimeout: 2 * time.Second, 396 TLSCredentials: securityConfig.ClientTLSCreds, 397 KeyRotator: kr, 398 } 399 400 ctx, cancel := context.WithCancel(context.Background()) 401 n := raft.NewNode(newNodeOpts) 402 403 healthServer := health.NewHealthServer() 404 api.RegisterHealthServer(s, healthServer) 405 raft.Register(s, n) 406 407 go s.Serve(wrappedListener) 408 409 healthServer.SetServingStatus("Raft", api.HealthCheckResponse_SERVING) 410 411 return &TestNode{ 412 Node: n, 413 Listener: wrappedListener, 414 SecurityConfig: securityConfig, 415 Address: newNodeOpts.Addr, 416 StateDir: newNodeOpts.StateDir, 417 cancel: cancel, 418 Server: s, 419 KeyRotator: kr, 420 }, ctx 421 } 422 423 // RestartNode restarts a raft test node 424 func RestartNode(t *testing.T, clockSource *fakeclock.FakeClock, oldNode *TestNode, forceNewCluster bool) *TestNode { 425 n, ctx := CopyNode(t, clockSource, oldNode, forceNewCluster, nil) 426 427 err := n.Node.JoinAndStart(ctx) 428 require.NoError(t, err, "can't join cluster") 429 430 go n.Node.Run(ctx) 431 432 return n 433 } 434 435 // NewRaftCluster creates a new raft cluster with 3 nodes for testing 436 func NewRaftCluster(t *testing.T, tc *cautils.TestCA, config ...*api.RaftConfig) (map[uint64]*TestNode, *fakeclock.FakeClock) { 437 var ( 438 raftConfig *api.RaftConfig 439 clockSource *fakeclock.FakeClock 440 ) 441 if len(config) > 1 { 442 panic("more than one optional argument provided") 443 } 444 if len(config) == 1 { 445 raftConfig = config[0] 446 } 447 nodes := make(map[uint64]*TestNode) 448 nodes[1], clockSource = NewInitNode(t, tc, raftConfig) 449 AddRaftNode(t, clockSource, nodes, tc) 450 AddRaftNode(t, clockSource, nodes, tc) 451 return nodes, clockSource 452 } 453 454 // AddRaftNode adds an additional raft test node to an existing cluster 455 func AddRaftNode(t *testing.T, clockSource *fakeclock.FakeClock, nodes map[uint64]*TestNode, tc *cautils.TestCA, opts ...raft.NodeOptions) { 456 n := uint64(len(nodes) + 1) 457 nodes[n] = NewJoinNode(t, clockSource, nodes[1].Address, tc, opts...) 458 WaitForCluster(t, clockSource, nodes) 459 } 460 461 // TeardownCluster destroys a raft cluster used for tests 462 func TeardownCluster(nodes map[uint64]*TestNode) { 463 for _, node := range nodes { 464 ShutdownNode(node) 465 } 466 } 467 468 // ShutdownNode shuts down a raft test node and deletes the content 469 // of the state directory 470 func ShutdownNode(node *TestNode) { 471 node.Server.Stop() 472 if node.cancel != nil { 473 node.cancel() 474 <-node.Done() 475 } 476 os.RemoveAll(node.StateDir) 477 node.Listener.CloseListener() 478 } 479 480 // ShutdownRaft shutdowns only raft part of node. 481 func (n *TestNode) ShutdownRaft() { 482 if n.cancel != nil { 483 n.cancel() 484 <-n.Done() 485 } 486 } 487 488 // CleanupNonRunningNode frees resources associated with a node which is not 489 // running. 490 func CleanupNonRunningNode(node *TestNode) { 491 node.Server.Stop() 492 os.RemoveAll(node.StateDir) 493 node.Listener.CloseListener() 494 } 495 496 // Leader determines who is the leader amongst a set of raft nodes 497 // belonging to the same cluster 498 func Leader(nodes map[uint64]*TestNode) *TestNode { 499 for _, n := range nodes { 500 if n.Config.ID == n.Leader() { 501 return n 502 } 503 } 504 panic("could not find a leader") 505 } 506 507 // ProposeValue proposes a value to a raft test cluster 508 func ProposeValue(t *testing.T, raftNode *TestNode, time time.Duration, nodeID ...string) (*api.Node, error) { 509 nodeIDStr := "id1" 510 if len(nodeID) != 0 { 511 nodeIDStr = nodeID[0] 512 } 513 node := &api.Node{ 514 ID: nodeIDStr, 515 Spec: api.NodeSpec{ 516 Annotations: api.Annotations{ 517 Name: nodeIDStr, 518 }, 519 }, 520 } 521 522 storeActions := []api.StoreAction{ 523 { 524 Action: api.StoreActionKindCreate, 525 Target: &api.StoreAction_Node{ 526 Node: node, 527 }, 528 }, 529 } 530 531 ctx, cancel := context.WithTimeout(context.Background(), time) 532 533 err := raftNode.ProposeValue(ctx, storeActions, func() { 534 err := raftNode.MemoryStore().ApplyStoreActions(storeActions) 535 assert.NoError(t, err, "error applying actions") 536 }) 537 cancel() 538 if err != nil { 539 return nil, err 540 } 541 542 return node, nil 543 } 544 545 // CheckValue checks that the value has been propagated between raft members 546 func CheckValue(t *testing.T, clockSource *fakeclock.FakeClock, raftNode *TestNode, createdNode *api.Node) { 547 assert.NoError(t, testutils.PollFunc(clockSource, func() error { 548 var err error 549 raftNode.MemoryStore().View(func(tx store.ReadTx) { 550 var allNodes []*api.Node 551 allNodes, err = store.FindNodes(tx, store.All) 552 if err != nil { 553 return 554 } 555 if len(allNodes) != 1 { 556 err = errors.Errorf("expected 1 node, got %d nodes", len(allNodes)) 557 return 558 } 559 if !reflect.DeepEqual(allNodes[0], createdNode) { 560 err = errors.New("node did not match expected value") 561 } 562 }) 563 return err 564 })) 565 } 566 567 // CheckNoValue checks that there is no value replicated on nodes, generally 568 // used to test the absence of a leader 569 func CheckNoValue(t *testing.T, clockSource *fakeclock.FakeClock, raftNode *TestNode) { 570 assert.NoError(t, testutils.PollFunc(clockSource, func() error { 571 var err error 572 raftNode.MemoryStore().View(func(tx store.ReadTx) { 573 var allNodes []*api.Node 574 allNodes, err = store.FindNodes(tx, store.All) 575 if err != nil { 576 return 577 } 578 if len(allNodes) != 0 { 579 err = errors.Errorf("expected no nodes, got %d", len(allNodes)) 580 } 581 }) 582 return err 583 })) 584 } 585 586 // CheckValuesOnNodes checks that all the nodes in the cluster have the same 587 // replicated data, generally used to check if a node can catch up with the logs 588 // correctly 589 func CheckValuesOnNodes(t *testing.T, clockSource *fakeclock.FakeClock, checkNodes map[uint64]*TestNode, ids []string, values []*api.Node) { 590 iteration := 0 591 for checkNodeID, node := range checkNodes { 592 assert.NoError(t, testutils.PollFunc(clockSource, func() error { 593 var err error 594 node.MemoryStore().View(func(tx store.ReadTx) { 595 var allNodes []*api.Node 596 allNodes, err = store.FindNodes(tx, store.All) 597 if err != nil { 598 return 599 } 600 for i, id := range ids { 601 n := store.GetNode(tx, id) 602 if n == nil { 603 err = errors.Errorf("node %s not found on %d (iteration %d)", id, checkNodeID, iteration) 604 return 605 } 606 if !reflect.DeepEqual(values[i], n) { 607 err = errors.Errorf("node %s did not match expected value on %d (iteration %d)", id, checkNodeID, iteration) 608 return 609 } 610 } 611 if len(allNodes) != len(ids) { 612 err = errors.Errorf("expected %d nodes, got %d (iteration %d)", len(ids), len(allNodes), iteration) 613 return 614 } 615 }) 616 return err 617 })) 618 iteration++ 619 } 620 } 621 622 // GetAllValuesOnNode returns all values on this node 623 func GetAllValuesOnNode(t *testing.T, clockSource *fakeclock.FakeClock, raftNode *TestNode) ([]string, []*api.Node) { 624 ids := []string{} 625 values := []*api.Node{} 626 assert.NoError(t, testutils.PollFunc(clockSource, func() error { 627 var err error 628 raftNode.MemoryStore().View(func(tx store.ReadTx) { 629 var allNodes []*api.Node 630 allNodes, err = store.FindNodes(tx, store.All) 631 if err != nil { 632 return 633 } 634 for _, node := range allNodes { 635 ids = append(ids, node.ID) 636 values = append(values, node) 637 } 638 }) 639 return err 640 })) 641 642 return ids, values 643 } 644 645 // NewSnapshotMessage creates and returns a raftpb.Message of type MsgSnap 646 // where the snapshot data is of the given size and the value of each byte 647 // is (index of the byte) % 256. 648 func NewSnapshotMessage(from, to uint64, size int) *raftpb.Message { 649 data := make([]byte, size) 650 for i := 0; i < size; i++ { 651 data[i] = byte(i % (1 << 8)) 652 } 653 654 return &raftpb.Message{ 655 Type: raftpb.MsgSnap, 656 From: from, 657 To: to, 658 Snapshot: raftpb.Snapshot{ 659 Data: data, 660 // Include the snapshot size in the Index field for testing. 661 Metadata: raftpb.SnapshotMetadata{ 662 Index: uint64(len(data)), 663 }, 664 }, 665 } 666 } 667 668 // VerifySnapshot verifies that the snapshot data where each byte is 669 // of the value (index % sizeof(byte)). 670 func VerifySnapshot(raftMsg *raftpb.Message) bool { 671 for i, b := range raftMsg.Snapshot.Data { 672 if int(b) != i%(1<<8) { 673 return false 674 } 675 } 676 677 return len(raftMsg.Snapshot.Data) == int(raftMsg.Snapshot.Metadata.Index) 678 }