github.com/btccom/go-micro/v2@v2.9.3/api/router/registry/registry.go (about)

     1  // Package registry provides a dynamic api service router
     2  package registry
     3  
     4  import (
     5  	"errors"
     6  	"fmt"
     7  	"net/http"
     8  	"regexp"
     9  	"strings"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/btccom/go-micro/v2/api"
    14  	"github.com/btccom/go-micro/v2/api/router"
    15  	"github.com/btccom/go-micro/v2/api/router/util"
    16  	"github.com/btccom/go-micro/v2/logger"
    17  	"github.com/btccom/go-micro/v2/metadata"
    18  	"github.com/btccom/go-micro/v2/registry"
    19  	"github.com/btccom/go-micro/v2/registry/cache"
    20  )
    21  
    22  // endpoint struct, that holds compiled pcre
    23  type endpoint struct {
    24  	hostregs []*regexp.Regexp
    25  	pathregs []util.Pattern
    26  	pcreregs []*regexp.Regexp
    27  }
    28  
    29  // router is the default router
    30  type registryRouter struct {
    31  	exit chan bool
    32  	opts router.Options
    33  
    34  	// registry cache
    35  	rc cache.Cache
    36  
    37  	sync.RWMutex
    38  	eps map[string]*api.Service
    39  	// compiled regexp for host and path
    40  	ceps map[string]*endpoint
    41  }
    42  
    43  func (r *registryRouter) isClosed() bool {
    44  	select {
    45  	case <-r.exit:
    46  		return true
    47  	default:
    48  		return false
    49  	}
    50  }
    51  
    52  // refresh list of api services
    53  func (r *registryRouter) refresh() {
    54  	var attempts int
    55  
    56  	for {
    57  		services, err := r.opts.Registry.ListServices()
    58  		if err != nil {
    59  			attempts++
    60  			if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
    61  				logger.Errorf("unable to list services: %v", err)
    62  			}
    63  			time.Sleep(time.Duration(attempts) * time.Second)
    64  			continue
    65  		}
    66  
    67  		attempts = 0
    68  
    69  		// for each service, get service and store endpoints
    70  		for _, s := range services {
    71  			service, err := r.rc.GetService(s.Name)
    72  			if err != nil {
    73  				if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
    74  					logger.Errorf("unable to get service: %v", err)
    75  				}
    76  				continue
    77  			}
    78  			r.store(service)
    79  		}
    80  
    81  		// refresh list in 10 minutes... cruft
    82  		// use registry watching
    83  		select {
    84  		case <-time.After(time.Minute * 10):
    85  		case <-r.exit:
    86  			return
    87  		}
    88  	}
    89  }
    90  
    91  // process watch event
    92  func (r *registryRouter) process(res *registry.Result) {
    93  	// skip these things
    94  	if res == nil || res.Service == nil {
    95  		return
    96  	}
    97  
    98  	// get entry from cache
    99  	service, err := r.rc.GetService(res.Service.Name)
   100  	if err != nil {
   101  		if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   102  			logger.Errorf("unable to get service: %v", err)
   103  		}
   104  		return
   105  	}
   106  
   107  	// update our local endpoints
   108  	r.store(service)
   109  }
   110  
   111  // store local endpoint cache
   112  func (r *registryRouter) store(services []*registry.Service) {
   113  	// endpoints
   114  	eps := map[string]*api.Service{}
   115  
   116  	// services
   117  	names := map[string]bool{}
   118  
   119  	// create a new endpoint mapping
   120  	for _, service := range services {
   121  		// set names we need later
   122  		names[service.Name] = true
   123  
   124  		// map per endpoint
   125  		for _, sep := range service.Endpoints {
   126  			// create a key service:endpoint_name
   127  			key := fmt.Sprintf("%s.%s", service.Name, sep.Name)
   128  			// decode endpoint
   129  			end := api.Decode(sep.Metadata)
   130  
   131  			// if we got nothing skip
   132  			if err := api.Validate(end); err != nil {
   133  				if logger.V(logger.TraceLevel, logger.DefaultLogger) {
   134  					logger.Tracef("endpoint validation failed: %v", err)
   135  				}
   136  				continue
   137  			}
   138  
   139  			// try get endpoint
   140  			ep, ok := eps[key]
   141  			if !ok {
   142  				ep = &api.Service{Name: service.Name}
   143  			}
   144  
   145  			// overwrite the endpoint
   146  			ep.Endpoint = end
   147  			// append services
   148  			ep.Services = append(ep.Services, service)
   149  			// store it
   150  			eps[key] = ep
   151  		}
   152  	}
   153  
   154  	r.Lock()
   155  	defer r.Unlock()
   156  
   157  	// delete any existing eps for services we know
   158  	for key, service := range r.eps {
   159  		// skip what we don't care about
   160  		if !names[service.Name] {
   161  			continue
   162  		}
   163  
   164  		// ok we know this thing
   165  		// delete delete delete
   166  		delete(r.eps, key)
   167  	}
   168  
   169  	// now set the eps we have
   170  	for name, ep := range eps {
   171  		r.eps[name] = ep
   172  		cep := &endpoint{}
   173  
   174  		for _, h := range ep.Endpoint.Host {
   175  			if h == "" || h == "*" {
   176  				continue
   177  			}
   178  			hostreg, err := regexp.CompilePOSIX(h)
   179  			if err != nil {
   180  				if logger.V(logger.TraceLevel, logger.DefaultLogger) {
   181  					logger.Tracef("endpoint have invalid host regexp: %v", err)
   182  				}
   183  				continue
   184  			}
   185  			cep.hostregs = append(cep.hostregs, hostreg)
   186  		}
   187  
   188  		for _, p := range ep.Endpoint.Path {
   189  			var pcreok bool
   190  
   191  			if p[0] == '^' && p[len(p)-1] == '$' {
   192  				pcrereg, err := regexp.CompilePOSIX(p)
   193  				if err == nil {
   194  					cep.pcreregs = append(cep.pcreregs, pcrereg)
   195  					pcreok = true
   196  				}
   197  			}
   198  
   199  			rule, err := util.Parse(p)
   200  			if err != nil && !pcreok {
   201  				if logger.V(logger.TraceLevel, logger.DefaultLogger) {
   202  					logger.Tracef("endpoint have invalid path pattern: %v", err)
   203  				}
   204  				continue
   205  			} else if err != nil && pcreok {
   206  				continue
   207  			}
   208  
   209  			tpl := rule.Compile()
   210  			pathreg, err := util.NewPattern(tpl.Version, tpl.OpCodes, tpl.Pool, "")
   211  			if err != nil {
   212  				if logger.V(logger.TraceLevel, logger.DefaultLogger) {
   213  					logger.Tracef("endpoint have invalid path pattern: %v", err)
   214  				}
   215  				continue
   216  			}
   217  			cep.pathregs = append(cep.pathregs, pathreg)
   218  		}
   219  
   220  		r.ceps[name] = cep
   221  	}
   222  }
   223  
   224  // watch for endpoint changes
   225  func (r *registryRouter) watch() {
   226  	var attempts int
   227  
   228  	for {
   229  		if r.isClosed() {
   230  			return
   231  		}
   232  
   233  		// watch for changes
   234  		w, err := r.opts.Registry.Watch()
   235  		if err != nil {
   236  			attempts++
   237  			if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   238  				logger.Errorf("error watching endpoints: %v", err)
   239  			}
   240  			time.Sleep(time.Duration(attempts) * time.Second)
   241  			continue
   242  		}
   243  
   244  		ch := make(chan bool)
   245  
   246  		go func() {
   247  			select {
   248  			case <-ch:
   249  				w.Stop()
   250  			case <-r.exit:
   251  				w.Stop()
   252  			}
   253  		}()
   254  
   255  		// reset if we get here
   256  		attempts = 0
   257  
   258  		for {
   259  			// process next event
   260  			res, err := w.Next()
   261  			if err != nil {
   262  				if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   263  					logger.Errorf("error getting next endoint: %v", err)
   264  				}
   265  				close(ch)
   266  				break
   267  			}
   268  			r.process(res)
   269  		}
   270  	}
   271  }
   272  
   273  func (r *registryRouter) Options() router.Options {
   274  	return r.opts
   275  }
   276  
   277  func (r *registryRouter) Close() error {
   278  	select {
   279  	case <-r.exit:
   280  		return nil
   281  	default:
   282  		close(r.exit)
   283  		r.rc.Stop()
   284  	}
   285  	return nil
   286  }
   287  
   288  func (r *registryRouter) Register(ep *api.Endpoint) error {
   289  	return nil
   290  }
   291  
   292  func (r *registryRouter) Deregister(ep *api.Endpoint) error {
   293  	return nil
   294  }
   295  
   296  func (r *registryRouter) Endpoint(req *http.Request) (*api.Service, error) {
   297  	if r.isClosed() {
   298  		return nil, errors.New("router closed")
   299  	}
   300  
   301  	r.RLock()
   302  	defer r.RUnlock()
   303  
   304  	var idx int
   305  	if len(req.URL.Path) > 0 && req.URL.Path != "/" {
   306  		idx = 1
   307  	}
   308  	path := strings.Split(req.URL.Path[idx:], "/")
   309  
   310  	// use the first match
   311  	// TODO: weighted matching
   312  	for n, e := range r.eps {
   313  		cep, ok := r.ceps[n]
   314  		if !ok {
   315  			continue
   316  		}
   317  		ep := e.Endpoint
   318  		var mMatch, hMatch, pMatch bool
   319  		// 1. try method
   320  		for _, m := range ep.Method {
   321  			if m == req.Method {
   322  				mMatch = true
   323  				break
   324  			}
   325  		}
   326  		if !mMatch {
   327  			continue
   328  		}
   329  		if logger.V(logger.DebugLevel, logger.DefaultLogger) {
   330  			logger.Debugf("api method match %s", req.Method)
   331  		}
   332  
   333  		// 2. try host
   334  		if len(ep.Host) == 0 {
   335  			hMatch = true
   336  		} else {
   337  			for idx, h := range ep.Host {
   338  				if h == "" || h == "*" {
   339  					hMatch = true
   340  					break
   341  				} else {
   342  					if cep.hostregs[idx].MatchString(req.URL.Host) {
   343  						hMatch = true
   344  						break
   345  					}
   346  				}
   347  			}
   348  		}
   349  		if !hMatch {
   350  			continue
   351  		}
   352  		if logger.V(logger.DebugLevel, logger.DefaultLogger) {
   353  			logger.Debugf("api host match %s", req.URL.Host)
   354  		}
   355  
   356  		// 3. try path via google.api path matching
   357  		for _, pathreg := range cep.pathregs {
   358  			matches, err := pathreg.Match(path, "")
   359  			if err != nil {
   360  				if logger.V(logger.DebugLevel, logger.DefaultLogger) {
   361  					logger.Debugf("api gpath not match %s != %v", path, pathreg)
   362  				}
   363  				continue
   364  			}
   365  			if logger.V(logger.DebugLevel, logger.DefaultLogger) {
   366  				logger.Debugf("api gpath match %s = %v", path, pathreg)
   367  			}
   368  			pMatch = true
   369  			ctx := req.Context()
   370  			md, ok := metadata.FromContext(ctx)
   371  			if !ok {
   372  				md = make(metadata.Metadata)
   373  			}
   374  			for k, v := range matches {
   375  				md[fmt.Sprintf("x-api-field-%s", k)] = v
   376  			}
   377  			md["x-api-body"] = ep.Body
   378  			*req = *req.Clone(metadata.NewContext(ctx, md))
   379  			break
   380  		}
   381  
   382  		if !pMatch {
   383  			// 4. try path via pcre path matching
   384  			for _, pathreg := range cep.pcreregs {
   385  				if !pathreg.MatchString(req.URL.Path) {
   386  					if logger.V(logger.DebugLevel, logger.DefaultLogger) {
   387  						logger.Debugf("api pcre path not match %s != %v", path, pathreg)
   388  					}
   389  					continue
   390  				}
   391  				if logger.V(logger.DebugLevel, logger.DefaultLogger) {
   392  					logger.Debugf("api pcre path match %s != %v", path, pathreg)
   393  				}
   394  				pMatch = true
   395  				break
   396  			}
   397  		}
   398  
   399  		if !pMatch {
   400  			continue
   401  		}
   402  
   403  		// TODO: Percentage traffic
   404  		// we got here, so its a match
   405  		return e, nil
   406  	}
   407  
   408  	// no match
   409  	return nil, errors.New("not found")
   410  }
   411  
   412  func (r *registryRouter) Route(req *http.Request) (*api.Service, error) {
   413  	if r.isClosed() {
   414  		return nil, errors.New("router closed")
   415  	}
   416  
   417  	// try get an endpoint
   418  	ep, err := r.Endpoint(req)
   419  	if err == nil {
   420  		return ep, nil
   421  	}
   422  
   423  	// error not nil
   424  	// ignore that shit
   425  	// TODO: don't ignore that shit
   426  
   427  	// get the service name
   428  	rp, err := r.opts.Resolver.Resolve(req)
   429  	if err != nil {
   430  		return nil, err
   431  	}
   432  
   433  	// service name
   434  	name := rp.Name
   435  
   436  	// get service
   437  	services, err := r.rc.GetService(name)
   438  	if err != nil {
   439  		return nil, err
   440  	}
   441  
   442  	// only use endpoint matching when the meta handler is set aka api.Default
   443  	switch r.opts.Handler {
   444  	// rpc handlers
   445  	case "meta", "api", "rpc":
   446  		handler := r.opts.Handler
   447  
   448  		// set default handler to api
   449  		if r.opts.Handler == "meta" {
   450  			handler = "rpc"
   451  		}
   452  
   453  		// construct api service
   454  		return &api.Service{
   455  			Name: name,
   456  			Endpoint: &api.Endpoint{
   457  				Name:    rp.Method,
   458  				Handler: handler,
   459  			},
   460  			Services: services,
   461  		}, nil
   462  	// http handler
   463  	case "http", "proxy", "web":
   464  		// construct api service
   465  		return &api.Service{
   466  			Name: name,
   467  			Endpoint: &api.Endpoint{
   468  				Name:    req.URL.String(),
   469  				Handler: r.opts.Handler,
   470  				Host:    []string{req.Host},
   471  				Method:  []string{req.Method},
   472  				Path:    []string{req.URL.Path},
   473  			},
   474  			Services: services,
   475  		}, nil
   476  	}
   477  
   478  	return nil, errors.New("unknown handler")
   479  }
   480  
   481  func newRouter(opts ...router.Option) *registryRouter {
   482  	options := router.NewOptions(opts...)
   483  	r := &registryRouter{
   484  		exit: make(chan bool),
   485  		opts: options,
   486  		rc:   cache.New(options.Registry),
   487  		eps:  make(map[string]*api.Service),
   488  		ceps: make(map[string]*endpoint),
   489  	}
   490  	go r.watch()
   491  	go r.refresh()
   492  	return r
   493  }
   494  
   495  // NewRouter returns the default router
   496  func NewRouter(opts ...router.Option) router.Router {
   497  	return newRouter(opts...)
   498  }