github.com/cnotch/ipchub@v1.1.0/provider/route/routetable.go (about)

     1  // Copyright (c) 2019,CAOHONGJU All rights reserved.
     2  // Use of this source code is governed by a MIT-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package route
     6  
     7  import (
     8  	"sync"
     9  
    10  	"github.com/cnotch/ipchub/utils"
    11  	"github.com/cnotch/xlog"
    12  )
    13  
    14  var globalT = &routetable{
    15  	m: make(map[string]*Route),
    16  }
    17  
    18  func init() {
    19  	// 默认为内存提供者,避免没有初始化全局函数调用问题
    20  	globalT.Reset(&memProvider{})
    21  }
    22  
    23  // Reset 重置路由表提供者
    24  func Reset(provider Provider) {
    25  	globalT.Reset(provider)
    26  }
    27  
    28  // Match 从路由表中获取和路径匹配的路由实例
    29  func Match(path string) *Route {
    30  	return globalT.Match(path)
    31  }
    32  
    33  // All 获取所有的路由
    34  func All() []*Route {
    35  	return globalT.All()
    36  }
    37  
    38  // Get 获取取指定模式的路由
    39  func Get(pattern string) *Route {
    40  	return globalT.Get(pattern)
    41  }
    42  
    43  // Del 删除指定模式的路由
    44  func Del(pattern string) error {
    45  	return globalT.Del(pattern)
    46  }
    47  
    48  // Save 保存路由
    49  func Save(src *Route) error {
    50  	return globalT.Save(src)
    51  }
    52  
    53  // Flush 刷新路由
    54  func Flush() error {
    55  	return globalT.Flush()
    56  }
    57  
    58  type routetable struct {
    59  	lock sync.RWMutex
    60  	m    map[string]*Route // 路由map
    61  	l    []*Route          // 路由list
    62  
    63  	saves   []*Route // 自上次Flush后新的保存和删除的路由
    64  	removes []*Route
    65  
    66  	provider Provider
    67  }
    68  
    69  func (t *routetable) Reset(provider Provider) {
    70  	t.lock.Lock()
    71  	defer t.lock.Unlock()
    72  
    73  	t.m = make(map[string]*Route)
    74  	t.l = t.l[:0]
    75  	t.saves = t.saves[:0]
    76  	t.removes = t.removes[:0]
    77  	t.provider = provider
    78  
    79  	routes, err := provider.LoadAll()
    80  	if err != nil {
    81  		panic("Load route table fail")
    82  	}
    83  
    84  	if cap(t.l) < len(routes) {
    85  		t.l = make([]*Route, 0, len(routes))
    86  	}
    87  
    88  	// 加入缓存
    89  	for _, r := range routes {
    90  		if err := r.init(); err != nil {
    91  			xlog.Warnf("route table init failed: `%v`", err)
    92  			continue // 忽略错误的配置
    93  		}
    94  		t.m[r.Pattern] = r
    95  		t.l = append(t.l, r)
    96  	}
    97  }
    98  
    99  func (t *routetable) Match(path string) *Route {
   100  	t.lock.RLock()
   101  	defer t.lock.RUnlock()
   102  
   103  	path = utils.CanonicalPath(path)
   104  	if path[len(path)-1] == '/' { // 必须有具体的子路径
   105  		return nil
   106  	}
   107  
   108  	r, ok := t.m[path]
   109  	if ok { // 精确匹配
   110  		ret := *r
   111  		return &ret
   112  	}
   113  
   114  	// 获取最长有效匹配的路由
   115  	var n = 0
   116  	for k, v := range t.m {
   117  		if !pathMatch(k, path) {
   118  			continue
   119  		}
   120  
   121  		if r == nil || len(k) > n {
   122  			n = len(k)
   123  			r = v
   124  		}
   125  	}
   126  
   127  	if r != nil {
   128  		ret := *r
   129  		r = &ret
   130  		if r.URL[len(r.URL)-1] == '/' {
   131  			r.URL = r.URL + path[len(r.Pattern):]
   132  		} else {
   133  			r.URL = r.URL + path[len(r.Pattern)-1:]
   134  		}
   135  		r.Pattern = path
   136  	}
   137  	return r
   138  }
   139  
   140  func (t *routetable) Get(pattern string) *Route {
   141  	t.lock.RLock()
   142  	defer t.lock.RUnlock()
   143  
   144  	pattern = utils.CanonicalPath(pattern)
   145  	r, _ := t.m[pattern]
   146  	return r
   147  }
   148  
   149  func (t *routetable) Del(pattern string) error {
   150  	t.lock.Lock()
   151  	defer t.lock.Unlock()
   152  
   153  	pattern = utils.CanonicalPath(pattern)
   154  	r, ok := t.m[pattern]
   155  
   156  	if ok {
   157  		delete(t.m, pattern)
   158  
   159  		// 从完整列表中删除
   160  		for i, r2 := range t.l {
   161  			if r.Pattern == r2.Pattern {
   162  				t.l = append(t.l[:i], t.l[i+1:]...)
   163  				break
   164  			}
   165  		}
   166  
   167  		// 从保存列表中删除
   168  		for i, r2 := range t.saves {
   169  			if r.Pattern == r2.Pattern {
   170  				t.saves = append(t.saves[:i], t.saves[i+1:]...)
   171  				break
   172  			}
   173  		}
   174  
   175  		t.removes = append(t.removes, r)
   176  	}
   177  	return nil
   178  }
   179  
   180  func (t *routetable) Save(newr *Route) error {
   181  	t.lock.Lock()
   182  	defer t.lock.Unlock()
   183  
   184  	err := newr.init()
   185  	if err != nil {
   186  		return err
   187  	}
   188  
   189  	r, ok := t.m[newr.Pattern]
   190  
   191  	if ok { // 更新
   192  		r.CopyFrom(newr)
   193  
   194  		save := true
   195  		// 如果保存列表存在,不新增
   196  		for _, r2 := range t.saves {
   197  			if r.Pattern == r2.Pattern {
   198  				save = false
   199  				break
   200  			}
   201  		}
   202  
   203  		if save {
   204  			t.saves = append(t.saves, r)
   205  		}
   206  	} else { // 新增
   207  		r = newr
   208  		t.m[r.Pattern] = r
   209  
   210  		t.l = append(t.l, r)
   211  		t.saves = append(t.saves, r)
   212  
   213  		for i, r2 := range t.removes {
   214  			if r.Pattern == r2.Pattern {
   215  				t.removes = append(t.removes[:i], t.removes[i+1:]...)
   216  				break
   217  			}
   218  		}
   219  	}
   220  	return nil
   221  }
   222  
   223  func (t *routetable) Flush() error {
   224  	t.lock.Lock()
   225  	defer t.lock.Unlock()
   226  
   227  	if len(t.saves)+len(t.removes) == 0 {
   228  		return nil
   229  	}
   230  
   231  	err := t.provider.Flush(t.l, t.saves, t.removes)
   232  	if err != nil {
   233  		return err
   234  	}
   235  
   236  	t.saves = t.saves[:0]
   237  	t.removes = t.removes[:0]
   238  	return nil
   239  }
   240  
   241  func (t *routetable) All() []*Route {
   242  	t.lock.RLock()
   243  	defer t.lock.RUnlock()
   244  
   245  	routes := make([]*Route, len(t.l))
   246  	copy(routes, t.l)
   247  	return routes
   248  }
   249  
   250  // Does path match pattern?
   251  func pathMatch(pattern, path string) bool {
   252  	if len(pattern) == 0 {
   253  		// should not happen
   254  		return false
   255  	}
   256  	n := len(pattern)
   257  	if pattern[n-1] != '/' {
   258  		return pattern == path
   259  	}
   260  	return len(path) >= n && path[0:n] == pattern
   261  }