go-micro.dev/v5@v5.12.0/registry/mdns_registry.go (about)

     1  // Package mdns is a multicast dns registry
     2  package registry
     3  
     4  import (
     5  	"bytes"
     6  	"compress/zlib"
     7  	"context"
     8  	"encoding/hex"
     9  	"encoding/json"
    10  	"fmt"
    11  	"io"
    12  	"net"
    13  	"strconv"
    14  	"strings"
    15  	"sync"
    16  	"time"
    17  
    18  	"github.com/google/uuid"
    19  	log "go-micro.dev/v5/logger"
    20  	"go-micro.dev/v5/util/mdns"
    21  )
    22  
    23  var (
    24  	// use a .micro domain rather than .local.
    25  	mdnsDomain = "micro"
    26  )
    27  
    28  type mdnsTxt struct {
    29  	Metadata  map[string]string
    30  	Service   string
    31  	Version   string
    32  	Endpoints []*Endpoint
    33  }
    34  
    35  type mdnsEntry struct {
    36  	node *mdns.Server
    37  	id   string
    38  }
    39  
    40  type mdnsRegistry struct {
    41  	opts     *Options
    42  	services map[string][]*mdnsEntry
    43  
    44  	// watchers
    45  	watchers map[string]*mdnsWatcher
    46  
    47  	// listener
    48  	listener chan *mdns.ServiceEntry
    49  	// the mdns domain
    50  	domain string
    51  
    52  	mtx sync.RWMutex
    53  
    54  	sync.Mutex
    55  }
    56  
    57  type mdnsWatcher struct {
    58  	wo   WatchOptions
    59  	ch   chan *mdns.ServiceEntry
    60  	exit chan struct{}
    61  	// the registry
    62  	registry *mdnsRegistry
    63  	id       string
    64  	// the mdns domain
    65  	domain string
    66  }
    67  
    68  func encode(txt *mdnsTxt) ([]string, error) {
    69  	b, err := json.Marshal(txt)
    70  	if err != nil {
    71  		return nil, err
    72  	}
    73  
    74  	var buf bytes.Buffer
    75  	defer buf.Reset()
    76  
    77  	w := zlib.NewWriter(&buf)
    78  	if _, err := w.Write(b); err != nil {
    79  		return nil, err
    80  	}
    81  	w.Close()
    82  
    83  	encoded := hex.EncodeToString(buf.Bytes())
    84  
    85  	// individual txt limit
    86  	if len(encoded) <= 255 {
    87  		return []string{encoded}, nil
    88  	}
    89  
    90  	// split encoded string
    91  	var record []string
    92  
    93  	for len(encoded) > 255 {
    94  		record = append(record, encoded[:255])
    95  		encoded = encoded[255:]
    96  	}
    97  
    98  	record = append(record, encoded)
    99  
   100  	return record, nil
   101  }
   102  
   103  func decode(record []string) (*mdnsTxt, error) {
   104  	encoded := strings.Join(record, "")
   105  
   106  	hr, err := hex.DecodeString(encoded)
   107  	if err != nil {
   108  		return nil, err
   109  	}
   110  
   111  	br := bytes.NewReader(hr)
   112  	zr, err := zlib.NewReader(br)
   113  	if err != nil {
   114  		return nil, err
   115  	}
   116  
   117  	rbuf, err := io.ReadAll(zr)
   118  	if err != nil {
   119  		return nil, err
   120  	}
   121  
   122  	var txt *mdnsTxt
   123  
   124  	if err := json.Unmarshal(rbuf, &txt); err != nil {
   125  		return nil, err
   126  	}
   127  
   128  	return txt, nil
   129  }
   130  func newRegistry(opts ...Option) Registry {
   131  	mergedOpts := append([]Option{Timeout(time.Millisecond * 100)}, opts...)
   132  	options := NewOptions(mergedOpts...)
   133  
   134  	// set the domain
   135  	domain := mdnsDomain
   136  
   137  	d, ok := options.Context.Value("mdns.domain").(string)
   138  	if ok {
   139  		domain = d
   140  	}
   141  
   142  	return &mdnsRegistry{
   143  		opts:     options,
   144  		domain:   domain,
   145  		services: make(map[string][]*mdnsEntry),
   146  		watchers: make(map[string]*mdnsWatcher),
   147  	}
   148  }
   149  
   150  func (m *mdnsRegistry) Init(opts ...Option) error {
   151  	for _, o := range opts {
   152  		o(m.opts)
   153  	}
   154  	return nil
   155  }
   156  
   157  func (m *mdnsRegistry) Options() Options {
   158  	return *m.opts
   159  }
   160  
   161  func (m *mdnsRegistry) Register(service *Service, opts ...RegisterOption) error {
   162  	m.Lock()
   163  	defer m.Unlock()
   164  
   165  	logger := m.opts.Logger
   166  	entries, ok := m.services[service.Name]
   167  	// first entry, create wildcard used for list queries
   168  	if !ok {
   169  		s, err := mdns.NewMDNSService(
   170  			service.Name,
   171  			"_services",
   172  			m.domain+".",
   173  			"",
   174  			9999,
   175  			[]net.IP{net.ParseIP("0.0.0.0")},
   176  			nil,
   177  		)
   178  		if err != nil {
   179  			return err
   180  		}
   181  
   182  		srv, err := mdns.NewServer(&mdns.Config{Zone: &mdns.DNSSDService{MDNSService: s}})
   183  		if err != nil {
   184  			return err
   185  		}
   186  
   187  		// append the wildcard entry
   188  		entries = append(entries, &mdnsEntry{id: "*", node: srv})
   189  	}
   190  
   191  	var gerr error
   192  
   193  	for _, node := range service.Nodes {
   194  		var seen bool
   195  		var e *mdnsEntry
   196  
   197  		for _, entry := range entries {
   198  			if node.Id == entry.id {
   199  				seen = true
   200  				e = entry
   201  				break
   202  			}
   203  		}
   204  
   205  		// already registered, continue
   206  		if seen {
   207  			continue
   208  			// doesn't exist
   209  		} else {
   210  			e = &mdnsEntry{}
   211  		}
   212  
   213  		txt, err := encode(&mdnsTxt{
   214  			Service:   service.Name,
   215  			Version:   service.Version,
   216  			Endpoints: service.Endpoints,
   217  			Metadata:  node.Metadata,
   218  		})
   219  
   220  		if err != nil {
   221  			gerr = err
   222  			continue
   223  		}
   224  
   225  		host, pt, err := net.SplitHostPort(node.Address)
   226  		if err != nil {
   227  			gerr = err
   228  			continue
   229  		}
   230  		port, _ := strconv.Atoi(pt)
   231  
   232  		logger.Logf(log.DebugLevel, "[mdns] registry create new service with ip: %s for: %s", net.ParseIP(host).String(), host)
   233  
   234  		// we got here, new node
   235  		s, err := mdns.NewMDNSService(
   236  			node.Id,
   237  			service.Name,
   238  			m.domain+".",
   239  			"",
   240  			port,
   241  			[]net.IP{net.ParseIP(host)},
   242  			txt,
   243  		)
   244  		if err != nil {
   245  			gerr = err
   246  			continue
   247  		}
   248  
   249  		srv, err := mdns.NewServer(&mdns.Config{Zone: s, LocalhostChecking: true})
   250  		if err != nil {
   251  			gerr = err
   252  			continue
   253  		}
   254  
   255  		e.id = node.Id
   256  		e.node = srv
   257  		entries = append(entries, e)
   258  	}
   259  
   260  	// save
   261  	m.services[service.Name] = entries
   262  
   263  	return gerr
   264  }
   265  
   266  func (m *mdnsRegistry) Deregister(service *Service, opts ...DeregisterOption) error {
   267  	m.Lock()
   268  	defer m.Unlock()
   269  
   270  	var newEntries []*mdnsEntry
   271  
   272  	// loop existing entries, check if any match, shutdown those that do
   273  	for _, entry := range m.services[service.Name] {
   274  		var remove bool
   275  
   276  		for _, node := range service.Nodes {
   277  			if node.Id == entry.id {
   278  				entry.node.Shutdown()
   279  				remove = true
   280  				break
   281  			}
   282  		}
   283  
   284  		// keep it?
   285  		if !remove {
   286  			newEntries = append(newEntries, entry)
   287  		}
   288  	}
   289  
   290  	// last entry is the wildcard for list queries. Remove it.
   291  	if len(newEntries) == 1 && newEntries[0].id == "*" {
   292  		newEntries[0].node.Shutdown()
   293  		delete(m.services, service.Name)
   294  	} else {
   295  		m.services[service.Name] = newEntries
   296  	}
   297  
   298  	return nil
   299  }
   300  
   301  func (m *mdnsRegistry) GetService(service string, opts ...GetOption) ([]*Service, error) {
   302  	logger := m.opts.Logger
   303  	serviceMap := make(map[string]*Service)
   304  	entries := make(chan *mdns.ServiceEntry, 10)
   305  	done := make(chan bool)
   306  
   307  	p := mdns.DefaultParams(service)
   308  	// set context with timeout
   309  	var cancel context.CancelFunc
   310  	p.Context, cancel = context.WithTimeout(context.Background(), m.opts.Timeout)
   311  	defer cancel()
   312  	// set entries channel
   313  	p.Entries = entries
   314  	// set the domain
   315  	p.Domain = m.domain
   316  
   317  	go func() {
   318  		for {
   319  			select {
   320  			case e := <-entries:
   321  				// list record so skip
   322  				if p.Service == "_services" {
   323  					continue
   324  				}
   325  				if p.Domain != m.domain {
   326  					continue
   327  				}
   328  				if e.TTL == 0 {
   329  					continue
   330  				}
   331  
   332  				txt, err := decode(e.InfoFields)
   333  				if err != nil {
   334  					continue
   335  				}
   336  
   337  				if txt.Service != service {
   338  					continue
   339  				}
   340  
   341  				s, ok := serviceMap[txt.Version]
   342  				if !ok {
   343  					s = &Service{
   344  						Name:      txt.Service,
   345  						Version:   txt.Version,
   346  						Endpoints: txt.Endpoints,
   347  					}
   348  				}
   349  				addr := ""
   350  				// prefer ipv4 addrs
   351  				if len(e.AddrV4) > 0 {
   352  					addr = net.JoinHostPort(e.AddrV4.String(), fmt.Sprint(e.Port))
   353  					// else use ipv6
   354  				} else if len(e.AddrV6) > 0 {
   355  					addr = net.JoinHostPort(e.AddrV6.String(), fmt.Sprint(e.Port))
   356  				} else {
   357  					logger.Logf(log.InfoLevel, "[mdns]: invalid endpoint received: %v", e)
   358  					continue
   359  				}
   360  				s.Nodes = append(s.Nodes, &Node{
   361  					Id:       strings.TrimSuffix(e.Name, "."+p.Service+"."+p.Domain+"."),
   362  					Address:  addr,
   363  					Metadata: txt.Metadata,
   364  				})
   365  
   366  				serviceMap[txt.Version] = s
   367  			case <-p.Context.Done():
   368  				close(done)
   369  				return
   370  			}
   371  		}
   372  	}()
   373  
   374  	// execute the query
   375  	if err := mdns.Query(p); err != nil {
   376  		return nil, err
   377  	}
   378  
   379  	// wait for completion
   380  	<-done
   381  
   382  	// create list and return
   383  	services := make([]*Service, 0, len(serviceMap))
   384  
   385  	for _, service := range serviceMap {
   386  		services = append(services, service)
   387  	}
   388  
   389  	return services, nil
   390  }
   391  
   392  func (m *mdnsRegistry) ListServices(opts ...ListOption) ([]*Service, error) {
   393  	serviceMap := make(map[string]bool)
   394  	entries := make(chan *mdns.ServiceEntry, 10)
   395  	done := make(chan bool)
   396  
   397  	p := mdns.DefaultParams("_services")
   398  	// set context with timeout
   399  	var cancel context.CancelFunc
   400  	p.Context, cancel = context.WithTimeout(context.Background(), m.opts.Timeout)
   401  	defer cancel()
   402  	// set entries channel
   403  	p.Entries = entries
   404  	// set domain
   405  	p.Domain = m.domain
   406  
   407  	var services []*Service
   408  
   409  	go func() {
   410  		for {
   411  			select {
   412  			case e := <-entries:
   413  				if e.TTL == 0 {
   414  					continue
   415  				}
   416  				if !strings.HasSuffix(e.Name, p.Domain+".") {
   417  					continue
   418  				}
   419  				name := strings.TrimSuffix(e.Name, "."+p.Service+"."+p.Domain+".")
   420  				if !serviceMap[name] {
   421  					serviceMap[name] = true
   422  					services = append(services, &Service{Name: name})
   423  				}
   424  			case <-p.Context.Done():
   425  				close(done)
   426  				return
   427  			}
   428  		}
   429  	}()
   430  
   431  	// execute query
   432  	if err := mdns.Query(p); err != nil {
   433  		return nil, err
   434  	}
   435  
   436  	// wait till done
   437  	<-done
   438  
   439  	return services, nil
   440  }
   441  
   442  func (m *mdnsRegistry) Watch(opts ...WatchOption) (Watcher, error) {
   443  	var wo WatchOptions
   444  	for _, o := range opts {
   445  		o(&wo)
   446  	}
   447  
   448  	md := &mdnsWatcher{
   449  		id:       uuid.New().String(),
   450  		wo:       wo,
   451  		ch:       make(chan *mdns.ServiceEntry, 32),
   452  		exit:     make(chan struct{}),
   453  		domain:   m.domain,
   454  		registry: m,
   455  	}
   456  
   457  	m.mtx.Lock()
   458  	defer m.mtx.Unlock()
   459  
   460  	// save the watcher
   461  	m.watchers[md.id] = md
   462  
   463  	// check of the listener exists
   464  	if m.listener != nil {
   465  		return md, nil
   466  	}
   467  
   468  	// start the listener
   469  	go func() {
   470  		// go to infinity
   471  		for {
   472  			m.mtx.Lock()
   473  
   474  			// just return if there are no watchers
   475  			if len(m.watchers) == 0 {
   476  				m.listener = nil
   477  				m.mtx.Unlock()
   478  				return
   479  			}
   480  
   481  			// check existing listener
   482  			if m.listener != nil {
   483  				m.mtx.Unlock()
   484  				return
   485  			}
   486  
   487  			// reset the listener
   488  			exit := make(chan struct{})
   489  			ch := make(chan *mdns.ServiceEntry, 32)
   490  			m.listener = ch
   491  
   492  			m.mtx.Unlock()
   493  
   494  			// send messages to the watchers
   495  			go func() {
   496  				send := func(w *mdnsWatcher, e *mdns.ServiceEntry) {
   497  					select {
   498  					case w.ch <- e:
   499  					default:
   500  					}
   501  				}
   502  
   503  				for {
   504  					select {
   505  					case <-exit:
   506  						return
   507  					case e, ok := <-ch:
   508  						if !ok {
   509  							return
   510  						}
   511  						m.mtx.RLock()
   512  						// send service entry to all watchers
   513  						for _, w := range m.watchers {
   514  							send(w, e)
   515  						}
   516  						m.mtx.RUnlock()
   517  					}
   518  				}
   519  			}()
   520  
   521  			// start listening, blocking call
   522  			mdns.Listen(ch, exit)
   523  
   524  			// mdns.Listen has unblocked
   525  			// kill the saved listener
   526  			m.mtx.Lock()
   527  			m.listener = nil
   528  			close(ch)
   529  			m.mtx.Unlock()
   530  		}
   531  	}()
   532  
   533  	return md, nil
   534  }
   535  
   536  func (m *mdnsRegistry) String() string {
   537  	return "mdns"
   538  }
   539  
   540  func (m *mdnsWatcher) Next() (*Result, error) {
   541  	for {
   542  		select {
   543  		case e := <-m.ch:
   544  			txt, err := decode(e.InfoFields)
   545  			if err != nil {
   546  				continue
   547  			}
   548  
   549  			if len(txt.Service) == 0 || len(txt.Version) == 0 {
   550  				continue
   551  			}
   552  
   553  			// Filter watch options
   554  			// wo.Service: Only keep services we care about
   555  			if len(m.wo.Service) > 0 && txt.Service != m.wo.Service {
   556  				continue
   557  			}
   558  			var action string
   559  			if e.TTL == 0 {
   560  				action = "delete"
   561  			} else {
   562  				action = "create"
   563  			}
   564  
   565  			service := &Service{
   566  				Name:      txt.Service,
   567  				Version:   txt.Version,
   568  				Endpoints: txt.Endpoints,
   569  			}
   570  
   571  			// skip anything without the domain we care about
   572  			suffix := fmt.Sprintf(".%s.%s.", service.Name, m.domain)
   573  			if !strings.HasSuffix(e.Name, suffix) {
   574  				continue
   575  			}
   576  
   577  			var addr string
   578  			if len(e.AddrV4) > 0 {
   579  				addr = net.JoinHostPort(e.AddrV4.String(), fmt.Sprint(e.Port))
   580  			} else if len(e.AddrV6) > 0 {
   581  				addr = net.JoinHostPort(e.AddrV6.String(), fmt.Sprint(e.Port))
   582  			} else {
   583  				addr = e.Addr.String()
   584  			}
   585  
   586  			service.Nodes = append(service.Nodes, &Node{
   587  				Id:       strings.TrimSuffix(e.Name, suffix),
   588  				Address:  addr,
   589  				Metadata: txt.Metadata,
   590  			})
   591  
   592  			return &Result{
   593  				Action:  action,
   594  				Service: service,
   595  			}, nil
   596  		case <-m.exit:
   597  			return nil, ErrWatcherStopped
   598  		}
   599  	}
   600  }
   601  
   602  func (m *mdnsWatcher) Stop() {
   603  	select {
   604  	case <-m.exit:
   605  		return
   606  	default:
   607  		close(m.exit)
   608  		// remove self from the registry
   609  
   610  		m.registry.mtx.Lock()
   611  		delete(m.registry.watchers, m.id)
   612  		m.registry.mtx.Unlock()
   613  	}
   614  }
   615  
   616  // NewRegistry returns a new default registry which is mdns.
   617  func NewMDNSRegistry(opts ...Option) Registry {
   618  	return newRegistry(opts...)
   619  }