github.com/nyan233/littlerpc@v0.4.6-0.20230316182519-0c8d5c48abaf/core/client/loadbalance/balancer.go (about)

     1  package loadbalance
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"github.com/nyan233/littlerpc/core/common/logger"
     7  	"github.com/nyan233/littlerpc/core/common/transport"
     8  	"github.com/nyan233/littlerpc/core/container"
     9  	"github.com/nyan233/littlerpc/core/utils/convert"
    10  	"github.com/nyan233/littlerpc/core/utils/hash"
    11  	"sync/atomic"
    12  	"time"
    13  )
    14  
    15  type nodeSource struct {
    16  	Count uint64
    17  	Conns []transport.ConnAdapter
    18  }
    19  
    20  type balancerImpl struct {
    21  	logger     logger.LLogger
    22  	ctx        context.Context
    23  	cancelFunc context.CancelFunc
    24  	scheme     string
    25  	resolve    ResolverFunc
    26  	// 地址列表的更新时间
    27  	updateInterval  time.Duration
    28  	nodeList        atomic.Pointer[[]RpcNode]
    29  	connFactory     NewConn
    30  	muxConnSize     int
    31  	closeFunc       CloseConn
    32  	nodeConnManager *container.RCUMap[string, *nodeSource]
    33  }
    34  
    35  func (b *balancerImpl) Exit() error {
    36  	b.cancelFunc()
    37  	return nil
    38  }
    39  
    40  func New(cfg Config) Balancer {
    41  	b := new(balancerImpl)
    42  	b.muxConnSize = cfg.MuxConnSize
    43  	b.scheme = cfg.Scheme
    44  	b.logger = cfg.Logger
    45  	b.closeFunc = cfg.CloseFunc
    46  	b.updateInterval = cfg.ResolverUpdateInterval
    47  	b.ctx, b.cancelFunc = context.WithCancel(context.Background())
    48  	b.resolve = cfg.Resolver
    49  	b.connFactory = cfg.ConnectionFactory
    50  	b.nodeConnManager = container.NewRCUMap[string, *nodeSource](128)
    51  	b.startResolver()
    52  	return b
    53  }
    54  
    55  func (b *balancerImpl) startResolver() {
    56  	nodeList, err := b.resolve()
    57  	if err != nil {
    58  		panic(fmt.Errorf("startResolver resolve faild: %v", err))
    59  	}
    60  	nodeConnInitSet := make([]container.RCUMapElement[string, *nodeSource], 0, len(nodeList))
    61  	for _, node := range nodeList {
    62  		conns, err := b.createConns(node, b.muxConnSize)
    63  		if err != nil {
    64  			panic(fmt.Errorf("startResolver init conns faild: %v", err))
    65  		}
    66  		nodeConnInitSet = append(nodeConnInitSet, container.RCUMapElement[string, *nodeSource]{
    67  			Key: node.Address,
    68  			Value: &nodeSource{
    69  				Count: 0,
    70  				Conns: conns,
    71  			},
    72  		})
    73  	}
    74  	b.nodeList.Store(&nodeList)
    75  	b.nodeConnManager.StoreMulti(nodeConnInitSet)
    76  	if b.updateInterval <= 0 {
    77  		return
    78  	}
    79  	ticker := time.NewTicker(b.updateInterval)
    80  	go func() {
    81  		for {
    82  			select {
    83  			case <-ticker.C:
    84  				tmp, err := b.resolve()
    85  				if err != nil {
    86  					b.logger.Error("LRPC: runtime resolve failed: %v", err)
    87  					continue
    88  				}
    89  				b.modifyNodeList(tmp)
    90  			case <-b.ctx.Done():
    91  				break
    92  			}
    93  		}
    94  	}()
    95  }
    96  
    97  func (b *balancerImpl) modifyNodeList(newNodeList container.Slice[RpcNode]) {
    98  	if newNodeList.Len() == 0 {
    99  		b.logger.Warn("LRPC: loadBalancer resolve result list length equal zero")
   100  		return
   101  	}
   102  	// 找出需要建立新连接的节点的节点, 即不存在于旧列表中的的节点, 且不重复
   103  	newNodeList.Unique()
   104  	oldList := *b.nodeList.Load()
   105  	oldCmpMap := make(map[string]struct{}, len(oldList))
   106  	for _, node := range oldList {
   107  		oldCmpMap[node.Address] = struct{}{}
   108  	}
   109  	// 旧节点和新节点之间存在映射
   110  	existMapping := make(map[string]struct{})
   111  	newNodeConnSet := make([]container.RCUMapElement[string, *nodeSource], 0, 16)
   112  	for _, newNode := range newNodeList {
   113  		_, exist := oldCmpMap[newNode.Address]
   114  		if exist {
   115  			existMapping[newNode.Address] = struct{}{}
   116  			continue
   117  		}
   118  		// 某个节点的连接建立失败不会中断整个更新过程, 而是忽略这个节点
   119  		conns, err := b.createConns(newNode, b.muxConnSize)
   120  		if err != nil {
   121  			b.logger.Warn("LRPC: loadBalancer new conn failed: %v", err)
   122  			break
   123  		}
   124  		newNodeConnSet = append(newNodeConnSet, container.RCUMapElement[string, *nodeSource]{
   125  			Key: newNode.Address,
   126  			Value: &nodeSource{
   127  				Count: 0,
   128  				Conns: conns,
   129  			},
   130  		})
   131  	}
   132  	// 准备关闭与new list不重叠的节点的连接
   133  	ableCloseNode := make([]string, 0, 16)
   134  	for _, oldNode := range oldList {
   135  		_, exist := existMapping[oldNode.Address]
   136  		if exist {
   137  			continue
   138  		}
   139  		ableCloseNode = append(ableCloseNode, oldNode.Address)
   140  	}
   141  	ableCloseNodeSource := b.nodeConnManager.StoreAndDeleteMulti(newNodeConnSet, ableCloseNode)
   142  	for _, v := range ableCloseNodeSource {
   143  		b.closeConns(v.Value.Conns)
   144  	}
   145  	// 此时旧节点列表的状态已经清理完毕
   146  	b.nodeList.Store((*[]RpcNode)(&newNodeList))
   147  }
   148  
   149  func (b *balancerImpl) createConns(node RpcNode, size int) ([]transport.ConnAdapter, error) {
   150  	conns := make([]transport.ConnAdapter, size)
   151  	for i := 0; i < size; i++ {
   152  		conn, err := b.connFactory(node)
   153  		if err != nil {
   154  			return nil, err
   155  		}
   156  		conns[i] = conn
   157  	}
   158  	return conns, nil
   159  }
   160  
   161  func (b *balancerImpl) closeConns(conns []transport.ConnAdapter) {
   162  	for _, conn := range conns {
   163  		b.closeFunc(conn)
   164  	}
   165  }
   166  
   167  func (b *balancerImpl) Scheme() string {
   168  	return b.scheme
   169  }
   170  
   171  func (b *balancerImpl) Target(service string) transport.ConnAdapter {
   172  	const (
   173  		HashSeed = 1024
   174  	)
   175  	hashCode := hash.Murmurhash3Onx8632(convert.StringToBytes(service), HashSeed)
   176  	nodeList := *b.nodeList.Load()
   177  	node := nodeList[hashCode%uint32(len(nodeList))]
   178  	src, _ := b.nodeConnManager.LoadOk(node.Address)
   179  	conn := src.Conns[atomic.LoadUint64(&src.Count)%uint64(len(src.Conns))]
   180  	atomic.AddUint64(&src.Count, 1)
   181  	return conn
   182  }