github.com/hellobchain/third_party@v0.0.0-20230331131523-deb0478a2e52/go-chi/chi/tree.go (about) 1 package chi 2 3 // Radix tree implementation below is a based on the original work by 4 // Armon Dadgar in https://github.com/armon/go-radix/blob/master/radix.go 5 // (MIT licensed). It's been heavily modified for use as a HTTP routing tree. 6 7 import ( 8 "fmt" 9 "math" 10 "regexp" 11 "sort" 12 "strconv" 13 "strings" 14 15 "github.com/hellobchain/newcryptosm/http" 16 ) 17 18 type methodTyp int 19 20 const ( 21 mSTUB methodTyp = 1 << iota 22 mCONNECT 23 mDELETE 24 mGET 25 mHEAD 26 mOPTIONS 27 mPATCH 28 mPOST 29 mPUT 30 mTRACE 31 ) 32 33 var mALL = mCONNECT | mDELETE | mGET | mHEAD | 34 mOPTIONS | mPATCH | mPOST | mPUT | mTRACE 35 36 var methodMap = map[string]methodTyp{ 37 "CONNECT": mCONNECT, 38 "DELETE": mDELETE, 39 "GET": mGET, 40 "HEAD": mHEAD, 41 "OPTIONS": mOPTIONS, 42 "PATCH": mPATCH, 43 "POST": mPOST, 44 "PUT": mPUT, 45 "TRACE": mTRACE, 46 } 47 48 // RegisterMethod adds support for custom HTTP method handlers, available 49 // via Router#Method and Router#MethodFunc 50 func RegisterMethod(method string) { 51 if method == "" { 52 return 53 } 54 method = strings.ToUpper(method) 55 if _, ok := methodMap[method]; ok { 56 return 57 } 58 n := len(methodMap) 59 if n > strconv.IntSize { 60 panic(fmt.Sprintf("chi: max number of methods reached (%d)", strconv.IntSize)) 61 } 62 mt := methodTyp(math.Exp2(float64(n))) 63 methodMap[method] = mt 64 mALL |= mt 65 } 66 67 type nodeTyp uint8 68 69 const ( 70 ntStatic nodeTyp = iota // /home 71 ntRegexp // /{id:[0-9]+} 72 ntParam // /{user} 73 ntCatchAll // /api/v1/* 74 ) 75 76 type node struct { 77 // node type: static, regexp, param, catchAll 78 typ nodeTyp 79 80 // first byte of the prefix 81 label byte 82 83 // first byte of the child prefix 84 tail byte 85 86 // prefix is the common prefix we ignore 87 prefix string 88 89 // regexp matcher for regexp nodes 90 rex *regexp.Regexp 91 92 // HTTP handler endpoints on the leaf node 93 endpoints endpoints 94 95 // subroutes on the leaf node 96 subroutes Routes 97 98 // child nodes should be stored in-order for iteration, 99 // in groups of the node type. 100 children [ntCatchAll + 1]nodes 101 } 102 103 // endpoints is a mapping of http method constants to handlers 104 // for a given route. 105 type endpoints map[methodTyp]*endpoint 106 107 type endpoint struct { 108 // endpoint handler 109 handler http.Handler 110 111 // pattern is the routing pattern for handler nodes 112 pattern string 113 114 // parameter keys recorded on handler nodes 115 paramKeys []string 116 } 117 118 func (s endpoints) Value(method methodTyp) *endpoint { 119 mh, ok := s[method] 120 if !ok { 121 mh = &endpoint{} 122 s[method] = mh 123 } 124 return mh 125 } 126 127 func (n *node) InsertRoute(method methodTyp, pattern string, handler http.Handler) *node { 128 var parent *node 129 search := pattern 130 131 for { 132 // Handle key exhaustion 133 if len(search) == 0 { 134 // Insert or update the node's leaf handler 135 n.setEndpoint(method, handler, pattern) 136 return n 137 } 138 139 // We're going to be searching for a wild node next, 140 // in this case, we need to get the tail 141 var label = search[0] 142 var segTail byte 143 var segEndIdx int 144 var segTyp nodeTyp 145 var segRexpat string 146 if label == '{' || label == '*' { 147 segTyp, _, segRexpat, segTail, _, segEndIdx = patNextSegment(search) 148 } 149 150 var prefix string 151 if segTyp == ntRegexp { 152 prefix = segRexpat 153 } 154 155 // Look for the edge to attach to 156 parent = n 157 n = n.getEdge(segTyp, label, segTail, prefix) 158 159 // No edge, create one 160 if n == nil { 161 child := &node{label: label, tail: segTail, prefix: search} 162 hn := parent.addChild(child, search) 163 hn.setEndpoint(method, handler, pattern) 164 165 return hn 166 } 167 168 // Found an edge to match the pattern 169 170 if n.typ > ntStatic { 171 // We found a param node, trim the param from the search path and continue. 172 // This param/wild pattern segment would already be on the tree from a previous 173 // call to addChild when creating a new node. 174 search = search[segEndIdx:] 175 continue 176 } 177 178 // Static nodes fall below here. 179 // Determine longest prefix of the search key on match. 180 commonPrefix := longestPrefix(search, n.prefix) 181 if commonPrefix == len(n.prefix) { 182 // the common prefix is as long as the current node's prefix we're attempting to insert. 183 // keep the search going. 184 search = search[commonPrefix:] 185 continue 186 } 187 188 // Split the node 189 child := &node{ 190 typ: ntStatic, 191 prefix: search[:commonPrefix], 192 } 193 parent.replaceChild(search[0], segTail, child) 194 195 // Restore the existing node 196 n.label = n.prefix[commonPrefix] 197 n.prefix = n.prefix[commonPrefix:] 198 child.addChild(n, n.prefix) 199 200 // If the new key is a subset, set the method/handler on this node and finish. 201 search = search[commonPrefix:] 202 if len(search) == 0 { 203 child.setEndpoint(method, handler, pattern) 204 return child 205 } 206 207 // Create a new edge for the node 208 subchild := &node{ 209 typ: ntStatic, 210 label: search[0], 211 prefix: search, 212 } 213 hn := child.addChild(subchild, search) 214 hn.setEndpoint(method, handler, pattern) 215 return hn 216 } 217 } 218 219 // addChild appends the new `child` node to the tree using the `pattern` as the trie key. 220 // For a URL router like chi's, we split the static, param, regexp and wildcard segments 221 // into different nodes. In addition, addChild will recursively call itself until every 222 // pattern segment is added to the url pattern tree as individual nodes, depending on type. 223 func (n *node) addChild(child *node, prefix string) *node { 224 search := prefix 225 226 // handler leaf node added to the tree is the child. 227 // this may be overridden later down the flow 228 hn := child 229 230 // Parse next segment 231 segTyp, _, segRexpat, segTail, segStartIdx, segEndIdx := patNextSegment(search) 232 233 // Add child depending on next up segment 234 switch segTyp { 235 236 case ntStatic: 237 // Search prefix is all static (that is, has no params in path) 238 // noop 239 240 default: 241 // Search prefix contains a param, regexp or wildcard 242 243 if segTyp == ntRegexp { 244 rex, err := regexp.Compile(segRexpat) 245 if err != nil { 246 panic(fmt.Sprintf("chi: invalid regexp pattern '%s' in route param", segRexpat)) 247 } 248 child.prefix = segRexpat 249 child.rex = rex 250 } 251 252 if segStartIdx == 0 { 253 // Route starts with a param 254 child.typ = segTyp 255 256 if segTyp == ntCatchAll { 257 segStartIdx = -1 258 } else { 259 segStartIdx = segEndIdx 260 } 261 if segStartIdx < 0 { 262 segStartIdx = len(search) 263 } 264 child.tail = segTail // for params, we set the tail 265 266 if segStartIdx != len(search) { 267 // add static edge for the remaining part, split the end. 268 // its not possible to have adjacent param nodes, so its certainly 269 // going to be a static node next. 270 271 search = search[segStartIdx:] // advance search position 272 273 nn := &node{ 274 typ: ntStatic, 275 label: search[0], 276 prefix: search, 277 } 278 hn = child.addChild(nn, search) 279 } 280 281 } else if segStartIdx > 0 { 282 // Route has some param 283 284 // starts with a static segment 285 child.typ = ntStatic 286 child.prefix = search[:segStartIdx] 287 child.rex = nil 288 289 // add the param edge node 290 search = search[segStartIdx:] 291 292 nn := &node{ 293 typ: segTyp, 294 label: search[0], 295 tail: segTail, 296 } 297 hn = child.addChild(nn, search) 298 299 } 300 } 301 302 n.children[child.typ] = append(n.children[child.typ], child) 303 n.children[child.typ].Sort() 304 return hn 305 } 306 307 func (n *node) replaceChild(label, tail byte, child *node) { 308 for i := 0; i < len(n.children[child.typ]); i++ { 309 if n.children[child.typ][i].label == label && n.children[child.typ][i].tail == tail { 310 n.children[child.typ][i] = child 311 n.children[child.typ][i].label = label 312 n.children[child.typ][i].tail = tail 313 return 314 } 315 } 316 panic("chi: replacing missing child") 317 } 318 319 func (n *node) getEdge(ntyp nodeTyp, label, tail byte, prefix string) *node { 320 nds := n.children[ntyp] 321 for i := 0; i < len(nds); i++ { 322 if nds[i].label == label && nds[i].tail == tail { 323 if ntyp == ntRegexp && nds[i].prefix != prefix { 324 continue 325 } 326 return nds[i] 327 } 328 } 329 return nil 330 } 331 332 func (n *node) setEndpoint(method methodTyp, handler http.Handler, pattern string) { 333 // Set the handler for the method type on the node 334 if n.endpoints == nil { 335 n.endpoints = make(endpoints, 0) 336 } 337 338 paramKeys := patParamKeys(pattern) 339 340 if method&mSTUB == mSTUB { 341 n.endpoints.Value(mSTUB).handler = handler 342 } 343 if method&mALL == mALL { 344 h := n.endpoints.Value(mALL) 345 h.handler = handler 346 h.pattern = pattern 347 h.paramKeys = paramKeys 348 for _, m := range methodMap { 349 h := n.endpoints.Value(m) 350 h.handler = handler 351 h.pattern = pattern 352 h.paramKeys = paramKeys 353 } 354 } else { 355 h := n.endpoints.Value(method) 356 h.handler = handler 357 h.pattern = pattern 358 h.paramKeys = paramKeys 359 } 360 } 361 362 func (n *node) FindRoute(rctx *Context, method methodTyp, path string) (*node, endpoints, http.Handler) { 363 // Reset the context routing pattern and params 364 rctx.routePattern = "" 365 rctx.routeParams.Keys = rctx.routeParams.Keys[:0] 366 rctx.routeParams.Values = rctx.routeParams.Values[:0] 367 368 // Find the routing handlers for the path 369 rn := n.findRoute(rctx, method, path) 370 if rn == nil { 371 return nil, nil, nil 372 } 373 374 // Record the routing params in the request lifecycle 375 rctx.URLParams.Keys = append(rctx.URLParams.Keys, rctx.routeParams.Keys...) 376 rctx.URLParams.Values = append(rctx.URLParams.Values, rctx.routeParams.Values...) 377 378 // Record the routing pattern in the request lifecycle 379 if rn.endpoints[method].pattern != "" { 380 rctx.routePattern = rn.endpoints[method].pattern 381 rctx.RoutePatterns = append(rctx.RoutePatterns, rctx.routePattern) 382 } 383 384 return rn, rn.endpoints, rn.endpoints[method].handler 385 } 386 387 // Recursive edge traversal by checking all nodeTyp groups along the way. 388 // It's like searching through a multi-dimensional radix trie. 389 func (n *node) findRoute(rctx *Context, method methodTyp, path string) *node { 390 nn := n 391 search := path 392 393 for t, nds := range nn.children { 394 ntyp := nodeTyp(t) 395 if len(nds) == 0 { 396 continue 397 } 398 399 var xn *node 400 xsearch := search 401 402 var label byte 403 if search != "" { 404 label = search[0] 405 } 406 407 switch ntyp { 408 case ntStatic: 409 xn = nds.findEdge(label) 410 if xn == nil || !strings.HasPrefix(xsearch, xn.prefix) { 411 continue 412 } 413 xsearch = xsearch[len(xn.prefix):] 414 415 case ntParam, ntRegexp: 416 // short-circuit and return no matching route for empty param values 417 if xsearch == "" { 418 continue 419 } 420 421 // serially loop through each node grouped by the tail delimiter 422 for idx := 0; idx < len(nds); idx++ { 423 xn = nds[idx] 424 425 // label for param nodes is the delimiter byte 426 p := strings.IndexByte(xsearch, xn.tail) 427 428 if p < 0 { 429 if xn.tail == '/' { 430 p = len(xsearch) 431 } else { 432 continue 433 } 434 } 435 436 if ntyp == ntRegexp && xn.rex != nil { 437 if xn.rex.Match([]byte(xsearch[:p])) == false { 438 continue 439 } 440 } else if strings.IndexByte(xsearch[:p], '/') != -1 { 441 // avoid a match across path segments 442 continue 443 } 444 445 rctx.routeParams.Values = append(rctx.routeParams.Values, xsearch[:p]) 446 xsearch = xsearch[p:] 447 break 448 } 449 450 default: 451 // catch-all nodes 452 rctx.routeParams.Values = append(rctx.routeParams.Values, search) 453 xn = nds[0] 454 xsearch = "" 455 } 456 457 if xn == nil { 458 continue 459 } 460 461 // did we find it yet? 462 if len(xsearch) == 0 { 463 if xn.isLeaf() { 464 h, _ := xn.endpoints[method] 465 if h != nil && h.handler != nil { 466 rctx.routeParams.Keys = append(rctx.routeParams.Keys, h.paramKeys...) 467 return xn 468 } 469 470 // flag that the routing context found a route, but not a corresponding 471 // supported method 472 rctx.methodNotAllowed = true 473 } 474 } 475 476 // recursively find the next node.. 477 fin := xn.findRoute(rctx, method, xsearch) 478 if fin != nil { 479 return fin 480 } 481 482 // Did not find final handler, let's remove the param here if it was set 483 if xn.typ > ntStatic { 484 if len(rctx.routeParams.Values) > 0 { 485 rctx.routeParams.Values = rctx.routeParams.Values[:len(rctx.routeParams.Values)-1] 486 } 487 } 488 489 } 490 491 return nil 492 } 493 494 func (n *node) findEdge(ntyp nodeTyp, label byte) *node { 495 nds := n.children[ntyp] 496 num := len(nds) 497 idx := 0 498 499 switch ntyp { 500 case ntStatic, ntParam, ntRegexp: 501 i, j := 0, num-1 502 for i <= j { 503 idx = i + (j-i)/2 504 if label > nds[idx].label { 505 i = idx + 1 506 } else if label < nds[idx].label { 507 j = idx - 1 508 } else { 509 i = num // breaks cond 510 } 511 } 512 if nds[idx].label != label { 513 return nil 514 } 515 return nds[idx] 516 517 default: // catch all 518 return nds[idx] 519 } 520 } 521 522 func (n *node) isEmpty() bool { 523 for _, nds := range n.children { 524 if len(nds) > 0 { 525 return false 526 } 527 } 528 return true 529 } 530 531 func (n *node) isLeaf() bool { 532 return n.endpoints != nil 533 } 534 535 func (n *node) findPattern(pattern string) bool { 536 nn := n 537 for _, nds := range nn.children { 538 if len(nds) == 0 { 539 continue 540 } 541 542 n = nn.findEdge(nds[0].typ, pattern[0]) 543 if n == nil { 544 continue 545 } 546 547 var idx int 548 var xpattern string 549 550 switch n.typ { 551 case ntStatic: 552 idx = longestPrefix(pattern, n.prefix) 553 if idx < len(n.prefix) { 554 continue 555 } 556 557 case ntParam, ntRegexp: 558 idx = strings.IndexByte(pattern, '}') + 1 559 560 case ntCatchAll: 561 idx = longestPrefix(pattern, "*") 562 563 default: 564 panic("chi: unknown node type") 565 } 566 567 xpattern = pattern[idx:] 568 if len(xpattern) == 0 { 569 return true 570 } 571 572 return n.findPattern(xpattern) 573 } 574 return false 575 } 576 577 func (n *node) routes() []Route { 578 rts := []Route{} 579 580 n.walk(func(eps endpoints, subroutes Routes) bool { 581 if eps[mSTUB] != nil && eps[mSTUB].handler != nil && subroutes == nil { 582 return false 583 } 584 585 // Group methodHandlers by unique patterns 586 pats := make(map[string]endpoints, 0) 587 588 for mt, h := range eps { 589 if h.pattern == "" { 590 continue 591 } 592 p, ok := pats[h.pattern] 593 if !ok { 594 p = endpoints{} 595 pats[h.pattern] = p 596 } 597 p[mt] = h 598 } 599 600 for p, mh := range pats { 601 hs := make(map[string]http.Handler, 0) 602 if mh[mALL] != nil && mh[mALL].handler != nil { 603 hs["*"] = mh[mALL].handler 604 } 605 606 for mt, h := range mh { 607 if h.handler == nil { 608 continue 609 } 610 m := methodTypString(mt) 611 if m == "" { 612 continue 613 } 614 hs[m] = h.handler 615 } 616 617 rt := Route{p, hs, subroutes} 618 rts = append(rts, rt) 619 } 620 621 return false 622 }) 623 624 return rts 625 } 626 627 func (n *node) walk(fn func(eps endpoints, subroutes Routes) bool) bool { 628 // Visit the leaf values if any 629 if (n.endpoints != nil || n.subroutes != nil) && fn(n.endpoints, n.subroutes) { 630 return true 631 } 632 633 // Recurse on the children 634 for _, ns := range n.children { 635 for _, cn := range ns { 636 if cn.walk(fn) { 637 return true 638 } 639 } 640 } 641 return false 642 } 643 644 // patNextSegment returns the next segment details from a pattern: 645 // node type, param key, regexp string, param tail byte, param starting index, param ending index 646 func patNextSegment(pattern string) (nodeTyp, string, string, byte, int, int) { 647 ps := strings.Index(pattern, "{") 648 ws := strings.Index(pattern, "*") 649 650 if ps < 0 && ws < 0 { 651 return ntStatic, "", "", 0, 0, len(pattern) // we return the entire thing 652 } 653 654 // Sanity check 655 if ps >= 0 && ws >= 0 && ws < ps { 656 panic("chi: wildcard '*' must be the last pattern in a route, otherwise use a '{param}'") 657 } 658 659 var tail byte = '/' // Default endpoint tail to / byte 660 661 if ps >= 0 { 662 // Param/Regexp pattern is next 663 nt := ntParam 664 665 // Read to closing } taking into account opens and closes in curl count (cc) 666 cc := 0 667 pe := ps 668 for i, c := range pattern[ps:] { 669 if c == '{' { 670 cc++ 671 } else if c == '}' { 672 cc-- 673 if cc == 0 { 674 pe = ps + i 675 break 676 } 677 } 678 } 679 if pe == ps { 680 panic("chi: route param closing delimiter '}' is missing") 681 } 682 683 key := pattern[ps+1 : pe] 684 pe++ // set end to next position 685 686 if pe < len(pattern) { 687 tail = pattern[pe] 688 } 689 690 var rexpat string 691 if idx := strings.Index(key, ":"); idx >= 0 { 692 nt = ntRegexp 693 rexpat = key[idx+1:] 694 key = key[:idx] 695 } 696 697 if len(rexpat) > 0 { 698 if rexpat[0] != '^' { 699 rexpat = "^" + rexpat 700 } 701 if rexpat[len(rexpat)-1] != '$' { 702 rexpat = rexpat + "$" 703 } 704 } 705 706 return nt, key, rexpat, tail, ps, pe 707 } 708 709 // Wildcard pattern as finale 710 // TODO: should we panic if there is stuff after the * ??? 711 return ntCatchAll, "*", "", 0, ws, len(pattern) 712 } 713 714 func patParamKeys(pattern string) []string { 715 pat := pattern 716 paramKeys := []string{} 717 for { 718 ptyp, paramKey, _, _, _, e := patNextSegment(pat) 719 if ptyp == ntStatic { 720 return paramKeys 721 } 722 for i := 0; i < len(paramKeys); i++ { 723 if paramKeys[i] == paramKey { 724 panic(fmt.Sprintf("chi: routing pattern '%s' contains duplicate param key, '%s'", pattern, paramKey)) 725 } 726 } 727 paramKeys = append(paramKeys, paramKey) 728 pat = pat[e:] 729 } 730 } 731 732 // longestPrefix finds the length of the shared prefix 733 // of two strings 734 func longestPrefix(k1, k2 string) int { 735 max := len(k1) 736 if l := len(k2); l < max { 737 max = l 738 } 739 var i int 740 for i = 0; i < max; i++ { 741 if k1[i] != k2[i] { 742 break 743 } 744 } 745 return i 746 } 747 748 func methodTypString(method methodTyp) string { 749 for s, t := range methodMap { 750 if method == t { 751 return s 752 } 753 } 754 return "" 755 } 756 757 type nodes []*node 758 759 // Sort the list of nodes by label 760 func (ns nodes) Sort() { sort.Sort(ns); ns.tailSort() } 761 func (ns nodes) Len() int { return len(ns) } 762 func (ns nodes) Swap(i, j int) { ns[i], ns[j] = ns[j], ns[i] } 763 func (ns nodes) Less(i, j int) bool { return ns[i].label < ns[j].label } 764 765 // tailSort pushes nodes with '/' as the tail to the end of the list for param nodes. 766 // The list order determines the traversal order. 767 func (ns nodes) tailSort() { 768 for i := len(ns) - 1; i >= 0; i-- { 769 if ns[i].typ > ntStatic && ns[i].tail == '/' { 770 ns.Swap(i, len(ns)-1) 771 return 772 } 773 } 774 } 775 776 func (ns nodes) findEdge(label byte) *node { 777 num := len(ns) 778 idx := 0 779 i, j := 0, num-1 780 for i <= j { 781 idx = i + (j-i)/2 782 if label > ns[idx].label { 783 i = idx + 1 784 } else if label < ns[idx].label { 785 j = idx - 1 786 } else { 787 i = num // breaks cond 788 } 789 } 790 if ns[idx].label != label { 791 return nil 792 } 793 return ns[idx] 794 } 795 796 // Route describes the details of a routing handler. 797 type Route struct { 798 Pattern string 799 Handlers map[string]http.Handler 800 SubRoutes Routes 801 } 802 803 // WalkFunc is the type of the function called for each method and route visited by Walk. 804 type WalkFunc func(method string, route string, handler http.Handler, middlewares ...func(http.Handler) http.Handler) error 805 806 // Walk walks any router tree that implements Routes interface. 807 func Walk(r Routes, walkFn WalkFunc) error { 808 return walk(r, walkFn, "") 809 } 810 811 func walk(r Routes, walkFn WalkFunc, parentRoute string, parentMw ...func(http.Handler) http.Handler) error { 812 for _, route := range r.Routes() { 813 mws := make([]func(http.Handler) http.Handler, len(parentMw)) 814 copy(mws, parentMw) 815 mws = append(mws, r.Middlewares()...) 816 817 if route.SubRoutes != nil { 818 if err := walk(route.SubRoutes, walkFn, parentRoute+route.Pattern, mws...); err != nil { 819 return err 820 } 821 continue 822 } 823 824 for method, handler := range route.Handlers { 825 if method == "*" { 826 // Ignore a "catchAll" method, since we pass down all the specific methods for each route. 827 continue 828 } 829 830 fullRoute := parentRoute + route.Pattern 831 832 if chain, ok := handler.(*ChainHandler); ok { 833 if err := walkFn(method, fullRoute, chain.Endpoint, append(mws, chain.Middlewares...)...); err != nil { 834 return err 835 } 836 } else { 837 if err := walkFn(method, fullRoute, handler, mws...); err != nil { 838 return err 839 } 840 } 841 } 842 } 843 844 return nil 845 }