github.com/unigraph-dev/dgraph@v1.1.1-0.20200923154953-8b52b426f765/conn/pool.go (about)

     1  /*
     2   * Copyright 2016-2018 Dgraph Labs, Inc. and Contributors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package conn
    18  
    19  import (
    20  	"context"
    21  	"sync"
    22  	"time"
    23  
    24  	"github.com/dgraph-io/badger/y"
    25  	"github.com/dgraph-io/dgo/protos/api"
    26  	"github.com/dgraph-io/dgraph/protos/pb"
    27  	"github.com/dgraph-io/dgraph/x"
    28  	"github.com/golang/glog"
    29  	"github.com/pkg/errors"
    30  	"go.opencensus.io/plugin/ocgrpc"
    31  
    32  	"google.golang.org/grpc"
    33  )
    34  
    35  var (
    36  	// ErrNoConnection indicates no connection exists to a node.
    37  	ErrNoConnection = errors.New("No connection exists")
    38  	// ErrUnhealthyConnection indicates the connection to a node is unhealthy.
    39  	ErrUnhealthyConnection = errors.New("Unhealthy connection")
    40  	echoDuration           = 500 * time.Millisecond
    41  )
    42  
    43  // Pool is used to manage the grpc client connection(s) for communicating with other
    44  // worker instances.  Right now it just holds one of them.
    45  type Pool struct {
    46  	sync.RWMutex
    47  	// A pool now consists of one connection.  gRPC uses HTTP2 transport to combine
    48  	// messages in the same TCP stream.
    49  	conn *grpc.ClientConn
    50  
    51  	lastEcho time.Time
    52  	Addr     string
    53  	closer   *y.Closer
    54  }
    55  
    56  // Pools manages a concurrency-safe set of Pool.
    57  type Pools struct {
    58  	sync.RWMutex
    59  	all map[string]*Pool
    60  }
    61  
    62  var pi *Pools
    63  
    64  func init() {
    65  	pi = new(Pools)
    66  	pi.all = make(map[string]*Pool)
    67  }
    68  
    69  // GetPools returns the list of pools.
    70  func GetPools() *Pools {
    71  	return pi
    72  }
    73  
    74  // Get returns the list for the given address.
    75  func (p *Pools) Get(addr string) (*Pool, error) {
    76  	p.RLock()
    77  	defer p.RUnlock()
    78  	pool, ok := p.all[addr]
    79  	if !ok {
    80  		return nil, ErrNoConnection
    81  	}
    82  	if !pool.IsHealthy() {
    83  		return nil, ErrUnhealthyConnection
    84  	}
    85  	return pool, nil
    86  }
    87  
    88  // RemoveInvalid removes invalid nodes from the list of pools.
    89  func (p *Pools) RemoveInvalid(state *pb.MembershipState) {
    90  	// Keeps track of valid IP addresses, assigned to active nodes. We do this
    91  	// to avoid removing valid IP addresses from the Removed list.
    92  	validAddr := make(map[string]struct{})
    93  	for _, group := range state.Groups {
    94  		for _, member := range group.Members {
    95  			validAddr[member.Addr] = struct{}{}
    96  		}
    97  	}
    98  	for _, member := range state.Zeros {
    99  		validAddr[member.Addr] = struct{}{}
   100  	}
   101  	for _, member := range state.Removed {
   102  		// Some nodes could have the same IP address. So, check before disconnecting.
   103  		if _, valid := validAddr[member.Addr]; !valid {
   104  			p.remove(member.Addr)
   105  		}
   106  	}
   107  }
   108  
   109  func (p *Pools) remove(addr string) {
   110  	p.Lock()
   111  	defer p.Unlock()
   112  	pool, ok := p.all[addr]
   113  	if !ok {
   114  		return
   115  	}
   116  	glog.Warningf("DISCONNECTING from %s\n", addr)
   117  	delete(p.all, addr)
   118  	pool.shutdown()
   119  }
   120  
   121  func (p *Pools) getPool(addr string) (*Pool, bool) {
   122  	p.RLock()
   123  	defer p.RUnlock()
   124  	existingPool, has := p.all[addr]
   125  	return existingPool, has
   126  }
   127  
   128  // Connect creates a Pool instance for the node with the given address or returns the existing one.
   129  func (p *Pools) Connect(addr string) *Pool {
   130  	existingPool, has := p.getPool(addr)
   131  	if has {
   132  		return existingPool
   133  	}
   134  
   135  	pool, err := newPool(addr)
   136  	if err != nil {
   137  		glog.Errorf("Unable to connect to host: %s", addr)
   138  		return nil
   139  	}
   140  
   141  	p.Lock()
   142  	defer p.Unlock()
   143  	existingPool, has = p.all[addr]
   144  	if has {
   145  		go pool.shutdown() // Not being used, so release the resources.
   146  		return existingPool
   147  	}
   148  	glog.Infof("CONNECTED to %v\n", addr)
   149  	p.all[addr] = pool
   150  	return pool
   151  }
   152  
   153  // newPool creates a new "pool" with one gRPC connection, refcount 0.
   154  func newPool(addr string) (*Pool, error) {
   155  	conn, err := grpc.Dial(addr,
   156  		grpc.WithStatsHandler(&ocgrpc.ClientHandler{}),
   157  		grpc.WithDefaultCallOptions(
   158  			grpc.MaxCallRecvMsgSize(x.GrpcMaxSize),
   159  			grpc.MaxCallSendMsgSize(x.GrpcMaxSize)),
   160  		grpc.WithBackoffMaxDelay(time.Second),
   161  		grpc.WithInsecure())
   162  	if err != nil {
   163  		return nil, err
   164  	}
   165  	pl := &Pool{conn: conn, Addr: addr, lastEcho: time.Now(), closer: y.NewCloser(1)}
   166  	go pl.MonitorHealth()
   167  	return pl, nil
   168  }
   169  
   170  // Get returns the connection to use from the pool of connections.
   171  func (p *Pool) Get() *grpc.ClientConn {
   172  	p.RLock()
   173  	defer p.RUnlock()
   174  	return p.conn
   175  }
   176  
   177  func (p *Pool) shutdown() {
   178  	glog.Warningf("Shutting down extra connection to %s", p.Addr)
   179  	p.closer.SignalAndWait()
   180  	p.conn.Close()
   181  }
   182  
   183  // SetUnhealthy marks a pool as unhealthy.
   184  func (p *Pool) SetUnhealthy() {
   185  	p.Lock()
   186  	defer p.Unlock()
   187  	p.lastEcho = time.Time{}
   188  }
   189  
   190  func (p *Pool) listenToHeartbeat() error {
   191  	conn := p.Get()
   192  	c := pb.NewRaftClient(conn)
   193  
   194  	ctx, cancel := context.WithCancel(context.Background())
   195  	defer cancel()
   196  
   197  	s, err := c.Heartbeat(ctx, &api.Payload{})
   198  	if err != nil {
   199  		return err
   200  	}
   201  
   202  	go func() {
   203  		select {
   204  		case <-ctx.Done():
   205  		case <-p.closer.HasBeenClosed():
   206  			cancel()
   207  		}
   208  	}()
   209  
   210  	// This loop can block indefinitely as long as it keeps on receiving pings back.
   211  	for {
   212  		_, err := s.Recv()
   213  		if err != nil {
   214  			return err
   215  		}
   216  		// We do this periodic stream receive based approach to defend against network partitions.
   217  		p.Lock()
   218  		p.lastEcho = time.Now()
   219  		p.Unlock()
   220  	}
   221  }
   222  
   223  // MonitorHealth monitors the health of the connection via Echo. This function blocks forever.
   224  func (p *Pool) MonitorHealth() {
   225  	defer p.closer.Done()
   226  
   227  	var lastErr error
   228  	for {
   229  		select {
   230  		case <-p.closer.HasBeenClosed():
   231  			return
   232  		default:
   233  			err := p.listenToHeartbeat()
   234  			if lastErr != nil && err == nil {
   235  				glog.Infof("Connection established with %v\n", p.Addr)
   236  			} else if err != nil && lastErr == nil {
   237  				glog.Warningf("Connection lost with %v. Error: %v\n", p.Addr, err)
   238  			}
   239  			lastErr = err
   240  			// Sleep for a bit before retrying.
   241  			time.Sleep(echoDuration)
   242  		}
   243  	}
   244  }
   245  
   246  // IsHealthy returns whether the pool is healthy.
   247  func (p *Pool) IsHealthy() bool {
   248  	if p == nil {
   249  		return false
   250  	}
   251  	p.RLock()
   252  	defer p.RUnlock()
   253  	return time.Since(p.lastEcho) < 4*echoDuration
   254  }