github.com/sohaha/zlsgo@v1.7.13-0.20240501141223-10dd1a906f76/znet/limiter/rule.go (about) 1 package limiter 2 3 import ( 4 "sync" 5 "time" 6 ) 7 8 type singleRule struct { 9 notRecordsIndex map[int]struct{} 10 locker *sync.Mutex 11 usedRecordsIndex sync.Map 12 records []*circleQueue 13 defaultExpiration time.Duration 14 cleanupInterval time.Duration 15 allowed int 16 estimated int 17 } 18 19 // newRule Initialize an access control policy 20 func newRule(defaultExpiration time.Duration, allowed int, estimated ...int) *singleRule { 21 if allowed <= 0 { 22 allowed = 1 23 } 24 userEstimated := 0 25 if len(estimated) > 0 { 26 userEstimated = estimated[0] 27 } 28 if userEstimated <= 0 { 29 userEstimated = allowed 30 } 31 cleanupInterval := defaultExpiration / 100 32 if cleanupInterval < time.Second*1 { 33 cleanupInterval = time.Second * 1 34 } 35 if cleanupInterval > time.Second*60 { 36 cleanupInterval = time.Second * 60 37 } 38 vc := createRule(defaultExpiration, cleanupInterval, allowed, userEstimated) 39 go vc.deleteExpired() 40 return vc 41 } 42 43 func createRule(defaultExpiration, cleanupInterval time.Duration, allowed, userEstimated int) *singleRule { 44 var vc singleRule 45 var locker sync.Mutex 46 vc.defaultExpiration = defaultExpiration 47 vc.cleanupInterval = cleanupInterval 48 vc.allowed = allowed 49 vc.estimated = userEstimated 50 vc.notRecordsIndex = make(map[int]struct{}) 51 vc.locker = &locker 52 vc.records = make([]*circleQueue, vc.estimated) 53 for i := range vc.records { 54 vc.records[i] = newCircleQueue(vc.allowed) 55 vc.notRecordsIndex[i] = struct{}{} 56 } 57 return &vc 58 59 } 60 61 // allowVisit Whether access is allowed or not. If access is allowed, an access record is added to the access record 62 func (r *singleRule) allowVisit(key interface{}) bool { 63 return r.add(key) == nil 64 } 65 66 // remainingVisits Remaining visits 67 func (r *singleRule) remainingVisits(key interface{}) int { 68 if index, exist := r.usedRecordsIndex.Load(key); exist { 69 r.records[index.(int)].deleteExpired() 70 return r.records[index.(int)].unUsedSize() 71 } 72 return r.allowed 73 } 74 75 // add access record 76 func (r *singleRule) add(key interface{}) (err error) { 77 r.locker.Lock() 78 defer r.locker.Unlock() 79 80 if index, exist := r.usedRecordsIndex.Load(key); exist { 81 r.records[index.(int)].deleteExpired() 82 return r.records[index.(int)].push(time.Now().Add(r.defaultExpiration).UnixNano()) 83 } 84 85 if len(r.notRecordsIndex) > 0 { 86 for index := range r.notRecordsIndex { 87 delete(r.notRecordsIndex, index) 88 r.usedRecordsIndex.Store(key, index) 89 return r.records[index].push(time.Now().Add(r.defaultExpiration).UnixNano()) 90 } 91 } 92 queue := newCircleQueue(r.allowed) 93 r.records = append(r.records, queue) 94 index := len(r.records) - 1 95 r.usedRecordsIndex.Store(key, index) 96 return r.records[index].push(time.Now().Add(r.defaultExpiration).UnixNano()) 97 } 98 99 // deleteExpired Delete expired data 100 func (r *singleRule) deleteExpired() { 101 for range time.Tick(r.cleanupInterval) { 102 r.deleteExpiredOnce() 103 r.recovery() 104 } 105 } 106 107 // deleteExpiredOnce Delete expired data once in a specific time interval 108 func (r *singleRule) deleteExpiredOnce() { 109 r.usedRecordsIndex.Range(func(k, v interface{}) bool { 110 r.locker.Lock() 111 index := v.(int) 112 if index < len(r.records) && index >= 0 { 113 r.records[index].deleteExpired() 114 if r.records[index].usedSize() == 0 { 115 r.usedRecordsIndex.Delete(k) 116 r.notRecordsIndex[index] = struct{}{} 117 } 118 } else { 119 r.usedRecordsIndex.Delete(k) 120 } 121 r.locker.Unlock() 122 return true 123 }) 124 } 125 126 func (r *singleRule) recovery() { 127 r.locker.Lock() 128 defer r.locker.Unlock() 129 if r.needRecovery() { 130 curLen := len(r.records) 131 unUsedLen := len(r.notRecordsIndex) 132 usedLen := curLen - unUsedLen 133 var newLen int 134 if usedLen < r.estimated { 135 newLen = r.estimated 136 } else { 137 newLen = usedLen * 2 138 } 139 visitorRecordsNew := make([]*circleQueue, newLen) 140 for i := range visitorRecordsNew { 141 visitorRecordsNew[i] = newCircleQueue(r.allowed) 142 } 143 r.notRecordsIndex = make(map[int]struct{}) 144 indexNew := 0 145 r.usedRecordsIndex.Range(func(k, v interface{}) bool { 146 indexOld := v.(int) 147 visitorRecordsNew[indexNew] = r.records[indexOld] 148 indexNew++ 149 return true 150 }) 151 r.records = visitorRecordsNew 152 for index := range r.records { 153 if index >= indexNew { 154 r.notRecordsIndex[index] = struct{}{} 155 } 156 } 157 } 158 } 159 160 func (r *singleRule) needRecovery() bool { 161 curLen := len(r.records) 162 unUsedLen := len(r.notRecordsIndex) 163 usedLen := curLen - unUsedLen 164 if curLen < 2*r.estimated { 165 return false 166 } 167 if usedLen*2 < unUsedLen { 168 return true 169 } 170 return false 171 }