github.com/imgk/caddy-trojan@v0.0.0-20221206043256-2631719e16c8/app/upstream.go (about)

     1  package app
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"strings"
     8  	"sync"
     9  
    10  	"github.com/caddyserver/caddy/v2"
    11  	"github.com/caddyserver/certmagic"
    12  	"go.uber.org/zap"
    13  
    14  	"github.com/imgk/caddy-trojan/trojan"
    15  	"github.com/imgk/caddy-trojan/utils"
    16  )
    17  
    18  func init() {
    19  	caddy.RegisterModule(CaddyUpstream{})
    20  	caddy.RegisterModule(MemoryUpstream{})
    21  }
    22  
    23  // Upstream is ...
    24  type Upstream interface {
    25  	// Add is ...
    26  	Add(string) error
    27  	// Delete is ...
    28  	Delete(string) error
    29  	// Range is ...
    30  	Range(func(string, int64, int64))
    31  	// Validate is ...
    32  	Validate(string) bool
    33  	// Consume is ...
    34  	Consume(string, int64, int64) error
    35  }
    36  
    37  // TaskType is ...
    38  type TaskType int
    39  
    40  const (
    41  	TaskAdd TaskType = iota
    42  	TaskDelete
    43  	TaskConsume
    44  )
    45  
    46  // Task is ...
    47  type Task struct {
    48  	Type  TaskType
    49  	Value struct {
    50  		Password string
    51  		Key      string
    52  		Traffic
    53  	}
    54  }
    55  
    56  // MemoryUpstream is ...
    57  type MemoryUpstream struct {
    58  	// UpstreamRaw is ...
    59  	UpstreamRaw json.RawMessage `json:"persist" caddy:"namespace=trojan.upstreams inline_key=upstream"`
    60  
    61  	ch chan Task
    62  	up Upstream
    63  
    64  	mu sync.RWMutex
    65  	mm map[string]Traffic
    66  }
    67  
    68  // CaddyModule is ...
    69  func (MemoryUpstream) CaddyModule() caddy.ModuleInfo {
    70  	return caddy.ModuleInfo{
    71  		ID:  "trojan.upstreams.memory",
    72  		New: func() caddy.Module { return new(MemoryUpstream) },
    73  	}
    74  }
    75  
    76  // Provision is ...
    77  func (u *MemoryUpstream) Provision(ctx caddy.Context) error {
    78  	u.mm = make(map[string]Traffic)
    79  
    80  	if u.UpstreamRaw == nil {
    81  		return nil
    82  	}
    83  
    84  	mod, err := ctx.LoadModule(u, "UpstreamRaw")
    85  	if err != nil {
    86  		return err
    87  	}
    88  	up := mod.(Upstream)
    89  
    90  	up.Range(func(k string, nr, nw int64) {
    91  		u.AddKey(k)
    92  		u.Consume(k, nr, nw)
    93  	})
    94  
    95  	u.up = up
    96  	u.ch = make(chan Task, 16)
    97  
    98  	go func(up Upstream, ch chan Task) {
    99  		for {
   100  			t, ok := <-ch
   101  			if !ok {
   102  				break
   103  			}
   104  			switch t.Type {
   105  			case TaskAdd:
   106  				up.Add(t.Value.Password)
   107  			case TaskDelete:
   108  				up.Delete(t.Value.Password)
   109  			case TaskConsume:
   110  				up.Consume(t.Value.Key, t.Value.Up, t.Value.Down)
   111  			default:
   112  			}
   113  		}
   114  	}(u.up, u.ch)
   115  
   116  	return nil
   117  }
   118  
   119  // Cleanup is ...
   120  func (u *MemoryUpstream) Cleanup() error {
   121  	close(u.ch)
   122  	return nil
   123  }
   124  
   125  // Add is ...
   126  func (u *MemoryUpstream) Add(s string) error {
   127  	b := [trojan.HeaderLen]byte{}
   128  	trojan.GenKey(s, b[:])
   129  
   130  	u.AddKey(string(b[:]))
   131  
   132  	if u.up == nil {
   133  		return nil
   134  	}
   135  
   136  	t := Task{Type: TaskAdd}
   137  	t.Value.Password = s
   138  	u.ch <- t
   139  	return nil
   140  }
   141  
   142  // AddKey is ...
   143  func (u *MemoryUpstream) AddKey(key string) {
   144  	u.mu.Lock()
   145  	u.mm[key] = Traffic{
   146  		Up:   0,
   147  		Down: 0,
   148  	}
   149  	u.mu.Unlock()
   150  }
   151  
   152  // Delete is ...
   153  func (u *MemoryUpstream) Delete(s string) error {
   154  	b := [trojan.HeaderLen]byte{}
   155  	trojan.GenKey(s, b[:])
   156  	key := utils.ByteSliceToString(b[:])
   157  	u.mu.Lock()
   158  	delete(u.mm, key)
   159  	u.mu.Unlock()
   160  
   161  	if u.up == nil {
   162  		return nil
   163  	}
   164  
   165  	t := Task{Type: TaskDelete}
   166  	t.Value.Password = s
   167  	u.ch <- t
   168  	return nil
   169  }
   170  
   171  // Range is ...
   172  func (u *MemoryUpstream) Range(fn func(string, int64, int64)) {
   173  	u.mu.RLock()
   174  	for k, v := range u.mm {
   175  		fn(k, v.Up, v.Down)
   176  	}
   177  	u.mu.RUnlock()
   178  }
   179  
   180  // Validate is ...
   181  func (u *MemoryUpstream) Validate(k string) bool {
   182  	u.mu.RLock()
   183  	_, ok := u.mm[k]
   184  	u.mu.RUnlock()
   185  	return ok
   186  }
   187  
   188  // Consume is ...
   189  func (u *MemoryUpstream) Consume(k string, nr, nw int64) error {
   190  	u.mu.Lock()
   191  	traffic := u.mm[k]
   192  	traffic.Up += nr
   193  	traffic.Down += nw
   194  	u.mm[k] = traffic
   195  	u.mu.Unlock()
   196  
   197  	if u.up == nil {
   198  		return nil
   199  	}
   200  
   201  	t := Task{Type: TaskConsume}
   202  	t.Value.Key = k
   203  	t.Value.Up = nr
   204  	t.Value.Down = nw
   205  	u.ch <- t
   206  	return nil
   207  }
   208  
   209  // CaddyUpstream is ...
   210  type CaddyUpstream struct {
   211  	// Prefix is ...
   212  	Prefix string `json:"-,omitempty"`
   213  	// Storage is ...
   214  	Storage certmagic.Storage `json:"-,omitempty"`
   215  	// Logger is ...
   216  	Logger *zap.Logger `json:"-,omitempty"`
   217  }
   218  
   219  // CaddyModule is ...
   220  func (CaddyUpstream) CaddyModule() caddy.ModuleInfo {
   221  	return caddy.ModuleInfo{
   222  		ID:  "trojan.upstreams.caddy",
   223  		New: func() caddy.Module { return new(CaddyUpstream) },
   224  	}
   225  }
   226  
   227  // Provision is ...
   228  func (u *CaddyUpstream) Provision(ctx caddy.Context) error {
   229  	u.Prefix = "trojan/"
   230  	u.Storage = ctx.Storage()
   231  	u.Logger = ctx.Logger(u)
   232  	return nil
   233  }
   234  
   235  // Add is ...
   236  func (u *CaddyUpstream) Add(s string) error {
   237  	b := [trojan.HeaderLen]byte{}
   238  	trojan.GenKey(s, b[:])
   239  	key := u.Prefix + string(b[:])
   240  	if u.Storage.Exists(context.Background(), key) {
   241  		return nil
   242  	}
   243  	traffic := Traffic{
   244  		Up:   0,
   245  		Down: 0,
   246  	}
   247  	bb, err := json.Marshal(&traffic)
   248  	if err != nil {
   249  		return err
   250  	}
   251  	return u.Storage.Store(context.Background(), key, bb)
   252  }
   253  
   254  // Delete is ...
   255  func (u *CaddyUpstream) Delete(s string) error {
   256  	b := [trojan.HeaderLen]byte{}
   257  	trojan.GenKey(s, b[:])
   258  	key := u.Prefix + utils.ByteSliceToString(b[:])
   259  	if !u.Storage.Exists(context.Background(), key) {
   260  		return nil
   261  	}
   262  	return u.Storage.Delete(context.Background(), key)
   263  }
   264  
   265  // Range is ...
   266  func (u *CaddyUpstream) Range(fn func(k string, up, down int64)) {
   267  	prekeys, err := u.Storage.List(context.Background(), u.Prefix, false)
   268  	if err != nil {
   269  		return
   270  	}
   271  
   272  	traffic := Traffic{}
   273  	for _, k := range prekeys {
   274  		b, err := u.Storage.Load(context.Background(), k)
   275  		if err != nil {
   276  			u.Logger.Error(fmt.Sprintf("load user error: %v", err))
   277  			continue
   278  		}
   279  		if err := json.Unmarshal(b, &traffic); err != nil {
   280  			u.Logger.Error(fmt.Sprintf("load user error: %v", err))
   281  			continue
   282  		}
   283  		fn(strings.TrimPrefix(k, u.Prefix), traffic.Up, traffic.Down)
   284  	}
   285  
   286  	return
   287  }
   288  
   289  // Validate is ...
   290  func (u *CaddyUpstream) Validate(k string) bool {
   291  	key := u.Prefix + k
   292  	return u.Storage.Exists(context.Background(), key)
   293  }
   294  
   295  // Consume is ...
   296  func (u *CaddyUpstream) Consume(k string, nr, nw int64) error {
   297  	key := u.Prefix + k
   298  
   299  	u.Storage.Lock(context.Background(), key)
   300  	defer u.Storage.Unlock(context.Background(), key)
   301  
   302  	b, err := u.Storage.Load(context.Background(), key)
   303  	if err != nil {
   304  		return err
   305  	}
   306  
   307  	traffic := Traffic{}
   308  	if err := json.Unmarshal(b, &traffic); err != nil {
   309  		return err
   310  	}
   311  
   312  	traffic.Up += nr
   313  	traffic.Down += nw
   314  
   315  	b, err = json.Marshal(&traffic)
   316  	if err != nil {
   317  		return err
   318  	}
   319  
   320  	return u.Storage.Store(context.Background(), key, b)
   321  }
   322  
   323  var (
   324  	_ Upstream           = (*CaddyUpstream)(nil)
   325  	_ Upstream           = (*MemoryUpstream)(nil)
   326  	_ caddy.CleanerUpper = (*MemoryUpstream)(nil)
   327  	_ caddy.Provisioner  = (*MemoryUpstream)(nil)
   328  )