github.com/database64128/shadowsocks-go@v1.7.0/cred/manager.go (about) 1 package cred 2 3 import ( 4 "bytes" 5 "encoding/json" 6 "errors" 7 "fmt" 8 "os" 9 "strings" 10 "sync" 11 "time" 12 "unsafe" 13 14 "github.com/database64128/shadowsocks-go/maps" 15 "github.com/database64128/shadowsocks-go/mmap" 16 "github.com/database64128/shadowsocks-go/slices" 17 "github.com/database64128/shadowsocks-go/ss2022" 18 "go.uber.org/zap" 19 ) 20 21 var ( 22 ErrEmptyUsername = errors.New("empty username") 23 ErrNonexistentUser = errors.New("nonexistent user") 24 ) 25 26 // ManagedServer stores information about a server whose credentials are managed by the credential manager. 27 type ManagedServer struct { 28 pskLength int 29 tcp *ss2022.CredStore 30 udp *ss2022.CredStore 31 path string 32 cachedContent string 33 cachedCredMap map[string]*cachedUserCredential 34 cachedUserLookupMap ss2022.UserLookupMap 35 mu sync.RWMutex 36 wg sync.WaitGroup 37 saveQueue chan struct{} 38 done chan struct{} 39 logger *zap.Logger 40 } 41 42 // UserCredential stores a user's credential. 43 type UserCredential struct { 44 Name string `json:"username"` 45 UPSK []byte `json:"uPSK"` 46 } 47 48 // Less is useful for sorting user credentials by username. 49 func (uc UserCredential) Less(other UserCredential) bool { 50 return uc.Name < other.Name 51 } 52 53 type cachedUserCredential struct { 54 uPSK []byte 55 uPSKHash [ss2022.IdentityHeaderLength]byte 56 } 57 58 // Credentials returns the server credentials. 59 func (s *ManagedServer) Credentials() []UserCredential { 60 s.mu.RLock() 61 ucs := make([]UserCredential, 0, len(s.cachedCredMap)) 62 for username, cachedCred := range s.cachedCredMap { 63 ucs = append(ucs, UserCredential{ 64 Name: username, 65 UPSK: cachedCred.uPSK, 66 }) 67 } 68 s.mu.RUnlock() 69 slices.SortFunc(ucs, UserCredential.Less) 70 return ucs 71 } 72 73 // GetCredential returns the user credential. 74 func (s *ManagedServer) GetCredential(username string) (UserCredential, bool) { 75 s.mu.RLock() 76 cachedCred := s.cachedCredMap[username] 77 s.mu.RUnlock() 78 if cachedCred == nil { 79 return UserCredential{}, false 80 } 81 return UserCredential{ 82 Name: username, 83 UPSK: cachedCred.uPSK, 84 }, true 85 } 86 87 func (s *ManagedServer) saveToFile() error { 88 uPSKMap := make(map[string][]byte, len(s.cachedCredMap)) 89 for username, uc := range s.cachedCredMap { 90 uPSKMap[username] = uc.uPSK 91 } 92 93 b, err := json.MarshalIndent(uPSKMap, "", " ") 94 if err != nil { 95 return err 96 } 97 98 if err = os.WriteFile(s.path, b, 0644); err != nil { 99 return err 100 } 101 102 s.cachedContent = unsafe.String(&b[0], len(b)) 103 return nil 104 } 105 106 func (s *ManagedServer) dequeueSave() { 107 for { 108 // Wait for incoming save job. 109 select { 110 case <-s.saveQueue: 111 case <-s.done: 112 return 113 } 114 115 // Wait for cooldown. 116 select { 117 case <-time.After(5 * time.Second): 118 case <-s.done: 119 } 120 121 // Clear save queue after cooldown. 122 select { 123 case <-s.saveQueue: 124 default: 125 } 126 127 // The save operation only reads cachedCredMap and writes cachedContent. 128 // It is without doubt that taking the read lock is enough for cachedCredMap. 129 // As for cachedContent, the only other place that reads and writes it is LoadFromFile, 130 // which takes the write lock. So it is safe to take just the read lock here. 131 s.mu.RLock() 132 if err := s.saveToFile(); err != nil { 133 s.logger.Warn("Failed to save credentials", zap.Error(err)) 134 } 135 s.mu.RUnlock() 136 } 137 } 138 139 // Start starts the managed server. 140 func (s *ManagedServer) Start() { 141 s.wg.Add(1) 142 go func() { 143 s.dequeueSave() 144 s.wg.Done() 145 }() 146 } 147 148 // Stop stops the managed server. 149 func (s *ManagedServer) Stop() { 150 close(s.done) 151 s.wg.Wait() 152 } 153 154 func (s *ManagedServer) enqueueSave() { 155 select { 156 case s.saveQueue <- struct{}{}: 157 default: 158 } 159 } 160 161 func (s *ManagedServer) updateProdULM(f func(ss2022.UserLookupMap)) { 162 if s.tcp != nil { 163 s.tcp.UpdateUserLookupMap(f) 164 } 165 if s.udp != nil { 166 s.udp.UpdateUserLookupMap(f) 167 } 168 } 169 170 // AddCredential adds a user credential. 171 func (s *ManagedServer) AddCredential(username string, uPSK []byte) error { 172 if username == "" { 173 return ErrEmptyUsername 174 } 175 if len(uPSK) != s.pskLength { 176 return &ss2022.PSKLengthError{PSK: uPSK, ExpectedLength: s.pskLength} 177 } 178 s.mu.Lock() 179 if s.cachedCredMap[username] != nil { 180 s.mu.Unlock() 181 return fmt.Errorf("user %s already exists", username) 182 } 183 c, err := ss2022.NewServerUserCipherConfig(username, uPSK, s.udp != nil) 184 if err != nil { 185 s.mu.Unlock() 186 return err 187 } 188 uc := &cachedUserCredential{ 189 uPSK: uPSK, 190 uPSKHash: ss2022.PSKHash(uPSK), 191 } 192 s.cachedCredMap[username] = uc 193 s.cachedUserLookupMap[uc.uPSKHash] = c 194 s.mu.Unlock() 195 s.enqueueSave() 196 s.updateProdULM(func(ulm ss2022.UserLookupMap) { 197 ulm[uc.uPSKHash] = c 198 }) 199 return nil 200 } 201 202 // UpdateCredential updates a user credential. 203 func (s *ManagedServer) UpdateCredential(username string, uPSK []byte) error { 204 if len(uPSK) != s.pskLength { 205 return &ss2022.PSKLengthError{PSK: uPSK, ExpectedLength: s.pskLength} 206 } 207 s.mu.Lock() 208 uc := s.cachedCredMap[username] 209 if uc == nil { 210 s.mu.Unlock() 211 return fmt.Errorf("%w: %s", ErrNonexistentUser, username) 212 } 213 if bytes.Equal(uc.uPSK, uPSK) { 214 s.mu.Unlock() 215 return fmt.Errorf("user %s already has the same uPSK", username) 216 } 217 c, err := ss2022.NewServerUserCipherConfig(username, uPSK, s.udp != nil) 218 if err != nil { 219 s.mu.Unlock() 220 return err 221 } 222 oldUPSKHash := uc.uPSKHash 223 uc.uPSK = uPSK 224 uc.uPSKHash = ss2022.PSKHash(uPSK) 225 delete(s.cachedUserLookupMap, oldUPSKHash) 226 s.cachedUserLookupMap[uc.uPSKHash] = c 227 s.mu.Unlock() 228 s.enqueueSave() 229 s.updateProdULM(func(ulm ss2022.UserLookupMap) { 230 delete(ulm, oldUPSKHash) 231 ulm[uc.uPSKHash] = c 232 }) 233 return nil 234 } 235 236 // DeleteCredential deletes a user credential. 237 func (s *ManagedServer) DeleteCredential(username string) error { 238 s.mu.Lock() 239 uc := s.cachedCredMap[username] 240 if uc == nil { 241 s.mu.Unlock() 242 return fmt.Errorf("%w: %s", ErrNonexistentUser, username) 243 } 244 delete(s.cachedCredMap, username) 245 delete(s.cachedUserLookupMap, uc.uPSKHash) 246 s.mu.Unlock() 247 s.enqueueSave() 248 s.updateProdULM(func(ulm ss2022.UserLookupMap) { 249 delete(ulm, uc.uPSKHash) 250 }) 251 return nil 252 } 253 254 // LoadFromFile loads credentials from the configured credential file 255 // and applies the changes to the associated credential stores. 256 func (s *ManagedServer) LoadFromFile() error { 257 content, err := mmap.ReadFile[string](s.path) 258 if err != nil { 259 return err 260 } 261 defer mmap.Unmap(content) 262 263 s.mu.Lock() 264 // Skip if the file content is unchanged. 265 if content == s.cachedContent { 266 s.mu.Unlock() 267 return nil 268 } 269 270 r := strings.NewReader(content) 271 d := json.NewDecoder(r) 272 d.DisallowUnknownFields() 273 var uPSKMap map[string][]byte 274 if err = d.Decode(&uPSKMap); err != nil { 275 s.mu.Unlock() 276 return err 277 } 278 279 userLookupMap := make(ss2022.UserLookupMap, len(uPSKMap)) 280 credMap := make(map[string]*cachedUserCredential, len(uPSKMap)) 281 for username, uPSK := range uPSKMap { 282 if len(uPSK) != s.pskLength { 283 s.mu.Unlock() 284 return &ss2022.PSKLengthError{PSK: uPSK, ExpectedLength: s.pskLength} 285 } 286 287 uPSKHash := ss2022.PSKHash(uPSK) 288 c := userLookupMap[uPSKHash] 289 if c != nil { 290 s.mu.Unlock() 291 return fmt.Errorf("duplicate uPSK for user %s and %s", c.Name, username) 292 } 293 c, err := ss2022.NewServerUserCipherConfig(username, uPSK, s.udp != nil) 294 if err != nil { 295 s.mu.Unlock() 296 return err 297 } 298 299 userLookupMap[uPSKHash] = c 300 credMap[username] = &cachedUserCredential{uPSK, uPSKHash} 301 } 302 303 s.cachedContent = strings.Clone(content) 304 s.cachedUserLookupMap = userLookupMap 305 s.cachedCredMap = credMap 306 s.mu.Unlock() 307 308 if s.tcp != nil { 309 s.tcp.ReplaceUserLookupMap(maps.Clone(s.cachedUserLookupMap)) 310 } 311 if s.udp != nil { 312 s.udp.ReplaceUserLookupMap(maps.Clone(s.cachedUserLookupMap)) 313 } 314 315 return nil 316 } 317 318 // Manager manages credentials for servers of supported protocols. 319 type Manager struct { 320 logger *zap.Logger 321 servers map[string]*ManagedServer 322 } 323 324 // NewManager returns a new credential manager. 325 func NewManager(logger *zap.Logger) *Manager { 326 return &Manager{ 327 logger: logger, 328 servers: make(map[string]*ManagedServer), 329 } 330 } 331 332 // ReloadAll asks all managed servers to reload credentials from files. 333 func (m *Manager) ReloadAll() { 334 for name, s := range m.servers { 335 if err := s.LoadFromFile(); err != nil { 336 m.logger.Warn("Failed to reload credentials", zap.String("server", name), zap.Error(err)) 337 continue 338 } 339 m.logger.Info("Reloaded credentials", zap.String("server", name)) 340 } 341 } 342 343 // LoadAll loads credentials for all managed servers. 344 func (m *Manager) LoadAll() error { 345 for name, s := range m.servers { 346 if err := s.LoadFromFile(); err != nil { 347 return fmt.Errorf("failed to load credentials for server %s: %w", name, err) 348 } 349 m.logger.Debug("Loaded credentials", zap.String("server", name)) 350 } 351 return nil 352 } 353 354 // String implements the service.Service String method. 355 func (m *Manager) String() string { 356 return "credential manager" 357 } 358 359 // Start starts all managed servers and registers to reload on SIGUSR1. 360 func (m *Manager) Start() error { 361 for _, s := range m.servers { 362 s.Start() 363 } 364 m.registerSIGUSR1() 365 return nil 366 } 367 368 // Stop gracefully stops all managed servers. 369 func (m *Manager) Stop() error { 370 for _, s := range m.servers { 371 s.Stop() 372 } 373 return nil 374 } 375 376 // RegisterServer registers a server to the manager. 377 func (m *Manager) RegisterServer(name, path string, pskLength int, tcpCredStore, udpCredStore *ss2022.CredStore) (*ManagedServer, error) { 378 s := m.servers[name] 379 if s != nil { 380 return nil, fmt.Errorf("server already registered: %s", name) 381 } 382 s = &ManagedServer{ 383 pskLength: pskLength, 384 tcp: tcpCredStore, 385 udp: udpCredStore, 386 path: path, 387 saveQueue: make(chan struct{}, 1), 388 done: make(chan struct{}), 389 logger: m.logger, 390 } 391 if err := s.LoadFromFile(); err != nil { 392 return nil, fmt.Errorf("failed to load credentials for server %s: %w", name, err) 393 } 394 m.servers[name] = s 395 m.logger.Debug("Registered server", zap.String("server", name)) 396 return s, nil 397 }