github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/cred/manager.go (about) 1 package cred 2 3 import ( 4 "bytes" 5 "cmp" 6 "context" 7 "encoding/json" 8 "errors" 9 "fmt" 10 "maps" 11 "os" 12 "slices" 13 "strings" 14 "sync" 15 "time" 16 "unsafe" 17 18 "github.com/database64128/shadowsocks-go/mmap" 19 "github.com/database64128/shadowsocks-go/ss2022" 20 "go.uber.org/zap" 21 ) 22 23 var ( 24 ErrEmptyUsername = errors.New("empty username") 25 ErrNonexistentUser = errors.New("nonexistent user") 26 ) 27 28 // ManagedServer stores information about a server whose credentials are managed by the credential manager. 29 type ManagedServer struct { 30 pskLength int 31 tcp *ss2022.CredStore 32 udp *ss2022.CredStore 33 path string 34 cachedContent string 35 cachedCredMap map[string]*cachedUserCredential 36 cachedUserLookupMap ss2022.UserLookupMap 37 mu sync.RWMutex 38 wg sync.WaitGroup 39 saveQueue chan struct{} 40 logger *zap.Logger 41 } 42 43 // UserCredential stores a user's credential. 44 type UserCredential struct { 45 Name string `json:"username"` 46 UPSK []byte `json:"uPSK"` 47 } 48 49 // Compare is useful for sorting user credentials by username. 50 func (uc UserCredential) Compare(other UserCredential) int { 51 return cmp.Compare(uc.Name, other.Name) 52 } 53 54 type cachedUserCredential struct { 55 uPSK []byte 56 uPSKHash [ss2022.IdentityHeaderLength]byte 57 } 58 59 // Credentials returns the server credentials. 60 func (s *ManagedServer) Credentials() []UserCredential { 61 s.mu.RLock() 62 ucs := make([]UserCredential, 0, len(s.cachedCredMap)) 63 for username, cachedCred := range s.cachedCredMap { 64 ucs = append(ucs, UserCredential{ 65 Name: username, 66 UPSK: cachedCred.uPSK, 67 }) 68 } 69 s.mu.RUnlock() 70 slices.SortFunc(ucs, UserCredential.Compare) 71 return ucs 72 } 73 74 // GetCredential returns the user credential. 75 func (s *ManagedServer) GetCredential(username string) (UserCredential, bool) { 76 s.mu.RLock() 77 cachedCred := s.cachedCredMap[username] 78 s.mu.RUnlock() 79 if cachedCred == nil { 80 return UserCredential{}, false 81 } 82 return UserCredential{ 83 Name: username, 84 UPSK: cachedCred.uPSK, 85 }, true 86 } 87 88 func (s *ManagedServer) saveToFile() error { 89 uPSKMap := make(map[string][]byte, len(s.cachedCredMap)) 90 for username, uc := range s.cachedCredMap { 91 uPSKMap[username] = uc.uPSK 92 } 93 94 b, err := json.MarshalIndent(uPSKMap, "", " ") 95 if err != nil { 96 return err 97 } 98 99 if err = os.WriteFile(s.path, b, 0644); err != nil { 100 return err 101 } 102 103 s.cachedContent = unsafe.String(unsafe.SliceData(b), len(b)) 104 return nil 105 } 106 107 func (s *ManagedServer) dequeueSave(ctx context.Context) { 108 for { 109 // Wait for incoming save job. 110 select { 111 case <-s.saveQueue: 112 case <-ctx.Done(): 113 return 114 } 115 116 // Wait for cooldown. 117 select { 118 case <-time.After(5 * time.Second): 119 case <-ctx.Done(): 120 } 121 122 // Clear save queue after cooldown. 123 select { 124 case <-s.saveQueue: 125 default: 126 } 127 128 // The save operation only reads cachedCredMap and writes cachedContent. 129 // It is without doubt that taking the read lock is enough for cachedCredMap. 130 // As for cachedContent, the only other place that reads and writes it is LoadFromFile, 131 // which takes the write lock. So it is safe to take just the read lock here. 132 s.mu.RLock() 133 if err := s.saveToFile(); err != nil { 134 s.logger.Warn("Failed to save credentials", zap.Error(err)) 135 } 136 s.mu.RUnlock() 137 } 138 } 139 140 // Start starts the managed server. 141 func (s *ManagedServer) Start(ctx context.Context) { 142 s.wg.Add(1) 143 go func() { 144 s.dequeueSave(ctx) 145 s.wg.Done() 146 }() 147 } 148 149 // Stop stops the managed server. 150 func (s *ManagedServer) Stop() { 151 s.wg.Wait() 152 } 153 154 func (s *ManagedServer) enqueueSave() { 155 select { 156 case s.saveQueue <- struct{}{}: 157 default: 158 } 159 } 160 161 func (s *ManagedServer) updateProdULM(f func(ss2022.UserLookupMap)) { 162 if s.tcp != nil { 163 s.tcp.UpdateUserLookupMap(f) 164 } 165 if s.udp != nil { 166 s.udp.UpdateUserLookupMap(f) 167 } 168 } 169 170 // AddCredential adds a user credential. 171 func (s *ManagedServer) AddCredential(username string, uPSK []byte) error { 172 if username == "" { 173 return ErrEmptyUsername 174 } 175 if len(uPSK) != s.pskLength { 176 return &ss2022.PSKLengthError{PSK: uPSK, ExpectedLength: s.pskLength} 177 } 178 s.mu.Lock() 179 if s.cachedCredMap[username] != nil { 180 s.mu.Unlock() 181 return fmt.Errorf("user %s already exists", username) 182 } 183 c, err := ss2022.NewServerUserCipherConfig(username, uPSK, s.udp != nil) 184 if err != nil { 185 s.mu.Unlock() 186 return err 187 } 188 uc := &cachedUserCredential{ 189 uPSK: uPSK, 190 uPSKHash: ss2022.PSKHash(uPSK), 191 } 192 s.cachedCredMap[username] = uc 193 s.cachedUserLookupMap[uc.uPSKHash] = c 194 s.mu.Unlock() 195 s.enqueueSave() 196 s.updateProdULM(func(ulm ss2022.UserLookupMap) { 197 ulm[uc.uPSKHash] = c 198 }) 199 return nil 200 } 201 202 // UpdateCredential updates a user credential. 203 func (s *ManagedServer) UpdateCredential(username string, uPSK []byte) error { 204 if len(uPSK) != s.pskLength { 205 return &ss2022.PSKLengthError{PSK: uPSK, ExpectedLength: s.pskLength} 206 } 207 s.mu.Lock() 208 uc := s.cachedCredMap[username] 209 if uc == nil { 210 s.mu.Unlock() 211 return fmt.Errorf("%w: %s", ErrNonexistentUser, username) 212 } 213 if bytes.Equal(uc.uPSK, uPSK) { 214 s.mu.Unlock() 215 return fmt.Errorf("user %s already has the same uPSK", username) 216 } 217 c, err := ss2022.NewServerUserCipherConfig(username, uPSK, s.udp != nil) 218 if err != nil { 219 s.mu.Unlock() 220 return err 221 } 222 oldUPSKHash := uc.uPSKHash 223 uc.uPSK = uPSK 224 uc.uPSKHash = ss2022.PSKHash(uPSK) 225 delete(s.cachedUserLookupMap, oldUPSKHash) 226 s.cachedUserLookupMap[uc.uPSKHash] = c 227 s.mu.Unlock() 228 s.enqueueSave() 229 s.updateProdULM(func(ulm ss2022.UserLookupMap) { 230 delete(ulm, oldUPSKHash) 231 ulm[uc.uPSKHash] = c 232 }) 233 return nil 234 } 235 236 // DeleteCredential deletes a user credential. 237 func (s *ManagedServer) DeleteCredential(username string) error { 238 s.mu.Lock() 239 uc := s.cachedCredMap[username] 240 if uc == nil { 241 s.mu.Unlock() 242 return fmt.Errorf("%w: %s", ErrNonexistentUser, username) 243 } 244 delete(s.cachedCredMap, username) 245 delete(s.cachedUserLookupMap, uc.uPSKHash) 246 s.mu.Unlock() 247 s.enqueueSave() 248 s.updateProdULM(func(ulm ss2022.UserLookupMap) { 249 delete(ulm, uc.uPSKHash) 250 }) 251 return nil 252 } 253 254 // LoadFromFile loads credentials from the configured credential file 255 // and applies the changes to the associated credential stores. 256 func (s *ManagedServer) LoadFromFile() error { 257 content, err := mmap.ReadFile[string](s.path) 258 if err != nil { 259 return err 260 } 261 defer mmap.Unmap(content) 262 263 s.mu.Lock() 264 // Skip if the file content is unchanged. 265 if content == s.cachedContent { 266 s.mu.Unlock() 267 return nil 268 } 269 270 r := strings.NewReader(content) 271 d := json.NewDecoder(r) 272 d.DisallowUnknownFields() 273 var uPSKMap map[string][]byte 274 if err = d.Decode(&uPSKMap); err != nil { 275 s.mu.Unlock() 276 return err 277 } 278 279 userLookupMap := make(ss2022.UserLookupMap, len(uPSKMap)) 280 credMap := make(map[string]*cachedUserCredential, len(uPSKMap)) 281 for username, uPSK := range uPSKMap { 282 if len(uPSK) != s.pskLength { 283 s.mu.Unlock() 284 return &ss2022.PSKLengthError{PSK: uPSK, ExpectedLength: s.pskLength} 285 } 286 287 uPSKHash := ss2022.PSKHash(uPSK) 288 c := userLookupMap[uPSKHash] 289 if c != nil { 290 s.mu.Unlock() 291 return fmt.Errorf("duplicate uPSK for user %s and %s", c.Name, username) 292 } 293 c, err := ss2022.NewServerUserCipherConfig(username, uPSK, s.udp != nil) 294 if err != nil { 295 s.mu.Unlock() 296 return err 297 } 298 299 userLookupMap[uPSKHash] = c 300 credMap[username] = &cachedUserCredential{uPSK, uPSKHash} 301 } 302 303 s.cachedContent = strings.Clone(content) 304 s.cachedUserLookupMap = userLookupMap 305 s.cachedCredMap = credMap 306 s.mu.Unlock() 307 308 if s.tcp != nil { 309 s.tcp.ReplaceUserLookupMap(maps.Clone(s.cachedUserLookupMap)) 310 } 311 if s.udp != nil { 312 s.udp.ReplaceUserLookupMap(maps.Clone(s.cachedUserLookupMap)) 313 } 314 315 return nil 316 } 317 318 // Manager manages credentials for servers of supported protocols. 319 type Manager struct { 320 logger *zap.Logger 321 servers map[string]*ManagedServer 322 } 323 324 // NewManager returns a new credential manager. 325 func NewManager(logger *zap.Logger) *Manager { 326 return &Manager{ 327 logger: logger, 328 servers: make(map[string]*ManagedServer), 329 } 330 } 331 332 // ReloadAll asks all managed servers to reload credentials from files. 333 func (m *Manager) ReloadAll() { 334 for name, s := range m.servers { 335 if err := s.LoadFromFile(); err != nil { 336 m.logger.Warn("Failed to reload credentials", zap.String("server", name), zap.Error(err)) 337 continue 338 } 339 m.logger.Info("Reloaded credentials", zap.String("server", name)) 340 } 341 } 342 343 // LoadAll loads credentials for all managed servers. 344 func (m *Manager) LoadAll() error { 345 for name, s := range m.servers { 346 if err := s.LoadFromFile(); err != nil { 347 return fmt.Errorf("failed to load credentials for server %s: %w", name, err) 348 } 349 m.logger.Debug("Loaded credentials", zap.String("server", name)) 350 } 351 return nil 352 } 353 354 // String implements the service.Service String method. 355 func (m *Manager) String() string { 356 return "credential manager" 357 } 358 359 // Start starts all managed servers and registers to reload on SIGUSR1. 360 func (m *Manager) Start(ctx context.Context) error { 361 for _, s := range m.servers { 362 s.Start(ctx) 363 } 364 m.registerSIGUSR1() 365 return nil 366 } 367 368 // Stop gracefully stops all managed servers. 369 func (m *Manager) Stop() error { 370 for _, s := range m.servers { 371 s.Stop() 372 } 373 return nil 374 } 375 376 // RegisterServer registers a server to the manager. 377 func (m *Manager) RegisterServer(name, path string, pskLength int, tcpCredStore, udpCredStore *ss2022.CredStore) (*ManagedServer, error) { 378 s := m.servers[name] 379 if s != nil { 380 return nil, fmt.Errorf("server already registered: %s", name) 381 } 382 s = &ManagedServer{ 383 pskLength: pskLength, 384 tcp: tcpCredStore, 385 udp: udpCredStore, 386 path: path, 387 saveQueue: make(chan struct{}, 1), 388 logger: m.logger, 389 } 390 if err := s.LoadFromFile(); err != nil { 391 return nil, fmt.Errorf("failed to load credentials for server %s: %w", name, err) 392 } 393 m.servers[name] = s 394 m.logger.Debug("Registered server", zap.String("server", name)) 395 return s, nil 396 }