github.com/shuguocloud/go-zero@v1.3.0/core/discov/internal/registry.go (about)

     1  package internal
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"sort"
     8  	"strings"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/shuguocloud/go-zero/core/contextx"
    13  	"github.com/shuguocloud/go-zero/core/lang"
    14  	"github.com/shuguocloud/go-zero/core/logx"
    15  	"github.com/shuguocloud/go-zero/core/syncx"
    16  	"github.com/shuguocloud/go-zero/core/threading"
    17  	clientv3 "go.etcd.io/etcd/client/v3"
    18  )
    19  
    20  var (
    21  	registry = Registry{
    22  		clusters: make(map[string]*cluster),
    23  	}
    24  	connManager = syncx.NewResourceManager()
    25  )
    26  
    27  // A Registry is a registry that manages the etcd client connections.
    28  type Registry struct {
    29  	clusters map[string]*cluster
    30  	lock     sync.Mutex
    31  }
    32  
    33  // GetRegistry returns a global Registry.
    34  func GetRegistry() *Registry {
    35  	return &registry
    36  }
    37  
    38  // GetConn returns an etcd client connection associated with given endpoints.
    39  func (r *Registry) GetConn(endpoints []string) (EtcdClient, error) {
    40  	c, _ := r.getCluster(endpoints)
    41  	return c.getClient()
    42  }
    43  
    44  // Monitor monitors the key on given etcd endpoints, notify with the given UpdateListener.
    45  func (r *Registry) Monitor(endpoints []string, key string, l UpdateListener) error {
    46  	c, exists := r.getCluster(endpoints)
    47  	// if exists, the existing values should be updated to the listener.
    48  	if exists {
    49  		kvs := c.getCurrent(key)
    50  		for _, kv := range kvs {
    51  			l.OnAdd(kv)
    52  		}
    53  	}
    54  
    55  	return c.monitor(key, l)
    56  }
    57  
    58  func (r *Registry) getCluster(endpoints []string) (c *cluster, exists bool) {
    59  	clusterKey := getClusterKey(endpoints)
    60  	r.lock.Lock()
    61  	defer r.lock.Unlock()
    62  	c, exists = r.clusters[clusterKey]
    63  	if !exists {
    64  		c = newCluster(endpoints)
    65  		r.clusters[clusterKey] = c
    66  	}
    67  
    68  	return
    69  }
    70  
    71  type cluster struct {
    72  	endpoints  []string
    73  	key        string
    74  	values     map[string]map[string]string
    75  	listeners  map[string][]UpdateListener
    76  	watchGroup *threading.RoutineGroup
    77  	done       chan lang.PlaceholderType
    78  	lock       sync.Mutex
    79  }
    80  
    81  func newCluster(endpoints []string) *cluster {
    82  	return &cluster{
    83  		endpoints:  endpoints,
    84  		key:        getClusterKey(endpoints),
    85  		values:     make(map[string]map[string]string),
    86  		listeners:  make(map[string][]UpdateListener),
    87  		watchGroup: threading.NewRoutineGroup(),
    88  		done:       make(chan lang.PlaceholderType),
    89  	}
    90  }
    91  
    92  func (c *cluster) context(cli EtcdClient) context.Context {
    93  	return contextx.ValueOnlyFrom(cli.Ctx())
    94  }
    95  
    96  func (c *cluster) getClient() (EtcdClient, error) {
    97  	val, err := connManager.GetResource(c.key, func() (io.Closer, error) {
    98  		return c.newClient()
    99  	})
   100  	if err != nil {
   101  		return nil, err
   102  	}
   103  
   104  	return val.(EtcdClient), nil
   105  }
   106  
   107  func (c *cluster) getCurrent(key string) []KV {
   108  	c.lock.Lock()
   109  	defer c.lock.Unlock()
   110  
   111  	var kvs []KV
   112  	for k, v := range c.values[key] {
   113  		kvs = append(kvs, KV{
   114  			Key: k,
   115  			Val: v,
   116  		})
   117  	}
   118  
   119  	return kvs
   120  }
   121  
   122  func (c *cluster) handleChanges(key string, kvs []KV) {
   123  	var add []KV
   124  	var remove []KV
   125  	c.lock.Lock()
   126  	listeners := append([]UpdateListener(nil), c.listeners[key]...)
   127  	vals, ok := c.values[key]
   128  	if !ok {
   129  		add = kvs
   130  		vals = make(map[string]string)
   131  		for _, kv := range kvs {
   132  			vals[kv.Key] = kv.Val
   133  		}
   134  		c.values[key] = vals
   135  	} else {
   136  		m := make(map[string]string)
   137  		for _, kv := range kvs {
   138  			m[kv.Key] = kv.Val
   139  		}
   140  		for k, v := range vals {
   141  			if val, ok := m[k]; !ok || v != val {
   142  				remove = append(remove, KV{
   143  					Key: k,
   144  					Val: v,
   145  				})
   146  			}
   147  		}
   148  		for k, v := range m {
   149  			if val, ok := vals[k]; !ok || v != val {
   150  				add = append(add, KV{
   151  					Key: k,
   152  					Val: v,
   153  				})
   154  			}
   155  		}
   156  		c.values[key] = m
   157  	}
   158  	c.lock.Unlock()
   159  
   160  	for _, kv := range add {
   161  		for _, l := range listeners {
   162  			l.OnAdd(kv)
   163  		}
   164  	}
   165  	for _, kv := range remove {
   166  		for _, l := range listeners {
   167  			l.OnDelete(kv)
   168  		}
   169  	}
   170  }
   171  
   172  func (c *cluster) handleWatchEvents(key string, events []*clientv3.Event) {
   173  	c.lock.Lock()
   174  	listeners := append([]UpdateListener(nil), c.listeners[key]...)
   175  	c.lock.Unlock()
   176  
   177  	for _, ev := range events {
   178  		switch ev.Type {
   179  		case clientv3.EventTypePut:
   180  			c.lock.Lock()
   181  			if vals, ok := c.values[key]; ok {
   182  				vals[string(ev.Kv.Key)] = string(ev.Kv.Value)
   183  			} else {
   184  				c.values[key] = map[string]string{string(ev.Kv.Key): string(ev.Kv.Value)}
   185  			}
   186  			c.lock.Unlock()
   187  			for _, l := range listeners {
   188  				l.OnAdd(KV{
   189  					Key: string(ev.Kv.Key),
   190  					Val: string(ev.Kv.Value),
   191  				})
   192  			}
   193  		case clientv3.EventTypeDelete:
   194  			if vals, ok := c.values[key]; ok {
   195  				delete(vals, string(ev.Kv.Key))
   196  			}
   197  			for _, l := range listeners {
   198  				l.OnDelete(KV{
   199  					Key: string(ev.Kv.Key),
   200  					Val: string(ev.Kv.Value),
   201  				})
   202  			}
   203  		default:
   204  			logx.Errorf("Unknown event type: %v", ev.Type)
   205  		}
   206  	}
   207  }
   208  
   209  func (c *cluster) load(cli EtcdClient, key string) {
   210  	var resp *clientv3.GetResponse
   211  	for {
   212  		var err error
   213  		ctx, cancel := context.WithTimeout(c.context(cli), RequestTimeout)
   214  		resp, err = cli.Get(ctx, makeKeyPrefix(key), clientv3.WithPrefix())
   215  		cancel()
   216  		if err == nil {
   217  			break
   218  		}
   219  
   220  		logx.Error(err)
   221  		time.Sleep(coolDownInterval)
   222  	}
   223  
   224  	var kvs []KV
   225  	for _, ev := range resp.Kvs {
   226  		kvs = append(kvs, KV{
   227  			Key: string(ev.Key),
   228  			Val: string(ev.Value),
   229  		})
   230  	}
   231  
   232  	c.handleChanges(key, kvs)
   233  }
   234  
   235  func (c *cluster) monitor(key string, l UpdateListener) error {
   236  	c.lock.Lock()
   237  	c.listeners[key] = append(c.listeners[key], l)
   238  	c.lock.Unlock()
   239  
   240  	cli, err := c.getClient()
   241  	if err != nil {
   242  		return err
   243  	}
   244  
   245  	c.load(cli, key)
   246  	c.watchGroup.Run(func() {
   247  		c.watch(cli, key)
   248  	})
   249  
   250  	return nil
   251  }
   252  
   253  func (c *cluster) newClient() (EtcdClient, error) {
   254  	cli, err := NewClient(c.endpoints)
   255  	if err != nil {
   256  		return nil, err
   257  	}
   258  
   259  	go c.watchConnState(cli)
   260  
   261  	return cli, nil
   262  }
   263  
   264  func (c *cluster) reload(cli EtcdClient) {
   265  	c.lock.Lock()
   266  	close(c.done)
   267  	c.watchGroup.Wait()
   268  	c.done = make(chan lang.PlaceholderType)
   269  	c.watchGroup = threading.NewRoutineGroup()
   270  	var keys []string
   271  	for k := range c.listeners {
   272  		keys = append(keys, k)
   273  	}
   274  	c.lock.Unlock()
   275  
   276  	for _, key := range keys {
   277  		k := key
   278  		c.watchGroup.Run(func() {
   279  			c.load(cli, k)
   280  			c.watch(cli, k)
   281  		})
   282  	}
   283  }
   284  
   285  func (c *cluster) watch(cli EtcdClient, key string) {
   286  	for {
   287  		if c.watchStream(cli, key) {
   288  			return
   289  		}
   290  	}
   291  }
   292  
   293  func (c *cluster) watchStream(cli EtcdClient, key string) bool {
   294  	rch := cli.Watch(clientv3.WithRequireLeader(c.context(cli)), makeKeyPrefix(key), clientv3.WithPrefix())
   295  	for {
   296  		select {
   297  		case wresp, ok := <-rch:
   298  			if !ok {
   299  				logx.Error("etcd monitor chan has been closed")
   300  				return false
   301  			}
   302  			if wresp.Canceled {
   303  				logx.Errorf("etcd monitor chan has been canceled, error: %v", wresp.Err())
   304  				return false
   305  			}
   306  			if wresp.Err() != nil {
   307  				logx.Error(fmt.Sprintf("etcd monitor chan error: %v", wresp.Err()))
   308  				return false
   309  			}
   310  
   311  			c.handleWatchEvents(key, wresp.Events)
   312  		case <-c.done:
   313  			return true
   314  		}
   315  	}
   316  }
   317  
   318  func (c *cluster) watchConnState(cli EtcdClient) {
   319  	watcher := newStateWatcher()
   320  	watcher.addListener(func() {
   321  		go c.reload(cli)
   322  	})
   323  	watcher.watch(cli.ActiveConnection())
   324  }
   325  
   326  // DialClient dials an etcd cluster with given endpoints.
   327  func DialClient(endpoints []string) (EtcdClient, error) {
   328  	cfg := clientv3.Config{
   329  		Endpoints:            endpoints,
   330  		AutoSyncInterval:     autoSyncInterval,
   331  		DialTimeout:          DialTimeout,
   332  		DialKeepAliveTime:    dialKeepAliveTime,
   333  		DialKeepAliveTimeout: DialTimeout,
   334  		RejectOldCluster:     true,
   335  	}
   336  	if account, ok := GetAccount(endpoints); ok {
   337  		cfg.Username = account.User
   338  		cfg.Password = account.Pass
   339  	}
   340  	if tlsCfg, ok := GetTLS(endpoints); ok {
   341  		cfg.TLS = tlsCfg
   342  	}
   343  
   344  	return clientv3.New(cfg)
   345  }
   346  
   347  func getClusterKey(endpoints []string) string {
   348  	sort.Strings(endpoints)
   349  	return strings.Join(endpoints, endpointsSeparator)
   350  }
   351  
   352  func makeKeyPrefix(key string) string {
   353  	return fmt.Sprintf("%s%c", key, Delimiter)
   354  }