github.com/sohaha/zlsgo@v1.7.13-0.20240501141223-10dd1a906f76/znet/tree.go (about)

     1  package znet
     2  
     3  import (
     4  	"strings"
     5  )
     6  
     7  type (
     8  	// Tree records node
     9  	Tree struct {
    10  		root       *Node
    11  		routes     map[string]*Node
    12  		parameters Parameters
    13  	}
    14  
    15  	// Node records any URL params, and executes an end handlerFn.
    16  	Node struct {
    17  		value      interface{}
    18  		handle     handlerFn
    19  		children   map[string]*Node
    20  		key        string
    21  		path       string
    22  		middleware []handlerFn
    23  		depth      int
    24  		isPattern  bool
    25  	}
    26  )
    27  
    28  func NewNode(key string, depth int) *Node {
    29  	return &Node{
    30  		key:      key,
    31  		depth:    depth,
    32  		children: make(map[string]*Node),
    33  	}
    34  }
    35  
    36  func (t *Node) WithValue(v interface{}) *Node {
    37  	t.value = v
    38  	return t
    39  }
    40  
    41  func (t *Node) Value() interface{} {
    42  	return t.value
    43  }
    44  
    45  func (t *Node) Path() string {
    46  	return t.path
    47  }
    48  
    49  func (t *Node) Handle() handlerFn {
    50  	return t.handle
    51  }
    52  
    53  func NewTree() *Tree {
    54  	return &Tree{
    55  		root:   NewNode("/", 1),
    56  		routes: make(map[string]*Node),
    57  	}
    58  }
    59  
    60  func (t *Tree) Add(path string, handle handlerFn, middleware ...handlerFn) (currentNode *Node) {
    61  	currentNode = t.root
    62  	wareLen := len(middleware)
    63  	if path != currentNode.key {
    64  		res := strings.Split(path, "/")
    65  		end := len(res) - 1
    66  		for i, key := range res {
    67  			if key == "" {
    68  				if i != end {
    69  					continue
    70  				}
    71  				key = "/"
    72  			}
    73  			node, ok := currentNode.children[key]
    74  			if !ok {
    75  				node = NewNode(key, currentNode.depth+1)
    76  				if wareLen > 0 && i == end {
    77  					node.middleware = append(node.middleware, middleware...)
    78  				}
    79  				currentNode.children[key] = node
    80  			} else if node.handle == nil {
    81  				if wareLen > 0 && i == end {
    82  					node.middleware = append(node.middleware, middleware...)
    83  				}
    84  			}
    85  			currentNode = node
    86  		}
    87  	}
    88  
    89  	if wareLen > 0 && currentNode.depth == 1 {
    90  		currentNode.middleware = append(currentNode.middleware, middleware...)
    91  	}
    92  
    93  	currentNode.handle = handle
    94  	currentNode.isPattern = true
    95  	currentNode.path = path
    96  	if routeName := t.parameters.routeName; routeName != "" {
    97  		t.routes[routeName] = currentNode
    98  	}
    99  	return
   100  }
   101  
   102  func (t *Tree) Find(pattern string, isRegex bool) (nodes []*Node) {
   103  	var (
   104  		node  = t.root
   105  		queue []*Node
   106  	)
   107  	if pattern == node.path {
   108  		nodes = append(nodes, node)
   109  		return
   110  	}
   111  
   112  	res := strings.Split(pattern, "/")
   113  	for i := range res {
   114  		key := res[i]
   115  		if key == "" {
   116  			continue
   117  		}
   118  		child, ok := node.children[key]
   119  		if !ok && isRegex {
   120  			break
   121  		}
   122  		if !ok && !isRegex {
   123  			return
   124  		}
   125  		if pattern == child.path && !isRegex {
   126  			nodes = append(nodes, child)
   127  			return
   128  		}
   129  		node = child
   130  	}
   131  
   132  	queue = append(queue, node)
   133  	for len(queue) > 0 {
   134  		var queueTemp []*Node
   135  		for _, n := range queue {
   136  			if n.isPattern {
   137  				nodes = append(nodes, n)
   138  			}
   139  			for _, childNode := range n.children {
   140  				queueTemp = append(queueTemp, childNode)
   141  			}
   142  		}
   143  		queue = queueTemp
   144  	}
   145  	return
   146  }