github.com/ipfans/trojan-go@v0.11.0/statistic/memory/memory.go (about) 1 package memory 2 3 import ( 4 "context" 5 "sync" 6 "sync/atomic" 7 "time" 8 9 "golang.org/x/time/rate" 10 11 "github.com/ipfans/trojan-go/common" 12 "github.com/ipfans/trojan-go/config" 13 "github.com/ipfans/trojan-go/log" 14 "github.com/ipfans/trojan-go/statistic" 15 ) 16 17 const Name = "MEMORY" 18 19 type User struct { 20 // WARNING: do not change the order of these fields. 21 // 64-bit fields that use `sync/atomic` package functions 22 // must be 64-bit aligned on 32-bit systems. 23 // Reference: https://github.com/golang/go/issues/599 24 // Solution: https://github.com/golang/go/issues/11891#issuecomment-433623786 25 sent uint64 26 recv uint64 27 lastSent uint64 28 lastRecv uint64 29 sendSpeed uint64 30 recvSpeed uint64 31 32 hash string 33 ipTable sync.Map 34 ipNum int32 35 maxIPNum int 36 limiterLock sync.RWMutex 37 sendLimiter *rate.Limiter 38 recvLimiter *rate.Limiter 39 ctx context.Context 40 cancel context.CancelFunc 41 } 42 43 func (u *User) Close() error { 44 u.ResetTraffic() 45 u.cancel() 46 return nil 47 } 48 49 func (u *User) AddIP(ip string) bool { 50 if u.maxIPNum <= 0 { 51 return true 52 } 53 _, found := u.ipTable.Load(ip) 54 if found { 55 return true 56 } 57 if int(u.ipNum)+1 > u.maxIPNum { 58 return false 59 } 60 u.ipTable.Store(ip, true) 61 atomic.AddInt32(&u.ipNum, 1) 62 return true 63 } 64 65 func (u *User) DelIP(ip string) bool { 66 if u.maxIPNum <= 0 { 67 return true 68 } 69 _, found := u.ipTable.Load(ip) 70 if !found { 71 return false 72 } 73 u.ipTable.Delete(ip) 74 atomic.AddInt32(&u.ipNum, -1) 75 return true 76 } 77 78 func (u *User) GetIP() int { 79 return int(u.ipNum) 80 } 81 82 func (u *User) SetIPLimit(n int) { 83 u.maxIPNum = n 84 } 85 86 func (u *User) GetIPLimit() int { 87 return u.maxIPNum 88 } 89 90 func (u *User) AddTraffic(sent, recv int) { 91 u.limiterLock.RLock() 92 defer u.limiterLock.RUnlock() 93 94 if u.sendLimiter != nil && sent >= 0 { 95 u.sendLimiter.WaitN(u.ctx, sent) 96 } else if u.recvLimiter != nil && recv >= 0 { 97 u.recvLimiter.WaitN(u.ctx, recv) 98 } 99 atomic.AddUint64(&u.sent, uint64(sent)) 100 atomic.AddUint64(&u.recv, uint64(recv)) 101 } 102 103 func (u *User) SetSpeedLimit(send, recv int) { 104 u.limiterLock.Lock() 105 defer u.limiterLock.Unlock() 106 107 if send <= 0 { 108 u.sendLimiter = nil 109 } else { 110 u.sendLimiter = rate.NewLimiter(rate.Limit(send), send*2) 111 } 112 if recv <= 0 { 113 u.recvLimiter = nil 114 } else { 115 u.recvLimiter = rate.NewLimiter(rate.Limit(recv), recv*2) 116 } 117 } 118 119 func (u *User) GetSpeedLimit() (send, recv int) { 120 u.limiterLock.RLock() 121 defer u.limiterLock.RUnlock() 122 123 if u.sendLimiter != nil { 124 send = int(u.sendLimiter.Limit()) 125 } 126 if u.recvLimiter != nil { 127 recv = int(u.recvLimiter.Limit()) 128 } 129 return 130 } 131 132 func (u *User) Hash() string { 133 return u.hash 134 } 135 136 func (u *User) SetTraffic(send, recv uint64) { 137 atomic.StoreUint64(&u.sent, send) 138 atomic.StoreUint64(&u.recv, recv) 139 } 140 141 func (u *User) GetTraffic() (uint64, uint64) { 142 return atomic.LoadUint64(&u.sent), atomic.LoadUint64(&u.recv) 143 } 144 145 func (u *User) ResetTraffic() (uint64, uint64) { 146 sent := atomic.SwapUint64(&u.sent, 0) 147 recv := atomic.SwapUint64(&u.recv, 0) 148 atomic.StoreUint64(&u.lastSent, 0) 149 atomic.StoreUint64(&u.lastRecv, 0) 150 return sent, recv 151 } 152 153 func (u *User) speedUpdater() { 154 ticker := time.NewTicker(time.Second) 155 for { 156 select { 157 case <-u.ctx.Done(): 158 return 159 case <-ticker.C: 160 sent, recv := u.GetTraffic() 161 atomic.StoreUint64(&u.sendSpeed, sent-u.lastSent) 162 atomic.StoreUint64(&u.recvSpeed, recv-u.lastRecv) 163 atomic.StoreUint64(&u.lastSent, sent) 164 atomic.StoreUint64(&u.lastRecv, recv) 165 } 166 } 167 } 168 169 func (u *User) GetSpeed() (uint64, uint64) { 170 return atomic.LoadUint64(&u.sendSpeed), atomic.LoadUint64(&u.recvSpeed) 171 } 172 173 type Authenticator struct { 174 users sync.Map 175 ctx context.Context 176 } 177 178 func (a *Authenticator) AuthUser(hash string) (bool, statistic.User) { 179 if user, found := a.users.Load(hash); found { 180 return true, user.(*User) 181 } 182 return false, nil 183 } 184 185 func (a *Authenticator) AddUser(hash string) error { 186 if _, found := a.users.Load(hash); found { 187 return common.NewError("hash " + hash + " is already exist") 188 } 189 ctx, cancel := context.WithCancel(a.ctx) 190 meter := &User{ 191 hash: hash, 192 ctx: ctx, 193 cancel: cancel, 194 } 195 go meter.speedUpdater() 196 a.users.Store(hash, meter) 197 return nil 198 } 199 200 func (a *Authenticator) DelUser(hash string) error { 201 meter, found := a.users.Load(hash) 202 if !found { 203 return common.NewError("hash " + hash + " not found") 204 } 205 meter.(*User).Close() 206 a.users.Delete(hash) 207 return nil 208 } 209 210 func (a *Authenticator) ListUsers() []statistic.User { 211 result := make([]statistic.User, 0) 212 a.users.Range(func(k, v interface{}) bool { 213 result = append(result, v.(*User)) 214 return true 215 }) 216 return result 217 } 218 219 func (a *Authenticator) Close() error { 220 return nil 221 } 222 223 func NewAuthenticator(ctx context.Context) (statistic.Authenticator, error) { 224 cfg := config.FromContext(ctx, Name).(*Config) 225 u := &Authenticator{ 226 ctx: ctx, 227 } 228 for _, password := range cfg.Passwords { 229 hash := common.SHA224String(password) 230 u.AddUser(hash) 231 } 232 log.Debug("memory authenticator created") 233 return u, nil 234 } 235 236 func init() { 237 statistic.RegisterAuthenticatorCreator(Name, NewAuthenticator) 238 }