github.com/micro/go-micro/v2@v2.9.1/registry/mdns_registry.go (about)

     1  // Package mdns is a multicast dns registry
     2  package registry
     3  
     4  import (
     5  	"bytes"
     6  	"compress/zlib"
     7  	"context"
     8  	"encoding/hex"
     9  	"encoding/json"
    10  	"fmt"
    11  	"io/ioutil"
    12  	"net"
    13  	"strconv"
    14  	"strings"
    15  	"sync"
    16  	"time"
    17  
    18  	"github.com/google/uuid"
    19  	"github.com/micro/go-micro/v2/logger"
    20  	"github.com/micro/go-micro/v2/util/mdns"
    21  )
    22  
    23  var (
    24  	// use a .micro domain rather than .local
    25  	mdnsDomain = "micro"
    26  )
    27  
    28  type mdnsTxt struct {
    29  	Service   string
    30  	Version   string
    31  	Endpoints []*Endpoint
    32  	Metadata  map[string]string
    33  }
    34  
    35  type mdnsEntry struct {
    36  	id   string
    37  	node *mdns.Server
    38  }
    39  
    40  type mdnsRegistry struct {
    41  	opts Options
    42  	// the mdns domain
    43  	domain string
    44  
    45  	sync.Mutex
    46  	services map[string][]*mdnsEntry
    47  
    48  	mtx sync.RWMutex
    49  
    50  	// watchers
    51  	watchers map[string]*mdnsWatcher
    52  
    53  	// listener
    54  	listener chan *mdns.ServiceEntry
    55  }
    56  
    57  type mdnsWatcher struct {
    58  	id   string
    59  	wo   WatchOptions
    60  	ch   chan *mdns.ServiceEntry
    61  	exit chan struct{}
    62  	// the mdns domain
    63  	domain string
    64  	// the registry
    65  	registry *mdnsRegistry
    66  }
    67  
    68  func encode(txt *mdnsTxt) ([]string, error) {
    69  	b, err := json.Marshal(txt)
    70  	if err != nil {
    71  		return nil, err
    72  	}
    73  
    74  	var buf bytes.Buffer
    75  	defer buf.Reset()
    76  
    77  	w := zlib.NewWriter(&buf)
    78  	if _, err := w.Write(b); err != nil {
    79  		return nil, err
    80  	}
    81  	w.Close()
    82  
    83  	encoded := hex.EncodeToString(buf.Bytes())
    84  
    85  	// individual txt limit
    86  	if len(encoded) <= 255 {
    87  		return []string{encoded}, nil
    88  	}
    89  
    90  	// split encoded string
    91  	var record []string
    92  
    93  	for len(encoded) > 255 {
    94  		record = append(record, encoded[:255])
    95  		encoded = encoded[255:]
    96  	}
    97  
    98  	record = append(record, encoded)
    99  
   100  	return record, nil
   101  }
   102  
   103  func decode(record []string) (*mdnsTxt, error) {
   104  	encoded := strings.Join(record, "")
   105  
   106  	hr, err := hex.DecodeString(encoded)
   107  	if err != nil {
   108  		return nil, err
   109  	}
   110  
   111  	br := bytes.NewReader(hr)
   112  	zr, err := zlib.NewReader(br)
   113  	if err != nil {
   114  		return nil, err
   115  	}
   116  
   117  	rbuf, err := ioutil.ReadAll(zr)
   118  	if err != nil {
   119  		return nil, err
   120  	}
   121  
   122  	var txt *mdnsTxt
   123  
   124  	if err := json.Unmarshal(rbuf, &txt); err != nil {
   125  		return nil, err
   126  	}
   127  
   128  	return txt, nil
   129  }
   130  func newRegistry(opts ...Option) Registry {
   131  	options := Options{
   132  		Context: context.Background(),
   133  		Timeout: time.Millisecond * 100,
   134  	}
   135  
   136  	for _, o := range opts {
   137  		o(&options)
   138  	}
   139  
   140  	// set the domain
   141  	domain := mdnsDomain
   142  
   143  	d, ok := options.Context.Value("mdns.domain").(string)
   144  	if ok {
   145  		domain = d
   146  	}
   147  
   148  	return &mdnsRegistry{
   149  		opts:     options,
   150  		domain:   domain,
   151  		services: make(map[string][]*mdnsEntry),
   152  		watchers: make(map[string]*mdnsWatcher),
   153  	}
   154  }
   155  
   156  func (m *mdnsRegistry) Init(opts ...Option) error {
   157  	for _, o := range opts {
   158  		o(&m.opts)
   159  	}
   160  	return nil
   161  }
   162  
   163  func (m *mdnsRegistry) Options() Options {
   164  	return m.opts
   165  }
   166  
   167  func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error {
   168  	m.Lock()
   169  	defer m.Unlock()
   170  
   171  	entries, ok := m.services[service.Name]
   172  	// first entry, create wildcard used for list queries
   173  	if !ok {
   174  		s, err := mdns.NewMDNSService(
   175  			service.Name,
   176  			"_services",
   177  			m.domain+".",
   178  			"",
   179  			9999,
   180  			[]net.IP{net.ParseIP("0.0.0.0")},
   181  			nil,
   182  		)
   183  		if err != nil {
   184  			return err
   185  		}
   186  
   187  		srv, err := mdns.NewServer(&mdns.Config{Zone: &mdns.DNSSDService{MDNSService: s}})
   188  		if err != nil {
   189  			return err
   190  		}
   191  
   192  		// append the wildcard entry
   193  		entries = append(entries, &mdnsEntry{id: "*", node: srv})
   194  	}
   195  
   196  	var gerr error
   197  
   198  	for _, node := range service.Nodes {
   199  		var seen bool
   200  		var e *mdnsEntry
   201  
   202  		for _, entry := range entries {
   203  			if node.Id == entry.id {
   204  				seen = true
   205  				e = entry
   206  				break
   207  			}
   208  		}
   209  
   210  		// already registered, continue
   211  		if seen {
   212  			continue
   213  			// doesn't exist
   214  		} else {
   215  			e = &mdnsEntry{}
   216  		}
   217  
   218  		txt, err := encode(&mdnsTxt{
   219  			Service:   service.Name,
   220  			Version:   service.Version,
   221  			Endpoints: service.Endpoints,
   222  			Metadata:  node.Metadata,
   223  		})
   224  
   225  		if err != nil {
   226  			gerr = err
   227  			continue
   228  		}
   229  
   230  		host, pt, err := net.SplitHostPort(node.Address)
   231  		if err != nil {
   232  			gerr = err
   233  			continue
   234  		}
   235  		port, _ := strconv.Atoi(pt)
   236  
   237  		if logger.V(logger.DebugLevel, logger.DefaultLogger) {
   238  			logger.Debugf("[mdns] registry create new service with ip: %s for: %s", net.ParseIP(host).String(), host)
   239  		}
   240  		// we got here, new node
   241  		s, err := mdns.NewMDNSService(
   242  			node.Id,
   243  			service.Name,
   244  			m.domain+".",
   245  			"",
   246  			port,
   247  			[]net.IP{net.ParseIP(host)},
   248  			txt,
   249  		)
   250  		if err != nil {
   251  			gerr = err
   252  			continue
   253  		}
   254  
   255  		srv, err := mdns.NewServer(&mdns.Config{Zone: s, LocalhostChecking: true})
   256  		if err != nil {
   257  			gerr = err
   258  			continue
   259  		}
   260  
   261  		e.id = node.Id
   262  		e.node = srv
   263  		entries = append(entries, e)
   264  	}
   265  
   266  	// save
   267  	m.services[service.Name] = entries
   268  
   269  	return gerr
   270  }
   271  
   272  func (m *mdnsRegistry) Deregister(service *Service, opts ...DeregisterOption) error {
   273  	m.Lock()
   274  	defer m.Unlock()
   275  
   276  	var newEntries []*mdnsEntry
   277  
   278  	// loop existing entries, check if any match, shutdown those that do
   279  	for _, entry := range m.services[service.Name] {
   280  		var remove bool
   281  
   282  		for _, node := range service.Nodes {
   283  			if node.Id == entry.id {
   284  				entry.node.Shutdown()
   285  				remove = true
   286  				break
   287  			}
   288  		}
   289  
   290  		// keep it?
   291  		if !remove {
   292  			newEntries = append(newEntries, entry)
   293  		}
   294  	}
   295  
   296  	// last entry is the wildcard for list queries. Remove it.
   297  	if len(newEntries) == 1 && newEntries[0].id == "*" {
   298  		newEntries[0].node.Shutdown()
   299  		delete(m.services, service.Name)
   300  	} else {
   301  		m.services[service.Name] = newEntries
   302  	}
   303  
   304  	return nil
   305  }
   306  
   307  func (m *mdnsRegistry) GetService(service string, opts ...GetOption) ([]*Service, error) {
   308  	serviceMap := make(map[string]*Service)
   309  	entries := make(chan *mdns.ServiceEntry, 10)
   310  	done := make(chan bool)
   311  
   312  	p := mdns.DefaultParams(service)
   313  	// set context with timeout
   314  	var cancel context.CancelFunc
   315  	p.Context, cancel = context.WithTimeout(context.Background(), m.opts.Timeout)
   316  	defer cancel()
   317  	// set entries channel
   318  	p.Entries = entries
   319  	// set the domain
   320  	p.Domain = m.domain
   321  
   322  	go func() {
   323  		for {
   324  			select {
   325  			case e := <-entries:
   326  				// list record so skip
   327  				if p.Service == "_services" {
   328  					continue
   329  				}
   330  				if p.Domain != m.domain {
   331  					continue
   332  				}
   333  				if e.TTL == 0 {
   334  					continue
   335  				}
   336  
   337  				txt, err := decode(e.InfoFields)
   338  				if err != nil {
   339  					continue
   340  				}
   341  
   342  				if txt.Service != service {
   343  					continue
   344  				}
   345  
   346  				s, ok := serviceMap[txt.Version]
   347  				if !ok {
   348  					s = &Service{
   349  						Name:      txt.Service,
   350  						Version:   txt.Version,
   351  						Endpoints: txt.Endpoints,
   352  					}
   353  				}
   354  				addr := ""
   355  				// prefer ipv4 addrs
   356  				if len(e.AddrV4) > 0 {
   357  					addr = e.AddrV4.String()
   358  					// else use ipv6
   359  				} else if len(e.AddrV6) > 0 {
   360  					addr = "[" + e.AddrV6.String() + "]"
   361  				} else {
   362  					if logger.V(logger.InfoLevel, logger.DefaultLogger) {
   363  						logger.Infof("[mdns]: invalid endpoint received: %v", e)
   364  					}
   365  					continue
   366  				}
   367  				s.Nodes = append(s.Nodes, &Node{
   368  					Id:       strings.TrimSuffix(e.Name, "."+p.Service+"."+p.Domain+"."),
   369  					Address:  fmt.Sprintf("%s:%d", addr, e.Port),
   370  					Metadata: txt.Metadata,
   371  				})
   372  
   373  				serviceMap[txt.Version] = s
   374  			case <-p.Context.Done():
   375  				close(done)
   376  				return
   377  			}
   378  		}
   379  	}()
   380  
   381  	// execute the query
   382  	if err := mdns.Query(p); err != nil {
   383  		return nil, err
   384  	}
   385  
   386  	// wait for completion
   387  	<-done
   388  
   389  	// create list and return
   390  	services := make([]*Service, 0, len(serviceMap))
   391  
   392  	for _, service := range serviceMap {
   393  		services = append(services, service)
   394  	}
   395  
   396  	return services, nil
   397  }
   398  
   399  func (m *mdnsRegistry) ListServices(opts ...ListOption) ([]*Service, error) {
   400  	serviceMap := make(map[string]bool)
   401  	entries := make(chan *mdns.ServiceEntry, 10)
   402  	done := make(chan bool)
   403  
   404  	p := mdns.DefaultParams("_services")
   405  	// set context with timeout
   406  	var cancel context.CancelFunc
   407  	p.Context, cancel = context.WithTimeout(context.Background(), m.opts.Timeout)
   408  	defer cancel()
   409  	// set entries channel
   410  	p.Entries = entries
   411  	// set domain
   412  	p.Domain = m.domain
   413  
   414  	var services []*Service
   415  
   416  	go func() {
   417  		for {
   418  			select {
   419  			case e := <-entries:
   420  				if e.TTL == 0 {
   421  					continue
   422  				}
   423  				if !strings.HasSuffix(e.Name, p.Domain+".") {
   424  					continue
   425  				}
   426  				name := strings.TrimSuffix(e.Name, "."+p.Service+"."+p.Domain+".")
   427  				if !serviceMap[name] {
   428  					serviceMap[name] = true
   429  					services = append(services, &Service{Name: name})
   430  				}
   431  			case <-p.Context.Done():
   432  				close(done)
   433  				return
   434  			}
   435  		}
   436  	}()
   437  
   438  	// execute query
   439  	if err := mdns.Query(p); err != nil {
   440  		return nil, err
   441  	}
   442  
   443  	// wait till done
   444  	<-done
   445  
   446  	return services, nil
   447  }
   448  
   449  func (m *mdnsRegistry) Watch(opts ...WatchOption) (Watcher, error) {
   450  	var wo WatchOptions
   451  	for _, o := range opts {
   452  		o(&wo)
   453  	}
   454  
   455  	md := &mdnsWatcher{
   456  		id:       uuid.New().String(),
   457  		wo:       wo,
   458  		ch:       make(chan *mdns.ServiceEntry, 32),
   459  		exit:     make(chan struct{}),
   460  		domain:   m.domain,
   461  		registry: m,
   462  	}
   463  
   464  	m.mtx.Lock()
   465  	defer m.mtx.Unlock()
   466  
   467  	// save the watcher
   468  	m.watchers[md.id] = md
   469  
   470  	// check of the listener exists
   471  	if m.listener != nil {
   472  		return md, nil
   473  	}
   474  
   475  	// start the listener
   476  	go func() {
   477  		// go to infinity
   478  		for {
   479  			m.mtx.Lock()
   480  
   481  			// just return if there are no watchers
   482  			if len(m.watchers) == 0 {
   483  				m.listener = nil
   484  				m.mtx.Unlock()
   485  				return
   486  			}
   487  
   488  			// check existing listener
   489  			if m.listener != nil {
   490  				m.mtx.Unlock()
   491  				return
   492  			}
   493  
   494  			// reset the listener
   495  			exit := make(chan struct{})
   496  			ch := make(chan *mdns.ServiceEntry, 32)
   497  			m.listener = ch
   498  
   499  			m.mtx.Unlock()
   500  
   501  			// send messages to the watchers
   502  			go func() {
   503  				send := func(w *mdnsWatcher, e *mdns.ServiceEntry) {
   504  					select {
   505  					case w.ch <- e:
   506  					default:
   507  					}
   508  				}
   509  
   510  				for {
   511  					select {
   512  					case <-exit:
   513  						return
   514  					case e, ok := <-ch:
   515  						if !ok {
   516  							return
   517  						}
   518  						m.mtx.RLock()
   519  						// send service entry to all watchers
   520  						for _, w := range m.watchers {
   521  							send(w, e)
   522  						}
   523  						m.mtx.RUnlock()
   524  					}
   525  				}
   526  
   527  			}()
   528  
   529  			// start listening, blocking call
   530  			mdns.Listen(ch, exit)
   531  
   532  			// mdns.Listen has unblocked
   533  			// kill the saved listener
   534  			m.mtx.Lock()
   535  			m.listener = nil
   536  			close(ch)
   537  			m.mtx.Unlock()
   538  		}
   539  	}()
   540  
   541  	return md, nil
   542  }
   543  
   544  func (m *mdnsRegistry) String() string {
   545  	return "mdns"
   546  }
   547  
   548  func (m *mdnsWatcher) Next() (*Result, error) {
   549  	for {
   550  		select {
   551  		case e := <-m.ch:
   552  			txt, err := decode(e.InfoFields)
   553  			if err != nil {
   554  				continue
   555  			}
   556  
   557  			if len(txt.Service) == 0 || len(txt.Version) == 0 {
   558  				continue
   559  			}
   560  
   561  			// Filter watch options
   562  			// wo.Service: Only keep services we care about
   563  			if len(m.wo.Service) > 0 && txt.Service != m.wo.Service {
   564  				continue
   565  			}
   566  			var action string
   567  			if e.TTL == 0 {
   568  				action = "delete"
   569  			} else {
   570  				action = "create"
   571  			}
   572  
   573  			service := &Service{
   574  				Name:      txt.Service,
   575  				Version:   txt.Version,
   576  				Endpoints: txt.Endpoints,
   577  			}
   578  
   579  			// skip anything without the domain we care about
   580  			suffix := fmt.Sprintf(".%s.%s.", service.Name, m.domain)
   581  			if !strings.HasSuffix(e.Name, suffix) {
   582  				continue
   583  			}
   584  
   585  			var addr string
   586  			if len(e.AddrV4) > 0 {
   587  				addr = e.AddrV4.String()
   588  			} else if len(e.AddrV6) > 0 {
   589  				addr = "[" + e.AddrV6.String() + "]"
   590  			} else {
   591  				addr = e.Addr.String()
   592  			}
   593  
   594  			service.Nodes = append(service.Nodes, &Node{
   595  				Id:       strings.TrimSuffix(e.Name, suffix),
   596  				Address:  fmt.Sprintf("%s:%d", addr, e.Port),
   597  				Metadata: txt.Metadata,
   598  			})
   599  
   600  			return &Result{
   601  				Action:  action,
   602  				Service: service,
   603  			}, nil
   604  		case <-m.exit:
   605  			return nil, ErrWatcherStopped
   606  		}
   607  	}
   608  }
   609  
   610  func (m *mdnsWatcher) Stop() {
   611  	select {
   612  	case <-m.exit:
   613  		return
   614  	default:
   615  		close(m.exit)
   616  		// remove self from the registry
   617  		m.registry.mtx.Lock()
   618  		delete(m.registry.watchers, m.id)
   619  		m.registry.mtx.Unlock()
   620  	}
   621  }
   622  
   623  // NewRegistry returns a new default registry which is mdns
   624  func NewRegistry(opts ...Option) Registry {
   625  	return newRegistry(opts...)
   626  }