github.com/unigraph-dev/dgraph@v1.1.1-0.20200923154953-8b52b426f765/conn/raft_server.go (about)

     1  /*
     2   * Copyright 2018 Dgraph Labs, Inc. and Contributors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package conn
    18  
    19  import (
    20  	"context"
    21  	"encoding/binary"
    22  	"math/rand"
    23  	"sync"
    24  	"sync/atomic"
    25  	"time"
    26  
    27  	"github.com/dgraph-io/dgo/protos/api"
    28  	"github.com/dgraph-io/dgraph/protos/pb"
    29  	"github.com/dgraph-io/dgraph/x"
    30  	"github.com/golang/glog"
    31  	"github.com/pkg/errors"
    32  	"go.etcd.io/etcd/raft/raftpb"
    33  	otrace "go.opencensus.io/trace"
    34  )
    35  
    36  type sendmsg struct {
    37  	to   uint64
    38  	data []byte
    39  }
    40  
    41  type lockedSource struct {
    42  	lk  sync.Mutex
    43  	src rand.Source
    44  }
    45  
    46  func (r *lockedSource) Int63() int64 {
    47  	r.lk.Lock()
    48  	defer r.lk.Unlock()
    49  	return r.src.Int63()
    50  }
    51  
    52  func (r *lockedSource) Seed(seed int64) {
    53  	r.lk.Lock()
    54  	defer r.lk.Unlock()
    55  	r.src.Seed(seed)
    56  }
    57  
    58  // ProposalCtx stores the context for a proposal with extra information.
    59  type ProposalCtx struct {
    60  	Found uint32
    61  	ErrCh chan error
    62  	Ctx   context.Context
    63  }
    64  
    65  type proposals struct {
    66  	sync.RWMutex
    67  	all map[string]*ProposalCtx
    68  }
    69  
    70  func (p *proposals) Store(key string, pctx *ProposalCtx) bool {
    71  	if len(key) == 0 {
    72  		return false
    73  	}
    74  	p.Lock()
    75  	defer p.Unlock()
    76  	if p.all == nil {
    77  		p.all = make(map[string]*ProposalCtx)
    78  	}
    79  	if _, has := p.all[key]; has {
    80  		return false
    81  	}
    82  	p.all[key] = pctx
    83  	return true
    84  }
    85  
    86  func (p *proposals) Ctx(key string) context.Context {
    87  	if pctx := p.Get(key); pctx != nil {
    88  		return pctx.Ctx
    89  	}
    90  	return context.Background()
    91  }
    92  
    93  func (p *proposals) Get(key string) *ProposalCtx {
    94  	p.RLock()
    95  	defer p.RUnlock()
    96  	return p.all[key]
    97  }
    98  
    99  func (p *proposals) Delete(key string) {
   100  	if len(key) == 0 {
   101  		return
   102  	}
   103  	p.Lock()
   104  	defer p.Unlock()
   105  	delete(p.all, key)
   106  }
   107  
   108  func (p *proposals) Done(key string, err error) {
   109  	if len(key) == 0 {
   110  		return
   111  	}
   112  	p.Lock()
   113  	defer p.Unlock()
   114  	pd, has := p.all[key]
   115  	if !has {
   116  		// If we assert here, there would be a race condition between a context
   117  		// timing out, and a proposal getting applied immediately after. That
   118  		// would cause assert to fail. So, don't assert.
   119  		return
   120  	}
   121  	delete(p.all, key)
   122  	pd.ErrCh <- err
   123  }
   124  
   125  // RaftServer is a wrapper around node that implements the Raft service.
   126  type RaftServer struct {
   127  	m    sync.RWMutex
   128  	node *Node
   129  }
   130  
   131  // UpdateNode safely updates the node.
   132  func (w *RaftServer) UpdateNode(n *Node) {
   133  	w.m.Lock()
   134  	defer w.m.Unlock()
   135  	w.node = n
   136  }
   137  
   138  // GetNode safely retrieves the node.
   139  func (w *RaftServer) GetNode() *Node {
   140  	w.m.RLock()
   141  	defer w.m.RUnlock()
   142  	return w.node
   143  }
   144  
   145  // NewRaftServer returns a pointer to a new RaftServer instance.
   146  func NewRaftServer(n *Node) *RaftServer {
   147  	return &RaftServer{node: n}
   148  }
   149  
   150  // IsPeer checks whether this node is a peer of the node sending the request.
   151  func (w *RaftServer) IsPeer(ctx context.Context, rc *pb.RaftContext) (
   152  	*pb.PeerResponse, error) {
   153  	node := w.GetNode()
   154  	if node == nil || node.Raft() == nil {
   155  		return &pb.PeerResponse{}, ErrNoNode
   156  	}
   157  
   158  	confState := node.ConfState()
   159  
   160  	if confState == nil {
   161  		return &pb.PeerResponse{}, nil
   162  	}
   163  
   164  	for _, raftIdx := range confState.Nodes {
   165  		if rc.Id == raftIdx {
   166  			return &pb.PeerResponse{Status: true}, nil
   167  		}
   168  	}
   169  	return &pb.PeerResponse{}, nil
   170  }
   171  
   172  // JoinCluster handles requests to join the cluster.
   173  func (w *RaftServer) JoinCluster(ctx context.Context,
   174  	rc *pb.RaftContext) (*api.Payload, error) {
   175  	if ctx.Err() != nil {
   176  		return &api.Payload{}, ctx.Err()
   177  	}
   178  
   179  	node := w.GetNode()
   180  	if node == nil || node.Raft() == nil {
   181  		return nil, ErrNoNode
   182  	}
   183  
   184  	return node.joinCluster(ctx, rc)
   185  }
   186  
   187  // RaftMessage handles RAFT messages.
   188  func (w *RaftServer) RaftMessage(server pb.Raft_RaftMessageServer) error {
   189  	ctx := server.Context()
   190  	if ctx.Err() != nil {
   191  		return ctx.Err()
   192  	}
   193  	span := otrace.FromContext(ctx)
   194  
   195  	node := w.GetNode()
   196  	if node == nil || node.Raft() == nil {
   197  		return ErrNoNode
   198  	}
   199  	span.Annotatef(nil, "Stream server is node %#x", node.Id)
   200  
   201  	var rc *pb.RaftContext
   202  	raft := node.Raft()
   203  	step := func(data []byte) error {
   204  		ctx, cancel := context.WithTimeout(ctx, time.Minute)
   205  		defer cancel()
   206  
   207  		for idx := 0; idx < len(data); {
   208  			x.AssertTruef(len(data[idx:]) >= 4,
   209  				"Slice left of size: %v. Expected at least 4.", len(data[idx:]))
   210  
   211  			sz := int(binary.LittleEndian.Uint32(data[idx : idx+4]))
   212  			idx += 4
   213  			msg := raftpb.Message{}
   214  			if idx+sz > len(data) {
   215  				return errors.Errorf(
   216  					"Invalid query. Specified size %v overflows slice [%v,%v)\n",
   217  					sz, idx, len(data))
   218  			}
   219  			if err := msg.Unmarshal(data[idx : idx+sz]); err != nil {
   220  				x.Check(err)
   221  			}
   222  			// This should be done in order, and not via a goroutine.
   223  			// Step can block forever. See: https://github.com/etcd-io/etcd/issues/10585
   224  			// So, add a context with timeout to allow it to get out of the blockage.
   225  			if glog.V(2) {
   226  				switch msg.Type {
   227  				case raftpb.MsgHeartbeat, raftpb.MsgHeartbeatResp:
   228  					atomic.AddInt64(&node.heartbeatsIn, 1)
   229  				case raftpb.MsgReadIndex, raftpb.MsgReadIndexResp:
   230  				case raftpb.MsgApp, raftpb.MsgAppResp:
   231  				case raftpb.MsgProp:
   232  				default:
   233  					glog.Infof("RaftComm: [%#x] Received msg of type: %s from %#x",
   234  						msg.To, msg.Type, msg.From)
   235  				}
   236  			}
   237  			if err := raft.Step(ctx, msg); err != nil {
   238  				glog.Warningf("Error while raft.Step from %#x: %v. Closing RaftMessage stream.",
   239  					rc.GetId(), err)
   240  				return errors.Wrapf(err, "error while raft.Step from %#x", rc.GetId())
   241  			}
   242  			idx += sz
   243  		}
   244  		return nil
   245  	}
   246  
   247  	for loop := 1; ; loop++ {
   248  		batch, err := server.Recv()
   249  		if err != nil {
   250  			return err
   251  		}
   252  		if loop%1e6 == 0 {
   253  			glog.V(2).Infof("%d messages received by %#x from %#x", loop, node.Id, rc.GetId())
   254  		}
   255  		if loop == 1 {
   256  			rc = batch.GetContext()
   257  			span.Annotatef(nil, "Stream from %#x", rc.GetId())
   258  			if rc != nil {
   259  				node.Connect(rc.Id, rc.Addr)
   260  			}
   261  		}
   262  		if batch.Payload == nil {
   263  			continue
   264  		}
   265  		data := batch.Payload.Data
   266  		if err := step(data); err != nil {
   267  			return err
   268  		}
   269  	}
   270  }
   271  
   272  // Heartbeat rpc call is used to check connection with other workers after worker
   273  // tcp server for this instance starts.
   274  func (w *RaftServer) Heartbeat(in *api.Payload, stream pb.Raft_HeartbeatServer) error {
   275  	ticker := time.NewTicker(echoDuration)
   276  	defer ticker.Stop()
   277  
   278  	ctx := stream.Context()
   279  	out := &api.Payload{Data: []byte("beat")}
   280  	for {
   281  		select {
   282  		case <-ctx.Done():
   283  			return ctx.Err()
   284  		case <-ticker.C:
   285  			if err := stream.Send(out); err != nil {
   286  				return err
   287  			}
   288  		}
   289  	}
   290  }