github.com/kaisenlinux/docker.io@v0.0.0-20230510090727-ea55db55fac7/swarmkit/manager/state/raft/transport/peer.go (about)

     1  package transport
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  	"time"
     8  
     9  	"google.golang.org/grpc"
    10  	"google.golang.org/grpc/codes"
    11  
    12  	"github.com/coreos/etcd/raft"
    13  	"github.com/coreos/etcd/raft/raftpb"
    14  	"github.com/docker/swarmkit/api"
    15  	"github.com/docker/swarmkit/log"
    16  	"github.com/docker/swarmkit/manager/state/raft/membership"
    17  	"github.com/pkg/errors"
    18  	"google.golang.org/grpc/status"
    19  )
    20  
    21  const (
    22  	// GRPCMaxMsgSize is the max allowed gRPC message size for raft messages.
    23  	GRPCMaxMsgSize = 4 << 20
    24  )
    25  
    26  type peer struct {
    27  	id uint64
    28  
    29  	tr *Transport
    30  
    31  	msgc chan raftpb.Message
    32  
    33  	ctx    context.Context
    34  	cancel context.CancelFunc
    35  	done   chan struct{}
    36  
    37  	mu      sync.Mutex
    38  	cc      *grpc.ClientConn
    39  	addr    string
    40  	newAddr string
    41  
    42  	active       bool
    43  	becameActive time.Time
    44  }
    45  
    46  func newPeer(id uint64, addr string, tr *Transport) (*peer, error) {
    47  	cc, err := tr.dial(addr)
    48  	if err != nil {
    49  		return nil, errors.Wrapf(err, "failed to create conn for %x with addr %s", id, addr)
    50  	}
    51  	ctx, cancel := context.WithCancel(tr.ctx)
    52  	ctx = log.WithField(ctx, "peer_id", fmt.Sprintf("%x", id))
    53  	p := &peer{
    54  		id:     id,
    55  		addr:   addr,
    56  		cc:     cc,
    57  		tr:     tr,
    58  		ctx:    ctx,
    59  		cancel: cancel,
    60  		msgc:   make(chan raftpb.Message, 4096),
    61  		done:   make(chan struct{}),
    62  	}
    63  	go p.run(ctx)
    64  	return p, nil
    65  }
    66  
    67  func (p *peer) send(m raftpb.Message) (err error) {
    68  	p.mu.Lock()
    69  	defer func() {
    70  		if err != nil {
    71  			p.active = false
    72  			p.becameActive = time.Time{}
    73  		}
    74  		p.mu.Unlock()
    75  	}()
    76  	select {
    77  	case <-p.ctx.Done():
    78  		return p.ctx.Err()
    79  	default:
    80  	}
    81  	select {
    82  	case p.msgc <- m:
    83  	case <-p.ctx.Done():
    84  		return p.ctx.Err()
    85  	default:
    86  		p.tr.config.ReportUnreachable(p.id)
    87  		return errors.Errorf("peer is unreachable")
    88  	}
    89  	return nil
    90  }
    91  
    92  func (p *peer) update(addr string) error {
    93  	p.mu.Lock()
    94  	defer p.mu.Unlock()
    95  	if p.addr == addr {
    96  		return nil
    97  	}
    98  	cc, err := p.tr.dial(addr)
    99  	if err != nil {
   100  		return err
   101  	}
   102  
   103  	p.cc.Close()
   104  	p.cc = cc
   105  	p.addr = addr
   106  	return nil
   107  }
   108  
   109  func (p *peer) updateAddr(addr string) error {
   110  	p.mu.Lock()
   111  	defer p.mu.Unlock()
   112  	if p.addr == addr {
   113  		return nil
   114  	}
   115  	log.G(p.ctx).Debugf("peer %x updated to address %s, it will be used if old failed", p.id, addr)
   116  	p.newAddr = addr
   117  	return nil
   118  }
   119  
   120  func (p *peer) conn() *grpc.ClientConn {
   121  	p.mu.Lock()
   122  	defer p.mu.Unlock()
   123  	return p.cc
   124  }
   125  
   126  func (p *peer) address() string {
   127  	p.mu.Lock()
   128  	defer p.mu.Unlock()
   129  	return p.addr
   130  }
   131  
   132  func (p *peer) resolveAddr(ctx context.Context, id uint64) (string, error) {
   133  	resp, err := api.NewRaftClient(p.conn()).ResolveAddress(ctx, &api.ResolveAddressRequest{RaftID: id})
   134  	if err != nil {
   135  		return "", errors.Wrap(err, "failed to resolve address")
   136  	}
   137  	return resp.Addr, nil
   138  }
   139  
   140  // Returns the raft message struct size (not including the payload size) for the given raftpb.Message.
   141  // The payload is typically the snapshot or append entries.
   142  func raftMessageStructSize(m *raftpb.Message) int {
   143  	return (&api.ProcessRaftMessageRequest{Message: m}).Size() - len(m.Snapshot.Data)
   144  }
   145  
   146  // Returns the max allowable payload based on MaxRaftMsgSize and
   147  // the struct size for the given raftpb.Message.
   148  func raftMessagePayloadSize(m *raftpb.Message) int {
   149  	return GRPCMaxMsgSize - raftMessageStructSize(m)
   150  }
   151  
   152  // Split a large raft message into smaller messages.
   153  // Currently this means splitting the []Snapshot.Data into chunks whose size
   154  // is dictacted by MaxRaftMsgSize.
   155  func splitSnapshotData(ctx context.Context, m *raftpb.Message) []api.StreamRaftMessageRequest {
   156  	var messages []api.StreamRaftMessageRequest
   157  	if m.Type != raftpb.MsgSnap {
   158  		return messages
   159  	}
   160  
   161  	// get the size of the data to be split.
   162  	size := len(m.Snapshot.Data)
   163  
   164  	// Get the max payload size.
   165  	payloadSize := raftMessagePayloadSize(m)
   166  
   167  	// split the snapshot into smaller messages.
   168  	for snapDataIndex := 0; snapDataIndex < size; {
   169  		chunkSize := size - snapDataIndex
   170  		if chunkSize > payloadSize {
   171  			chunkSize = payloadSize
   172  		}
   173  
   174  		raftMsg := *m
   175  
   176  		// sub-slice for this snapshot chunk.
   177  		raftMsg.Snapshot.Data = m.Snapshot.Data[snapDataIndex : snapDataIndex+chunkSize]
   178  
   179  		snapDataIndex += chunkSize
   180  
   181  		// add message to the list of messages to be sent.
   182  		msg := api.StreamRaftMessageRequest{Message: &raftMsg}
   183  		messages = append(messages, msg)
   184  	}
   185  
   186  	return messages
   187  }
   188  
   189  // Function to check if this message needs to be split to be streamed
   190  // (because it is larger than GRPCMaxMsgSize).
   191  // Returns true if the message type is MsgSnap
   192  // and size larger than MaxRaftMsgSize.
   193  func needsSplitting(m *raftpb.Message) bool {
   194  	raftMsg := api.ProcessRaftMessageRequest{Message: m}
   195  	return m.Type == raftpb.MsgSnap && raftMsg.Size() > GRPCMaxMsgSize
   196  }
   197  
   198  func (p *peer) sendProcessMessage(ctx context.Context, m raftpb.Message) error {
   199  	ctx, cancel := context.WithTimeout(ctx, p.tr.config.SendTimeout)
   200  	defer cancel()
   201  
   202  	var err error
   203  	var stream api.Raft_StreamRaftMessageClient
   204  	stream, err = api.NewRaftClient(p.conn()).StreamRaftMessage(ctx)
   205  
   206  	if err == nil {
   207  		// Split the message if needed.
   208  		// Currently only supported for MsgSnap.
   209  		var msgs []api.StreamRaftMessageRequest
   210  		if needsSplitting(&m) {
   211  			msgs = splitSnapshotData(ctx, &m)
   212  		} else {
   213  			raftMsg := api.StreamRaftMessageRequest{Message: &m}
   214  			msgs = append(msgs, raftMsg)
   215  		}
   216  
   217  		// Stream
   218  		for _, msg := range msgs {
   219  			err = stream.Send(&msg)
   220  			if err != nil {
   221  				log.G(ctx).WithError(err).Error("error streaming message to peer")
   222  				stream.CloseAndRecv()
   223  				break
   224  			}
   225  		}
   226  
   227  		// Finished sending all the messages.
   228  		// Close and receive response.
   229  		if err == nil {
   230  			_, err = stream.CloseAndRecv()
   231  
   232  			if err != nil {
   233  				log.G(ctx).WithError(err).Error("error receiving response")
   234  			}
   235  		}
   236  	} else {
   237  		log.G(ctx).WithError(err).Error("error sending message to peer")
   238  	}
   239  
   240  	// Try doing a regular rpc if the receiver doesn't support streaming.
   241  	s, _ := status.FromError(err)
   242  	if s.Code() == codes.Unimplemented {
   243  		log.G(ctx).Info("sending message to raft peer using ProcessRaftMessage()")
   244  		_, err = api.NewRaftClient(p.conn()).ProcessRaftMessage(ctx, &api.ProcessRaftMessageRequest{Message: &m})
   245  	}
   246  
   247  	// Handle errors.
   248  	s, _ = status.FromError(err)
   249  	if s.Code() == codes.NotFound && s.Message() == membership.ErrMemberRemoved.Error() {
   250  		p.tr.config.NodeRemoved()
   251  	}
   252  	if m.Type == raftpb.MsgSnap {
   253  		if err != nil {
   254  			p.tr.config.ReportSnapshot(m.To, raft.SnapshotFailure)
   255  		} else {
   256  			p.tr.config.ReportSnapshot(m.To, raft.SnapshotFinish)
   257  		}
   258  	}
   259  	if err != nil {
   260  		p.tr.config.ReportUnreachable(m.To)
   261  		return err
   262  	}
   263  	return nil
   264  }
   265  
   266  func healthCheckConn(ctx context.Context, cc *grpc.ClientConn) error {
   267  	resp, err := api.NewHealthClient(cc).Check(ctx, &api.HealthCheckRequest{Service: "Raft"})
   268  	if err != nil {
   269  		return errors.Wrap(err, "failed to check health")
   270  	}
   271  	if resp.Status != api.HealthCheckResponse_SERVING {
   272  		return errors.Errorf("health check returned status %s", resp.Status)
   273  	}
   274  	return nil
   275  }
   276  
   277  func (p *peer) healthCheck(ctx context.Context) error {
   278  	ctx, cancel := context.WithTimeout(ctx, p.tr.config.SendTimeout)
   279  	defer cancel()
   280  	return healthCheckConn(ctx, p.conn())
   281  }
   282  
   283  func (p *peer) setActive() {
   284  	p.mu.Lock()
   285  	if !p.active {
   286  		p.active = true
   287  		p.becameActive = time.Now()
   288  	}
   289  	p.mu.Unlock()
   290  }
   291  
   292  func (p *peer) setInactive() {
   293  	p.mu.Lock()
   294  	p.active = false
   295  	p.becameActive = time.Time{}
   296  	p.mu.Unlock()
   297  }
   298  
   299  func (p *peer) activeTime() time.Time {
   300  	p.mu.Lock()
   301  	defer p.mu.Unlock()
   302  	return p.becameActive
   303  }
   304  
   305  func (p *peer) drain() error {
   306  	ctx, cancel := context.WithTimeout(context.Background(), 16*time.Second)
   307  	defer cancel()
   308  	for {
   309  		select {
   310  		case m, ok := <-p.msgc:
   311  			if !ok {
   312  				// all messages proceeded
   313  				return nil
   314  			}
   315  			if err := p.sendProcessMessage(ctx, m); err != nil {
   316  				return errors.Wrap(err, "send drain message")
   317  			}
   318  		case <-ctx.Done():
   319  			return ctx.Err()
   320  		}
   321  	}
   322  }
   323  
   324  func (p *peer) handleAddressChange(ctx context.Context) error {
   325  	p.mu.Lock()
   326  	newAddr := p.newAddr
   327  	p.newAddr = ""
   328  	p.mu.Unlock()
   329  	if newAddr == "" {
   330  		return nil
   331  	}
   332  	cc, err := p.tr.dial(newAddr)
   333  	if err != nil {
   334  		return err
   335  	}
   336  	ctx, cancel := context.WithTimeout(ctx, p.tr.config.SendTimeout)
   337  	defer cancel()
   338  	if err := healthCheckConn(ctx, cc); err != nil {
   339  		cc.Close()
   340  		return err
   341  	}
   342  	// there is possibility of race if host changing address too fast, but
   343  	// it's unlikely and eventually thing should be settled
   344  	p.mu.Lock()
   345  	p.cc.Close()
   346  	p.cc = cc
   347  	p.addr = newAddr
   348  	p.tr.config.UpdateNode(p.id, p.addr)
   349  	p.mu.Unlock()
   350  	return nil
   351  }
   352  
   353  func (p *peer) run(ctx context.Context) {
   354  	defer func() {
   355  		p.mu.Lock()
   356  		p.active = false
   357  		p.becameActive = time.Time{}
   358  		// at this point we can be sure that nobody will write to msgc
   359  		if p.msgc != nil {
   360  			close(p.msgc)
   361  		}
   362  		p.mu.Unlock()
   363  		if err := p.drain(); err != nil {
   364  			log.G(ctx).WithError(err).Error("failed to drain message queue")
   365  		}
   366  		close(p.done)
   367  	}()
   368  	if err := p.healthCheck(ctx); err == nil {
   369  		p.setActive()
   370  	}
   371  	for {
   372  		select {
   373  		case <-ctx.Done():
   374  			return
   375  		default:
   376  		}
   377  
   378  		select {
   379  		case m := <-p.msgc:
   380  			// we do not propagate context here, because this operation should be finished
   381  			// or timed out for correct raft work.
   382  			err := p.sendProcessMessage(context.Background(), m)
   383  			if err != nil {
   384  				log.G(ctx).WithError(err).Debugf("failed to send message %s", m.Type)
   385  				p.setInactive()
   386  				if err := p.handleAddressChange(ctx); err != nil {
   387  					log.G(ctx).WithError(err).Error("failed to change address after failure")
   388  				}
   389  				continue
   390  			}
   391  			p.setActive()
   392  		case <-ctx.Done():
   393  			return
   394  		}
   395  	}
   396  }
   397  
   398  func (p *peer) stop() {
   399  	p.cancel()
   400  	<-p.done
   401  }