github.com/kaisenlinux/docker.io@v0.0.0-20230510090727-ea55db55fac7/swarmkit/manager/state/raft/transport/mock_raft_test.go (about) 1 package transport 2 3 import ( 4 "context" 5 "io" 6 "net" 7 "time" 8 9 "github.com/coreos/etcd/raft" 10 "github.com/coreos/etcd/raft/raftpb" 11 "github.com/docker/swarmkit/api" 12 "github.com/docker/swarmkit/log" 13 "github.com/docker/swarmkit/manager/health" 14 "github.com/docker/swarmkit/manager/state/raft/membership" 15 "google.golang.org/grpc" 16 "google.golang.org/grpc/codes" 17 "google.golang.org/grpc/status" 18 ) 19 20 type snapshotReport struct { 21 id uint64 22 status raft.SnapshotStatus 23 } 24 25 type updateInfo struct { 26 id uint64 27 addr string 28 } 29 30 type mockRaft struct { 31 lis net.Listener 32 s *grpc.Server 33 tr *Transport 34 35 nodeRemovedSignal chan struct{} 36 37 removed map[uint64]bool 38 39 processedMessages chan *raftpb.Message 40 processedSnapshots chan snapshotReport 41 42 reportedUnreachables chan uint64 43 updatedNodes chan updateInfo 44 45 forceErrorStream bool 46 } 47 48 func newMockRaft() (*mockRaft, error) { 49 l, err := net.Listen("tcp", "0.0.0.0:0") 50 if err != nil { 51 return nil, err 52 } 53 mr := &mockRaft{ 54 lis: l, 55 s: grpc.NewServer(), 56 removed: make(map[uint64]bool), 57 nodeRemovedSignal: make(chan struct{}), 58 processedMessages: make(chan *raftpb.Message, 4096), 59 processedSnapshots: make(chan snapshotReport, 4096), 60 reportedUnreachables: make(chan uint64, 4096), 61 updatedNodes: make(chan updateInfo, 4096), 62 } 63 cfg := &Config{ 64 HeartbeatInterval: 3 * time.Second, 65 SendTimeout: 2 * time.Second, 66 Raft: mr, 67 } 68 tr := New(cfg) 69 mr.tr = tr 70 hs := health.NewHealthServer() 71 hs.SetServingStatus("Raft", api.HealthCheckResponse_SERVING) 72 api.RegisterRaftServer(mr.s, mr) 73 api.RegisterHealthServer(mr.s, hs) 74 go mr.s.Serve(l) 75 return mr, nil 76 } 77 78 func (r *mockRaft) Addr() string { 79 return r.lis.Addr().String() 80 } 81 82 func (r *mockRaft) Stop() { 83 r.tr.Stop() 84 r.s.Stop() 85 } 86 87 func (r *mockRaft) RemovePeer(id uint64) error { 88 r.removed[id] = true 89 return r.tr.RemovePeer(id) 90 } 91 92 func (r *mockRaft) ProcessRaftMessage(ctx context.Context, req *api.ProcessRaftMessageRequest) (*api.ProcessRaftMessageResponse, error) { 93 if r.removed[req.Message.From] { 94 return nil, status.Errorf(codes.NotFound, "%s", membership.ErrMemberRemoved.Error()) 95 } 96 r.processedMessages <- req.Message 97 return &api.ProcessRaftMessageResponse{}, nil 98 } 99 100 // StreamRaftMessage is the mock server endpoint for streaming messages of type StreamRaftMessageRequest. 101 func (r *mockRaft) StreamRaftMessage(stream api.Raft_StreamRaftMessageServer) error { 102 if r.forceErrorStream { 103 return status.Errorf(codes.Unimplemented, "streaming not supported") 104 } 105 var recvdMsg, assembledMessage *api.StreamRaftMessageRequest 106 var err error 107 for { 108 recvdMsg, err = stream.Recv() 109 if err == io.EOF { 110 break 111 } else if err != nil { 112 log.G(context.Background()).WithError(err).Error("error while reading from stream") 113 return err 114 } 115 116 if r.removed[recvdMsg.Message.From] { 117 return status.Errorf(codes.NotFound, "%s", membership.ErrMemberRemoved.Error()) 118 } 119 120 if assembledMessage == nil { 121 assembledMessage = recvdMsg 122 continue 123 } 124 125 // For all message types except raftpb.MsgSnap, 126 // we don't expect more than a single message 127 // on the stream. 128 if recvdMsg.Message.Type != raftpb.MsgSnap { 129 panic("Unexpected message type received on stream: " + string(recvdMsg.Message.Type)) 130 } 131 132 // Append received snapshot chunk to the chunk that was already received. 133 assembledMessage.Message.Snapshot.Data = append(assembledMessage.Message.Snapshot.Data, recvdMsg.Message.Snapshot.Data...) 134 } 135 136 // We should have the complete snapshot. Verify and process. 137 if err == io.EOF { 138 if assembledMessage.Message.Type == raftpb.MsgSnap { 139 if !verifySnapshot(assembledMessage.Message) { 140 log.G(context.Background()).Error("snapshot data mismatch") 141 panic("invalid snapshot data") 142 } 143 } 144 145 r.processedMessages <- assembledMessage.Message 146 147 return stream.SendAndClose(&api.StreamRaftMessageResponse{}) 148 } 149 150 return nil 151 } 152 153 func (r *mockRaft) ResolveAddress(ctx context.Context, req *api.ResolveAddressRequest) (*api.ResolveAddressResponse, error) { 154 addr, err := r.tr.PeerAddr(req.RaftID) 155 if err != nil { 156 return nil, err 157 } 158 return &api.ResolveAddressResponse{ 159 Addr: addr, 160 }, nil 161 } 162 163 func (r *mockRaft) ReportUnreachable(id uint64) { 164 r.reportedUnreachables <- id 165 } 166 167 func (r *mockRaft) IsIDRemoved(id uint64) bool { 168 return r.removed[id] 169 } 170 171 func (r *mockRaft) ReportSnapshot(id uint64, status raft.SnapshotStatus) { 172 r.processedSnapshots <- snapshotReport{ 173 id: id, 174 status: status, 175 } 176 } 177 178 func (r *mockRaft) UpdateNode(id uint64, addr string) { 179 r.updatedNodes <- updateInfo{ 180 id: id, 181 addr: addr, 182 } 183 } 184 185 func (r *mockRaft) NodeRemoved() { 186 close(r.nodeRemovedSignal) 187 } 188 189 type mockCluster struct { 190 rafts map[uint64]*mockRaft 191 } 192 193 func newCluster() *mockCluster { 194 return &mockCluster{ 195 rafts: make(map[uint64]*mockRaft), 196 } 197 } 198 199 func (c *mockCluster) Stop() { 200 for _, r := range c.rafts { 201 r.s.Stop() 202 } 203 } 204 205 func (c *mockCluster) Add(id uint64) error { 206 mr, err := newMockRaft() 207 if err != nil { 208 return err 209 } 210 for otherID, otherRaft := range c.rafts { 211 if err := mr.tr.AddPeer(otherID, otherRaft.Addr()); err != nil { 212 return err 213 } 214 if err := otherRaft.tr.AddPeer(id, mr.Addr()); err != nil { 215 return err 216 } 217 } 218 c.rafts[id] = mr 219 return nil 220 } 221 222 func (c *mockCluster) Get(id uint64) *mockRaft { 223 return c.rafts[id] 224 }