github.com/projecteru2/core@v0.0.0-20240321043226-06bcc1c23f58/client/clientpool.go (about)

     1  package client
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  	"time"
     7  
     8  	"github.com/projecteru2/core/log"
     9  	pb "github.com/projecteru2/core/rpc/gen"
    10  	"github.com/projecteru2/core/types"
    11  	"github.com/projecteru2/core/utils"
    12  )
    13  
    14  type clientWithStatus struct {
    15  	client pb.CoreRPCClient
    16  	addr   string
    17  	alive  bool
    18  }
    19  
    20  // PoolConfig config for client pool
    21  type PoolConfig struct {
    22  	EruAddrs          []string
    23  	Auth              types.AuthConfig
    24  	ConnectionTimeout time.Duration
    25  }
    26  
    27  // Pool implement of RPCClientPool
    28  type Pool struct {
    29  	mu         sync.Mutex
    30  	rpcClients []*clientWithStatus
    31  }
    32  
    33  // NewCoreRPCClientPool .
    34  func NewCoreRPCClientPool(ctx context.Context, config *PoolConfig) (*Pool, error) {
    35  	if len(config.EruAddrs) == 0 {
    36  		return nil, types.ErrInvaildEruIPAddress
    37  	}
    38  	c := &Pool{rpcClients: []*clientWithStatus{}}
    39  	for _, addr := range config.EruAddrs {
    40  		var rpc *Client
    41  		var err error
    42  		utils.WithTimeout(ctx, config.ConnectionTimeout, func(ctx context.Context) {
    43  			rpc, err = NewClient(ctx, addr, config.Auth)
    44  		})
    45  		if err != nil {
    46  			log.WithFunc("client.NewCoreRPCClientPool").Errorf(ctx, err, "connect to %s failed", addr)
    47  			continue
    48  		}
    49  		rpcClient := rpc.GetRPCClient()
    50  		c.rpcClients = append(c.rpcClients, &clientWithStatus{client: rpcClient, addr: addr})
    51  	}
    52  
    53  	// init client status
    54  	c.updateClientsStatus(ctx, config.ConnectionTimeout)
    55  
    56  	allFailed := true
    57  	for _, rpc := range c.rpcClients {
    58  		if rpc.alive {
    59  			allFailed = false
    60  		}
    61  	}
    62  
    63  	if allFailed {
    64  		return nil, types.ErrAllConnectionsFailed
    65  	}
    66  
    67  	go func() {
    68  		ticker := time.NewTicker(config.ConnectionTimeout * 2)
    69  		defer ticker.Stop()
    70  		for {
    71  			select {
    72  			case <-ticker.C:
    73  				c.updateClientsStatus(ctx, config.ConnectionTimeout)
    74  			case <-ctx.Done():
    75  				return
    76  			}
    77  		}
    78  	}()
    79  
    80  	return c, nil
    81  }
    82  
    83  // GetClient finds the first *client.Client instance with an active connection. If all connections are dead, returns the first one.
    84  func (c *Pool) GetClient() pb.CoreRPCClient {
    85  	c.mu.Lock()
    86  	defer c.mu.Unlock()
    87  
    88  	for _, rpc := range c.rpcClients {
    89  		if rpc.alive {
    90  			return rpc.client
    91  		}
    92  	}
    93  	return c.rpcClients[0].client
    94  }
    95  
    96  func checkAlive(ctx context.Context, rpc *clientWithStatus, timeout time.Duration) bool {
    97  	var err error
    98  	utils.WithTimeout(ctx, timeout, func(ctx context.Context) {
    99  		_, err = rpc.client.Info(ctx, &pb.Empty{})
   100  	})
   101  	logger := log.WithFunc("client.checkAlive")
   102  	if err != nil {
   103  		logger.Errorf(ctx, err, "connect to %s failed", rpc.addr)
   104  		return false
   105  	}
   106  	logger.Debugf(ctx, "connect to %s success", rpc.addr)
   107  	return true
   108  }
   109  
   110  func (c *Pool) updateClientsStatus(ctx context.Context, timeout time.Duration) {
   111  	c.mu.Lock()
   112  	defer c.mu.Unlock()
   113  
   114  	wg := &sync.WaitGroup{}
   115  	defer wg.Wait()
   116  	for _, rpc := range c.rpcClients {
   117  		wg.Add(1)
   118  		go func(r *clientWithStatus) {
   119  			defer wg.Done()
   120  			r.alive = checkAlive(ctx, r, timeout)
   121  		}(rpc)
   122  	}
   123  }