github.com/ipfans/trojan-go@v0.11.0/statistic/memory/memory.go (about)

     1  package memory
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  	"sync/atomic"
     7  	"time"
     8  
     9  	"golang.org/x/time/rate"
    10  
    11  	"github.com/ipfans/trojan-go/common"
    12  	"github.com/ipfans/trojan-go/config"
    13  	"github.com/ipfans/trojan-go/log"
    14  	"github.com/ipfans/trojan-go/statistic"
    15  )
    16  
    17  const Name = "MEMORY"
    18  
    19  type User struct {
    20  	// WARNING: do not change the order of these fields.
    21  	// 64-bit fields that use `sync/atomic` package functions
    22  	// must be 64-bit aligned on 32-bit systems.
    23  	// Reference: https://github.com/golang/go/issues/599
    24  	// Solution: https://github.com/golang/go/issues/11891#issuecomment-433623786
    25  	sent      uint64
    26  	recv      uint64
    27  	lastSent  uint64
    28  	lastRecv  uint64
    29  	sendSpeed uint64
    30  	recvSpeed uint64
    31  
    32  	hash        string
    33  	ipTable     sync.Map
    34  	ipNum       int32
    35  	maxIPNum    int
    36  	limiterLock sync.RWMutex
    37  	sendLimiter *rate.Limiter
    38  	recvLimiter *rate.Limiter
    39  	ctx         context.Context
    40  	cancel      context.CancelFunc
    41  }
    42  
    43  func (u *User) Close() error {
    44  	u.ResetTraffic()
    45  	u.cancel()
    46  	return nil
    47  }
    48  
    49  func (u *User) AddIP(ip string) bool {
    50  	if u.maxIPNum <= 0 {
    51  		return true
    52  	}
    53  	_, found := u.ipTable.Load(ip)
    54  	if found {
    55  		return true
    56  	}
    57  	if int(u.ipNum)+1 > u.maxIPNum {
    58  		return false
    59  	}
    60  	u.ipTable.Store(ip, true)
    61  	atomic.AddInt32(&u.ipNum, 1)
    62  	return true
    63  }
    64  
    65  func (u *User) DelIP(ip string) bool {
    66  	if u.maxIPNum <= 0 {
    67  		return true
    68  	}
    69  	_, found := u.ipTable.Load(ip)
    70  	if !found {
    71  		return false
    72  	}
    73  	u.ipTable.Delete(ip)
    74  	atomic.AddInt32(&u.ipNum, -1)
    75  	return true
    76  }
    77  
    78  func (u *User) GetIP() int {
    79  	return int(u.ipNum)
    80  }
    81  
    82  func (u *User) SetIPLimit(n int) {
    83  	u.maxIPNum = n
    84  }
    85  
    86  func (u *User) GetIPLimit() int {
    87  	return u.maxIPNum
    88  }
    89  
    90  func (u *User) AddTraffic(sent, recv int) {
    91  	u.limiterLock.RLock()
    92  	defer u.limiterLock.RUnlock()
    93  
    94  	if u.sendLimiter != nil && sent >= 0 {
    95  		u.sendLimiter.WaitN(u.ctx, sent)
    96  	} else if u.recvLimiter != nil && recv >= 0 {
    97  		u.recvLimiter.WaitN(u.ctx, recv)
    98  	}
    99  	atomic.AddUint64(&u.sent, uint64(sent))
   100  	atomic.AddUint64(&u.recv, uint64(recv))
   101  }
   102  
   103  func (u *User) SetSpeedLimit(send, recv int) {
   104  	u.limiterLock.Lock()
   105  	defer u.limiterLock.Unlock()
   106  
   107  	if send <= 0 {
   108  		u.sendLimiter = nil
   109  	} else {
   110  		u.sendLimiter = rate.NewLimiter(rate.Limit(send), send*2)
   111  	}
   112  	if recv <= 0 {
   113  		u.recvLimiter = nil
   114  	} else {
   115  		u.recvLimiter = rate.NewLimiter(rate.Limit(recv), recv*2)
   116  	}
   117  }
   118  
   119  func (u *User) GetSpeedLimit() (send, recv int) {
   120  	u.limiterLock.RLock()
   121  	defer u.limiterLock.RUnlock()
   122  
   123  	if u.sendLimiter != nil {
   124  		send = int(u.sendLimiter.Limit())
   125  	}
   126  	if u.recvLimiter != nil {
   127  		recv = int(u.recvLimiter.Limit())
   128  	}
   129  	return
   130  }
   131  
   132  func (u *User) Hash() string {
   133  	return u.hash
   134  }
   135  
   136  func (u *User) SetTraffic(send, recv uint64) {
   137  	atomic.StoreUint64(&u.sent, send)
   138  	atomic.StoreUint64(&u.recv, recv)
   139  }
   140  
   141  func (u *User) GetTraffic() (uint64, uint64) {
   142  	return atomic.LoadUint64(&u.sent), atomic.LoadUint64(&u.recv)
   143  }
   144  
   145  func (u *User) ResetTraffic() (uint64, uint64) {
   146  	sent := atomic.SwapUint64(&u.sent, 0)
   147  	recv := atomic.SwapUint64(&u.recv, 0)
   148  	atomic.StoreUint64(&u.lastSent, 0)
   149  	atomic.StoreUint64(&u.lastRecv, 0)
   150  	return sent, recv
   151  }
   152  
   153  func (u *User) speedUpdater() {
   154  	ticker := time.NewTicker(time.Second)
   155  	for {
   156  		select {
   157  		case <-u.ctx.Done():
   158  			return
   159  		case <-ticker.C:
   160  			sent, recv := u.GetTraffic()
   161  			atomic.StoreUint64(&u.sendSpeed, sent-u.lastSent)
   162  			atomic.StoreUint64(&u.recvSpeed, recv-u.lastRecv)
   163  			atomic.StoreUint64(&u.lastSent, sent)
   164  			atomic.StoreUint64(&u.lastRecv, recv)
   165  		}
   166  	}
   167  }
   168  
   169  func (u *User) GetSpeed() (uint64, uint64) {
   170  	return atomic.LoadUint64(&u.sendSpeed), atomic.LoadUint64(&u.recvSpeed)
   171  }
   172  
   173  type Authenticator struct {
   174  	users sync.Map
   175  	ctx   context.Context
   176  }
   177  
   178  func (a *Authenticator) AuthUser(hash string) (bool, statistic.User) {
   179  	if user, found := a.users.Load(hash); found {
   180  		return true, user.(*User)
   181  	}
   182  	return false, nil
   183  }
   184  
   185  func (a *Authenticator) AddUser(hash string) error {
   186  	if _, found := a.users.Load(hash); found {
   187  		return common.NewError("hash " + hash + " is already exist")
   188  	}
   189  	ctx, cancel := context.WithCancel(a.ctx)
   190  	meter := &User{
   191  		hash:   hash,
   192  		ctx:    ctx,
   193  		cancel: cancel,
   194  	}
   195  	go meter.speedUpdater()
   196  	a.users.Store(hash, meter)
   197  	return nil
   198  }
   199  
   200  func (a *Authenticator) DelUser(hash string) error {
   201  	meter, found := a.users.Load(hash)
   202  	if !found {
   203  		return common.NewError("hash " + hash + " not found")
   204  	}
   205  	meter.(*User).Close()
   206  	a.users.Delete(hash)
   207  	return nil
   208  }
   209  
   210  func (a *Authenticator) ListUsers() []statistic.User {
   211  	result := make([]statistic.User, 0)
   212  	a.users.Range(func(k, v interface{}) bool {
   213  		result = append(result, v.(*User))
   214  		return true
   215  	})
   216  	return result
   217  }
   218  
   219  func (a *Authenticator) Close() error {
   220  	return nil
   221  }
   222  
   223  func NewAuthenticator(ctx context.Context) (statistic.Authenticator, error) {
   224  	cfg := config.FromContext(ctx, Name).(*Config)
   225  	u := &Authenticator{
   226  		ctx: ctx,
   227  	}
   228  	for _, password := range cfg.Passwords {
   229  		hash := common.SHA224String(password)
   230  		u.AddUser(hash)
   231  	}
   232  	log.Debug("memory authenticator created")
   233  	return u, nil
   234  }
   235  
   236  func init() {
   237  	statistic.RegisterAuthenticatorCreator(Name, NewAuthenticator)
   238  }