github.com/eagleql/xray-core@v1.4.4/proxy/vmess/validator.go (about)

     1  package vmess
     2  
     3  import (
     4  	"crypto/hmac"
     5  	"crypto/sha256"
     6  	"hash/crc64"
     7  	"strings"
     8  	"sync"
     9  	"sync/atomic"
    10  	"time"
    11  
    12  	"github.com/eagleql/xray-core/common"
    13  	"github.com/eagleql/xray-core/common/dice"
    14  	"github.com/eagleql/xray-core/common/protocol"
    15  	"github.com/eagleql/xray-core/common/serial"
    16  	"github.com/eagleql/xray-core/common/task"
    17  	"github.com/eagleql/xray-core/proxy/vmess/aead"
    18  )
    19  
    20  const (
    21  	updateInterval   = 10 * time.Second
    22  	cacheDurationSec = 120
    23  )
    24  
    25  type user struct {
    26  	user    protocol.MemoryUser
    27  	lastSec protocol.Timestamp
    28  }
    29  
    30  // TimedUserValidator is a user Validator based on time.
    31  type TimedUserValidator struct {
    32  	sync.RWMutex
    33  	users    []*user
    34  	userHash map[[16]byte]indexTimePair
    35  	hasher   protocol.IDHash
    36  	baseTime protocol.Timestamp
    37  	task     *task.Periodic
    38  
    39  	behaviorSeed  uint64
    40  	behaviorFused bool
    41  
    42  	aeadDecoderHolder *aead.AuthIDDecoderHolder
    43  }
    44  
    45  type indexTimePair struct {
    46  	user    *user
    47  	timeInc uint32
    48  
    49  	taintedFuse *uint32
    50  }
    51  
    52  // NewTimedUserValidator creates a new TimedUserValidator.
    53  func NewTimedUserValidator(hasher protocol.IDHash) *TimedUserValidator {
    54  	tuv := &TimedUserValidator{
    55  		users:             make([]*user, 0, 16),
    56  		userHash:          make(map[[16]byte]indexTimePair, 1024),
    57  		hasher:            hasher,
    58  		baseTime:          protocol.Timestamp(time.Now().Unix() - cacheDurationSec*2),
    59  		aeadDecoderHolder: aead.NewAuthIDDecoderHolder(),
    60  	}
    61  	tuv.task = &task.Periodic{
    62  		Interval: updateInterval,
    63  		Execute: func() error {
    64  			tuv.updateUserHash()
    65  			return nil
    66  		},
    67  	}
    68  	common.Must(tuv.task.Start())
    69  	return tuv
    70  }
    71  
    72  func (v *TimedUserValidator) generateNewHashes(nowSec protocol.Timestamp, user *user) {
    73  	var hashValue [16]byte
    74  	genEndSec := nowSec + cacheDurationSec
    75  	genHashForID := func(id *protocol.ID) {
    76  		idHash := v.hasher(id.Bytes())
    77  		genBeginSec := user.lastSec
    78  		if genBeginSec < nowSec-cacheDurationSec {
    79  			genBeginSec = nowSec - cacheDurationSec
    80  		}
    81  		for ts := genBeginSec; ts <= genEndSec; ts++ {
    82  			common.Must2(serial.WriteUint64(idHash, uint64(ts)))
    83  			idHash.Sum(hashValue[:0])
    84  			idHash.Reset()
    85  
    86  			v.userHash[hashValue] = indexTimePair{
    87  				user:        user,
    88  				timeInc:     uint32(ts - v.baseTime),
    89  				taintedFuse: new(uint32),
    90  			}
    91  		}
    92  	}
    93  
    94  	account := user.user.Account.(*MemoryAccount)
    95  
    96  	genHashForID(account.ID)
    97  	for _, id := range account.AlterIDs {
    98  		genHashForID(id)
    99  	}
   100  	user.lastSec = genEndSec
   101  }
   102  
   103  func (v *TimedUserValidator) removeExpiredHashes(expire uint32) {
   104  	for key, pair := range v.userHash {
   105  		if pair.timeInc < expire {
   106  			delete(v.userHash, key)
   107  		}
   108  	}
   109  }
   110  
   111  func (v *TimedUserValidator) updateUserHash() {
   112  	now := time.Now()
   113  	nowSec := protocol.Timestamp(now.Unix())
   114  
   115  	v.Lock()
   116  	defer v.Unlock()
   117  
   118  	for _, user := range v.users {
   119  		v.generateNewHashes(nowSec, user)
   120  	}
   121  
   122  	expire := protocol.Timestamp(now.Unix() - cacheDurationSec)
   123  	if expire > v.baseTime {
   124  		v.removeExpiredHashes(uint32(expire - v.baseTime))
   125  	}
   126  }
   127  
   128  func (v *TimedUserValidator) Add(u *protocol.MemoryUser) error {
   129  	v.Lock()
   130  	defer v.Unlock()
   131  
   132  	nowSec := time.Now().Unix()
   133  
   134  	uu := &user{
   135  		user:    *u,
   136  		lastSec: protocol.Timestamp(nowSec - cacheDurationSec),
   137  	}
   138  	v.users = append(v.users, uu)
   139  	v.generateNewHashes(protocol.Timestamp(nowSec), uu)
   140  
   141  	account := uu.user.Account.(*MemoryAccount)
   142  	if !v.behaviorFused {
   143  		hashkdf := hmac.New(sha256.New, []byte("VMESSBSKDF"))
   144  		hashkdf.Write(account.ID.Bytes())
   145  		v.behaviorSeed = crc64.Update(v.behaviorSeed, crc64.MakeTable(crc64.ECMA), hashkdf.Sum(nil))
   146  	}
   147  
   148  	var cmdkeyfl [16]byte
   149  	copy(cmdkeyfl[:], account.ID.CmdKey())
   150  	v.aeadDecoderHolder.AddUser(cmdkeyfl, u)
   151  
   152  	return nil
   153  }
   154  
   155  func (v *TimedUserValidator) Get(userHash []byte) (*protocol.MemoryUser, protocol.Timestamp, bool, error) {
   156  	v.RLock()
   157  	defer v.RUnlock()
   158  
   159  	v.behaviorFused = true
   160  
   161  	var fixedSizeHash [16]byte
   162  	copy(fixedSizeHash[:], userHash)
   163  	pair, found := v.userHash[fixedSizeHash]
   164  	if found {
   165  		user := pair.user.user
   166  		if atomic.LoadUint32(pair.taintedFuse) == 0 {
   167  			return &user, protocol.Timestamp(pair.timeInc) + v.baseTime, true, nil
   168  		}
   169  		return nil, 0, false, ErrTainted
   170  	}
   171  	return nil, 0, false, ErrNotFound
   172  }
   173  
   174  func (v *TimedUserValidator) GetAEAD(userHash []byte) (*protocol.MemoryUser, bool, error) {
   175  	v.RLock()
   176  	defer v.RUnlock()
   177  
   178  	var userHashFL [16]byte
   179  	copy(userHashFL[:], userHash)
   180  
   181  	userd, err := v.aeadDecoderHolder.Match(userHashFL)
   182  	if err != nil {
   183  		return nil, false, err
   184  	}
   185  	return userd.(*protocol.MemoryUser), true, err
   186  }
   187  
   188  func (v *TimedUserValidator) Remove(email string) bool {
   189  	v.Lock()
   190  	defer v.Unlock()
   191  
   192  	email = strings.ToLower(email)
   193  	idx := -1
   194  	for i, u := range v.users {
   195  		if strings.EqualFold(u.user.Email, email) {
   196  			idx = i
   197  			var cmdkeyfl [16]byte
   198  			copy(cmdkeyfl[:], u.user.Account.(*MemoryAccount).ID.CmdKey())
   199  			v.aeadDecoderHolder.RemoveUser(cmdkeyfl)
   200  			break
   201  		}
   202  	}
   203  	if idx == -1 {
   204  		return false
   205  	}
   206  	ulen := len(v.users)
   207  
   208  	v.users[idx] = v.users[ulen-1]
   209  	v.users[ulen-1] = nil
   210  	v.users = v.users[:ulen-1]
   211  
   212  	return true
   213  }
   214  
   215  // Close implements common.Closable.
   216  func (v *TimedUserValidator) Close() error {
   217  	return v.task.Close()
   218  }
   219  
   220  func (v *TimedUserValidator) GetBehaviorSeed() uint64 {
   221  	v.Lock()
   222  	defer v.Unlock()
   223  
   224  	v.behaviorFused = true
   225  	if v.behaviorSeed == 0 {
   226  		v.behaviorSeed = dice.RollUint64()
   227  	}
   228  	return v.behaviorSeed
   229  }
   230  
   231  func (v *TimedUserValidator) BurnTaintFuse(userHash []byte) error {
   232  	v.RLock()
   233  	defer v.RUnlock()
   234  
   235  	var userHashFL [16]byte
   236  	copy(userHashFL[:], userHash)
   237  
   238  	pair, found := v.userHash[userHashFL]
   239  	if found {
   240  		if atomic.CompareAndSwapUint32(pair.taintedFuse, 0, 1) {
   241  			return nil
   242  		}
   243  		return ErrTainted
   244  	}
   245  	return ErrNotFound
   246  }
   247  
   248  var ErrNotFound = newError("Not Found")
   249  
   250  var ErrTainted = newError("ErrTainted")