github.com/eagleql/xray-core@v1.4.4/proxy/vmess/validator.go (about) 1 package vmess 2 3 import ( 4 "crypto/hmac" 5 "crypto/sha256" 6 "hash/crc64" 7 "strings" 8 "sync" 9 "sync/atomic" 10 "time" 11 12 "github.com/eagleql/xray-core/common" 13 "github.com/eagleql/xray-core/common/dice" 14 "github.com/eagleql/xray-core/common/protocol" 15 "github.com/eagleql/xray-core/common/serial" 16 "github.com/eagleql/xray-core/common/task" 17 "github.com/eagleql/xray-core/proxy/vmess/aead" 18 ) 19 20 const ( 21 updateInterval = 10 * time.Second 22 cacheDurationSec = 120 23 ) 24 25 type user struct { 26 user protocol.MemoryUser 27 lastSec protocol.Timestamp 28 } 29 30 // TimedUserValidator is a user Validator based on time. 31 type TimedUserValidator struct { 32 sync.RWMutex 33 users []*user 34 userHash map[[16]byte]indexTimePair 35 hasher protocol.IDHash 36 baseTime protocol.Timestamp 37 task *task.Periodic 38 39 behaviorSeed uint64 40 behaviorFused bool 41 42 aeadDecoderHolder *aead.AuthIDDecoderHolder 43 } 44 45 type indexTimePair struct { 46 user *user 47 timeInc uint32 48 49 taintedFuse *uint32 50 } 51 52 // NewTimedUserValidator creates a new TimedUserValidator. 53 func NewTimedUserValidator(hasher protocol.IDHash) *TimedUserValidator { 54 tuv := &TimedUserValidator{ 55 users: make([]*user, 0, 16), 56 userHash: make(map[[16]byte]indexTimePair, 1024), 57 hasher: hasher, 58 baseTime: protocol.Timestamp(time.Now().Unix() - cacheDurationSec*2), 59 aeadDecoderHolder: aead.NewAuthIDDecoderHolder(), 60 } 61 tuv.task = &task.Periodic{ 62 Interval: updateInterval, 63 Execute: func() error { 64 tuv.updateUserHash() 65 return nil 66 }, 67 } 68 common.Must(tuv.task.Start()) 69 return tuv 70 } 71 72 func (v *TimedUserValidator) generateNewHashes(nowSec protocol.Timestamp, user *user) { 73 var hashValue [16]byte 74 genEndSec := nowSec + cacheDurationSec 75 genHashForID := func(id *protocol.ID) { 76 idHash := v.hasher(id.Bytes()) 77 genBeginSec := user.lastSec 78 if genBeginSec < nowSec-cacheDurationSec { 79 genBeginSec = nowSec - cacheDurationSec 80 } 81 for ts := genBeginSec; ts <= genEndSec; ts++ { 82 common.Must2(serial.WriteUint64(idHash, uint64(ts))) 83 idHash.Sum(hashValue[:0]) 84 idHash.Reset() 85 86 v.userHash[hashValue] = indexTimePair{ 87 user: user, 88 timeInc: uint32(ts - v.baseTime), 89 taintedFuse: new(uint32), 90 } 91 } 92 } 93 94 account := user.user.Account.(*MemoryAccount) 95 96 genHashForID(account.ID) 97 for _, id := range account.AlterIDs { 98 genHashForID(id) 99 } 100 user.lastSec = genEndSec 101 } 102 103 func (v *TimedUserValidator) removeExpiredHashes(expire uint32) { 104 for key, pair := range v.userHash { 105 if pair.timeInc < expire { 106 delete(v.userHash, key) 107 } 108 } 109 } 110 111 func (v *TimedUserValidator) updateUserHash() { 112 now := time.Now() 113 nowSec := protocol.Timestamp(now.Unix()) 114 115 v.Lock() 116 defer v.Unlock() 117 118 for _, user := range v.users { 119 v.generateNewHashes(nowSec, user) 120 } 121 122 expire := protocol.Timestamp(now.Unix() - cacheDurationSec) 123 if expire > v.baseTime { 124 v.removeExpiredHashes(uint32(expire - v.baseTime)) 125 } 126 } 127 128 func (v *TimedUserValidator) Add(u *protocol.MemoryUser) error { 129 v.Lock() 130 defer v.Unlock() 131 132 nowSec := time.Now().Unix() 133 134 uu := &user{ 135 user: *u, 136 lastSec: protocol.Timestamp(nowSec - cacheDurationSec), 137 } 138 v.users = append(v.users, uu) 139 v.generateNewHashes(protocol.Timestamp(nowSec), uu) 140 141 account := uu.user.Account.(*MemoryAccount) 142 if !v.behaviorFused { 143 hashkdf := hmac.New(sha256.New, []byte("VMESSBSKDF")) 144 hashkdf.Write(account.ID.Bytes()) 145 v.behaviorSeed = crc64.Update(v.behaviorSeed, crc64.MakeTable(crc64.ECMA), hashkdf.Sum(nil)) 146 } 147 148 var cmdkeyfl [16]byte 149 copy(cmdkeyfl[:], account.ID.CmdKey()) 150 v.aeadDecoderHolder.AddUser(cmdkeyfl, u) 151 152 return nil 153 } 154 155 func (v *TimedUserValidator) Get(userHash []byte) (*protocol.MemoryUser, protocol.Timestamp, bool, error) { 156 v.RLock() 157 defer v.RUnlock() 158 159 v.behaviorFused = true 160 161 var fixedSizeHash [16]byte 162 copy(fixedSizeHash[:], userHash) 163 pair, found := v.userHash[fixedSizeHash] 164 if found { 165 user := pair.user.user 166 if atomic.LoadUint32(pair.taintedFuse) == 0 { 167 return &user, protocol.Timestamp(pair.timeInc) + v.baseTime, true, nil 168 } 169 return nil, 0, false, ErrTainted 170 } 171 return nil, 0, false, ErrNotFound 172 } 173 174 func (v *TimedUserValidator) GetAEAD(userHash []byte) (*protocol.MemoryUser, bool, error) { 175 v.RLock() 176 defer v.RUnlock() 177 178 var userHashFL [16]byte 179 copy(userHashFL[:], userHash) 180 181 userd, err := v.aeadDecoderHolder.Match(userHashFL) 182 if err != nil { 183 return nil, false, err 184 } 185 return userd.(*protocol.MemoryUser), true, err 186 } 187 188 func (v *TimedUserValidator) Remove(email string) bool { 189 v.Lock() 190 defer v.Unlock() 191 192 email = strings.ToLower(email) 193 idx := -1 194 for i, u := range v.users { 195 if strings.EqualFold(u.user.Email, email) { 196 idx = i 197 var cmdkeyfl [16]byte 198 copy(cmdkeyfl[:], u.user.Account.(*MemoryAccount).ID.CmdKey()) 199 v.aeadDecoderHolder.RemoveUser(cmdkeyfl) 200 break 201 } 202 } 203 if idx == -1 { 204 return false 205 } 206 ulen := len(v.users) 207 208 v.users[idx] = v.users[ulen-1] 209 v.users[ulen-1] = nil 210 v.users = v.users[:ulen-1] 211 212 return true 213 } 214 215 // Close implements common.Closable. 216 func (v *TimedUserValidator) Close() error { 217 return v.task.Close() 218 } 219 220 func (v *TimedUserValidator) GetBehaviorSeed() uint64 { 221 v.Lock() 222 defer v.Unlock() 223 224 v.behaviorFused = true 225 if v.behaviorSeed == 0 { 226 v.behaviorSeed = dice.RollUint64() 227 } 228 return v.behaviorSeed 229 } 230 231 func (v *TimedUserValidator) BurnTaintFuse(userHash []byte) error { 232 v.RLock() 233 defer v.RUnlock() 234 235 var userHashFL [16]byte 236 copy(userHashFL[:], userHash) 237 238 pair, found := v.userHash[userHashFL] 239 if found { 240 if atomic.CompareAndSwapUint32(pair.taintedFuse, 0, 1) { 241 return nil 242 } 243 return ErrTainted 244 } 245 return ErrNotFound 246 } 247 248 var ErrNotFound = newError("Not Found") 249 250 var ErrTainted = newError("ErrTainted")