github.com/tickoalcantara12/micro/v3@v3.0.0-20221007104245-9d75b9bcbab9/service/api/router/registry/registry.go (about)

     1  // Copyright 2020 Asim Aslam
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  // Original source: github.com/micro/go-micro/v3/api/router/registry/registry.go
    16  
    17  // Package registry provides a dynamic api service router
    18  package registry
    19  
    20  import (
    21  	"errors"
    22  	"fmt"
    23  	"net/http"
    24  	"regexp"
    25  	"strings"
    26  	"sync"
    27  	"time"
    28  
    29  	"github.com/tickoalcantara12/micro/v3/service/api"
    30  	"github.com/tickoalcantara12/micro/v3/service/api/router"
    31  	"github.com/tickoalcantara12/micro/v3/service/context/metadata"
    32  	"github.com/tickoalcantara12/micro/v3/service/logger"
    33  	"github.com/tickoalcantara12/micro/v3/service/registry"
    34  	"github.com/tickoalcantara12/micro/v3/service/registry/cache"
    35  	"github.com/tickoalcantara12/micro/v3/util/namespace"
    36  	util "github.com/tickoalcantara12/micro/v3/util/router"
    37  )
    38  
    39  var (
    40  	errEmptyNamespace = errors.New("namespace is empty")
    41  	errNotFound       = errors.New("not found")
    42  )
    43  
    44  // endpoint struct, that holds compiled pcre
    45  type endpoint struct {
    46  	hostregs []*regexp.Regexp
    47  	pathregs []util.Pattern
    48  	pcreregs []*regexp.Regexp
    49  }
    50  
    51  // namespaceEntry holds the services and endpoint regexs for a namespace
    52  type namespaceEntry struct {
    53  	sync.RWMutex
    54  	eps map[string]*api.Service
    55  	// compiled regexp for host and path
    56  	ceps map[string]*endpoint
    57  }
    58  
    59  // router is the default router
    60  type registryRouter struct {
    61  	exit chan bool
    62  	opts router.Options
    63  
    64  	// registry cache
    65  	rc cache.Cache
    66  
    67  	// refresh channel
    68  	refreshChan chan string
    69  
    70  	sync.RWMutex
    71  	namespaces map[string]*namespaceEntry
    72  }
    73  
    74  func getDomain(srv *registry.Service) string {
    75  	// check the service metadata for domain
    76  	// TODO: domain as Domain field in registry?
    77  	if srv.Metadata != nil && len(srv.Metadata["domain"]) > 0 {
    78  		return srv.Metadata["domain"]
    79  	} else if nodes := srv.Nodes; len(nodes) > 0 && nodes[0].Metadata != nil {
    80  		// only return the domain if its set
    81  		if len(nodes[0].Metadata["domain"]) > 0 {
    82  			return nodes[0].Metadata["domain"]
    83  		}
    84  	}
    85  
    86  	// otherwise return wildcard
    87  	// TODO: return GlobalDomain or PublicDomain
    88  	return registry.DefaultDomain
    89  }
    90  
    91  func (r *registryRouter) isClosed() bool {
    92  	select {
    93  	case <-r.exit:
    94  		return true
    95  	default:
    96  		return false
    97  	}
    98  }
    99  
   100  // refreshNamespace refreshes the list of api services in the given namespace
   101  func (r *registryRouter) refreshNamespace(ns string) error {
   102  	services, err := r.opts.Registry.ListServices(registry.ListDomain(ns))
   103  	if err != nil {
   104  		if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   105  			logger.Errorf("unable to list services: %v", err)
   106  		}
   107  		return err
   108  	}
   109  	if len(services) == 0 {
   110  		return errEmptyNamespace
   111  	}
   112  
   113  	// for each service, get service and store endpoints
   114  	for _, s := range services {
   115  		// if we have nodes then use them
   116  		dns := getDomain(s)
   117  		if len(s.Nodes) > 0 && len(dns) > 0 {
   118  			r.store(dns, []*registry.Service{s})
   119  			continue
   120  		}
   121  
   122  		service, err := r.rc.GetService(s.Name, registry.GetDomain(ns))
   123  		if err != nil {
   124  			if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   125  				logger.Errorf("unable to get service: %v", err)
   126  			}
   127  			continue
   128  		}
   129  
   130  		// store each independently as if we have a wildcard query
   131  		// the domain for each service may differ
   132  		for _, srv := range service {
   133  			// get the namespace from the service
   134  			ns = getDomain(srv)
   135  			r.store(ns, []*registry.Service{srv})
   136  		}
   137  	}
   138  
   139  	return nil
   140  }
   141  
   142  // refresh list of api services
   143  func (r *registryRouter) refresh() {
   144  	refreshed := make(map[string]time.Time)
   145  
   146  	// do first load
   147  	r.refreshNamespace(registry.WildcardDomain)
   148  
   149  	for {
   150  		r.RLock()
   151  		namespaces := r.namespaces
   152  		r.RUnlock()
   153  
   154  		for ns, _ := range namespaces {
   155  			err := r.refreshNamespace(ns)
   156  			if err == errEmptyNamespace {
   157  				r.Lock()
   158  				delete(namespaces, ns)
   159  				r.Unlock()
   160  			}
   161  		}
   162  
   163  		// refresh the list every minute
   164  		// TODO: rely solely on watcher
   165  		select {
   166  		case domain := <-r.refreshChan:
   167  			v, ok := refreshed[domain]
   168  			if ok && time.Since(v) < time.Minute {
   169  				break
   170  			}
   171  			r.refreshNamespace(domain)
   172  		case <-time.After(time.Minute):
   173  		case <-r.exit:
   174  			return
   175  		}
   176  	}
   177  }
   178  
   179  // process watch event
   180  func (r *registryRouter) process(res *registry.Result) {
   181  	// skip these things
   182  	if res == nil || res.Service == nil {
   183  		return
   184  	}
   185  
   186  	// get entry from cache
   187  	// only deals with default namespace
   188  	service, err := r.rc.GetService(res.Service.Name)
   189  	if err != nil {
   190  		if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   191  			logger.Errorf("unable to get %v service: %v", res.Service.Name, err)
   192  		}
   193  		return
   194  	}
   195  
   196  	// only process if there's data
   197  	if len(service) == 0 {
   198  		return
   199  	}
   200  
   201  	// get te namespace
   202  	namespace := getDomain(service[0])
   203  
   204  	// update our local endpoints
   205  	r.store(namespace, service)
   206  }
   207  
   208  // store local endpoint cache
   209  func (r *registryRouter) store(namespace string, services []*registry.Service) {
   210  	// endpoints
   211  	eps := map[string]*api.Service{}
   212  
   213  	// services
   214  	names := map[string]bool{}
   215  
   216  	// create a new endpoint mapping
   217  	for _, service := range services {
   218  		// set names we need later
   219  		names[service.Name] = true
   220  
   221  		// map per endpoint
   222  		for _, sep := range service.Endpoints {
   223  			// create a key service:endpoint_name
   224  			key := fmt.Sprintf("%s.%s", service.Name, sep.Name)
   225  			// decode endpoint
   226  			end := api.Decode(sep.Metadata)
   227  			// no endpoint or no name
   228  			if end == nil || len(end.Name) == 0 {
   229  				continue
   230  			}
   231  			// if we got nothing skip
   232  			if err := api.Validate(end); err != nil {
   233  				if logger.V(logger.TraceLevel, logger.DefaultLogger) {
   234  					logger.Tracef("endpoint validation failed: %v", err)
   235  				}
   236  				continue
   237  			}
   238  
   239  			// try get endpoint
   240  			ep, ok := eps[key]
   241  			if !ok {
   242  				ep = &api.Service{Name: service.Name}
   243  			}
   244  
   245  			// overwrite the endpoint
   246  			ep.Endpoint = end
   247  			// append services
   248  			ep.Services = append(ep.Services, service)
   249  			// store it
   250  			eps[key] = ep
   251  		}
   252  	}
   253  
   254  	r.Lock()
   255  	nse, ok := r.namespaces[namespace]
   256  	if !ok {
   257  		nse = &namespaceEntry{
   258  			eps:  map[string]*api.Service{},
   259  			ceps: map[string]*endpoint{},
   260  		}
   261  		r.namespaces[namespace] = nse
   262  	}
   263  	r.Unlock()
   264  
   265  	nse.Lock()
   266  	defer nse.Unlock()
   267  
   268  	// delete any existing eps for services we know
   269  	for key, service := range nse.eps {
   270  		// skip what we don't care about
   271  		if !names[service.Name] {
   272  			continue
   273  		}
   274  
   275  		// ok we know this thing
   276  		// delete delete delete
   277  		delete(nse.eps, key)
   278  	}
   279  
   280  	// now set the eps we have
   281  	for name, ep := range eps {
   282  		nse.eps[name] = ep
   283  		cep := &endpoint{}
   284  
   285  		for _, h := range ep.Endpoint.Host {
   286  			if h == "" || h == "*" {
   287  				continue
   288  			}
   289  			hostreg, err := regexp.CompilePOSIX(h)
   290  			if err != nil {
   291  				if logger.V(logger.TraceLevel, logger.DefaultLogger) {
   292  					logger.Tracef("endpoint have invalid host regexp: %v", err)
   293  				}
   294  				continue
   295  			}
   296  			cep.hostregs = append(cep.hostregs, hostreg)
   297  		}
   298  
   299  		for _, p := range ep.Endpoint.Path {
   300  			var pcreok bool
   301  
   302  			if p[0] == '^' && p[len(p)-1] == '$' {
   303  				pcrereg, err := regexp.CompilePOSIX(p)
   304  				if err == nil {
   305  					cep.pcreregs = append(cep.pcreregs, pcrereg)
   306  					pcreok = true
   307  				}
   308  			}
   309  
   310  			rule, err := util.Parse(p)
   311  			if err != nil && !pcreok {
   312  				if logger.V(logger.TraceLevel, logger.DefaultLogger) {
   313  					logger.Tracef("endpoint have invalid path pattern: %v", err)
   314  				}
   315  				continue
   316  			} else if err != nil && pcreok {
   317  				continue
   318  			}
   319  
   320  			tpl := rule.Compile()
   321  			pathreg, err := util.NewPattern(tpl.Version, tpl.OpCodes, tpl.Pool, "")
   322  			if err != nil {
   323  				if logger.V(logger.TraceLevel, logger.DefaultLogger) {
   324  					logger.Tracef("endpoint have invalid path pattern: %v", err)
   325  				}
   326  				continue
   327  			}
   328  			cep.pathregs = append(cep.pathregs, pathreg)
   329  		}
   330  
   331  		nse.ceps[name] = cep
   332  	}
   333  }
   334  
   335  // watch for endpoint changes
   336  func (r *registryRouter) watch() {
   337  	var attempts int
   338  
   339  	for {
   340  		if r.isClosed() {
   341  			return
   342  		}
   343  
   344  		// watch for changes
   345  		w, err := r.opts.Registry.Watch(registry.WatchDomain(registry.WildcardDomain))
   346  		if err != nil {
   347  			attempts++
   348  			if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   349  				logger.Errorf("error watching endpoints: %v", err)
   350  			}
   351  			time.Sleep(time.Duration(attempts) * time.Second)
   352  			continue
   353  		}
   354  
   355  		ch := make(chan bool)
   356  
   357  		go func() {
   358  			select {
   359  			case <-ch:
   360  				w.Stop()
   361  			case <-r.exit:
   362  				w.Stop()
   363  			}
   364  		}()
   365  
   366  		// reset if we get here
   367  		attempts = 0
   368  
   369  		for {
   370  			// process next event
   371  			res, err := w.Next()
   372  			if err != nil {
   373  				if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   374  					logger.Errorf("error getting next endpoint: %v", err)
   375  				}
   376  				close(ch)
   377  				break
   378  			}
   379  			r.process(res)
   380  		}
   381  	}
   382  }
   383  
   384  func (r *registryRouter) Options() router.Options {
   385  	return r.opts
   386  }
   387  
   388  func (r *registryRouter) Close() error {
   389  	select {
   390  	case <-r.exit:
   391  		return nil
   392  	default:
   393  		close(r.exit)
   394  		r.rc.Stop()
   395  	}
   396  	return nil
   397  }
   398  
   399  func (r *registryRouter) Register(ep *api.Endpoint) error {
   400  	return nil
   401  }
   402  
   403  func (r *registryRouter) Deregister(ep *api.Endpoint) error {
   404  	return nil
   405  }
   406  
   407  func (r *registryRouter) Endpoint(req *http.Request) (*api.Service, error) {
   408  	if r.isClosed() {
   409  		return nil, errors.New("router closed")
   410  	}
   411  
   412  	var idx int
   413  	if len(req.URL.Path) > 0 && req.URL.Path != "/" {
   414  		idx = 1
   415  	}
   416  	path := strings.Split(req.URL.Path[idx:], "/")
   417  
   418  	// resolve so we can get the namespace
   419  	rp, err := r.opts.Resolver.Resolve(req)
   420  	if err != nil {
   421  		return nil, err
   422  	}
   423  	var ret *api.Service
   424  	r.RLock()
   425  	nse, ok := r.namespaces[rp.Domain]
   426  	r.RUnlock()
   427  	if !ok {
   428  		// no entry in cache
   429  		// TODO should we refresh the cache here?
   430  		return nil, errNotFound
   431  	}
   432  	nse.RLock()
   433  	defer nse.RUnlock()
   434  endpointLoop:
   435  	// loop through all endpoints to find either a path match or a regex match
   436  	// prefer path matches over regexp matches e.g. prefer /foobar over ^/.*$
   437  	// TODO: weighted matching
   438  	for n, e := range nse.eps {
   439  		cep, ok := nse.ceps[n]
   440  		if !ok {
   441  			continue
   442  		}
   443  		ep := e.Endpoint
   444  		var mMatch, hMatch bool
   445  		// 1. try method
   446  		for _, m := range ep.Method {
   447  			if m == req.Method {
   448  				mMatch = true
   449  				break
   450  			}
   451  		}
   452  		if !mMatch {
   453  			continue
   454  		}
   455  		if logger.V(logger.DebugLevel, logger.DefaultLogger) {
   456  			logger.Debugf("api method match %s", req.Method)
   457  		}
   458  
   459  		// 2. try host
   460  		if len(ep.Host) == 0 {
   461  			hMatch = true
   462  		} else {
   463  			for idx, h := range ep.Host {
   464  				if h == "" || h == "*" {
   465  					hMatch = true
   466  					break
   467  				} else {
   468  					if cep.hostregs[idx].MatchString(req.URL.Host) {
   469  						hMatch = true
   470  						break
   471  					}
   472  				}
   473  			}
   474  		}
   475  		if !hMatch {
   476  			continue
   477  		}
   478  		if logger.V(logger.DebugLevel, logger.DefaultLogger) {
   479  			logger.Debugf("api host match %s", req.URL.Host)
   480  		}
   481  
   482  		// 3. try path via google.api path matching
   483  		for _, pathreg := range cep.pathregs {
   484  			matches, err := pathreg.Match(path, "")
   485  			if err != nil {
   486  				if logger.V(logger.DebugLevel, logger.DefaultLogger) {
   487  					logger.Debugf("api gpath not match %s != %v", path, pathreg)
   488  				}
   489  				continue
   490  			}
   491  			if logger.V(logger.DebugLevel, logger.DefaultLogger) {
   492  				logger.Debugf("api gpath match %s = %v", path, pathreg)
   493  			}
   494  			ctx := req.Context()
   495  			md, ok := metadata.FromContext(ctx)
   496  			if !ok {
   497  				md = make(metadata.Metadata)
   498  			}
   499  			for k, v := range matches {
   500  				md[fmt.Sprintf("x-api-field-%s", k)] = v
   501  			}
   502  			md["x-api-body"] = ep.Body
   503  			*req = *req.Clone(metadata.NewContext(ctx, md))
   504  			ret = e
   505  			break endpointLoop
   506  		}
   507  
   508  		// 4. try path via pcre path matching
   509  		for _, pathreg := range cep.pcreregs {
   510  			if !pathreg.MatchString(req.URL.Path) {
   511  				if logger.V(logger.DebugLevel, logger.DefaultLogger) {
   512  					logger.Debugf("api pcre path not match %s != %v", path, pathreg)
   513  				}
   514  				continue
   515  			}
   516  			if logger.V(logger.DebugLevel, logger.DefaultLogger) {
   517  				logger.Debugf("api pcre path match %s != %v", path, pathreg)
   518  			}
   519  			ret = e
   520  			break
   521  		}
   522  
   523  		// TODO: Percentage traffic
   524  	}
   525  	if ret != nil {
   526  		return ret, nil
   527  	}
   528  
   529  	// no match
   530  	return nil, errNotFound
   531  }
   532  
   533  func (r *registryRouter) Route(req *http.Request) (*api.Service, error) {
   534  	if r.isClosed() {
   535  		return nil, errors.New("router closed")
   536  	}
   537  
   538  	// try get an endpoint from cache
   539  	ep, err := r.Endpoint(req)
   540  	if err == nil {
   541  		return ep, nil
   542  	}
   543  
   544  	// error not nil
   545  	// ignore that shit
   546  	// TODO: don't ignore that shit
   547  
   548  	// get the service name
   549  	rp, err := r.opts.Resolver.Resolve(req)
   550  	if err != nil {
   551  		return nil, err
   552  	}
   553  	// service name
   554  	name := rp.Name
   555  
   556  	// trigger an endpoint refresh
   557  	select {
   558  	case r.refreshChan <- rp.Domain:
   559  	default:
   560  	}
   561  
   562  	// get service
   563  	services, err := r.rc.GetService(name, registry.GetDomain(rp.Domain))
   564  	if err != nil {
   565  		return nil, err
   566  	}
   567  
   568  	// only use endpoint matching when the meta handler is set aka api.Default
   569  	switch r.opts.Handler {
   570  	// rpc handlers
   571  	case "meta", "api", "rpc":
   572  		handler := r.opts.Handler
   573  
   574  		// set default handler to api
   575  		if r.opts.Handler == "meta" {
   576  			handler = "rpc"
   577  		}
   578  
   579  		// construct api service
   580  		return &api.Service{
   581  			Name: name,
   582  			Endpoint: &api.Endpoint{
   583  				Name:    rp.Method,
   584  				Handler: handler,
   585  			},
   586  			Services: services,
   587  		}, nil
   588  	// http handler
   589  	case "http", "proxy", "web":
   590  		// construct api service
   591  		return &api.Service{
   592  			Name: name,
   593  			Endpoint: &api.Endpoint{
   594  				Name:    req.URL.String(),
   595  				Handler: r.opts.Handler,
   596  				Host:    []string{req.Host},
   597  				Method:  []string{req.Method},
   598  				Path:    []string{req.URL.Path},
   599  			},
   600  			Services: services,
   601  		}, nil
   602  	}
   603  
   604  	return nil, errors.New("unknown handler")
   605  }
   606  
   607  func newRouter(opts ...router.Option) *registryRouter {
   608  	options := router.NewOptions(opts...)
   609  	r := &registryRouter{
   610  		exit:        make(chan bool),
   611  		refreshChan: make(chan string),
   612  		opts:        options,
   613  		rc:          cache.New(options.Registry),
   614  		namespaces: map[string]*namespaceEntry{
   615  			namespace.DefaultNamespace: &namespaceEntry{
   616  				eps:  make(map[string]*api.Service),
   617  				ceps: make(map[string]*endpoint),
   618  			}},
   619  	}
   620  	go r.watch()
   621  	go r.refresh()
   622  	return r
   623  }
   624  
   625  // NewRouter returns the default router
   626  func NewRouter(opts ...router.Option) router.Router {
   627  	return newRouter(opts...)
   628  }