github.com/kaisenlinux/docker.io@v0.0.0-20230510090727-ea55db55fac7/swarmkit/manager/state/raft/transport/mock_raft_test.go (about)

     1  package transport
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"net"
     7  	"time"
     8  
     9  	"github.com/coreos/etcd/raft"
    10  	"github.com/coreos/etcd/raft/raftpb"
    11  	"github.com/docker/swarmkit/api"
    12  	"github.com/docker/swarmkit/log"
    13  	"github.com/docker/swarmkit/manager/health"
    14  	"github.com/docker/swarmkit/manager/state/raft/membership"
    15  	"google.golang.org/grpc"
    16  	"google.golang.org/grpc/codes"
    17  	"google.golang.org/grpc/status"
    18  )
    19  
    20  type snapshotReport struct {
    21  	id     uint64
    22  	status raft.SnapshotStatus
    23  }
    24  
    25  type updateInfo struct {
    26  	id   uint64
    27  	addr string
    28  }
    29  
    30  type mockRaft struct {
    31  	lis net.Listener
    32  	s   *grpc.Server
    33  	tr  *Transport
    34  
    35  	nodeRemovedSignal chan struct{}
    36  
    37  	removed map[uint64]bool
    38  
    39  	processedMessages  chan *raftpb.Message
    40  	processedSnapshots chan snapshotReport
    41  
    42  	reportedUnreachables chan uint64
    43  	updatedNodes         chan updateInfo
    44  
    45  	forceErrorStream bool
    46  }
    47  
    48  func newMockRaft() (*mockRaft, error) {
    49  	l, err := net.Listen("tcp", "0.0.0.0:0")
    50  	if err != nil {
    51  		return nil, err
    52  	}
    53  	mr := &mockRaft{
    54  		lis:                  l,
    55  		s:                    grpc.NewServer(),
    56  		removed:              make(map[uint64]bool),
    57  		nodeRemovedSignal:    make(chan struct{}),
    58  		processedMessages:    make(chan *raftpb.Message, 4096),
    59  		processedSnapshots:   make(chan snapshotReport, 4096),
    60  		reportedUnreachables: make(chan uint64, 4096),
    61  		updatedNodes:         make(chan updateInfo, 4096),
    62  	}
    63  	cfg := &Config{
    64  		HeartbeatInterval: 3 * time.Second,
    65  		SendTimeout:       2 * time.Second,
    66  		Raft:              mr,
    67  	}
    68  	tr := New(cfg)
    69  	mr.tr = tr
    70  	hs := health.NewHealthServer()
    71  	hs.SetServingStatus("Raft", api.HealthCheckResponse_SERVING)
    72  	api.RegisterRaftServer(mr.s, mr)
    73  	api.RegisterHealthServer(mr.s, hs)
    74  	go mr.s.Serve(l)
    75  	return mr, nil
    76  }
    77  
    78  func (r *mockRaft) Addr() string {
    79  	return r.lis.Addr().String()
    80  }
    81  
    82  func (r *mockRaft) Stop() {
    83  	r.tr.Stop()
    84  	r.s.Stop()
    85  }
    86  
    87  func (r *mockRaft) RemovePeer(id uint64) error {
    88  	r.removed[id] = true
    89  	return r.tr.RemovePeer(id)
    90  }
    91  
    92  func (r *mockRaft) ProcessRaftMessage(ctx context.Context, req *api.ProcessRaftMessageRequest) (*api.ProcessRaftMessageResponse, error) {
    93  	if r.removed[req.Message.From] {
    94  		return nil, status.Errorf(codes.NotFound, "%s", membership.ErrMemberRemoved.Error())
    95  	}
    96  	r.processedMessages <- req.Message
    97  	return &api.ProcessRaftMessageResponse{}, nil
    98  }
    99  
   100  // StreamRaftMessage is the mock server endpoint for streaming messages of type StreamRaftMessageRequest.
   101  func (r *mockRaft) StreamRaftMessage(stream api.Raft_StreamRaftMessageServer) error {
   102  	if r.forceErrorStream {
   103  		return status.Errorf(codes.Unimplemented, "streaming not supported")
   104  	}
   105  	var recvdMsg, assembledMessage *api.StreamRaftMessageRequest
   106  	var err error
   107  	for {
   108  		recvdMsg, err = stream.Recv()
   109  		if err == io.EOF {
   110  			break
   111  		} else if err != nil {
   112  			log.G(context.Background()).WithError(err).Error("error while reading from stream")
   113  			return err
   114  		}
   115  
   116  		if r.removed[recvdMsg.Message.From] {
   117  			return status.Errorf(codes.NotFound, "%s", membership.ErrMemberRemoved.Error())
   118  		}
   119  
   120  		if assembledMessage == nil {
   121  			assembledMessage = recvdMsg
   122  			continue
   123  		}
   124  
   125  		// For all message types except raftpb.MsgSnap,
   126  		// we don't expect more than a single message
   127  		// on the stream.
   128  		if recvdMsg.Message.Type != raftpb.MsgSnap {
   129  			panic("Unexpected message type received on stream: " + string(recvdMsg.Message.Type))
   130  		}
   131  
   132  		// Append received snapshot chunk to the chunk that was already received.
   133  		assembledMessage.Message.Snapshot.Data = append(assembledMessage.Message.Snapshot.Data, recvdMsg.Message.Snapshot.Data...)
   134  	}
   135  
   136  	// We should have the complete snapshot. Verify and process.
   137  	if err == io.EOF {
   138  		if assembledMessage.Message.Type == raftpb.MsgSnap {
   139  			if !verifySnapshot(assembledMessage.Message) {
   140  				log.G(context.Background()).Error("snapshot data mismatch")
   141  				panic("invalid snapshot data")
   142  			}
   143  		}
   144  
   145  		r.processedMessages <- assembledMessage.Message
   146  
   147  		return stream.SendAndClose(&api.StreamRaftMessageResponse{})
   148  	}
   149  
   150  	return nil
   151  }
   152  
   153  func (r *mockRaft) ResolveAddress(ctx context.Context, req *api.ResolveAddressRequest) (*api.ResolveAddressResponse, error) {
   154  	addr, err := r.tr.PeerAddr(req.RaftID)
   155  	if err != nil {
   156  		return nil, err
   157  	}
   158  	return &api.ResolveAddressResponse{
   159  		Addr: addr,
   160  	}, nil
   161  }
   162  
   163  func (r *mockRaft) ReportUnreachable(id uint64) {
   164  	r.reportedUnreachables <- id
   165  }
   166  
   167  func (r *mockRaft) IsIDRemoved(id uint64) bool {
   168  	return r.removed[id]
   169  }
   170  
   171  func (r *mockRaft) ReportSnapshot(id uint64, status raft.SnapshotStatus) {
   172  	r.processedSnapshots <- snapshotReport{
   173  		id:     id,
   174  		status: status,
   175  	}
   176  }
   177  
   178  func (r *mockRaft) UpdateNode(id uint64, addr string) {
   179  	r.updatedNodes <- updateInfo{
   180  		id:   id,
   181  		addr: addr,
   182  	}
   183  }
   184  
   185  func (r *mockRaft) NodeRemoved() {
   186  	close(r.nodeRemovedSignal)
   187  }
   188  
   189  type mockCluster struct {
   190  	rafts map[uint64]*mockRaft
   191  }
   192  
   193  func newCluster() *mockCluster {
   194  	return &mockCluster{
   195  		rafts: make(map[uint64]*mockRaft),
   196  	}
   197  }
   198  
   199  func (c *mockCluster) Stop() {
   200  	for _, r := range c.rafts {
   201  		r.s.Stop()
   202  	}
   203  }
   204  
   205  func (c *mockCluster) Add(id uint64) error {
   206  	mr, err := newMockRaft()
   207  	if err != nil {
   208  		return err
   209  	}
   210  	for otherID, otherRaft := range c.rafts {
   211  		if err := mr.tr.AddPeer(otherID, otherRaft.Addr()); err != nil {
   212  			return err
   213  		}
   214  		if err := otherRaft.tr.AddPeer(id, mr.Addr()); err != nil {
   215  			return err
   216  		}
   217  	}
   218  	c.rafts[id] = mr
   219  	return nil
   220  }
   221  
   222  func (c *mockCluster) Get(id uint64) *mockRaft {
   223  	return c.rafts[id]
   224  }