github.com/volts-dev/volts@v0.0.0-20240120094013-5e9c65924106/registry/mdns/mdns.go (about)

     1  package mdns
     2  
     3  import (
     4  	"bytes"
     5  	"compress/zlib"
     6  	"context"
     7  	"encoding/hex"
     8  	"encoding/json"
     9  	"fmt"
    10  	"io/ioutil"
    11  	"net"
    12  	"strconv"
    13  	"strings"
    14  	"sync"
    15  	"time"
    16  
    17  	"github.com/google/uuid"
    18  	"github.com/volts-dev/volts/internal/mdns"
    19  	"github.com/volts-dev/volts/registry"
    20  )
    21  
    22  var log = registry.Logger()
    23  
    24  type (
    25  	mdnsRegistry struct {
    26  		sync.Mutex
    27  		mtx      sync.RWMutex
    28  		domain   string // the mdns domain
    29  		config   *registry.Config
    30  		services map[string][]*mdnsEntry
    31  		watchers map[string]*mdnsWatcher // watchers
    32  		listener chan *mdns.ServiceEntry // listener
    33  	}
    34  
    35  	mdnsEntry struct {
    36  		id   string
    37  		node *mdns.Server
    38  	}
    39  
    40  	mdnsWatcher struct {
    41  		id   string
    42  		wo   *registry.WatchConfig
    43  		ch   chan *mdns.ServiceEntry
    44  		exit chan struct{}
    45  		// the mdns domain
    46  		domain string
    47  		// the registry
    48  		registry *mdnsRegistry
    49  	}
    50  
    51  	mdnsTxt struct {
    52  		Service   string
    53  		Version   string
    54  		Endpoints []*registry.Endpoint
    55  		Metadata  map[string]string
    56  	}
    57  )
    58  
    59  func init() {
    60  	registry.Register("mdns", New)
    61  }
    62  
    63  // NewMdnsRegistry
    64  func New(opts ...registry.Option) registry.IRegistry {
    65  	var defaultOpts []registry.Option
    66  	defaultOpts = append(defaultOpts,
    67  		registry.WithName("mdns"),
    68  		registry.Timeout(time.Millisecond*100),
    69  	)
    70  	cfg := registry.NewConfig(append(defaultOpts, opts...)...)
    71  	// set the domain
    72  	domain := mdnsDomain
    73  	d, ok := cfg.Context.Value("mdns.domain").(string)
    74  	if ok {
    75  		domain = d
    76  	}
    77  
    78  	reg := &mdnsRegistry{
    79  		domain:   domain,
    80  		config:   cfg,
    81  		services: make(map[string][]*mdnsEntry),
    82  		watchers: make(map[string]*mdnsWatcher),
    83  		listener: make(chan *mdns.ServiceEntry),
    84  	}
    85  	reg.config.Name = reg.String()
    86  	return reg
    87  }
    88  
    89  func encode(txt *mdnsTxt) ([]string, error) {
    90  	b, err := json.Marshal(txt)
    91  	if err != nil {
    92  		return nil, err
    93  	}
    94  
    95  	var buf bytes.Buffer
    96  	defer buf.Reset()
    97  
    98  	w := zlib.NewWriter(&buf)
    99  	if _, err := w.Write(b); err != nil {
   100  		return nil, err
   101  	}
   102  	w.Close()
   103  
   104  	encoded := hex.EncodeToString(buf.Bytes())
   105  
   106  	// individual txt limit
   107  	if len(encoded) <= 255 {
   108  		return []string{encoded}, nil
   109  	}
   110  
   111  	// split encoded string
   112  	var record []string
   113  
   114  	for len(encoded) > 255 {
   115  		record = append(record, encoded[:255])
   116  		encoded = encoded[255:]
   117  	}
   118  
   119  	record = append(record, encoded)
   120  
   121  	return record, nil
   122  }
   123  
   124  func decode(record []string) (*mdnsTxt, error) {
   125  	encoded := strings.Join(record, "")
   126  
   127  	hr, err := hex.DecodeString(encoded)
   128  	if err != nil {
   129  		return nil, err
   130  	}
   131  
   132  	br := bytes.NewReader(hr)
   133  	zr, err := zlib.NewReader(br)
   134  	if err != nil {
   135  		return nil, err
   136  	}
   137  
   138  	rbuf, err := ioutil.ReadAll(zr)
   139  	if err != nil {
   140  		return nil, err
   141  	}
   142  
   143  	var txt *mdnsTxt
   144  
   145  	if err := json.Unmarshal(rbuf, &txt); err != nil {
   146  		return nil, err
   147  	}
   148  
   149  	return txt, nil
   150  }
   151  
   152  func (self *mdnsRegistry) Init(opts ...registry.Option) error {
   153  	for _, o := range opts {
   154  		o(self.config)
   155  	}
   156  	return nil
   157  }
   158  
   159  func (self *mdnsRegistry) Config() *registry.Config {
   160  	return self.config
   161  }
   162  
   163  func (m *mdnsRegistry) Register(service *registry.Service, opts ...registry.Option) error {
   164  	m.Lock()
   165  	defer m.Unlock()
   166  
   167  	entries, ok := m.services[service.Name]
   168  	// first entry, create wildcard used for list queries
   169  	if !ok {
   170  		s, err := mdns.NewMDNSService(
   171  			service.Name,
   172  			"_services",
   173  			m.domain+".",
   174  			"",
   175  			9999,
   176  			[]net.IP{net.ParseIP("0.0.0.0")},
   177  			nil,
   178  		)
   179  		if err != nil {
   180  			return err
   181  		}
   182  
   183  		srv, err := mdns.NewServer(&mdns.Config{Zone: &mdns.DNSSDService{MDNSService: s}})
   184  		if err != nil {
   185  			return err
   186  		}
   187  
   188  		// append the wildcard entry
   189  		entries = append(entries, &mdnsEntry{id: "*", node: srv})
   190  	}
   191  
   192  	var gerr error
   193  
   194  	for _, node := range service.Nodes {
   195  		var seen bool
   196  		var e *mdnsEntry
   197  
   198  		for _, entry := range entries {
   199  			if node.Id == entry.id {
   200  				seen = true
   201  				e = entry
   202  				break
   203  			}
   204  		}
   205  
   206  		// already registered, continue
   207  		if seen {
   208  			continue
   209  			// doesn't exist
   210  		} else {
   211  			e = &mdnsEntry{}
   212  		}
   213  
   214  		txt, err := encode(&mdnsTxt{
   215  			Service:   service.Name,
   216  			Version:   service.Version,
   217  			Endpoints: service.Endpoints,
   218  			Metadata:  node.Metadata,
   219  		})
   220  
   221  		if err != nil {
   222  			gerr = err
   223  			continue
   224  		}
   225  
   226  		host, pt, err := net.SplitHostPort(node.Address)
   227  		if err != nil {
   228  			gerr = err
   229  			continue
   230  		}
   231  		port, _ := strconv.Atoi(pt)
   232  
   233  		//if logger.GetLevel()=={
   234  		log.Dbgf("[mdns] registry create new service with ip: %s for: %s", net.ParseIP(host).String(), host)
   235  		//}
   236  		// we got here, new node
   237  		s, err := mdns.NewMDNSService(
   238  			node.Id,
   239  			service.Name,
   240  			m.domain+".",
   241  			"",
   242  			port,
   243  			[]net.IP{net.ParseIP(host)},
   244  			txt,
   245  		)
   246  		if err != nil {
   247  			gerr = err
   248  			continue
   249  		}
   250  
   251  		srv, err := mdns.NewServer(&mdns.Config{Zone: s, LocalhostChecking: true})
   252  		if err != nil {
   253  			gerr = err
   254  			continue
   255  		}
   256  
   257  		e.id = node.Id
   258  		e.node = srv
   259  		entries = append(entries, e)
   260  	}
   261  
   262  	if gerr == nil {
   263  		// save
   264  		m.services[service.Name] = entries
   265  		m.config.LocalServices = append(m.config.LocalServices, service)
   266  	}
   267  
   268  	return gerr
   269  }
   270  
   271  func (m *mdnsRegistry) Deregister(service *registry.Service, opts ...registry.Option) error {
   272  	m.Lock()
   273  	defer m.Unlock()
   274  
   275  	var newEntries []*mdnsEntry
   276  
   277  	// loop existing entries, check if any match, shutdown those that do
   278  	for _, entry := range m.services[service.Name] {
   279  		var remove bool
   280  
   281  		for _, node := range service.Nodes {
   282  			if node.Id == entry.id {
   283  				entry.node.Shutdown()
   284  				remove = true
   285  				break
   286  			}
   287  		}
   288  
   289  		// keep it?
   290  		if !remove {
   291  			newEntries = append(newEntries, entry)
   292  		}
   293  	}
   294  
   295  	// last entry is the wildcard for list queries. Remove it.
   296  	if len(newEntries) == 1 && newEntries[0].id == "*" {
   297  		newEntries[0].node.Shutdown()
   298  		delete(m.services, service.Name)
   299  	} else {
   300  		m.services[service.Name] = newEntries
   301  	}
   302  
   303  	return nil
   304  }
   305  
   306  func (m *mdnsRegistry) LocalServices() []*registry.Service {
   307  	return m.config.LocalServices
   308  }
   309  
   310  func (m *mdnsRegistry) GetService(service string) ([]*registry.Service, error) {
   311  	serviceMap := make(map[string]*registry.Service)
   312  	entries := make(chan *mdns.ServiceEntry, 10)
   313  	done := make(chan bool)
   314  
   315  	p := mdns.DefaultParams(service)
   316  	// set context with timeout
   317  	var cancel context.CancelFunc
   318  	p.Context, cancel = context.WithTimeout(context.Background(), m.config.Timeout)
   319  	defer cancel()
   320  	// set entries channel
   321  	p.Entries = entries
   322  	// set the domain
   323  	p.Domain = m.domain
   324  
   325  	go func() {
   326  		for {
   327  			select {
   328  			case e := <-entries:
   329  				// list record so skip
   330  				if p.Service == "_services" {
   331  					continue
   332  				}
   333  				if p.Domain != m.domain {
   334  					continue
   335  				}
   336  				if e.TTL == 0 {
   337  					continue
   338  				}
   339  
   340  				txt, err := decode(e.InfoFields)
   341  				if err != nil {
   342  					continue
   343  				}
   344  
   345  				if txt.Service != service {
   346  					continue
   347  				}
   348  
   349  				s, ok := serviceMap[txt.Version]
   350  				if !ok {
   351  					s = &registry.Service{
   352  						Name:      txt.Service,
   353  						Version:   txt.Version,
   354  						Endpoints: txt.Endpoints,
   355  					}
   356  				}
   357  				addr := ""
   358  				// prefer ipv4 addrs
   359  				if len(e.AddrV4) > 0 {
   360  					addr = e.AddrV4.String()
   361  					// else use ipv6
   362  				} else if len(e.AddrV6) > 0 {
   363  					addr = "[" + e.AddrV6.String() + "]"
   364  				} else {
   365  					//if logger.V(logger.InfoLevel, logger.DefaultLogger) {
   366  					log.Infof("[mdns]: invalid endpoint received: %v", e)
   367  					//}
   368  					continue
   369  				}
   370  				s.Nodes = append(s.Nodes, &registry.Node{
   371  					Id:       strings.TrimSuffix(e.Name, "."+p.Service+"."+p.Domain+"."),
   372  					Address:  fmt.Sprintf("%s:%d", addr, e.Port),
   373  					Metadata: txt.Metadata,
   374  				})
   375  
   376  				serviceMap[txt.Version] = s
   377  			case <-p.Context.Done():
   378  				close(done)
   379  				return
   380  			}
   381  		}
   382  	}()
   383  
   384  	// execute the query
   385  	if err := mdns.Query(p); err != nil {
   386  		return nil, err
   387  	}
   388  
   389  	// wait for completion
   390  	<-done
   391  
   392  	// create list and return
   393  	services := make([]*registry.Service, 0, len(serviceMap))
   394  
   395  	for _, service := range serviceMap {
   396  		services = append(services, service)
   397  	}
   398  
   399  	return services, nil
   400  }
   401  
   402  func (m *mdnsRegistry) ListServices() ([]*registry.Service, error) {
   403  	serviceMap := make(map[string]bool)
   404  	entries := make(chan *mdns.ServiceEntry, 10)
   405  	done := make(chan bool)
   406  
   407  	p := mdns.DefaultParams("_services")
   408  	// set context with timeout
   409  	var cancel context.CancelFunc
   410  	p.Context, cancel = context.WithTimeout(context.Background(), m.config.Timeout)
   411  	defer cancel()
   412  	// set entries channel
   413  	p.Entries = entries
   414  	// set domain
   415  	p.Domain = m.domain
   416  
   417  	var services []*registry.Service
   418  
   419  	go func() {
   420  		for {
   421  			select {
   422  			case e := <-entries:
   423  				if e.TTL == 0 {
   424  					continue
   425  				}
   426  				if !strings.HasSuffix(e.Name, p.Domain+".") {
   427  					continue
   428  				}
   429  				name := strings.TrimSuffix(e.Name, "."+p.Service+"."+p.Domain+".")
   430  				if !serviceMap[name] {
   431  					serviceMap[name] = true
   432  					services = append(services, &registry.Service{Name: name})
   433  				}
   434  			case <-p.Context.Done():
   435  				close(done)
   436  				return
   437  			}
   438  		}
   439  	}()
   440  
   441  	// execute query
   442  	if err := mdns.Query(p); err != nil {
   443  		return nil, err
   444  	}
   445  
   446  	// wait till done
   447  	<-done
   448  
   449  	return services, nil
   450  }
   451  
   452  func (m *mdnsRegistry) Watcher(opts ...registry.WatchOptions) (registry.Watcher, error) {
   453  	wcfg := &registry.WatchConfig{}
   454  	for _, opt := range opts {
   455  		opt(wcfg)
   456  	}
   457  
   458  	md := &mdnsWatcher{
   459  		id:       uuid.New().String(),
   460  		wo:       wcfg,
   461  		ch:       make(chan *mdns.ServiceEntry, 32),
   462  		exit:     make(chan struct{}),
   463  		domain:   m.domain,
   464  		registry: m,
   465  	}
   466  
   467  	m.mtx.Lock()
   468  	defer m.mtx.Unlock()
   469  
   470  	// save the watcher
   471  	m.watchers[md.id] = md
   472  
   473  	// check of the listener exists
   474  	if m.listener != nil {
   475  		return md, nil
   476  	}
   477  
   478  	// start the listener
   479  	go func() {
   480  		// go to infinity
   481  		for {
   482  			m.mtx.Lock()
   483  
   484  			// just return if there are no watchers
   485  			if len(m.watchers) == 0 {
   486  				m.listener = nil
   487  				m.mtx.Unlock()
   488  				return
   489  			}
   490  
   491  			// check existing listener
   492  			if m.listener != nil {
   493  				m.mtx.Unlock()
   494  				return
   495  			}
   496  
   497  			// reset the listener
   498  			exit := make(chan struct{})
   499  			ch := make(chan *mdns.ServiceEntry, 32)
   500  			m.listener = ch
   501  
   502  			m.mtx.Unlock()
   503  
   504  			// send messages to the watchers
   505  			go func() {
   506  				send := func(w *mdnsWatcher, e *mdns.ServiceEntry) {
   507  					select {
   508  					case w.ch <- e:
   509  					default:
   510  					}
   511  				}
   512  
   513  				for {
   514  					select {
   515  					case <-exit:
   516  						return
   517  					case e, ok := <-ch:
   518  						if !ok {
   519  							return
   520  						}
   521  						m.mtx.RLock()
   522  						// send service entry to all watchers
   523  						for _, w := range m.watchers {
   524  							send(w, e)
   525  						}
   526  						m.mtx.RUnlock()
   527  					}
   528  				}
   529  
   530  			}()
   531  
   532  			// start listening, blocking call
   533  			mdns.Listen(ch, exit)
   534  
   535  			// mdns.Listen has unblocked
   536  			// kill the saved listener
   537  			m.mtx.Lock()
   538  			m.listener = nil
   539  			close(ch)
   540  			m.mtx.Unlock()
   541  		}
   542  	}()
   543  
   544  	return md, nil
   545  }
   546  
   547  func (m *mdnsRegistry) String() string {
   548  	return m.config.Name
   549  }
   550  
   551  func (m *mdnsWatcher) Next() (*registry.Result, error) {
   552  	for {
   553  		select {
   554  		case e := <-m.ch:
   555  			txt, err := decode(e.InfoFields)
   556  			if err != nil {
   557  				continue
   558  			}
   559  
   560  			if len(txt.Service) == 0 || len(txt.Version) == 0 {
   561  				continue
   562  			}
   563  
   564  			// Filter watch options
   565  			// wo.Service: Only keep services we care about
   566  			if len(m.wo.Service) > 0 && txt.Service != m.wo.Service {
   567  				continue
   568  			}
   569  			var action string
   570  			if e.TTL == 0 {
   571  				action = "delete"
   572  			} else {
   573  				action = "create"
   574  			}
   575  
   576  			service := &registry.Service{
   577  				Name:      txt.Service,
   578  				Version:   txt.Version,
   579  				Endpoints: txt.Endpoints,
   580  			}
   581  
   582  			// skip anything without the domain we care about
   583  			suffix := fmt.Sprintf(".%s.%s.", service.Name, m.domain)
   584  			if !strings.HasSuffix(e.Name, suffix) {
   585  				continue
   586  			}
   587  
   588  			var addr string
   589  			if len(e.AddrV4) > 0 {
   590  				addr = e.AddrV4.String()
   591  			} else if len(e.AddrV6) > 0 {
   592  				addr = "[" + e.AddrV6.String() + "]"
   593  			} else {
   594  				addr = e.Addr.String()
   595  			}
   596  
   597  			service.Nodes = append(service.Nodes, &registry.Node{
   598  				Id:       strings.TrimSuffix(e.Name, suffix),
   599  				Address:  fmt.Sprintf("%s:%d", addr, e.Port),
   600  				Metadata: txt.Metadata,
   601  			})
   602  
   603  			return &registry.Result{
   604  				Action:  action,
   605  				Service: service,
   606  			}, nil
   607  		case <-m.exit:
   608  			return nil, registry.ErrWatcherStopped
   609  		}
   610  	}
   611  }
   612  
   613  func (m *mdnsWatcher) Stop() {
   614  	select {
   615  	case <-m.exit:
   616  		return
   617  	default:
   618  		close(m.exit)
   619  		// remove self from the registry
   620  		m.registry.mtx.Lock()
   621  		delete(m.registry.watchers, m.id)
   622  		m.registry.mtx.Unlock()
   623  	}
   624  }