github.com/asynkron/protoactor-go@v0.0.0-20240308120642-ef91a6abee75/cluster/clusterproviders/zk/zk_provider.go (about)

     1  package zk
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"log/slog"
     7  	"net"
     8  	"strconv"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/asynkron/protoactor-go/cluster"
    13  	"github.com/go-zookeeper/zk"
    14  )
    15  
    16  var _ cluster.ClusterProvider = new(Provider)
    17  
    18  type RoleType int
    19  
    20  const (
    21  	Follower RoleType = iota
    22  	Leader
    23  )
    24  
    25  func (r RoleType) String() string {
    26  	if r == Leader {
    27  		return "LEADER"
    28  	}
    29  	return "FOLLOWER"
    30  }
    31  
    32  type Provider struct {
    33  	cluster             *cluster.Cluster
    34  	baseKey             string
    35  	clusterName         string
    36  	clusterKey          string
    37  	deregistered        bool
    38  	shutdown            bool
    39  	self                *Node
    40  	members             map[string]*Node // all, contains self.
    41  	clusterError        error
    42  	conn                zkConn
    43  	revision            uint64
    44  	fullpath            string
    45  	roleChangedListener RoleChangedListener
    46  	role                RoleType
    47  	roleChangedChan     chan RoleType
    48  }
    49  
    50  // New zk cluster provider with config
    51  func New(endpoints []string, opts ...Option) (*Provider, error) {
    52  	zkCfg := defaultConfig()
    53  	withEndpoints(endpoints)(zkCfg)
    54  	for _, fn := range opts {
    55  		fn(zkCfg)
    56  	}
    57  	p := &Provider{
    58  		cluster:             &cluster.Cluster{},
    59  		baseKey:             zkCfg.BaseKey,
    60  		clusterKey:          "",
    61  		clusterName:         "",
    62  		deregistered:        false,
    63  		shutdown:            false,
    64  		self:                &Node{},
    65  		members:             map[string]*Node{},
    66  		revision:            0,
    67  		fullpath:            "",
    68  		roleChangedListener: zkCfg.RoleChanged,
    69  		roleChangedChan:     make(chan RoleType, 1),
    70  		role:                Follower,
    71  	}
    72  	conn, err := connectZk(endpoints, zkCfg.SessionTimeout, WithEventCallback(p.onEvent))
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  	if auth := zkCfg.Auth; !auth.isEmpty() {
    77  		if err = conn.AddAuth(auth.Scheme, []byte(auth.Credential)); err != nil {
    78  			return nil, err
    79  		}
    80  	}
    81  	p.conn = conn
    82  
    83  	return p, nil
    84  }
    85  
    86  func (p *Provider) IsLeader() bool {
    87  	return p.role == Leader
    88  }
    89  
    90  func (p *Provider) init(c *cluster.Cluster) error {
    91  	p.cluster = c
    92  	addr := p.cluster.ActorSystem.Address()
    93  	host, port, err := splitHostPort(addr)
    94  	if err != nil {
    95  		return err
    96  	}
    97  
    98  	p.cluster = c
    99  	p.clusterName = p.cluster.Config.Name
   100  	p.clusterKey = joinPath(p.baseKey, p.clusterName)
   101  	knownKinds := c.GetClusterKinds()
   102  	nodeName := fmt.Sprintf("%v@%v:%v", p.clusterName, host, port)
   103  	p.self = NewNode(nodeName, host, port, knownKinds)
   104  	p.self.SetMeta(metaKeyID, p.getID())
   105  
   106  	if err = p.createClusterNode(p.clusterKey); err != nil {
   107  		return err
   108  	}
   109  	return nil
   110  }
   111  
   112  func (p *Provider) StartMember(c *cluster.Cluster) error {
   113  	if err := p.init(c); err != nil {
   114  		p.cluster.Logger().Error("init fail " + err.Error())
   115  		return err
   116  	}
   117  
   118  	p.startRoleChangedNotifyLoop()
   119  
   120  	// register self
   121  	if err := p.registerService(); err != nil {
   122  		p.cluster.Logger().Error("register service fail " + err.Error())
   123  		return err
   124  	}
   125  	p.cluster.Logger().Info("StartMember register service.", slog.String("node", p.self.ID), slog.String("seq", p.self.Meta[metaKeySeq]))
   126  
   127  	// fetch member list
   128  	nodes, version, err := p.fetchNodes()
   129  	if err != nil {
   130  		p.cluster.Logger().Error("fetch nodes fail " + err.Error())
   131  		return err
   132  	}
   133  	// initialize members
   134  	p.updateNodesWithSelf(nodes, version)
   135  	p.publishClusterTopologyEvent()
   136  	p.updateLeadership(nodes)
   137  	p.startWatching(true)
   138  
   139  	return nil
   140  }
   141  
   142  func (p *Provider) StartClient(c *cluster.Cluster) error {
   143  	if err := p.init(c); err != nil {
   144  		return err
   145  	}
   146  	nodes, version, err := p.fetchNodes()
   147  	if err != nil {
   148  		return err
   149  	}
   150  	// initialize members
   151  	p.updateNodes(nodes, version)
   152  	p.publishClusterTopologyEvent()
   153  	p.startWatching(false)
   154  
   155  	return nil
   156  }
   157  
   158  func (p *Provider) Shutdown(graceful bool) error {
   159  	p.shutdown = true
   160  	if !p.deregistered {
   161  		p.updateLeadership(nil)
   162  		err := p.deregisterService()
   163  		if err != nil {
   164  			p.cluster.Logger().Error("deregisterMember", slog.Any("error", err))
   165  			return err
   166  		}
   167  		p.deregistered = true
   168  	}
   169  	return nil
   170  }
   171  
   172  func (p *Provider) getID() string {
   173  	return p.self.ID
   174  }
   175  
   176  func (p *Provider) registerService() error {
   177  	data, err := p.self.Serialize()
   178  	if err != nil {
   179  		p.cluster.Logger().Error("registerService Serialize fail.", slog.Any("error", err))
   180  		return err
   181  	}
   182  
   183  	path, err := p.createEphemeralChildNode(data)
   184  	if err != nil {
   185  		p.cluster.Logger().Error("createEphemeralChildNode fail.", slog.String("node", p.clusterKey), slog.Any("error", err))
   186  		return err
   187  	}
   188  	p.fullpath = path
   189  	seq, _ := parseSeq(path)
   190  	p.self.SetMeta(metaKeySeq, intToStr(seq))
   191  	p.cluster.Logger().Info("RegisterService.", slog.String("id", p.self.ID), slog.Int("seq", seq))
   192  
   193  	return nil
   194  }
   195  
   196  func (p *Provider) createClusterNode(dir string) error {
   197  	if dir == "/" {
   198  		return nil
   199  	}
   200  	exist, _, err := p.conn.Exists(dir)
   201  	if err != nil {
   202  		p.cluster.Logger().Error("check exist of node fail", slog.String("dir", dir), slog.Any("error", err))
   203  		return err
   204  	}
   205  	if exist {
   206  		return nil
   207  	}
   208  	if err = p.createClusterNode(getParentDir(dir)); err != nil {
   209  		return err
   210  	}
   211  	if _, err = p.conn.Create(dir, []byte{}, 0, zk.WorldACL(zk.PermAll)); err != nil {
   212  		p.cluster.Logger().Error("create dir node fail", slog.String("dir", dir), slog.Any("error", err))
   213  		return err
   214  	}
   215  	return nil
   216  }
   217  
   218  func (p *Provider) deregisterService() error {
   219  	if p.fullpath != "" {
   220  		p.conn.Delete(p.fullpath, -1)
   221  	}
   222  	p.fullpath = ""
   223  	p.conn.Close()
   224  	return nil
   225  }
   226  
   227  func (p *Provider) keepWatching(ctx context.Context, registerSelf bool) error {
   228  	evtChan, err := p.addWatcher(ctx, p.clusterKey)
   229  	if err != nil {
   230  		p.cluster.Logger().Error("list children fail", slog.String("node", p.clusterKey), slog.Any("error", err))
   231  		return err
   232  	}
   233  
   234  	return p._keepWatching(registerSelf, evtChan)
   235  }
   236  
   237  func (p *Provider) addWatcher(ctx context.Context, clusterKey string) (<-chan zk.Event, error) {
   238  	_, stat, evtChan, err := p.conn.ChildrenW(clusterKey)
   239  	if err != nil {
   240  		p.cluster.Logger().Error("list children fail", slog.String("node", clusterKey), slog.Any("error", err))
   241  		return nil, err
   242  	}
   243  
   244  	p.cluster.Logger().Info("KeepWatching cluster.", slog.String("cluster", clusterKey), slog.Int("children", int(stat.NumChildren)))
   245  	if !p.isChildrenChanged(ctx, stat) {
   246  		return evtChan, nil
   247  	}
   248  
   249  	p.cluster.Logger().Info("Chilren changed, wait 1 sec and watch again", slog.Int("old_cversion", int(p.revision)), slog.Int("new_revison", int(stat.Cversion)))
   250  	time.Sleep(1 * time.Second)
   251  	nodes, version, err := p.fetchNodes()
   252  	if err != nil {
   253  		return nil, err
   254  	}
   255  	// initialize members
   256  	p.updateNodes(nodes, version)
   257  	p.publishClusterTopologyEvent()
   258  	p.updateLeadership(nodes)
   259  	return p.addWatcher(ctx, clusterKey)
   260  }
   261  
   262  func (p *Provider) isChildrenChanged(ctx context.Context, stat *zk.Stat) bool {
   263  	return stat.Cversion != int32(p.revision)
   264  }
   265  
   266  func (p *Provider) _keepWatching(registerSelf bool, stream <-chan zk.Event) error {
   267  	event := <-stream
   268  	if err := event.Err; err != nil {
   269  		p.cluster.Logger().Error("Failure watching service.", slog.Any("error", err))
   270  		if registerSelf && p.clusterNotContainsSelfPath() {
   271  			p.cluster.Logger().Info("Register info lost, register self again")
   272  			p.registerService()
   273  		}
   274  		return err
   275  	}
   276  	nodes, version, err := p.fetchNodes()
   277  	if err != nil {
   278  		p.cluster.Logger().Error("Failure fetch nodes when watching service.", slog.Any("error", err))
   279  		return err
   280  	}
   281  	if !p.containSelf(nodes) && registerSelf {
   282  		// i am lost, register self
   283  		if err = p.registerService(); err != nil {
   284  			return err
   285  		}
   286  		// reload nodes
   287  		nodes, version, err = p.fetchNodes()
   288  		if err != nil {
   289  			p.cluster.Logger().Error("Failure fetch nodes when watching service.", slog.Any("error", err))
   290  			return err
   291  		}
   292  	}
   293  	p.updateNodes(nodes, version)
   294  	p.publishClusterTopologyEvent()
   295  	if registerSelf {
   296  		p.updateLeadership(nodes)
   297  	}
   298  
   299  	return nil
   300  }
   301  
   302  func (p *Provider) clusterNotContainsSelfPath() bool {
   303  	children, _, err := p.conn.Children(p.clusterKey)
   304  	return err == nil && !stringContains(mapString(children, func(s string) string {
   305  		return joinPath(p.clusterKey, s)
   306  	}), p.fullpath)
   307  }
   308  
   309  func (p *Provider) containSelf(ns []*Node) bool {
   310  	for _, node := range ns {
   311  		if p.self != nil && node.ID == p.self.ID {
   312  			return true
   313  		}
   314  	}
   315  	return false
   316  }
   317  
   318  func (p *Provider) startRoleChangedNotifyLoop() {
   319  	go func() {
   320  		for !p.shutdown {
   321  			role := <-p.roleChangedChan
   322  			if lis := p.roleChangedListener; lis != nil {
   323  				safeRun(p.cluster.Logger(), func() { lis.OnRoleChanged(role) })
   324  			}
   325  		}
   326  	}()
   327  }
   328  
   329  func (p *Provider) updateLeadership(ns []*Node) {
   330  	role := Follower
   331  	if p.isLeaderOf(ns) {
   332  		role = Leader
   333  	}
   334  	if role != p.role {
   335  		p.cluster.Logger().Info("Role changed.", slog.String("from", p.role.String()), slog.String("to", role.String()))
   336  		p.role = role
   337  		p.roleChangedChan <- role
   338  	}
   339  }
   340  
   341  func (p *Provider) onEvent(evt zk.Event) {
   342  	if evt.Type != zk.EventSession {
   343  		return
   344  	}
   345  	switch evt.State {
   346  	case zk.StateConnecting, zk.StateDisconnected, zk.StateExpired:
   347  		if p.role == Leader {
   348  			p.role = Follower
   349  			p.roleChangedChan <- Follower
   350  		}
   351  	case zk.StateConnected, zk.StateHasSession:
   352  	}
   353  }
   354  
   355  func (p *Provider) isLeaderOf(ns []*Node) bool {
   356  	if len(ns) == 1 && p.self != nil && ns[0].ID == p.self.ID {
   357  		return true
   358  	}
   359  	var minSeq int
   360  	for _, node := range ns {
   361  		if seq := node.GetSeq(); (seq > 0 && seq < minSeq) || minSeq == 0 {
   362  			minSeq = seq
   363  		}
   364  	}
   365  	for _, node := range ns {
   366  		if p.self != nil && node.ID == p.self.ID {
   367  			return minSeq > 0 && minSeq == p.self.GetSeq()
   368  		}
   369  	}
   370  	return false
   371  }
   372  
   373  func (p *Provider) startWatching(registerSelf bool) {
   374  	ctx := context.TODO()
   375  	go func() {
   376  		for !p.shutdown {
   377  			if err := p.keepWatching(ctx, registerSelf); err != nil {
   378  				p.cluster.Logger().Error("Failed to keepWatching.", slog.Any("error", err))
   379  				p.clusterError = err
   380  			}
   381  		}
   382  	}()
   383  }
   384  
   385  // GetHealthStatus returns an error if the cluster health status has problems
   386  func (p *Provider) GetHealthStatus() error {
   387  	return p.clusterError
   388  }
   389  
   390  func (p *Provider) fetchNodes() ([]*Node, int32, error) {
   391  	children, stat, err := p.conn.Children(p.clusterKey)
   392  	if err != nil {
   393  		p.cluster.Logger().Error("FetchNodes fail.", slog.String("node", p.clusterKey), slog.Any("error", err))
   394  		return nil, 0, err
   395  	}
   396  
   397  	var nodes []*Node
   398  	for _, short := range children {
   399  		long := joinPath(p.clusterKey, short)
   400  		value, _, err := p.conn.Get(long)
   401  		if err != nil {
   402  			p.cluster.Logger().Error("FetchNodes fail.", slog.String("node", long), slog.Any("error", err))
   403  			return nil, stat.Cversion, err
   404  		}
   405  		n := Node{Meta: make(map[string]string)}
   406  		if err := n.Deserialize(value); err != nil {
   407  			p.cluster.Logger().Error("FetchNodes Deserialize fail.", slog.String("node", long), slog.String("val", string(value)), slog.Any("error", err))
   408  			return nil, stat.Cversion, err
   409  		}
   410  		seq, err := parseSeq(long)
   411  		if err != nil {
   412  			p.cluster.Logger().Error("FetchNodes parse seq fail.", slog.String("node", long), slog.String("val", string(value)), slog.Any("error", err))
   413  		} else {
   414  			n.SetMeta(metaKeySeq, intToStr(seq))
   415  		}
   416  		p.cluster.Logger().Info("FetchNodes new node.", slog.String("id", n.ID), slog.String("path", long), slog.Int("seq", seq))
   417  		nodes = append(nodes, &n)
   418  	}
   419  	return p.uniqNodes(nodes), stat.Cversion, nil
   420  }
   421  
   422  func (p *Provider) updateNodes(members []*Node, reversion int32) {
   423  	nm := make(map[string]*Node)
   424  	for _, n := range members {
   425  		nm[n.ID] = n
   426  	}
   427  	p.members = nm
   428  	p.revision = uint64(reversion)
   429  }
   430  
   431  func (p *Provider) uniqNodes(nodes []*Node) []*Node {
   432  	nodeMap := make(map[string]*Node)
   433  	for _, node := range nodes {
   434  		if n, ok := nodeMap[node.GetAddressString()]; ok {
   435  			// keep node with higher version
   436  			if node.GetSeq() > n.GetSeq() {
   437  				nodeMap[node.GetAddressString()] = node
   438  			}
   439  		} else {
   440  			nodeMap[node.GetAddressString()] = node
   441  		}
   442  	}
   443  
   444  	var out []*Node
   445  	for _, node := range nodeMap {
   446  		out = append(out, node)
   447  	}
   448  	return out
   449  }
   450  
   451  func (p *Provider) updateNodesWithSelf(members []*Node, version int32) {
   452  	p.updateNodes(members, version)
   453  	p.members[p.self.ID] = p.self
   454  }
   455  
   456  func (p *Provider) createClusterTopologyEvent() []*cluster.Member {
   457  	res := make([]*cluster.Member, len(p.members))
   458  	i := 0
   459  	for _, m := range p.members {
   460  		res[i] = m.MemberStatus()
   461  		i++
   462  	}
   463  	return res
   464  }
   465  
   466  func (p *Provider) publishClusterTopologyEvent() {
   467  	res := p.createClusterTopologyEvent()
   468  	p.cluster.Logger().Info("Update cluster.", slog.Int("members", len(res)))
   469  	p.cluster.MemberList.UpdateClusterTopology(res)
   470  }
   471  
   472  func splitHostPort(addr string) (host string, port int, err error) {
   473  	if h, p, e := net.SplitHostPort(addr); e != nil {
   474  		if addr != "nonhost" {
   475  			err = e
   476  		}
   477  		host = "nonhost"
   478  		port = -1
   479  	} else {
   480  		host = h
   481  		port, err = strconv.Atoi(p)
   482  	}
   483  	return
   484  }
   485  
   486  func (pro *Provider) createEphemeralChildNode(data []byte) (string, error) {
   487  	acl := zk.WorldACL(zk.PermAll)
   488  	prefix := joinPath(pro.clusterKey, "actor-")
   489  	path := ""
   490  	var err error
   491  	for i := 0; i < 3; i++ {
   492  		path, err = pro.conn.CreateProtectedEphemeralSequential(prefix, data, acl)
   493  		if err == zk.ErrNoNode {
   494  			// Create parent node.
   495  			parts := strings.Split(pro.clusterKey, "/")
   496  			pth := ""
   497  			for _, p := range parts[1:] {
   498  				var exists bool
   499  				pth += "/" + p
   500  				exists, _, err = pro.conn.Exists(pth)
   501  				if err != nil {
   502  					return "", err
   503  				}
   504  				if exists == true {
   505  					continue
   506  				}
   507  				_, err = pro.conn.Create(pth, []byte{}, 0, acl)
   508  				if err != nil && err != zk.ErrNodeExists {
   509  					return "", err
   510  				}
   511  			}
   512  		} else if err == nil {
   513  			break
   514  		} else {
   515  			return "", err
   516  		}
   517  	}
   518  	return path, err
   519  }