github.com/v2fly/v2ray-core/v4@v4.45.2/proxy/vmess/validator.go (about)

     1  //go:build !confonly
     2  // +build !confonly
     3  
     4  package vmess
     5  
     6  import (
     7  	"crypto/hmac"
     8  	"crypto/sha256"
     9  	"hash/crc64"
    10  	"strings"
    11  	"sync"
    12  	"sync/atomic"
    13  	"time"
    14  
    15  	"github.com/v2fly/v2ray-core/v4/common"
    16  	"github.com/v2fly/v2ray-core/v4/common/dice"
    17  	"github.com/v2fly/v2ray-core/v4/common/protocol"
    18  	"github.com/v2fly/v2ray-core/v4/common/serial"
    19  	"github.com/v2fly/v2ray-core/v4/common/task"
    20  	"github.com/v2fly/v2ray-core/v4/proxy/vmess/aead"
    21  )
    22  
    23  const (
    24  	updateInterval   = 10 * time.Second
    25  	cacheDurationSec = 120
    26  )
    27  
    28  type user struct {
    29  	user    protocol.MemoryUser
    30  	lastSec protocol.Timestamp
    31  }
    32  
    33  // TimedUserValidator is a user Validator based on time.
    34  type TimedUserValidator struct {
    35  	sync.RWMutex
    36  	users    []*user
    37  	userHash map[[16]byte]indexTimePair
    38  	hasher   protocol.IDHash
    39  	baseTime protocol.Timestamp
    40  	task     *task.Periodic
    41  
    42  	behaviorSeed  uint64
    43  	behaviorFused bool
    44  
    45  	aeadDecoderHolder *aead.AuthIDDecoderHolder
    46  
    47  	legacyWarnShown bool
    48  }
    49  
    50  type indexTimePair struct {
    51  	user    *user
    52  	timeInc uint32
    53  
    54  	taintedFuse *uint32
    55  }
    56  
    57  // NewTimedUserValidator creates a new TimedUserValidator.
    58  func NewTimedUserValidator(hasher protocol.IDHash) *TimedUserValidator {
    59  	tuv := &TimedUserValidator{
    60  		users:             make([]*user, 0, 16),
    61  		userHash:          make(map[[16]byte]indexTimePair, 1024),
    62  		hasher:            hasher,
    63  		baseTime:          protocol.Timestamp(time.Now().Unix() - cacheDurationSec*2),
    64  		aeadDecoderHolder: aead.NewAuthIDDecoderHolder(),
    65  	}
    66  	tuv.task = &task.Periodic{
    67  		Interval: updateInterval,
    68  		Execute: func() error {
    69  			tuv.updateUserHash()
    70  			return nil
    71  		},
    72  	}
    73  	common.Must(tuv.task.Start())
    74  	return tuv
    75  }
    76  
    77  func (v *TimedUserValidator) generateNewHashes(nowSec protocol.Timestamp, user *user) {
    78  	var hashValue [16]byte
    79  	genEndSec := nowSec + cacheDurationSec
    80  	genHashForID := func(id *protocol.ID) {
    81  		idHash := v.hasher(id.Bytes())
    82  		genBeginSec := user.lastSec
    83  		if genBeginSec < nowSec-cacheDurationSec {
    84  			genBeginSec = nowSec - cacheDurationSec
    85  		}
    86  		for ts := genBeginSec; ts <= genEndSec; ts++ {
    87  			common.Must2(serial.WriteUint64(idHash, uint64(ts)))
    88  			idHash.Sum(hashValue[:0])
    89  			idHash.Reset()
    90  
    91  			v.userHash[hashValue] = indexTimePair{
    92  				user:        user,
    93  				timeInc:     uint32(ts - v.baseTime),
    94  				taintedFuse: new(uint32),
    95  			}
    96  		}
    97  	}
    98  
    99  	account := user.user.Account.(*MemoryAccount)
   100  
   101  	genHashForID(account.ID)
   102  	for _, id := range account.AlterIDs {
   103  		genHashForID(id)
   104  	}
   105  	user.lastSec = genEndSec
   106  }
   107  
   108  func (v *TimedUserValidator) removeExpiredHashes(expire uint32) {
   109  	for key, pair := range v.userHash {
   110  		if pair.timeInc < expire {
   111  			delete(v.userHash, key)
   112  		}
   113  	}
   114  }
   115  
   116  func (v *TimedUserValidator) updateUserHash() {
   117  	now := time.Now()
   118  	nowSec := protocol.Timestamp(now.Unix())
   119  
   120  	v.Lock()
   121  	defer v.Unlock()
   122  
   123  	for _, user := range v.users {
   124  		v.generateNewHashes(nowSec, user)
   125  	}
   126  
   127  	expire := protocol.Timestamp(now.Unix() - cacheDurationSec)
   128  	if expire > v.baseTime {
   129  		v.removeExpiredHashes(uint32(expire - v.baseTime))
   130  	}
   131  }
   132  
   133  func (v *TimedUserValidator) Add(u *protocol.MemoryUser) error {
   134  	v.Lock()
   135  	defer v.Unlock()
   136  
   137  	nowSec := time.Now().Unix()
   138  
   139  	uu := &user{
   140  		user:    *u,
   141  		lastSec: protocol.Timestamp(nowSec - cacheDurationSec),
   142  	}
   143  	v.users = append(v.users, uu)
   144  	v.generateNewHashes(protocol.Timestamp(nowSec), uu)
   145  
   146  	account := uu.user.Account.(*MemoryAccount)
   147  	if !v.behaviorFused {
   148  		hashkdf := hmac.New(sha256.New, []byte("VMESSBSKDF"))
   149  		hashkdf.Write(account.ID.Bytes())
   150  		v.behaviorSeed = crc64.Update(v.behaviorSeed, crc64.MakeTable(crc64.ECMA), hashkdf.Sum(nil))
   151  	}
   152  
   153  	var cmdkeyfl [16]byte
   154  	copy(cmdkeyfl[:], account.ID.CmdKey())
   155  	v.aeadDecoderHolder.AddUser(cmdkeyfl, u)
   156  
   157  	return nil
   158  }
   159  
   160  func (v *TimedUserValidator) Get(userHash []byte) (*protocol.MemoryUser, protocol.Timestamp, bool, error) {
   161  	v.RLock()
   162  	defer v.RUnlock()
   163  
   164  	v.behaviorFused = true
   165  
   166  	var fixedSizeHash [16]byte
   167  	copy(fixedSizeHash[:], userHash)
   168  	pair, found := v.userHash[fixedSizeHash]
   169  	if found {
   170  		user := pair.user.user
   171  		if atomic.LoadUint32(pair.taintedFuse) == 0 {
   172  			return &user, protocol.Timestamp(pair.timeInc) + v.baseTime, true, nil
   173  		}
   174  		return nil, 0, false, ErrTainted
   175  	}
   176  	return nil, 0, false, ErrNotFound
   177  }
   178  
   179  func (v *TimedUserValidator) GetAEAD(userHash []byte) (*protocol.MemoryUser, bool, error) {
   180  	v.RLock()
   181  	defer v.RUnlock()
   182  
   183  	var userHashFL [16]byte
   184  	copy(userHashFL[:], userHash)
   185  
   186  	userd, err := v.aeadDecoderHolder.Match(userHashFL)
   187  	if err != nil {
   188  		return nil, false, err
   189  	}
   190  	return userd.(*protocol.MemoryUser), true, err
   191  }
   192  
   193  func (v *TimedUserValidator) Remove(email string) bool {
   194  	v.Lock()
   195  	defer v.Unlock()
   196  
   197  	email = strings.ToLower(email)
   198  	idx := -1
   199  	for i, u := range v.users {
   200  		if strings.EqualFold(u.user.Email, email) {
   201  			idx = i
   202  			var cmdkeyfl [16]byte
   203  			copy(cmdkeyfl[:], u.user.Account.(*MemoryAccount).ID.CmdKey())
   204  			v.aeadDecoderHolder.RemoveUser(cmdkeyfl)
   205  			break
   206  		}
   207  	}
   208  	if idx == -1 {
   209  		return false
   210  	}
   211  	ulen := len(v.users)
   212  
   213  	v.users[idx] = v.users[ulen-1]
   214  	v.users[ulen-1] = nil
   215  	v.users = v.users[:ulen-1]
   216  
   217  	return true
   218  }
   219  
   220  // Close implements common.Closable.
   221  func (v *TimedUserValidator) Close() error {
   222  	return v.task.Close()
   223  }
   224  
   225  func (v *TimedUserValidator) GetBehaviorSeed() uint64 {
   226  	v.Lock()
   227  	defer v.Unlock()
   228  
   229  	v.behaviorFused = true
   230  	if v.behaviorSeed == 0 {
   231  		v.behaviorSeed = dice.RollUint64()
   232  	}
   233  	return v.behaviorSeed
   234  }
   235  
   236  func (v *TimedUserValidator) BurnTaintFuse(userHash []byte) error {
   237  	v.RLock()
   238  	defer v.RUnlock()
   239  
   240  	var userHashFL [16]byte
   241  	copy(userHashFL[:], userHash)
   242  
   243  	pair, found := v.userHash[userHashFL]
   244  	if found {
   245  		if atomic.CompareAndSwapUint32(pair.taintedFuse, 0, 1) {
   246  			return nil
   247  		}
   248  		return ErrTainted
   249  	}
   250  	return ErrNotFound
   251  }
   252  
   253  /* ShouldShowLegacyWarn will return whether a Legacy Warning should be shown
   254  Not guaranteed to only return true once for every inbound, but it is okay.
   255  */
   256  func (v *TimedUserValidator) ShouldShowLegacyWarn() bool {
   257  	if v.legacyWarnShown {
   258  		return false
   259  	}
   260  	v.legacyWarnShown = true
   261  	return true
   262  }
   263  
   264  var ErrNotFound = newError("Not Found")
   265  
   266  var ErrTainted = newError("ErrTainted")