github.com/lingyao2333/mo-zero@v1.4.1/rest/router/patrouter.go (about) 1 package router 2 3 import ( 4 "errors" 5 "net/http" 6 "path" 7 "strings" 8 9 "github.com/lingyao2333/mo-zero/core/search" 10 "github.com/lingyao2333/mo-zero/rest/httpx" 11 "github.com/lingyao2333/mo-zero/rest/pathvar" 12 ) 13 14 const ( 15 allowHeader = "Allow" 16 allowMethodSeparator = ", " 17 ) 18 19 var ( 20 // ErrInvalidMethod is an error that indicates not a valid http method. 21 ErrInvalidMethod = errors.New("not a valid http method") 22 // ErrInvalidPath is an error that indicates path is not start with /. 23 ErrInvalidPath = errors.New("path must begin with '/'") 24 ) 25 26 type patRouter struct { 27 trees map[string]*search.Tree 28 notFound http.Handler 29 notAllowed http.Handler 30 } 31 32 // NewRouter returns a httpx.Router. 33 func NewRouter() httpx.Router { 34 return &patRouter{ 35 trees: make(map[string]*search.Tree), 36 } 37 } 38 39 func (pr *patRouter) Handle(method, reqPath string, handler http.Handler) error { 40 if !validMethod(method) { 41 return ErrInvalidMethod 42 } 43 44 if len(reqPath) == 0 || reqPath[0] != '/' { 45 return ErrInvalidPath 46 } 47 48 cleanPath := path.Clean(reqPath) 49 tree, ok := pr.trees[method] 50 if ok { 51 return tree.Add(cleanPath, handler) 52 } 53 54 tree = search.NewTree() 55 pr.trees[method] = tree 56 return tree.Add(cleanPath, handler) 57 } 58 59 func (pr *patRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { 60 reqPath := path.Clean(r.URL.Path) 61 if tree, ok := pr.trees[r.Method]; ok { 62 if result, ok := tree.Search(reqPath); ok { 63 if len(result.Params) > 0 { 64 r = pathvar.WithVars(r, result.Params) 65 } 66 result.Item.(http.Handler).ServeHTTP(w, r) 67 return 68 } 69 } 70 71 allows, ok := pr.methodsAllowed(r.Method, reqPath) 72 if !ok { 73 pr.handleNotFound(w, r) 74 return 75 } 76 77 if pr.notAllowed != nil { 78 pr.notAllowed.ServeHTTP(w, r) 79 } else { 80 w.Header().Set(allowHeader, allows) 81 w.WriteHeader(http.StatusMethodNotAllowed) 82 } 83 } 84 85 func (pr *patRouter) SetNotFoundHandler(handler http.Handler) { 86 pr.notFound = handler 87 } 88 89 func (pr *patRouter) SetNotAllowedHandler(handler http.Handler) { 90 pr.notAllowed = handler 91 } 92 93 func (pr *patRouter) handleNotFound(w http.ResponseWriter, r *http.Request) { 94 if pr.notFound != nil { 95 pr.notFound.ServeHTTP(w, r) 96 } else { 97 http.NotFound(w, r) 98 } 99 } 100 101 func (pr *patRouter) methodsAllowed(method, path string) (string, bool) { 102 var allows []string 103 104 for treeMethod, tree := range pr.trees { 105 if treeMethod == method { 106 continue 107 } 108 109 _, ok := tree.Search(path) 110 if ok { 111 allows = append(allows, treeMethod) 112 } 113 } 114 115 if len(allows) > 0 { 116 return strings.Join(allows, allowMethodSeparator), true 117 } 118 119 return "", false 120 } 121 122 func validMethod(method string) bool { 123 return method == http.MethodDelete || method == http.MethodGet || 124 method == http.MethodHead || method == http.MethodOptions || 125 method == http.MethodPatch || method == http.MethodPost || 126 method == http.MethodPut 127 }