github.com/lingyao2333/mo-zero@v1.4.1/rest/router/patrouter.go (about)

     1  package router
     2  
     3  import (
     4  	"errors"
     5  	"net/http"
     6  	"path"
     7  	"strings"
     8  
     9  	"github.com/lingyao2333/mo-zero/core/search"
    10  	"github.com/lingyao2333/mo-zero/rest/httpx"
    11  	"github.com/lingyao2333/mo-zero/rest/pathvar"
    12  )
    13  
    14  const (
    15  	allowHeader          = "Allow"
    16  	allowMethodSeparator = ", "
    17  )
    18  
    19  var (
    20  	// ErrInvalidMethod is an error that indicates not a valid http method.
    21  	ErrInvalidMethod = errors.New("not a valid http method")
    22  	// ErrInvalidPath is an error that indicates path is not start with /.
    23  	ErrInvalidPath = errors.New("path must begin with '/'")
    24  )
    25  
    26  type patRouter struct {
    27  	trees      map[string]*search.Tree
    28  	notFound   http.Handler
    29  	notAllowed http.Handler
    30  }
    31  
    32  // NewRouter returns a httpx.Router.
    33  func NewRouter() httpx.Router {
    34  	return &patRouter{
    35  		trees: make(map[string]*search.Tree),
    36  	}
    37  }
    38  
    39  func (pr *patRouter) Handle(method, reqPath string, handler http.Handler) error {
    40  	if !validMethod(method) {
    41  		return ErrInvalidMethod
    42  	}
    43  
    44  	if len(reqPath) == 0 || reqPath[0] != '/' {
    45  		return ErrInvalidPath
    46  	}
    47  
    48  	cleanPath := path.Clean(reqPath)
    49  	tree, ok := pr.trees[method]
    50  	if ok {
    51  		return tree.Add(cleanPath, handler)
    52  	}
    53  
    54  	tree = search.NewTree()
    55  	pr.trees[method] = tree
    56  	return tree.Add(cleanPath, handler)
    57  }
    58  
    59  func (pr *patRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    60  	reqPath := path.Clean(r.URL.Path)
    61  	if tree, ok := pr.trees[r.Method]; ok {
    62  		if result, ok := tree.Search(reqPath); ok {
    63  			if len(result.Params) > 0 {
    64  				r = pathvar.WithVars(r, result.Params)
    65  			}
    66  			result.Item.(http.Handler).ServeHTTP(w, r)
    67  			return
    68  		}
    69  	}
    70  
    71  	allows, ok := pr.methodsAllowed(r.Method, reqPath)
    72  	if !ok {
    73  		pr.handleNotFound(w, r)
    74  		return
    75  	}
    76  
    77  	if pr.notAllowed != nil {
    78  		pr.notAllowed.ServeHTTP(w, r)
    79  	} else {
    80  		w.Header().Set(allowHeader, allows)
    81  		w.WriteHeader(http.StatusMethodNotAllowed)
    82  	}
    83  }
    84  
    85  func (pr *patRouter) SetNotFoundHandler(handler http.Handler) {
    86  	pr.notFound = handler
    87  }
    88  
    89  func (pr *patRouter) SetNotAllowedHandler(handler http.Handler) {
    90  	pr.notAllowed = handler
    91  }
    92  
    93  func (pr *patRouter) handleNotFound(w http.ResponseWriter, r *http.Request) {
    94  	if pr.notFound != nil {
    95  		pr.notFound.ServeHTTP(w, r)
    96  	} else {
    97  		http.NotFound(w, r)
    98  	}
    99  }
   100  
   101  func (pr *patRouter) methodsAllowed(method, path string) (string, bool) {
   102  	var allows []string
   103  
   104  	for treeMethod, tree := range pr.trees {
   105  		if treeMethod == method {
   106  			continue
   107  		}
   108  
   109  		_, ok := tree.Search(path)
   110  		if ok {
   111  			allows = append(allows, treeMethod)
   112  		}
   113  	}
   114  
   115  	if len(allows) > 0 {
   116  		return strings.Join(allows, allowMethodSeparator), true
   117  	}
   118  
   119  	return "", false
   120  }
   121  
   122  func validMethod(method string) bool {
   123  	return method == http.MethodDelete || method == http.MethodGet ||
   124  		method == http.MethodHead || method == http.MethodOptions ||
   125  		method == http.MethodPatch || method == http.MethodPost ||
   126  		method == http.MethodPut
   127  }