github.com/btccom/go-micro/v2@v2.9.3/api/router/static/static.go (about)

     1  package static
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"net/http"
     7  	"regexp"
     8  	"strings"
     9  	"sync"
    10  
    11  	"github.com/btccom/go-micro/v2/api"
    12  	"github.com/btccom/go-micro/v2/api/router"
    13  	"github.com/btccom/go-micro/v2/api/router/util"
    14  	"github.com/btccom/go-micro/v2/logger"
    15  	"github.com/btccom/go-micro/v2/metadata"
    16  	"github.com/btccom/go-micro/v2/registry"
    17  	rutil "github.com/btccom/go-micro/v2/util/registry"
    18  )
    19  
    20  type endpoint struct {
    21  	apiep    *api.Endpoint
    22  	hostregs []*regexp.Regexp
    23  	pathregs []util.Pattern
    24  	pcreregs []*regexp.Regexp
    25  }
    26  
    27  // router is the default router
    28  type staticRouter struct {
    29  	exit chan bool
    30  	opts router.Options
    31  	sync.RWMutex
    32  	eps map[string]*endpoint
    33  }
    34  
    35  func (r *staticRouter) isClosed() bool {
    36  	select {
    37  	case <-r.exit:
    38  		return true
    39  	default:
    40  		return false
    41  	}
    42  }
    43  
    44  /*
    45  // watch for endpoint changes
    46  func (r *staticRouter) watch() {
    47  	var attempts int
    48  
    49  	for {
    50  		if r.isClosed() {
    51  			return
    52  		}
    53  
    54  		// watch for changes
    55  		w, err := r.opts.Registry.Watch()
    56  		if err != nil {
    57  			attempts++
    58  			log.Println("Error watching endpoints", err)
    59  			time.Sleep(time.Duration(attempts) * time.Second)
    60  			continue
    61  		}
    62  
    63  		ch := make(chan bool)
    64  
    65  		go func() {
    66  			select {
    67  			case <-ch:
    68  				w.Stop()
    69  			case <-r.exit:
    70  				w.Stop()
    71  			}
    72  		}()
    73  
    74  		// reset if we get here
    75  		attempts = 0
    76  
    77  		for {
    78  			// process next event
    79  			res, err := w.Next()
    80  			if err != nil {
    81  				log.Println("Error getting next endpoint", err)
    82  				close(ch)
    83  				break
    84  			}
    85  			r.process(res)
    86  		}
    87  	}
    88  }
    89  */
    90  
    91  func (r *staticRouter) Register(ep *api.Endpoint) error {
    92  	if err := api.Validate(ep); err != nil {
    93  		return err
    94  	}
    95  
    96  	var pathregs []util.Pattern
    97  	var hostregs []*regexp.Regexp
    98  	var pcreregs []*regexp.Regexp
    99  
   100  	for _, h := range ep.Host {
   101  		if h == "" || h == "*" {
   102  			continue
   103  		}
   104  		hostreg, err := regexp.CompilePOSIX(h)
   105  		if err != nil {
   106  			return err
   107  		}
   108  		hostregs = append(hostregs, hostreg)
   109  	}
   110  
   111  	for _, p := range ep.Path {
   112  		var pcreok bool
   113  
   114  		// pcre only when we have start and end markers
   115  		if p[0] == '^' && p[len(p)-1] == '$' {
   116  			pcrereg, err := regexp.CompilePOSIX(p)
   117  			if err == nil {
   118  				pcreregs = append(pcreregs, pcrereg)
   119  				pcreok = true
   120  			}
   121  		}
   122  
   123  		rule, err := util.Parse(p)
   124  		if err != nil && !pcreok {
   125  			return err
   126  		} else if err != nil && pcreok {
   127  			continue
   128  		}
   129  
   130  		tpl := rule.Compile()
   131  		pathreg, err := util.NewPattern(tpl.Version, tpl.OpCodes, tpl.Pool, "")
   132  		if err != nil {
   133  			return err
   134  		}
   135  		pathregs = append(pathregs, pathreg)
   136  	}
   137  
   138  	r.Lock()
   139  	r.eps[ep.Name] = &endpoint{
   140  		apiep:    ep,
   141  		pcreregs: pcreregs,
   142  		pathregs: pathregs,
   143  		hostregs: hostregs,
   144  	}
   145  	r.Unlock()
   146  	return nil
   147  }
   148  
   149  func (r *staticRouter) Deregister(ep *api.Endpoint) error {
   150  	if err := api.Validate(ep); err != nil {
   151  		return err
   152  	}
   153  	r.Lock()
   154  	delete(r.eps, ep.Name)
   155  	r.Unlock()
   156  	return nil
   157  }
   158  
   159  func (r *staticRouter) Options() router.Options {
   160  	return r.opts
   161  }
   162  
   163  func (r *staticRouter) Close() error {
   164  	select {
   165  	case <-r.exit:
   166  		return nil
   167  	default:
   168  		close(r.exit)
   169  	}
   170  	return nil
   171  }
   172  
   173  func (r *staticRouter) Endpoint(req *http.Request) (*api.Service, error) {
   174  	ep, err := r.endpoint(req)
   175  	if err != nil {
   176  		return nil, err
   177  	}
   178  
   179  	epf := strings.Split(ep.apiep.Name, ".")
   180  	services, err := r.opts.Registry.GetService(epf[0])
   181  	if err != nil {
   182  		return nil, err
   183  	}
   184  
   185  	// hack for stream endpoint
   186  	if ep.apiep.Stream {
   187  		svcs := rutil.Copy(services)
   188  		for _, svc := range svcs {
   189  			if len(svc.Endpoints) == 0 {
   190  				e := &registry.Endpoint{}
   191  				e.Name = strings.Join(epf[1:], ".")
   192  				e.Metadata = make(map[string]string)
   193  				e.Metadata["stream"] = "true"
   194  				svc.Endpoints = append(svc.Endpoints, e)
   195  			}
   196  			for _, e := range svc.Endpoints {
   197  				e.Name = strings.Join(epf[1:], ".")
   198  				e.Metadata = make(map[string]string)
   199  				e.Metadata["stream"] = "true"
   200  			}
   201  		}
   202  
   203  		services = svcs
   204  	}
   205  
   206  	svc := &api.Service{
   207  		Name: epf[0],
   208  		Endpoint: &api.Endpoint{
   209  			Name:    strings.Join(epf[1:], "."),
   210  			Handler: "rpc",
   211  			Host:    ep.apiep.Host,
   212  			Method:  ep.apiep.Method,
   213  			Path:    ep.apiep.Path,
   214  			Body:    ep.apiep.Body,
   215  			Stream:  ep.apiep.Stream,
   216  		},
   217  		Services: services,
   218  	}
   219  
   220  	return svc, nil
   221  }
   222  
   223  func (r *staticRouter) endpoint(req *http.Request) (*endpoint, error) {
   224  	if r.isClosed() {
   225  		return nil, errors.New("router closed")
   226  	}
   227  
   228  	r.RLock()
   229  	defer r.RUnlock()
   230  
   231  	var idx int
   232  	if len(req.URL.Path) > 0 && req.URL.Path != "/" {
   233  		idx = 1
   234  	}
   235  	path := strings.Split(req.URL.Path[idx:], "/")
   236  	// use the first match
   237  	// TODO: weighted matching
   238  
   239  	for _, ep := range r.eps {
   240  		var mMatch, hMatch, pMatch bool
   241  
   242  		// 1. try method
   243  		for _, m := range ep.apiep.Method {
   244  			if m == req.Method {
   245  				mMatch = true
   246  				break
   247  			}
   248  		}
   249  		if !mMatch {
   250  			continue
   251  		}
   252  		if logger.V(logger.DebugLevel, logger.DefaultLogger) {
   253  			logger.Debugf("api method match %s", req.Method)
   254  		}
   255  
   256  		// 2. try host
   257  		if len(ep.apiep.Host) == 0 {
   258  			hMatch = true
   259  		} else {
   260  			for idx, h := range ep.apiep.Host {
   261  				if h == "" || h == "*" {
   262  					hMatch = true
   263  					break
   264  				} else {
   265  					if ep.hostregs[idx].MatchString(req.URL.Host) {
   266  						hMatch = true
   267  						break
   268  					}
   269  				}
   270  			}
   271  		}
   272  		if !hMatch {
   273  			continue
   274  		}
   275  		if logger.V(logger.DebugLevel, logger.DefaultLogger) {
   276  			logger.Debugf("api host match %s", req.URL.Host)
   277  		}
   278  
   279  		// 3. try google.api path
   280  		for _, pathreg := range ep.pathregs {
   281  			matches, err := pathreg.Match(path, "")
   282  			if err != nil {
   283  				if logger.V(logger.DebugLevel, logger.DefaultLogger) {
   284  					logger.Debugf("api gpath not match %s != %v", path, pathreg)
   285  				}
   286  				continue
   287  			}
   288  			if logger.V(logger.DebugLevel, logger.DefaultLogger) {
   289  				logger.Debugf("api gpath match %s = %v", path, pathreg)
   290  			}
   291  			pMatch = true
   292  			ctx := req.Context()
   293  			md, ok := metadata.FromContext(ctx)
   294  			if !ok {
   295  				md = make(metadata.Metadata)
   296  			}
   297  			for k, v := range matches {
   298  				md[fmt.Sprintf("x-api-field-%s", k)] = v
   299  			}
   300  			md["x-api-body"] = ep.apiep.Body
   301  			*req = *req.Clone(metadata.NewContext(ctx, md))
   302  			break
   303  		}
   304  
   305  		if !pMatch {
   306  			// 4. try path via pcre path matching
   307  			for _, pathreg := range ep.pcreregs {
   308  				if !pathreg.MatchString(req.URL.Path) {
   309  					if logger.V(logger.DebugLevel, logger.DefaultLogger) {
   310  						logger.Debugf("api pcre path not match %s != %v", req.URL.Path, pathreg)
   311  					}
   312  					continue
   313  				}
   314  				pMatch = true
   315  				break
   316  			}
   317  		}
   318  
   319  		if !pMatch {
   320  			continue
   321  		}
   322  		// TODO: Percentage traffic
   323  
   324  		// we got here, so its a match
   325  		return ep, nil
   326  	}
   327  
   328  	// no match
   329  	return nil, fmt.Errorf("endpoint not found for %v", req.URL)
   330  }
   331  
   332  func (r *staticRouter) Route(req *http.Request) (*api.Service, error) {
   333  	if r.isClosed() {
   334  		return nil, errors.New("router closed")
   335  	}
   336  
   337  	// try get an endpoint
   338  	ep, err := r.Endpoint(req)
   339  	if err != nil {
   340  		return nil, err
   341  	}
   342  
   343  	return ep, nil
   344  }
   345  
   346  func NewRouter(opts ...router.Option) *staticRouter {
   347  	options := router.NewOptions(opts...)
   348  	r := &staticRouter{
   349  		exit: make(chan bool),
   350  		opts: options,
   351  		eps:  make(map[string]*endpoint),
   352  	}
   353  	//go r.watch()
   354  	//go r.refresh()
   355  	return r
   356  }