github.com/sohaha/zlsgo@v1.7.13-0.20240501141223-10dd1a906f76/znet/limiter/rule.go (about)

     1  package limiter
     2  
     3  import (
     4  	"sync"
     5  	"time"
     6  )
     7  
     8  type singleRule struct {
     9  	notRecordsIndex   map[int]struct{}
    10  	locker            *sync.Mutex
    11  	usedRecordsIndex  sync.Map
    12  	records           []*circleQueue
    13  	defaultExpiration time.Duration
    14  	cleanupInterval   time.Duration
    15  	allowed           int
    16  	estimated         int
    17  }
    18  
    19  // newRule Initialize an access control policy
    20  func newRule(defaultExpiration time.Duration, allowed int, estimated ...int) *singleRule {
    21  	if allowed <= 0 {
    22  		allowed = 1
    23  	}
    24  	userEstimated := 0
    25  	if len(estimated) > 0 {
    26  		userEstimated = estimated[0]
    27  	}
    28  	if userEstimated <= 0 {
    29  		userEstimated = allowed
    30  	}
    31  	cleanupInterval := defaultExpiration / 100
    32  	if cleanupInterval < time.Second*1 {
    33  		cleanupInterval = time.Second * 1
    34  	}
    35  	if cleanupInterval > time.Second*60 {
    36  		cleanupInterval = time.Second * 60
    37  	}
    38  	vc := createRule(defaultExpiration, cleanupInterval, allowed, userEstimated)
    39  	go vc.deleteExpired()
    40  	return vc
    41  }
    42  
    43  func createRule(defaultExpiration, cleanupInterval time.Duration, allowed, userEstimated int) *singleRule {
    44  	var vc singleRule
    45  	var locker sync.Mutex
    46  	vc.defaultExpiration = defaultExpiration
    47  	vc.cleanupInterval = cleanupInterval
    48  	vc.allowed = allowed
    49  	vc.estimated = userEstimated
    50  	vc.notRecordsIndex = make(map[int]struct{})
    51  	vc.locker = &locker
    52  	vc.records = make([]*circleQueue, vc.estimated)
    53  	for i := range vc.records {
    54  		vc.records[i] = newCircleQueue(vc.allowed)
    55  		vc.notRecordsIndex[i] = struct{}{}
    56  	}
    57  	return &vc
    58  
    59  }
    60  
    61  // allowVisit Whether access is allowed or not. If access is allowed, an access record is added to the access record
    62  func (r *singleRule) allowVisit(key interface{}) bool {
    63  	return r.add(key) == nil
    64  }
    65  
    66  // remainingVisits Remaining visits
    67  func (r *singleRule) remainingVisits(key interface{}) int {
    68  	if index, exist := r.usedRecordsIndex.Load(key); exist {
    69  		r.records[index.(int)].deleteExpired()
    70  		return r.records[index.(int)].unUsedSize()
    71  	}
    72  	return r.allowed
    73  }
    74  
    75  // add access record
    76  func (r *singleRule) add(key interface{}) (err error) {
    77  	r.locker.Lock()
    78  	defer r.locker.Unlock()
    79  
    80  	if index, exist := r.usedRecordsIndex.Load(key); exist {
    81  		r.records[index.(int)].deleteExpired()
    82  		return r.records[index.(int)].push(time.Now().Add(r.defaultExpiration).UnixNano())
    83  	}
    84  
    85  	if len(r.notRecordsIndex) > 0 {
    86  		for index := range r.notRecordsIndex {
    87  			delete(r.notRecordsIndex, index)
    88  			r.usedRecordsIndex.Store(key, index)
    89  			return r.records[index].push(time.Now().Add(r.defaultExpiration).UnixNano())
    90  		}
    91  	}
    92  	queue := newCircleQueue(r.allowed)
    93  	r.records = append(r.records, queue)
    94  	index := len(r.records) - 1
    95  	r.usedRecordsIndex.Store(key, index)
    96  	return r.records[index].push(time.Now().Add(r.defaultExpiration).UnixNano())
    97  }
    98  
    99  // deleteExpired Delete expired data
   100  func (r *singleRule) deleteExpired() {
   101  	for range time.Tick(r.cleanupInterval) {
   102  		r.deleteExpiredOnce()
   103  		r.recovery()
   104  	}
   105  }
   106  
   107  // deleteExpiredOnce Delete expired data once in a specific time interval
   108  func (r *singleRule) deleteExpiredOnce() {
   109  	r.usedRecordsIndex.Range(func(k, v interface{}) bool {
   110  		r.locker.Lock()
   111  		index := v.(int)
   112  		if index < len(r.records) && index >= 0 {
   113  			r.records[index].deleteExpired()
   114  			if r.records[index].usedSize() == 0 {
   115  				r.usedRecordsIndex.Delete(k)
   116  				r.notRecordsIndex[index] = struct{}{}
   117  			}
   118  		} else {
   119  			r.usedRecordsIndex.Delete(k)
   120  		}
   121  		r.locker.Unlock()
   122  		return true
   123  	})
   124  }
   125  
   126  func (r *singleRule) recovery() {
   127  	r.locker.Lock()
   128  	defer r.locker.Unlock()
   129  	if r.needRecovery() {
   130  		curLen := len(r.records)
   131  		unUsedLen := len(r.notRecordsIndex)
   132  		usedLen := curLen - unUsedLen
   133  		var newLen int
   134  		if usedLen < r.estimated {
   135  			newLen = r.estimated
   136  		} else {
   137  			newLen = usedLen * 2
   138  		}
   139  		visitorRecordsNew := make([]*circleQueue, newLen)
   140  		for i := range visitorRecordsNew {
   141  			visitorRecordsNew[i] = newCircleQueue(r.allowed)
   142  		}
   143  		r.notRecordsIndex = make(map[int]struct{})
   144  		indexNew := 0
   145  		r.usedRecordsIndex.Range(func(k, v interface{}) bool {
   146  			indexOld := v.(int)
   147  			visitorRecordsNew[indexNew] = r.records[indexOld]
   148  			indexNew++
   149  			return true
   150  		})
   151  		r.records = visitorRecordsNew
   152  		for index := range r.records {
   153  			if index >= indexNew {
   154  				r.notRecordsIndex[index] = struct{}{}
   155  			}
   156  		}
   157  	}
   158  }
   159  
   160  func (r *singleRule) needRecovery() bool {
   161  	curLen := len(r.records)
   162  	unUsedLen := len(r.notRecordsIndex)
   163  	usedLen := curLen - unUsedLen
   164  	if curLen < 2*r.estimated {
   165  		return false
   166  	}
   167  	if usedLen*2 < unUsedLen {
   168  		return true
   169  	}
   170  	return false
   171  }