github.com/database64128/shadowsocks-go@v1.7.0/cred/manager.go (about)

     1  package cred
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"os"
     9  	"strings"
    10  	"sync"
    11  	"time"
    12  	"unsafe"
    13  
    14  	"github.com/database64128/shadowsocks-go/maps"
    15  	"github.com/database64128/shadowsocks-go/mmap"
    16  	"github.com/database64128/shadowsocks-go/slices"
    17  	"github.com/database64128/shadowsocks-go/ss2022"
    18  	"go.uber.org/zap"
    19  )
    20  
    21  var (
    22  	ErrEmptyUsername   = errors.New("empty username")
    23  	ErrNonexistentUser = errors.New("nonexistent user")
    24  )
    25  
    26  // ManagedServer stores information about a server whose credentials are managed by the credential manager.
    27  type ManagedServer struct {
    28  	pskLength           int
    29  	tcp                 *ss2022.CredStore
    30  	udp                 *ss2022.CredStore
    31  	path                string
    32  	cachedContent       string
    33  	cachedCredMap       map[string]*cachedUserCredential
    34  	cachedUserLookupMap ss2022.UserLookupMap
    35  	mu                  sync.RWMutex
    36  	wg                  sync.WaitGroup
    37  	saveQueue           chan struct{}
    38  	done                chan struct{}
    39  	logger              *zap.Logger
    40  }
    41  
    42  // UserCredential stores a user's credential.
    43  type UserCredential struct {
    44  	Name string `json:"username"`
    45  	UPSK []byte `json:"uPSK"`
    46  }
    47  
    48  // Less is useful for sorting user credentials by username.
    49  func (uc UserCredential) Less(other UserCredential) bool {
    50  	return uc.Name < other.Name
    51  }
    52  
    53  type cachedUserCredential struct {
    54  	uPSK     []byte
    55  	uPSKHash [ss2022.IdentityHeaderLength]byte
    56  }
    57  
    58  // Credentials returns the server credentials.
    59  func (s *ManagedServer) Credentials() []UserCredential {
    60  	s.mu.RLock()
    61  	ucs := make([]UserCredential, 0, len(s.cachedCredMap))
    62  	for username, cachedCred := range s.cachedCredMap {
    63  		ucs = append(ucs, UserCredential{
    64  			Name: username,
    65  			UPSK: cachedCred.uPSK,
    66  		})
    67  	}
    68  	s.mu.RUnlock()
    69  	slices.SortFunc(ucs, UserCredential.Less)
    70  	return ucs
    71  }
    72  
    73  // GetCredential returns the user credential.
    74  func (s *ManagedServer) GetCredential(username string) (UserCredential, bool) {
    75  	s.mu.RLock()
    76  	cachedCred := s.cachedCredMap[username]
    77  	s.mu.RUnlock()
    78  	if cachedCred == nil {
    79  		return UserCredential{}, false
    80  	}
    81  	return UserCredential{
    82  		Name: username,
    83  		UPSK: cachedCred.uPSK,
    84  	}, true
    85  }
    86  
    87  func (s *ManagedServer) saveToFile() error {
    88  	uPSKMap := make(map[string][]byte, len(s.cachedCredMap))
    89  	for username, uc := range s.cachedCredMap {
    90  		uPSKMap[username] = uc.uPSK
    91  	}
    92  
    93  	b, err := json.MarshalIndent(uPSKMap, "", "    ")
    94  	if err != nil {
    95  		return err
    96  	}
    97  
    98  	if err = os.WriteFile(s.path, b, 0644); err != nil {
    99  		return err
   100  	}
   101  
   102  	s.cachedContent = unsafe.String(&b[0], len(b))
   103  	return nil
   104  }
   105  
   106  func (s *ManagedServer) dequeueSave() {
   107  	for {
   108  		// Wait for incoming save job.
   109  		select {
   110  		case <-s.saveQueue:
   111  		case <-s.done:
   112  			return
   113  		}
   114  
   115  		// Wait for cooldown.
   116  		select {
   117  		case <-time.After(5 * time.Second):
   118  		case <-s.done:
   119  		}
   120  
   121  		// Clear save queue after cooldown.
   122  		select {
   123  		case <-s.saveQueue:
   124  		default:
   125  		}
   126  
   127  		// The save operation only reads cachedCredMap and writes cachedContent.
   128  		// It is without doubt that taking the read lock is enough for cachedCredMap.
   129  		// As for cachedContent, the only other place that reads and writes it is LoadFromFile,
   130  		// which takes the write lock. So it is safe to take just the read lock here.
   131  		s.mu.RLock()
   132  		if err := s.saveToFile(); err != nil {
   133  			s.logger.Warn("Failed to save credentials", zap.Error(err))
   134  		}
   135  		s.mu.RUnlock()
   136  	}
   137  }
   138  
   139  // Start starts the managed server.
   140  func (s *ManagedServer) Start() {
   141  	s.wg.Add(1)
   142  	go func() {
   143  		s.dequeueSave()
   144  		s.wg.Done()
   145  	}()
   146  }
   147  
   148  // Stop stops the managed server.
   149  func (s *ManagedServer) Stop() {
   150  	close(s.done)
   151  	s.wg.Wait()
   152  }
   153  
   154  func (s *ManagedServer) enqueueSave() {
   155  	select {
   156  	case s.saveQueue <- struct{}{}:
   157  	default:
   158  	}
   159  }
   160  
   161  func (s *ManagedServer) updateProdULM(f func(ss2022.UserLookupMap)) {
   162  	if s.tcp != nil {
   163  		s.tcp.UpdateUserLookupMap(f)
   164  	}
   165  	if s.udp != nil {
   166  		s.udp.UpdateUserLookupMap(f)
   167  	}
   168  }
   169  
   170  // AddCredential adds a user credential.
   171  func (s *ManagedServer) AddCredential(username string, uPSK []byte) error {
   172  	if username == "" {
   173  		return ErrEmptyUsername
   174  	}
   175  	if len(uPSK) != s.pskLength {
   176  		return &ss2022.PSKLengthError{PSK: uPSK, ExpectedLength: s.pskLength}
   177  	}
   178  	s.mu.Lock()
   179  	if s.cachedCredMap[username] != nil {
   180  		s.mu.Unlock()
   181  		return fmt.Errorf("user %s already exists", username)
   182  	}
   183  	c, err := ss2022.NewServerUserCipherConfig(username, uPSK, s.udp != nil)
   184  	if err != nil {
   185  		s.mu.Unlock()
   186  		return err
   187  	}
   188  	uc := &cachedUserCredential{
   189  		uPSK:     uPSK,
   190  		uPSKHash: ss2022.PSKHash(uPSK),
   191  	}
   192  	s.cachedCredMap[username] = uc
   193  	s.cachedUserLookupMap[uc.uPSKHash] = c
   194  	s.mu.Unlock()
   195  	s.enqueueSave()
   196  	s.updateProdULM(func(ulm ss2022.UserLookupMap) {
   197  		ulm[uc.uPSKHash] = c
   198  	})
   199  	return nil
   200  }
   201  
   202  // UpdateCredential updates a user credential.
   203  func (s *ManagedServer) UpdateCredential(username string, uPSK []byte) error {
   204  	if len(uPSK) != s.pskLength {
   205  		return &ss2022.PSKLengthError{PSK: uPSK, ExpectedLength: s.pskLength}
   206  	}
   207  	s.mu.Lock()
   208  	uc := s.cachedCredMap[username]
   209  	if uc == nil {
   210  		s.mu.Unlock()
   211  		return fmt.Errorf("%w: %s", ErrNonexistentUser, username)
   212  	}
   213  	if bytes.Equal(uc.uPSK, uPSK) {
   214  		s.mu.Unlock()
   215  		return fmt.Errorf("user %s already has the same uPSK", username)
   216  	}
   217  	c, err := ss2022.NewServerUserCipherConfig(username, uPSK, s.udp != nil)
   218  	if err != nil {
   219  		s.mu.Unlock()
   220  		return err
   221  	}
   222  	oldUPSKHash := uc.uPSKHash
   223  	uc.uPSK = uPSK
   224  	uc.uPSKHash = ss2022.PSKHash(uPSK)
   225  	delete(s.cachedUserLookupMap, oldUPSKHash)
   226  	s.cachedUserLookupMap[uc.uPSKHash] = c
   227  	s.mu.Unlock()
   228  	s.enqueueSave()
   229  	s.updateProdULM(func(ulm ss2022.UserLookupMap) {
   230  		delete(ulm, oldUPSKHash)
   231  		ulm[uc.uPSKHash] = c
   232  	})
   233  	return nil
   234  }
   235  
   236  // DeleteCredential deletes a user credential.
   237  func (s *ManagedServer) DeleteCredential(username string) error {
   238  	s.mu.Lock()
   239  	uc := s.cachedCredMap[username]
   240  	if uc == nil {
   241  		s.mu.Unlock()
   242  		return fmt.Errorf("%w: %s", ErrNonexistentUser, username)
   243  	}
   244  	delete(s.cachedCredMap, username)
   245  	delete(s.cachedUserLookupMap, uc.uPSKHash)
   246  	s.mu.Unlock()
   247  	s.enqueueSave()
   248  	s.updateProdULM(func(ulm ss2022.UserLookupMap) {
   249  		delete(ulm, uc.uPSKHash)
   250  	})
   251  	return nil
   252  }
   253  
   254  // LoadFromFile loads credentials from the configured credential file
   255  // and applies the changes to the associated credential stores.
   256  func (s *ManagedServer) LoadFromFile() error {
   257  	content, err := mmap.ReadFile[string](s.path)
   258  	if err != nil {
   259  		return err
   260  	}
   261  	defer mmap.Unmap(content)
   262  
   263  	s.mu.Lock()
   264  	// Skip if the file content is unchanged.
   265  	if content == s.cachedContent {
   266  		s.mu.Unlock()
   267  		return nil
   268  	}
   269  
   270  	r := strings.NewReader(content)
   271  	d := json.NewDecoder(r)
   272  	d.DisallowUnknownFields()
   273  	var uPSKMap map[string][]byte
   274  	if err = d.Decode(&uPSKMap); err != nil {
   275  		s.mu.Unlock()
   276  		return err
   277  	}
   278  
   279  	userLookupMap := make(ss2022.UserLookupMap, len(uPSKMap))
   280  	credMap := make(map[string]*cachedUserCredential, len(uPSKMap))
   281  	for username, uPSK := range uPSKMap {
   282  		if len(uPSK) != s.pskLength {
   283  			s.mu.Unlock()
   284  			return &ss2022.PSKLengthError{PSK: uPSK, ExpectedLength: s.pskLength}
   285  		}
   286  
   287  		uPSKHash := ss2022.PSKHash(uPSK)
   288  		c := userLookupMap[uPSKHash]
   289  		if c != nil {
   290  			s.mu.Unlock()
   291  			return fmt.Errorf("duplicate uPSK for user %s and %s", c.Name, username)
   292  		}
   293  		c, err := ss2022.NewServerUserCipherConfig(username, uPSK, s.udp != nil)
   294  		if err != nil {
   295  			s.mu.Unlock()
   296  			return err
   297  		}
   298  
   299  		userLookupMap[uPSKHash] = c
   300  		credMap[username] = &cachedUserCredential{uPSK, uPSKHash}
   301  	}
   302  
   303  	s.cachedContent = strings.Clone(content)
   304  	s.cachedUserLookupMap = userLookupMap
   305  	s.cachedCredMap = credMap
   306  	s.mu.Unlock()
   307  
   308  	if s.tcp != nil {
   309  		s.tcp.ReplaceUserLookupMap(maps.Clone(s.cachedUserLookupMap))
   310  	}
   311  	if s.udp != nil {
   312  		s.udp.ReplaceUserLookupMap(maps.Clone(s.cachedUserLookupMap))
   313  	}
   314  
   315  	return nil
   316  }
   317  
   318  // Manager manages credentials for servers of supported protocols.
   319  type Manager struct {
   320  	logger  *zap.Logger
   321  	servers map[string]*ManagedServer
   322  }
   323  
   324  // NewManager returns a new credential manager.
   325  func NewManager(logger *zap.Logger) *Manager {
   326  	return &Manager{
   327  		logger:  logger,
   328  		servers: make(map[string]*ManagedServer),
   329  	}
   330  }
   331  
   332  // ReloadAll asks all managed servers to reload credentials from files.
   333  func (m *Manager) ReloadAll() {
   334  	for name, s := range m.servers {
   335  		if err := s.LoadFromFile(); err != nil {
   336  			m.logger.Warn("Failed to reload credentials", zap.String("server", name), zap.Error(err))
   337  			continue
   338  		}
   339  		m.logger.Info("Reloaded credentials", zap.String("server", name))
   340  	}
   341  }
   342  
   343  // LoadAll loads credentials for all managed servers.
   344  func (m *Manager) LoadAll() error {
   345  	for name, s := range m.servers {
   346  		if err := s.LoadFromFile(); err != nil {
   347  			return fmt.Errorf("failed to load credentials for server %s: %w", name, err)
   348  		}
   349  		m.logger.Debug("Loaded credentials", zap.String("server", name))
   350  	}
   351  	return nil
   352  }
   353  
   354  // String implements the service.Service String method.
   355  func (m *Manager) String() string {
   356  	return "credential manager"
   357  }
   358  
   359  // Start starts all managed servers and registers to reload on SIGUSR1.
   360  func (m *Manager) Start() error {
   361  	for _, s := range m.servers {
   362  		s.Start()
   363  	}
   364  	m.registerSIGUSR1()
   365  	return nil
   366  }
   367  
   368  // Stop gracefully stops all managed servers.
   369  func (m *Manager) Stop() error {
   370  	for _, s := range m.servers {
   371  		s.Stop()
   372  	}
   373  	return nil
   374  }
   375  
   376  // RegisterServer registers a server to the manager.
   377  func (m *Manager) RegisterServer(name, path string, pskLength int, tcpCredStore, udpCredStore *ss2022.CredStore) (*ManagedServer, error) {
   378  	s := m.servers[name]
   379  	if s != nil {
   380  		return nil, fmt.Errorf("server already registered: %s", name)
   381  	}
   382  	s = &ManagedServer{
   383  		pskLength: pskLength,
   384  		tcp:       tcpCredStore,
   385  		udp:       udpCredStore,
   386  		path:      path,
   387  		saveQueue: make(chan struct{}, 1),
   388  		done:      make(chan struct{}),
   389  		logger:    m.logger,
   390  	}
   391  	if err := s.LoadFromFile(); err != nil {
   392  		return nil, fmt.Errorf("failed to load credentials for server %s: %w", name, err)
   393  	}
   394  	m.servers[name] = s
   395  	m.logger.Debug("Registered server", zap.String("server", name))
   396  	return s, nil
   397  }