github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/cred/manager.go (about)

     1  package cred
     2  
     3  import (
     4  	"bytes"
     5  	"cmp"
     6  	"context"
     7  	"encoding/json"
     8  	"errors"
     9  	"fmt"
    10  	"maps"
    11  	"os"
    12  	"slices"
    13  	"strings"
    14  	"sync"
    15  	"time"
    16  	"unsafe"
    17  
    18  	"github.com/database64128/shadowsocks-go/mmap"
    19  	"github.com/database64128/shadowsocks-go/ss2022"
    20  	"go.uber.org/zap"
    21  )
    22  
    23  var (
    24  	ErrEmptyUsername   = errors.New("empty username")
    25  	ErrNonexistentUser = errors.New("nonexistent user")
    26  )
    27  
    28  // ManagedServer stores information about a server whose credentials are managed by the credential manager.
    29  type ManagedServer struct {
    30  	pskLength           int
    31  	tcp                 *ss2022.CredStore
    32  	udp                 *ss2022.CredStore
    33  	path                string
    34  	cachedContent       string
    35  	cachedCredMap       map[string]*cachedUserCredential
    36  	cachedUserLookupMap ss2022.UserLookupMap
    37  	mu                  sync.RWMutex
    38  	wg                  sync.WaitGroup
    39  	saveQueue           chan struct{}
    40  	logger              *zap.Logger
    41  }
    42  
    43  // UserCredential stores a user's credential.
    44  type UserCredential struct {
    45  	Name string `json:"username"`
    46  	UPSK []byte `json:"uPSK"`
    47  }
    48  
    49  // Compare is useful for sorting user credentials by username.
    50  func (uc UserCredential) Compare(other UserCredential) int {
    51  	return cmp.Compare(uc.Name, other.Name)
    52  }
    53  
    54  type cachedUserCredential struct {
    55  	uPSK     []byte
    56  	uPSKHash [ss2022.IdentityHeaderLength]byte
    57  }
    58  
    59  // Credentials returns the server credentials.
    60  func (s *ManagedServer) Credentials() []UserCredential {
    61  	s.mu.RLock()
    62  	ucs := make([]UserCredential, 0, len(s.cachedCredMap))
    63  	for username, cachedCred := range s.cachedCredMap {
    64  		ucs = append(ucs, UserCredential{
    65  			Name: username,
    66  			UPSK: cachedCred.uPSK,
    67  		})
    68  	}
    69  	s.mu.RUnlock()
    70  	slices.SortFunc(ucs, UserCredential.Compare)
    71  	return ucs
    72  }
    73  
    74  // GetCredential returns the user credential.
    75  func (s *ManagedServer) GetCredential(username string) (UserCredential, bool) {
    76  	s.mu.RLock()
    77  	cachedCred := s.cachedCredMap[username]
    78  	s.mu.RUnlock()
    79  	if cachedCred == nil {
    80  		return UserCredential{}, false
    81  	}
    82  	return UserCredential{
    83  		Name: username,
    84  		UPSK: cachedCred.uPSK,
    85  	}, true
    86  }
    87  
    88  func (s *ManagedServer) saveToFile() error {
    89  	uPSKMap := make(map[string][]byte, len(s.cachedCredMap))
    90  	for username, uc := range s.cachedCredMap {
    91  		uPSKMap[username] = uc.uPSK
    92  	}
    93  
    94  	b, err := json.MarshalIndent(uPSKMap, "", "    ")
    95  	if err != nil {
    96  		return err
    97  	}
    98  
    99  	if err = os.WriteFile(s.path, b, 0644); err != nil {
   100  		return err
   101  	}
   102  
   103  	s.cachedContent = unsafe.String(unsafe.SliceData(b), len(b))
   104  	return nil
   105  }
   106  
   107  func (s *ManagedServer) dequeueSave(ctx context.Context) {
   108  	for {
   109  		// Wait for incoming save job.
   110  		select {
   111  		case <-s.saveQueue:
   112  		case <-ctx.Done():
   113  			return
   114  		}
   115  
   116  		// Wait for cooldown.
   117  		select {
   118  		case <-time.After(5 * time.Second):
   119  		case <-ctx.Done():
   120  		}
   121  
   122  		// Clear save queue after cooldown.
   123  		select {
   124  		case <-s.saveQueue:
   125  		default:
   126  		}
   127  
   128  		// The save operation only reads cachedCredMap and writes cachedContent.
   129  		// It is without doubt that taking the read lock is enough for cachedCredMap.
   130  		// As for cachedContent, the only other place that reads and writes it is LoadFromFile,
   131  		// which takes the write lock. So it is safe to take just the read lock here.
   132  		s.mu.RLock()
   133  		if err := s.saveToFile(); err != nil {
   134  			s.logger.Warn("Failed to save credentials", zap.Error(err))
   135  		}
   136  		s.mu.RUnlock()
   137  	}
   138  }
   139  
   140  // Start starts the managed server.
   141  func (s *ManagedServer) Start(ctx context.Context) {
   142  	s.wg.Add(1)
   143  	go func() {
   144  		s.dequeueSave(ctx)
   145  		s.wg.Done()
   146  	}()
   147  }
   148  
   149  // Stop stops the managed server.
   150  func (s *ManagedServer) Stop() {
   151  	s.wg.Wait()
   152  }
   153  
   154  func (s *ManagedServer) enqueueSave() {
   155  	select {
   156  	case s.saveQueue <- struct{}{}:
   157  	default:
   158  	}
   159  }
   160  
   161  func (s *ManagedServer) updateProdULM(f func(ss2022.UserLookupMap)) {
   162  	if s.tcp != nil {
   163  		s.tcp.UpdateUserLookupMap(f)
   164  	}
   165  	if s.udp != nil {
   166  		s.udp.UpdateUserLookupMap(f)
   167  	}
   168  }
   169  
   170  // AddCredential adds a user credential.
   171  func (s *ManagedServer) AddCredential(username string, uPSK []byte) error {
   172  	if username == "" {
   173  		return ErrEmptyUsername
   174  	}
   175  	if len(uPSK) != s.pskLength {
   176  		return &ss2022.PSKLengthError{PSK: uPSK, ExpectedLength: s.pskLength}
   177  	}
   178  	s.mu.Lock()
   179  	if s.cachedCredMap[username] != nil {
   180  		s.mu.Unlock()
   181  		return fmt.Errorf("user %s already exists", username)
   182  	}
   183  	c, err := ss2022.NewServerUserCipherConfig(username, uPSK, s.udp != nil)
   184  	if err != nil {
   185  		s.mu.Unlock()
   186  		return err
   187  	}
   188  	uc := &cachedUserCredential{
   189  		uPSK:     uPSK,
   190  		uPSKHash: ss2022.PSKHash(uPSK),
   191  	}
   192  	s.cachedCredMap[username] = uc
   193  	s.cachedUserLookupMap[uc.uPSKHash] = c
   194  	s.mu.Unlock()
   195  	s.enqueueSave()
   196  	s.updateProdULM(func(ulm ss2022.UserLookupMap) {
   197  		ulm[uc.uPSKHash] = c
   198  	})
   199  	return nil
   200  }
   201  
   202  // UpdateCredential updates a user credential.
   203  func (s *ManagedServer) UpdateCredential(username string, uPSK []byte) error {
   204  	if len(uPSK) != s.pskLength {
   205  		return &ss2022.PSKLengthError{PSK: uPSK, ExpectedLength: s.pskLength}
   206  	}
   207  	s.mu.Lock()
   208  	uc := s.cachedCredMap[username]
   209  	if uc == nil {
   210  		s.mu.Unlock()
   211  		return fmt.Errorf("%w: %s", ErrNonexistentUser, username)
   212  	}
   213  	if bytes.Equal(uc.uPSK, uPSK) {
   214  		s.mu.Unlock()
   215  		return fmt.Errorf("user %s already has the same uPSK", username)
   216  	}
   217  	c, err := ss2022.NewServerUserCipherConfig(username, uPSK, s.udp != nil)
   218  	if err != nil {
   219  		s.mu.Unlock()
   220  		return err
   221  	}
   222  	oldUPSKHash := uc.uPSKHash
   223  	uc.uPSK = uPSK
   224  	uc.uPSKHash = ss2022.PSKHash(uPSK)
   225  	delete(s.cachedUserLookupMap, oldUPSKHash)
   226  	s.cachedUserLookupMap[uc.uPSKHash] = c
   227  	s.mu.Unlock()
   228  	s.enqueueSave()
   229  	s.updateProdULM(func(ulm ss2022.UserLookupMap) {
   230  		delete(ulm, oldUPSKHash)
   231  		ulm[uc.uPSKHash] = c
   232  	})
   233  	return nil
   234  }
   235  
   236  // DeleteCredential deletes a user credential.
   237  func (s *ManagedServer) DeleteCredential(username string) error {
   238  	s.mu.Lock()
   239  	uc := s.cachedCredMap[username]
   240  	if uc == nil {
   241  		s.mu.Unlock()
   242  		return fmt.Errorf("%w: %s", ErrNonexistentUser, username)
   243  	}
   244  	delete(s.cachedCredMap, username)
   245  	delete(s.cachedUserLookupMap, uc.uPSKHash)
   246  	s.mu.Unlock()
   247  	s.enqueueSave()
   248  	s.updateProdULM(func(ulm ss2022.UserLookupMap) {
   249  		delete(ulm, uc.uPSKHash)
   250  	})
   251  	return nil
   252  }
   253  
   254  // LoadFromFile loads credentials from the configured credential file
   255  // and applies the changes to the associated credential stores.
   256  func (s *ManagedServer) LoadFromFile() error {
   257  	content, err := mmap.ReadFile[string](s.path)
   258  	if err != nil {
   259  		return err
   260  	}
   261  	defer mmap.Unmap(content)
   262  
   263  	s.mu.Lock()
   264  	// Skip if the file content is unchanged.
   265  	if content == s.cachedContent {
   266  		s.mu.Unlock()
   267  		return nil
   268  	}
   269  
   270  	r := strings.NewReader(content)
   271  	d := json.NewDecoder(r)
   272  	d.DisallowUnknownFields()
   273  	var uPSKMap map[string][]byte
   274  	if err = d.Decode(&uPSKMap); err != nil {
   275  		s.mu.Unlock()
   276  		return err
   277  	}
   278  
   279  	userLookupMap := make(ss2022.UserLookupMap, len(uPSKMap))
   280  	credMap := make(map[string]*cachedUserCredential, len(uPSKMap))
   281  	for username, uPSK := range uPSKMap {
   282  		if len(uPSK) != s.pskLength {
   283  			s.mu.Unlock()
   284  			return &ss2022.PSKLengthError{PSK: uPSK, ExpectedLength: s.pskLength}
   285  		}
   286  
   287  		uPSKHash := ss2022.PSKHash(uPSK)
   288  		c := userLookupMap[uPSKHash]
   289  		if c != nil {
   290  			s.mu.Unlock()
   291  			return fmt.Errorf("duplicate uPSK for user %s and %s", c.Name, username)
   292  		}
   293  		c, err := ss2022.NewServerUserCipherConfig(username, uPSK, s.udp != nil)
   294  		if err != nil {
   295  			s.mu.Unlock()
   296  			return err
   297  		}
   298  
   299  		userLookupMap[uPSKHash] = c
   300  		credMap[username] = &cachedUserCredential{uPSK, uPSKHash}
   301  	}
   302  
   303  	s.cachedContent = strings.Clone(content)
   304  	s.cachedUserLookupMap = userLookupMap
   305  	s.cachedCredMap = credMap
   306  	s.mu.Unlock()
   307  
   308  	if s.tcp != nil {
   309  		s.tcp.ReplaceUserLookupMap(maps.Clone(s.cachedUserLookupMap))
   310  	}
   311  	if s.udp != nil {
   312  		s.udp.ReplaceUserLookupMap(maps.Clone(s.cachedUserLookupMap))
   313  	}
   314  
   315  	return nil
   316  }
   317  
   318  // Manager manages credentials for servers of supported protocols.
   319  type Manager struct {
   320  	logger  *zap.Logger
   321  	servers map[string]*ManagedServer
   322  }
   323  
   324  // NewManager returns a new credential manager.
   325  func NewManager(logger *zap.Logger) *Manager {
   326  	return &Manager{
   327  		logger:  logger,
   328  		servers: make(map[string]*ManagedServer),
   329  	}
   330  }
   331  
   332  // ReloadAll asks all managed servers to reload credentials from files.
   333  func (m *Manager) ReloadAll() {
   334  	for name, s := range m.servers {
   335  		if err := s.LoadFromFile(); err != nil {
   336  			m.logger.Warn("Failed to reload credentials", zap.String("server", name), zap.Error(err))
   337  			continue
   338  		}
   339  		m.logger.Info("Reloaded credentials", zap.String("server", name))
   340  	}
   341  }
   342  
   343  // LoadAll loads credentials for all managed servers.
   344  func (m *Manager) LoadAll() error {
   345  	for name, s := range m.servers {
   346  		if err := s.LoadFromFile(); err != nil {
   347  			return fmt.Errorf("failed to load credentials for server %s: %w", name, err)
   348  		}
   349  		m.logger.Debug("Loaded credentials", zap.String("server", name))
   350  	}
   351  	return nil
   352  }
   353  
   354  // String implements the service.Service String method.
   355  func (m *Manager) String() string {
   356  	return "credential manager"
   357  }
   358  
   359  // Start starts all managed servers and registers to reload on SIGUSR1.
   360  func (m *Manager) Start(ctx context.Context) error {
   361  	for _, s := range m.servers {
   362  		s.Start(ctx)
   363  	}
   364  	m.registerSIGUSR1()
   365  	return nil
   366  }
   367  
   368  // Stop gracefully stops all managed servers.
   369  func (m *Manager) Stop() error {
   370  	for _, s := range m.servers {
   371  		s.Stop()
   372  	}
   373  	return nil
   374  }
   375  
   376  // RegisterServer registers a server to the manager.
   377  func (m *Manager) RegisterServer(name, path string, pskLength int, tcpCredStore, udpCredStore *ss2022.CredStore) (*ManagedServer, error) {
   378  	s := m.servers[name]
   379  	if s != nil {
   380  		return nil, fmt.Errorf("server already registered: %s", name)
   381  	}
   382  	s = &ManagedServer{
   383  		pskLength: pskLength,
   384  		tcp:       tcpCredStore,
   385  		udp:       udpCredStore,
   386  		path:      path,
   387  		saveQueue: make(chan struct{}, 1),
   388  		logger:    m.logger,
   389  	}
   390  	if err := s.LoadFromFile(); err != nil {
   391  		return nil, fmt.Errorf("failed to load credentials for server %s: %w", name, err)
   392  	}
   393  	m.servers[name] = s
   394  	m.logger.Debug("Registered server", zap.String("server", name))
   395  	return s, nil
   396  }