github.com/Uhtred009/v2ray-core-1@v4.31.2+incompatible/proxy/vmess/validator.go (about)

     1  // +build !confonly
     2  
     3  package vmess
     4  
     5  import (
     6  	"crypto/hmac"
     7  	"crypto/sha256"
     8  	"hash"
     9  	"hash/crc64"
    10  	"strings"
    11  	"sync"
    12  	"sync/atomic"
    13  	"time"
    14  
    15  	"v2ray.com/core/common"
    16  	"v2ray.com/core/common/dice"
    17  	"v2ray.com/core/common/protocol"
    18  	"v2ray.com/core/common/serial"
    19  	"v2ray.com/core/common/task"
    20  	"v2ray.com/core/proxy/vmess/aead"
    21  )
    22  
    23  const (
    24  	updateInterval   = 10 * time.Second
    25  	cacheDurationSec = 120
    26  )
    27  
    28  type user struct {
    29  	user    protocol.MemoryUser
    30  	lastSec protocol.Timestamp
    31  }
    32  
    33  // TimedUserValidator is a user Validator based on time.
    34  type TimedUserValidator struct {
    35  	sync.RWMutex
    36  	users    []*user
    37  	userHash map[[16]byte]indexTimePair
    38  	hasher   protocol.IDHash
    39  	baseTime protocol.Timestamp
    40  	task     *task.Periodic
    41  
    42  	behaviorSeed  uint64
    43  	behaviorFused bool
    44  
    45  	aeadDecoderHolder *aead.AuthIDDecoderHolder
    46  }
    47  
    48  type indexTimePair struct {
    49  	user    *user
    50  	timeInc uint32
    51  
    52  	taintedFuse *uint32
    53  }
    54  
    55  // NewTimedUserValidator creates a new TimedUserValidator.
    56  func NewTimedUserValidator(hasher protocol.IDHash) *TimedUserValidator {
    57  	tuv := &TimedUserValidator{
    58  		users:             make([]*user, 0, 16),
    59  		userHash:          make(map[[16]byte]indexTimePair, 1024),
    60  		hasher:            hasher,
    61  		baseTime:          protocol.Timestamp(time.Now().Unix() - cacheDurationSec*2),
    62  		aeadDecoderHolder: aead.NewAuthIDDecoderHolder(),
    63  	}
    64  	tuv.task = &task.Periodic{
    65  		Interval: updateInterval,
    66  		Execute: func() error {
    67  			tuv.updateUserHash()
    68  			return nil
    69  		},
    70  	}
    71  	common.Must(tuv.task.Start())
    72  	return tuv
    73  }
    74  
    75  func (v *TimedUserValidator) generateNewHashes(nowSec protocol.Timestamp, user *user) {
    76  	var hashValue [16]byte
    77  	genEndSec := nowSec + cacheDurationSec
    78  	genHashForID := func(id *protocol.ID) {
    79  		idHash := v.hasher(id.Bytes())
    80  		genBeginSec := user.lastSec
    81  		if genBeginSec < nowSec-cacheDurationSec {
    82  			genBeginSec = nowSec - cacheDurationSec
    83  		}
    84  		for ts := genBeginSec; ts <= genEndSec; ts++ {
    85  			common.Must2(serial.WriteUint64(idHash, uint64(ts)))
    86  			idHash.Sum(hashValue[:0])
    87  			idHash.Reset()
    88  
    89  			v.userHash[hashValue] = indexTimePair{
    90  				user:        user,
    91  				timeInc:     uint32(ts - v.baseTime),
    92  				taintedFuse: new(uint32),
    93  			}
    94  		}
    95  	}
    96  
    97  	account := user.user.Account.(*MemoryAccount)
    98  
    99  	genHashForID(account.ID)
   100  	for _, id := range account.AlterIDs {
   101  		genHashForID(id)
   102  	}
   103  	user.lastSec = genEndSec
   104  }
   105  
   106  func (v *TimedUserValidator) removeExpiredHashes(expire uint32) {
   107  	for key, pair := range v.userHash {
   108  		if pair.timeInc < expire {
   109  			delete(v.userHash, key)
   110  		}
   111  	}
   112  }
   113  
   114  func (v *TimedUserValidator) updateUserHash() {
   115  	now := time.Now()
   116  	nowSec := protocol.Timestamp(now.Unix())
   117  	v.Lock()
   118  	defer v.Unlock()
   119  
   120  	for _, user := range v.users {
   121  		v.generateNewHashes(nowSec, user)
   122  	}
   123  
   124  	expire := protocol.Timestamp(now.Unix() - cacheDurationSec)
   125  	if expire > v.baseTime {
   126  		v.removeExpiredHashes(uint32(expire - v.baseTime))
   127  	}
   128  }
   129  
   130  func (v *TimedUserValidator) Add(u *protocol.MemoryUser) error {
   131  	v.Lock()
   132  	defer v.Unlock()
   133  
   134  	nowSec := time.Now().Unix()
   135  
   136  	uu := &user{
   137  		user:    *u,
   138  		lastSec: protocol.Timestamp(nowSec - cacheDurationSec),
   139  	}
   140  	v.users = append(v.users, uu)
   141  	v.generateNewHashes(protocol.Timestamp(nowSec), uu)
   142  
   143  	account := uu.user.Account.(*MemoryAccount)
   144  	if !v.behaviorFused {
   145  		hashkdf := hmac.New(func() hash.Hash { return sha256.New() }, []byte("VMESSBSKDF"))
   146  		hashkdf.Write(account.ID.Bytes())
   147  		v.behaviorSeed = crc64.Update(v.behaviorSeed, crc64.MakeTable(crc64.ECMA), hashkdf.Sum(nil))
   148  	}
   149  
   150  	var cmdkeyfl [16]byte
   151  	copy(cmdkeyfl[:], account.ID.CmdKey())
   152  	v.aeadDecoderHolder.AddUser(cmdkeyfl, u)
   153  
   154  	return nil
   155  }
   156  
   157  func (v *TimedUserValidator) Get(userHash []byte) (*protocol.MemoryUser, protocol.Timestamp, bool, error) {
   158  	defer v.RUnlock()
   159  	v.RLock()
   160  
   161  	v.behaviorFused = true
   162  
   163  	var fixedSizeHash [16]byte
   164  	copy(fixedSizeHash[:], userHash)
   165  	pair, found := v.userHash[fixedSizeHash]
   166  	if found {
   167  		user := pair.user.user
   168  		if atomic.LoadUint32(pair.taintedFuse) == 0 {
   169  			return &user, protocol.Timestamp(pair.timeInc) + v.baseTime, true, nil
   170  		}
   171  		return nil, 0, false, ErrTainted
   172  	}
   173  	return nil, 0, false, ErrNotFound
   174  }
   175  
   176  func (v *TimedUserValidator) GetAEAD(userHash []byte) (*protocol.MemoryUser, bool, error) {
   177  	defer v.RUnlock()
   178  	v.RLock()
   179  	var userHashFL [16]byte
   180  	copy(userHashFL[:], userHash)
   181  
   182  	userd, err := v.aeadDecoderHolder.Match(userHashFL)
   183  	if err != nil {
   184  		return nil, false, err
   185  	}
   186  	return userd.(*protocol.MemoryUser), true, err
   187  }
   188  
   189  func (v *TimedUserValidator) Remove(email string) bool {
   190  	v.Lock()
   191  	defer v.Unlock()
   192  
   193  	email = strings.ToLower(email)
   194  	idx := -1
   195  	for i, u := range v.users {
   196  		if strings.EqualFold(u.user.Email, email) {
   197  			idx = i
   198  			var cmdkeyfl [16]byte
   199  			copy(cmdkeyfl[:], u.user.Account.(*MemoryAccount).ID.CmdKey())
   200  			v.aeadDecoderHolder.RemoveUser(cmdkeyfl)
   201  			break
   202  		}
   203  	}
   204  	if idx == -1 {
   205  		return false
   206  	}
   207  	ulen := len(v.users)
   208  
   209  	v.users[idx] = v.users[ulen-1]
   210  	v.users[ulen-1] = nil
   211  	v.users = v.users[:ulen-1]
   212  
   213  	return true
   214  }
   215  
   216  // Close implements common.Closable.
   217  func (v *TimedUserValidator) Close() error {
   218  	return v.task.Close()
   219  }
   220  
   221  func (v *TimedUserValidator) GetBehaviorSeed() uint64 {
   222  	v.Lock()
   223  	defer v.Unlock()
   224  	v.behaviorFused = true
   225  	if v.behaviorSeed == 0 {
   226  		v.behaviorSeed = dice.RollUint64()
   227  	}
   228  	return v.behaviorSeed
   229  }
   230  
   231  func (v *TimedUserValidator) BurnTaintFuse(userHash []byte) error {
   232  	v.RLock()
   233  	defer v.RUnlock()
   234  	var userHashFL [16]byte
   235  	copy(userHashFL[:], userHash)
   236  
   237  	pair, found := v.userHash[userHashFL]
   238  	if found {
   239  		if atomic.CompareAndSwapUint32(pair.taintedFuse, 0, 1) {
   240  			return nil
   241  		}
   242  		return ErrTainted
   243  	}
   244  	return ErrNotFound
   245  }
   246  
   247  var ErrNotFound = newError("Not Found")
   248  
   249  var ErrTainted = newError("ErrTainted")