github.com/xraypb/Xray-core@v1.8.1/proxy/vmess/validator.go (about)

     1  package vmess
     2  
     3  import (
     4  	"crypto/hmac"
     5  	"crypto/sha256"
     6  	"hash/crc64"
     7  	"strings"
     8  	"sync"
     9  	"sync/atomic"
    10  	"time"
    11  
    12  	"github.com/xraypb/Xray-core/common"
    13  	"github.com/xraypb/Xray-core/common/dice"
    14  	"github.com/xraypb/Xray-core/common/protocol"
    15  	"github.com/xraypb/Xray-core/common/serial"
    16  	"github.com/xraypb/Xray-core/common/task"
    17  	"github.com/xraypb/Xray-core/proxy/vmess/aead"
    18  )
    19  
    20  const (
    21  	updateInterval   = 10 * time.Second
    22  	cacheDurationSec = 120
    23  )
    24  
    25  type user struct {
    26  	user    protocol.MemoryUser
    27  	lastSec protocol.Timestamp
    28  }
    29  
    30  // TimedUserValidator is a user Validator based on time.
    31  type TimedUserValidator struct {
    32  	sync.RWMutex
    33  	users    []*user
    34  	userHash map[[16]byte]indexTimePair
    35  	hasher   protocol.IDHash
    36  	baseTime protocol.Timestamp
    37  	task     *task.Periodic
    38  
    39  	behaviorSeed  uint64
    40  	behaviorFused bool
    41  
    42  	aeadDecoderHolder *aead.AuthIDDecoderHolder
    43  
    44  	legacyWarnShown bool
    45  }
    46  
    47  type indexTimePair struct {
    48  	user    *user
    49  	timeInc uint32
    50  
    51  	taintedFuse *uint32
    52  }
    53  
    54  // NewTimedUserValidator creates a new TimedUserValidator.
    55  func NewTimedUserValidator(hasher protocol.IDHash) *TimedUserValidator {
    56  	tuv := &TimedUserValidator{
    57  		users:             make([]*user, 0, 16),
    58  		userHash:          make(map[[16]byte]indexTimePair, 1024),
    59  		hasher:            hasher,
    60  		baseTime:          protocol.Timestamp(time.Now().Unix() - cacheDurationSec*2),
    61  		aeadDecoderHolder: aead.NewAuthIDDecoderHolder(),
    62  	}
    63  	tuv.task = &task.Periodic{
    64  		Interval: updateInterval,
    65  		Execute: func() error {
    66  			tuv.updateUserHash()
    67  			return nil
    68  		},
    69  	}
    70  	common.Must(tuv.task.Start())
    71  	return tuv
    72  }
    73  
    74  // visible for testing
    75  func (v *TimedUserValidator) GetBaseTime() protocol.Timestamp {
    76  	return v.baseTime
    77  }
    78  
    79  func (v *TimedUserValidator) generateNewHashes(nowSec protocol.Timestamp, user *user) {
    80  	var hashValue [16]byte
    81  	genEndSec := nowSec + cacheDurationSec
    82  	genHashForID := func(id *protocol.ID) {
    83  		idHash := v.hasher(id.Bytes())
    84  		genBeginSec := user.lastSec
    85  		if genBeginSec < nowSec-cacheDurationSec {
    86  			genBeginSec = nowSec - cacheDurationSec
    87  		}
    88  		for ts := genBeginSec; ts <= genEndSec; ts++ {
    89  			common.Must2(serial.WriteUint64(idHash, uint64(ts)))
    90  			idHash.Sum(hashValue[:0])
    91  			idHash.Reset()
    92  
    93  			v.userHash[hashValue] = indexTimePair{
    94  				user:        user,
    95  				timeInc:     uint32(ts - v.baseTime),
    96  				taintedFuse: new(uint32),
    97  			}
    98  		}
    99  	}
   100  
   101  	account := user.user.Account.(*MemoryAccount)
   102  
   103  	genHashForID(account.ID)
   104  	for _, id := range account.AlterIDs {
   105  		genHashForID(id)
   106  	}
   107  	user.lastSec = genEndSec
   108  }
   109  
   110  func (v *TimedUserValidator) removeExpiredHashes(expire uint32) {
   111  	for key, pair := range v.userHash {
   112  		if pair.timeInc < expire {
   113  			delete(v.userHash, key)
   114  		}
   115  	}
   116  }
   117  
   118  func (v *TimedUserValidator) updateUserHash() {
   119  	now := time.Now()
   120  	nowSec := protocol.Timestamp(now.Unix())
   121  
   122  	v.Lock()
   123  	defer v.Unlock()
   124  
   125  	for _, user := range v.users {
   126  		v.generateNewHashes(nowSec, user)
   127  	}
   128  
   129  	expire := protocol.Timestamp(now.Unix() - cacheDurationSec)
   130  	if expire > v.baseTime {
   131  		v.removeExpiredHashes(uint32(expire - v.baseTime))
   132  	}
   133  }
   134  
   135  func (v *TimedUserValidator) Add(u *protocol.MemoryUser) error {
   136  	v.Lock()
   137  	defer v.Unlock()
   138  
   139  	nowSec := time.Now().Unix()
   140  
   141  	uu := &user{
   142  		user:    *u,
   143  		lastSec: protocol.Timestamp(nowSec - cacheDurationSec),
   144  	}
   145  	v.users = append(v.users, uu)
   146  	v.generateNewHashes(protocol.Timestamp(nowSec), uu)
   147  
   148  	account := uu.user.Account.(*MemoryAccount)
   149  	if !v.behaviorFused {
   150  		hashkdf := hmac.New(sha256.New, []byte("VMESSBSKDF"))
   151  		hashkdf.Write(account.ID.Bytes())
   152  		v.behaviorSeed = crc64.Update(v.behaviorSeed, crc64.MakeTable(crc64.ECMA), hashkdf.Sum(nil))
   153  	}
   154  
   155  	var cmdkeyfl [16]byte
   156  	copy(cmdkeyfl[:], account.ID.CmdKey())
   157  	v.aeadDecoderHolder.AddUser(cmdkeyfl, u)
   158  
   159  	return nil
   160  }
   161  
   162  func (v *TimedUserValidator) Get(userHash []byte) (*protocol.MemoryUser, protocol.Timestamp, bool, error) {
   163  	v.RLock()
   164  	defer v.RUnlock()
   165  
   166  	v.behaviorFused = true
   167  
   168  	var fixedSizeHash [16]byte
   169  	copy(fixedSizeHash[:], userHash)
   170  	pair, found := v.userHash[fixedSizeHash]
   171  	if found {
   172  		user := pair.user.user
   173  		if atomic.LoadUint32(pair.taintedFuse) == 0 {
   174  			return &user, protocol.Timestamp(pair.timeInc) + v.baseTime, true, nil
   175  		}
   176  		return nil, 0, false, ErrTainted
   177  	}
   178  	return nil, 0, false, ErrNotFound
   179  }
   180  
   181  func (v *TimedUserValidator) GetAEAD(userHash []byte) (*protocol.MemoryUser, bool, error) {
   182  	v.RLock()
   183  	defer v.RUnlock()
   184  
   185  	var userHashFL [16]byte
   186  	copy(userHashFL[:], userHash)
   187  
   188  	userd, err := v.aeadDecoderHolder.Match(userHashFL)
   189  	if err != nil {
   190  		return nil, false, err
   191  	}
   192  	return userd.(*protocol.MemoryUser), true, err
   193  }
   194  
   195  func (v *TimedUserValidator) Remove(email string) bool {
   196  	v.Lock()
   197  	defer v.Unlock()
   198  
   199  	email = strings.ToLower(email)
   200  	idx := -1
   201  	for i, u := range v.users {
   202  		if strings.EqualFold(u.user.Email, email) {
   203  			idx = i
   204  			var cmdkeyfl [16]byte
   205  			copy(cmdkeyfl[:], u.user.Account.(*MemoryAccount).ID.CmdKey())
   206  			v.aeadDecoderHolder.RemoveUser(cmdkeyfl)
   207  			break
   208  		}
   209  	}
   210  	if idx == -1 {
   211  		return false
   212  	}
   213  	ulen := len(v.users)
   214  
   215  	v.users[idx] = v.users[ulen-1]
   216  	v.users[ulen-1] = nil
   217  	v.users = v.users[:ulen-1]
   218  
   219  	return true
   220  }
   221  
   222  // Close implements common.Closable.
   223  func (v *TimedUserValidator) Close() error {
   224  	return v.task.Close()
   225  }
   226  
   227  func (v *TimedUserValidator) GetBehaviorSeed() uint64 {
   228  	v.Lock()
   229  	defer v.Unlock()
   230  
   231  	v.behaviorFused = true
   232  	if v.behaviorSeed == 0 {
   233  		v.behaviorSeed = dice.RollUint64()
   234  	}
   235  	return v.behaviorSeed
   236  }
   237  
   238  func (v *TimedUserValidator) BurnTaintFuse(userHash []byte) error {
   239  	v.RLock()
   240  	defer v.RUnlock()
   241  
   242  	var userHashFL [16]byte
   243  	copy(userHashFL[:], userHash)
   244  
   245  	pair, found := v.userHash[userHashFL]
   246  	if found {
   247  		if atomic.CompareAndSwapUint32(pair.taintedFuse, 0, 1) {
   248  			return nil
   249  		}
   250  		return ErrTainted
   251  	}
   252  	return ErrNotFound
   253  }
   254  
   255  /*
   256  	ShouldShowLegacyWarn will return whether a Legacy Warning should be shown
   257  
   258  Not guaranteed to only return true once for every inbound, but it is okay.
   259  */
   260  func (v *TimedUserValidator) ShouldShowLegacyWarn() bool {
   261  	if v.legacyWarnShown {
   262  		return false
   263  	}
   264  	v.legacyWarnShown = true
   265  	return true
   266  }
   267  
   268  var ErrNotFound = newError("Not Found")
   269  
   270  var ErrTainted = newError("ErrTainted")