github.com/avenga/couper@v1.12.2/server/mux.go (about)

     1  package server
     2  
     3  import (
     4  	"context"
     5  	"net/http"
     6  	"sort"
     7  	"strings"
     8  
     9  	gmux "github.com/gorilla/mux"
    10  
    11  	"github.com/avenga/couper/config"
    12  	"github.com/avenga/couper/config/request"
    13  	"github.com/avenga/couper/config/runtime"
    14  	"github.com/avenga/couper/errors"
    15  	"github.com/avenga/couper/handler"
    16  	"github.com/avenga/couper/handler/middleware"
    17  	"github.com/avenga/couper/utils"
    18  )
    19  
    20  // Mux is a http request router and dispatches requests
    21  // to their corresponding http handlers.
    22  type Mux struct {
    23  	endpointRoot *gmux.Router
    24  	fileRoot     *gmux.Router
    25  	opts         *runtime.MuxOptions
    26  	spaRoot      *gmux.Router
    27  }
    28  
    29  const (
    30  	serverOptionsKey = "serverContextOptions"
    31  	wildcardSearch   = "/**"
    32  )
    33  
    34  func isParamSegment(segment string) bool {
    35  	return strings.HasPrefix(segment, "{") && strings.HasSuffix(segment, "}")
    36  }
    37  
    38  func SortPathPatterns(pathPatterns []string) {
    39  	sort.Slice(pathPatterns, func(i, j int) bool {
    40  		iSegments := strings.Split(strings.TrimPrefix(pathPatterns[i], "/"), "/")
    41  		jSegments := strings.Split(strings.TrimPrefix(pathPatterns[j], "/"), "/")
    42  		iLastSegment := iSegments[len(iSegments)-1]
    43  		jLastSegment := jSegments[len(jSegments)-1]
    44  		if iLastSegment != "**" && jLastSegment == "**" {
    45  			return true
    46  		}
    47  		if iLastSegment == "**" && jLastSegment != "**" {
    48  			return false
    49  		}
    50  		if len(iSegments) > len(jSegments) {
    51  			return true
    52  		}
    53  		if len(iSegments) < len(jSegments) {
    54  			return false
    55  		}
    56  		for k, iSegment := range iSegments {
    57  			jSegment := jSegments[k]
    58  			if !isParamSegment(iSegment) && isParamSegment(jSegment) {
    59  				return true
    60  			}
    61  			if isParamSegment(iSegment) && !isParamSegment(jSegment) {
    62  				return false
    63  			}
    64  		}
    65  		return sort.StringSlice{pathPatterns[i], pathPatterns[j]}.Less(0, 1)
    66  	})
    67  }
    68  
    69  func sortedPathPatterns(routes map[string]http.Handler) []string {
    70  	pathPatterns := make([]string, len(routes))
    71  	i := 0
    72  	for k := range routes {
    73  		pathPatterns[i] = k
    74  		i++
    75  	}
    76  	SortPathPatterns(pathPatterns)
    77  	return pathPatterns
    78  }
    79  
    80  func NewMux(options *runtime.MuxOptions) *Mux {
    81  	opts := options
    82  	if opts == nil {
    83  		opts = runtime.NewMuxOptions()
    84  	}
    85  
    86  	mux := &Mux{
    87  		opts:         opts,
    88  		endpointRoot: gmux.NewRouter(),
    89  		fileRoot:     gmux.NewRouter(),
    90  		spaRoot:      gmux.NewRouter(),
    91  	}
    92  
    93  	return mux
    94  }
    95  
    96  func (m *Mux) RegisterConfigured() {
    97  	for _, path := range sortedPathPatterns(m.opts.EndpointRoutes) {
    98  		// TODO: handle method option per endpoint configuration
    99  		mustAddRoute(m.endpointRoot, path, m.opts.EndpointRoutes[path], true)
   100  	}
   101  
   102  	for _, path := range sortedPathPatterns(m.opts.FileRoutes) {
   103  		mustAddRoute(m.fileRoot, utils.JoinOpenAPIPath(path, "/**"), m.opts.FileRoutes[path], false)
   104  	}
   105  
   106  	for _, path := range sortedPathPatterns(m.opts.SPARoutes) {
   107  		mustAddRoute(m.spaRoot, path, m.opts.SPARoutes[path], true)
   108  	}
   109  }
   110  
   111  var noDefaultMethods []string
   112  
   113  func registerHandler(root *gmux.Router, methods []string, path string, handler http.Handler) {
   114  	notAllowedMethodsHandler := errors.DefaultJSON.WithError(errors.MethodNotAllowed)
   115  	allowedMethodsHandler := middleware.NewAllowedMethodsHandler(methods, noDefaultMethods, handler, notAllowedMethodsHandler)
   116  	mustAddRoute(root, path, allowedMethodsHandler, false)
   117  }
   118  
   119  func (m *Mux) FindHandler(req *http.Request) http.Handler {
   120  	ctx := context.WithValue(req.Context(), request.ServerName, m.opts.ServerOptions.ServerName)
   121  	routeMatch, matches := m.match(m.endpointRoot, req)
   122  	if !matches {
   123  		// No matches for api or free endpoints. Determine if we have entered an api basePath
   124  		// and handle api related errors accordingly.
   125  		// Otherwise, look for existing files or spa fallback.
   126  		if tpl, api := m.getAPIErrorTemplate(req.URL.Path); tpl != nil {
   127  			*req = *req.WithContext(ctx)
   128  			return tpl.WithError(errors.RouteNotFound.Label(api.BasePath)) // TODO: api label
   129  		}
   130  
   131  		fileHandler, exist := m.hasFileResponse(req)
   132  		if exist {
   133  			*req = *req.WithContext(ctx)
   134  			return fileHandler
   135  		}
   136  
   137  		routeMatch, matches = m.match(m.spaRoot, req)
   138  
   139  		if !matches {
   140  			if fileHandler != nil {
   141  				return fileHandler
   142  			}
   143  
   144  			// Fallback
   145  			*req = *req.WithContext(ctx)
   146  			return m.opts.ServerOptions.ServerErrTpl.WithError(errors.RouteNotFound)
   147  		}
   148  	}
   149  
   150  	pathParams := make(request.PathParameter, len(routeMatch.Vars))
   151  	for k, value := range routeMatch.Vars {
   152  		key := strings.TrimSuffix(strings.TrimSuffix(k, "*"), "|.+")
   153  		pathParams[key] = value
   154  	}
   155  
   156  	pt, _ := routeMatch.Route.GetPathTemplate()
   157  	p := pt
   158  	for k, v := range routeMatch.Vars {
   159  		p = strings.Replace(p, "{"+k+"}", v, 1)
   160  	}
   161  	wc := strings.TrimPrefix(req.URL.Path, p)
   162  	wc = strings.TrimPrefix(wc, "/")
   163  
   164  	if routeMatch.Route.GetName() == "**" {
   165  		ctx = context.WithValue(ctx, request.Wildcard, wc)
   166  	}
   167  
   168  	ctx = context.WithValue(ctx, request.PathParams, pathParams)
   169  	*req = *req.WithContext(ctx)
   170  
   171  	return routeMatch.Handler
   172  }
   173  
   174  func (m *Mux) match(root *gmux.Router, req *http.Request) (*gmux.RouteMatch, bool) {
   175  	var routeMatch gmux.RouteMatch
   176  	if root.Match(req, &routeMatch) {
   177  		return &routeMatch, true
   178  	}
   179  
   180  	return nil, false
   181  }
   182  
   183  func (m *Mux) hasFileResponse(req *http.Request) (http.Handler, bool) {
   184  	routeMatch, matches := m.match(m.fileRoot, req)
   185  	if !matches {
   186  		return nil, false
   187  	}
   188  
   189  	fileHandler := routeMatch.Handler
   190  	unprotectedHandler := getChildHandler(fileHandler)
   191  	if fh, ok := unprotectedHandler.(*handler.File); ok {
   192  		return fileHandler, fh.HasResponse(req)
   193  	}
   194  
   195  	if fh, ok := fileHandler.(*handler.File); ok {
   196  		return fileHandler, fh.HasResponse(req)
   197  	}
   198  
   199  	return fileHandler, false
   200  }
   201  
   202  func (m *Mux) getAPIErrorTemplate(reqPath string) (*errors.Template, *config.API) {
   203  	for api, path := range m.opts.ServerOptions.APIBasePaths {
   204  		if !isConfigured(path) {
   205  			continue
   206  		}
   207  
   208  		var spaPaths, filesPaths []string
   209  
   210  		if len(m.opts.ServerOptions.SPABasePaths) == 0 {
   211  			spaPaths = []string{""}
   212  		} else {
   213  			spaPaths = m.opts.ServerOptions.SPABasePaths
   214  		}
   215  
   216  		if len(m.opts.ServerOptions.FilesBasePaths) == 0 {
   217  			filesPaths = []string{""}
   218  		} else {
   219  			filesPaths = m.opts.ServerOptions.FilesBasePaths
   220  		}
   221  
   222  		for _, spaPath := range spaPaths {
   223  			for _, filesPath := range filesPaths {
   224  				if isAPIError(path, filesPath, spaPath, reqPath) {
   225  					return m.opts.ServerOptions.APIErrTpls[api], api
   226  				}
   227  			}
   228  		}
   229  	}
   230  
   231  	return nil, nil
   232  }
   233  
   234  func mustAddRoute(root *gmux.Router, path string, handler http.Handler, trailingSlash bool) {
   235  	if strings.HasSuffix(path, wildcardSearch) {
   236  		path = path[:len(path)-len(wildcardSearch)]
   237  		if len(path) == 0 {
   238  			root.PathPrefix("/").Name("**").Handler(handler)
   239  			return
   240  		}
   241  		root.Path(path).Name("**").Handler(handler) // register /path ...
   242  		if !strings.HasSuffix(path, "/") {
   243  			path = path + "/" // ... and /path/**
   244  		}
   245  		root.PathPrefix(path).Name("**").Handler(handler)
   246  		return
   247  	}
   248  
   249  	if len(path) == 0 {
   250  		path = "/" // path at least be /
   251  	}
   252  	// cannot use Router.StrictSlash(true) because redirect and subsequent GET request would cause problem with CORS
   253  	if trailingSlash {
   254  		path = strings.TrimSuffix(path, "/")
   255  
   256  		if len(path) > 0 {
   257  			root.Path(path).Handler(handler) // register /path ...
   258  		}
   259  		path = path + "/" // ... and /path/
   260  	}
   261  	root.Path(path).Handler(handler)
   262  }
   263  
   264  // isAPIError checks the path w/ and w/o the
   265  // trailing slash against the request path.
   266  func isAPIError(apiPath, filesBasePath, spaBasePath, reqPath string) bool {
   267  	if matchesPath(apiPath, reqPath) {
   268  		if isConfigured(filesBasePath) && apiPath == filesBasePath {
   269  			return false
   270  		}
   271  		if isConfigured(spaBasePath) && apiPath == spaBasePath {
   272  			return false
   273  		}
   274  
   275  		return true
   276  	}
   277  
   278  	return false
   279  }
   280  
   281  // matchesPath checks the path w/ and w/o the
   282  // trailing slash against the request path.
   283  func matchesPath(path, reqPath string) bool {
   284  	p1 := path
   285  	p2 := path
   286  
   287  	if p1 != "/" && !strings.HasSuffix(p1, "/") {
   288  		p1 += "/"
   289  	}
   290  	if p2 != "/" && strings.HasSuffix(p2, "/") {
   291  		p2 = p2[:len(p2)-len("/")]
   292  	}
   293  
   294  	if strings.HasPrefix(reqPath, p1) || reqPath == p2 {
   295  		return true
   296  	}
   297  
   298  	return false
   299  }
   300  
   301  func isConfigured(basePath string) bool {
   302  	return basePath != ""
   303  }