github.com/asynkron/protoactor-go@v0.0.0-20240308120642-ef91a6abee75/cluster/clusterproviders/etcd/etcd_provider.go (about)

     1  package etcd
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"log/slog"
     7  	"net"
     8  	"strconv"
     9  	"strings"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"github.com/asynkron/protoactor-go/cluster"
    14  	clientv3 "go.etcd.io/etcd/client/v3"
    15  )
    16  
    17  type Provider struct {
    18  	leaseID       clientv3.LeaseID
    19  	cluster       *cluster.Cluster
    20  	baseKey       string
    21  	clusterName   string
    22  	deregistered  bool
    23  	shutdown      bool
    24  	self          *Node
    25  	members       map[string]*Node // all, contains self.
    26  	clusterError  error
    27  	client        *clientv3.Client
    28  	cancelWatch   func()
    29  	cancelWatchCh chan bool
    30  	keepAliveTTL  time.Duration
    31  	retryInterval time.Duration
    32  	revision      uint64
    33  	// deregisterCritical time.Duration
    34  }
    35  
    36  func New() (*Provider, error) {
    37  	return NewWithConfig("/protoactor", clientv3.Config{
    38  		Endpoints:   []string{"127.0.0.1:2379"},
    39  		DialTimeout: time.Second * 5,
    40  	})
    41  }
    42  
    43  func NewWithConfig(baseKey string, cfg clientv3.Config) (*Provider, error) {
    44  	client, err := clientv3.New(cfg)
    45  	if err != nil {
    46  		return nil, err
    47  	}
    48  	p := &Provider{
    49  		client:        client,
    50  		keepAliveTTL:  3 * time.Second,
    51  		retryInterval: 1 * time.Second,
    52  		baseKey:       baseKey,
    53  		members:       map[string]*Node{},
    54  		cancelWatchCh: make(chan bool),
    55  	}
    56  	return p, nil
    57  }
    58  
    59  func (p *Provider) init(c *cluster.Cluster) error {
    60  	p.cluster = c
    61  	addr := p.cluster.ActorSystem.Address()
    62  	host, port, err := splitHostPort(addr)
    63  	if err != nil {
    64  		return err
    65  	}
    66  
    67  	p.cluster = c
    68  	p.clusterName = p.cluster.Config.Name
    69  	memberID := p.cluster.ActorSystem.ID
    70  	knownKinds := c.GetClusterKinds()
    71  	nodeName := fmt.Sprintf("%v@%v", p.clusterName, memberID)
    72  	p.self = NewNode(nodeName, host, port, knownKinds)
    73  	p.self.SetMeta("id", p.getID())
    74  	return nil
    75  }
    76  
    77  func (p *Provider) StartMember(c *cluster.Cluster) error {
    78  	if err := p.init(c); err != nil {
    79  		return err
    80  	}
    81  
    82  	// fetch memberlist
    83  	nodes, err := p.fetchNodes()
    84  	if err != nil {
    85  		return err
    86  	}
    87  	// initialize members
    88  	p.updateNodesWithSelf(nodes)
    89  	p.publishClusterTopologyEvent()
    90  	p.startWatching()
    91  
    92  	// register self
    93  	if err := p.registerService(); err != nil {
    94  		return err
    95  	}
    96  	ctx := context.TODO()
    97  	p.startKeepAlive(ctx)
    98  	return nil
    99  }
   100  
   101  func (p *Provider) StartClient(c *cluster.Cluster) error {
   102  	if err := p.init(c); err != nil {
   103  		return err
   104  	}
   105  	nodes, err := p.fetchNodes()
   106  	if err != nil {
   107  		return err
   108  	}
   109  	// initialize members
   110  	p.updateNodes(nodes)
   111  	p.publishClusterTopologyEvent()
   112  	p.startWatching()
   113  	return nil
   114  }
   115  
   116  func (p *Provider) Shutdown(graceful bool) error {
   117  	p.shutdown = true
   118  	if !p.deregistered {
   119  		err := p.deregisterService()
   120  		if err != nil {
   121  			p.cluster.Logger().Error("deregisterMember", slog.Any("error", err))
   122  			return err
   123  		}
   124  		p.deregistered = true
   125  	}
   126  	if p.cancelWatch != nil {
   127  		p.cancelWatch()
   128  		p.cancelWatch = nil
   129  	}
   130  	return nil
   131  }
   132  
   133  func (p *Provider) keepAliveForever(ctx context.Context) error {
   134  	if p.self == nil {
   135  		return fmt.Errorf("keepalive must be after initialize")
   136  	}
   137  
   138  	data, err := p.self.Serialize()
   139  	if err != nil {
   140  		return err
   141  	}
   142  	fullKey := p.getEtcdKey()
   143  
   144  	var leaseId clientv3.LeaseID
   145  	leaseId, err = p.newLeaseID()
   146  	if err != nil {
   147  		return err
   148  	}
   149  	p.setLeaseID(leaseId)
   150  
   151  	if leaseId <= 0 {
   152  		return fmt.Errorf("grant lease failed. leaseId=%d", leaseId)
   153  	}
   154  	_, err = p.client.Put(context.TODO(), fullKey, string(data), clientv3.WithLease(leaseId))
   155  	if err != nil {
   156  		return err
   157  	}
   158  	kaRespCh, err := p.client.KeepAlive(context.TODO(), leaseId)
   159  	if err != nil {
   160  		return err
   161  	}
   162  
   163  	for resp := range kaRespCh {
   164  		if resp == nil {
   165  			return fmt.Errorf("keep alive failed. resp=%s", resp.String())
   166  		}
   167  		// plog.Infof("keep alive %s ttl=%d", p.getID(), resp.TTL)
   168  		if p.shutdown {
   169  			return nil
   170  		}
   171  	}
   172  	return nil
   173  }
   174  
   175  func (p *Provider) startKeepAlive(ctx context.Context) {
   176  	go func() {
   177  		for !p.shutdown {
   178  			if err := ctx.Err(); err != nil {
   179  				p.cluster.Logger().Info("Keepalive was stopped.", slog.Any("error", err))
   180  				return
   181  			}
   182  
   183  			if err := p.keepAliveForever(ctx); err != nil {
   184  				p.cluster.Logger().Info("Failure refreshing service TTL. ReTrying...", slog.Duration("after", p.retryInterval), slog.Any("error", err))
   185  			}
   186  			time.Sleep(p.retryInterval)
   187  		}
   188  	}()
   189  }
   190  
   191  func (p *Provider) getID() string {
   192  	return p.self.ID
   193  }
   194  
   195  func (p *Provider) getEtcdKey() string {
   196  	return p.buildKey(p.clusterName, p.getID())
   197  }
   198  
   199  func (p *Provider) registerService() error {
   200  	data, err := p.self.Serialize()
   201  	if err != nil {
   202  		return err
   203  	}
   204  	fullKey := p.getEtcdKey()
   205  	if err != nil {
   206  		return err
   207  	}
   208  	leaseId := p.getLeaseID()
   209  	if leaseId <= 0 {
   210  		_leaseId, err := p.newLeaseID()
   211  		if err != nil {
   212  			return err
   213  		}
   214  		leaseId = _leaseId
   215  		p.setLeaseID(leaseId)
   216  	}
   217  	_, err = p.client.Put(context.TODO(), fullKey, string(data), clientv3.WithLease(leaseId))
   218  	if err != nil {
   219  		return err
   220  	}
   221  	return nil
   222  }
   223  
   224  func (p *Provider) deregisterService() error {
   225  	fullKey := p.getEtcdKey()
   226  	_, err := p.client.Delete(context.TODO(), fullKey)
   227  	return err
   228  }
   229  
   230  func (p *Provider) handleWatchResponse(resp clientv3.WatchResponse) map[string]*Node {
   231  	changes := map[string]*Node{}
   232  	for _, ev := range resp.Events {
   233  		key := string(ev.Kv.Key)
   234  		nodeId, err := getNodeID(key, "/")
   235  		if err != nil {
   236  			p.cluster.Logger().Error("Invalid member.", slog.String("key", key))
   237  			continue
   238  		}
   239  
   240  		switch ev.Type {
   241  		case clientv3.EventTypePut:
   242  			node, err := NewNodeFromBytes(ev.Kv.Value)
   243  			if err != nil {
   244  				p.cluster.Logger().Error("Invalid member.", slog.String("key", key))
   245  				continue
   246  			}
   247  			if p.self.Equal(node) {
   248  				p.cluster.Logger().Debug("Skip self.", slog.String("key", key))
   249  				continue
   250  			}
   251  			if _, ok := p.members[nodeId]; ok {
   252  				p.cluster.Logger().Debug("Update member.", slog.String("key", key))
   253  			} else {
   254  				p.cluster.Logger().Debug("New member.", slog.String("key", key))
   255  			}
   256  			changes[nodeId] = node
   257  		case clientv3.EventTypeDelete:
   258  			node, ok := p.members[nodeId]
   259  			if !ok {
   260  				continue
   261  			}
   262  			p.cluster.Logger().Debug("Delete member.", slog.String("key", key))
   263  			cloned := *node
   264  			cloned.SetAlive(false)
   265  			changes[nodeId] = &cloned
   266  		default:
   267  			p.cluster.Logger().Error("Invalid etcd event.type.", slog.String("key", key),
   268  				slog.String("type", ev.Type.String()))
   269  		}
   270  	}
   271  	p.revision = uint64(resp.Header.GetRevision())
   272  	return changes
   273  }
   274  
   275  func (p *Provider) keepWatching(ctx context.Context) error {
   276  	clusterKey := p.buildKey(p.clusterName)
   277  	stream := p.client.Watch(ctx, clusterKey, clientv3.WithPrefix())
   278  	return p._keepWatching(stream)
   279  }
   280  
   281  func (p *Provider) _keepWatching(stream clientv3.WatchChan) error {
   282  	for resp := range stream {
   283  		if err := resp.Err(); err != nil {
   284  			p.cluster.Logger().Error("Failure watching service.")
   285  			return err
   286  		}
   287  		if len(resp.Events) <= 0 {
   288  			p.cluster.Logger().Error("Empty etcd.events.", slog.Int("events", len(resp.Events)))
   289  			continue
   290  		}
   291  		nodesChanges := p.handleWatchResponse(resp)
   292  		p.updateNodesWithChanges(nodesChanges)
   293  		p.publishClusterTopologyEvent()
   294  	}
   295  	return nil
   296  }
   297  
   298  func (p *Provider) startWatching() {
   299  	ctx := context.TODO()
   300  	ctx, cancel := context.WithCancel(ctx)
   301  	p.cancelWatch = cancel
   302  	go func() {
   303  		for !p.shutdown {
   304  			if err := p.keepWatching(ctx); err != nil {
   305  				p.cluster.Logger().Error("Failed to keepWatching.", slog.Any("error", err))
   306  				p.clusterError = err
   307  			}
   308  		}
   309  	}()
   310  }
   311  
   312  // GetHealthStatus returns an error if the cluster health status has problems
   313  func (p *Provider) GetHealthStatus() error {
   314  	return p.clusterError
   315  }
   316  
   317  func newContext(timeout time.Duration) (context.Context, context.CancelFunc) {
   318  	return context.WithTimeout(context.TODO(), timeout)
   319  }
   320  
   321  func (p *Provider) buildKey(names ...string) string {
   322  	return strings.Join(append([]string{p.baseKey}, names...), "/")
   323  }
   324  
   325  func (p *Provider) fetchNodes() ([]*Node, error) {
   326  	key := p.buildKey(p.clusterName)
   327  	resp, err := p.client.Get(context.TODO(), key, clientv3.WithPrefix())
   328  	if err != nil {
   329  		return nil, err
   330  	}
   331  	var nodes []*Node
   332  	for _, v := range resp.Kvs {
   333  		n := Node{}
   334  		if err := n.Deserialize(v.Value); err != nil {
   335  			return nil, err
   336  		}
   337  		nodes = append(nodes, &n)
   338  	}
   339  	p.revision = uint64(resp.Header.GetRevision())
   340  	// plog.Debug("fetch nodes",
   341  	// 	log.Uint64("raft term", resp.Header.GetRaftTerm()),
   342  	// 	log.Int64("revision", resp.Header.GetRevision()))
   343  	return nodes, nil
   344  }
   345  
   346  func (p *Provider) updateNodes(members []*Node) {
   347  	for _, n := range members {
   348  		p.members[n.ID] = n
   349  	}
   350  }
   351  
   352  func (p *Provider) updateNodesWithSelf(members []*Node) {
   353  	p.updateNodes(members)
   354  	p.members[p.self.ID] = p.self
   355  }
   356  
   357  func (p *Provider) updateNodesWithChanges(changes map[string]*Node) {
   358  	for memberId, member := range changes {
   359  		p.members[memberId] = member
   360  		if !member.IsAlive() {
   361  			delete(p.members, memberId)
   362  		}
   363  	}
   364  }
   365  
   366  func (p *Provider) createClusterTopologyEvent() []*cluster.Member {
   367  	res := make([]*cluster.Member, len(p.members))
   368  	i := 0
   369  	for _, m := range p.members {
   370  		res[i] = m.MemberStatus()
   371  		i++
   372  	}
   373  	return res
   374  }
   375  
   376  func (p *Provider) publishClusterTopologyEvent() {
   377  	res := p.createClusterTopologyEvent()
   378  	p.cluster.Logger().Info("Update cluster.", slog.Int("members", len(res)))
   379  	// for _, m := range res {
   380  	// 	plog.Info("\t", log.Object("member", m))
   381  	// }
   382  	p.cluster.MemberList.UpdateClusterTopology(res)
   383  	// p.cluster.ActorSystem.EventStream.Publish(res)
   384  }
   385  
   386  func (p *Provider) getLeaseID() clientv3.LeaseID {
   387  	return (clientv3.LeaseID)(atomic.LoadInt64((*int64)(&p.leaseID)))
   388  }
   389  
   390  func (p *Provider) setLeaseID(leaseID clientv3.LeaseID) {
   391  	atomic.StoreInt64((*int64)(&p.leaseID), (int64)(leaseID))
   392  }
   393  
   394  func (p *Provider) newLeaseID() (clientv3.LeaseID, error) {
   395  	ttlSecs := int64(p.keepAliveTTL / time.Second)
   396  	resp, err := p.client.Grant(context.TODO(), ttlSecs)
   397  	if err != nil {
   398  		return 0, err
   399  	}
   400  	return resp.ID, nil
   401  }
   402  
   403  func splitHostPort(addr string) (host string, port int, err error) {
   404  	if h, p, e := net.SplitHostPort(addr); e != nil {
   405  		if addr != "nonhost" {
   406  			err = e
   407  		}
   408  		host = "nonhost"
   409  		port = -1
   410  	} else {
   411  		host = h
   412  		port, err = strconv.Atoi(p)
   413  	}
   414  	return
   415  }