github.com/cozy/cozy-stack@v0.0.0-20240603063001-31110fa4cae1/web/accounts/statestore.go (about) 1 package accounts 2 3 import ( 4 "context" 5 "encoding/hex" 6 "encoding/json" 7 "sync" 8 "time" 9 10 "github.com/cozy/cozy-stack/pkg/config/config" 11 "github.com/cozy/cozy-stack/pkg/crypto" 12 "github.com/cozy/cozy-stack/pkg/logger" 13 "github.com/redis/go-redis/v9" 14 ) 15 16 const stateTTL = 15 * time.Minute 17 18 type stateHolder struct { 19 InstanceDomain string 20 AccountType string 21 ClientState string 22 Nonce string 23 Slug string 24 ExpiresAt int64 25 WebviewFlow bool 26 } 27 28 type stateStorage interface { 29 Add(*stateHolder) (string, error) 30 Find(ref string) *stateHolder 31 } 32 33 type memStateStorage map[string]*stateHolder 34 35 func (store memStateStorage) Add(state *stateHolder) (string, error) { 36 state.ExpiresAt = time.Now().UTC().Add(stateTTL).Unix() 37 ref := hex.EncodeToString(crypto.GenerateRandomBytes(16)) 38 store[ref] = state 39 return ref, nil 40 } 41 42 func (store memStateStorage) Find(ref string) *stateHolder { 43 state, ok := store[ref] 44 if !ok { 45 return nil 46 } 47 if state.ExpiresAt < time.Now().UTC().Unix() { 48 delete(store, ref) 49 return nil 50 } 51 return state 52 } 53 54 type subRedisInterface interface { 55 Get(ctx context.Context, key string) *redis.StringCmd 56 Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.StatusCmd 57 } 58 59 type redisStateStorage struct { 60 cl subRedisInterface 61 ctx context.Context 62 } 63 64 func (store *redisStateStorage) Add(s *stateHolder) (string, error) { 65 ref := hex.EncodeToString(crypto.GenerateRandomBytes(16)) 66 bb, err := json.Marshal(s) 67 if err != nil { 68 return "", err 69 } 70 return ref, store.cl.Set(store.ctx, ref, bb, stateTTL).Err() 71 } 72 73 func (store *redisStateStorage) Find(ref string) *stateHolder { 74 bb, err := store.cl.Get(store.ctx, ref).Bytes() 75 if err != nil { 76 return nil 77 } 78 var s stateHolder 79 err = json.Unmarshal(bb, &s) 80 if err != nil { 81 logger.WithNamespace("redis-state").Errorf( 82 "bad state in redis %s", string(bb)) 83 return nil 84 } 85 return &s 86 } 87 88 var globalStorage stateStorage 89 var globalStorageMutex sync.Mutex 90 91 func getStorage() stateStorage { 92 globalStorageMutex.Lock() 93 defer globalStorageMutex.Unlock() 94 if globalStorage != nil { 95 return globalStorage 96 } 97 cli := config.GetConfig().OauthStateStorage 98 if cli == nil { 99 globalStorage = &memStateStorage{} 100 } else { 101 ctx := context.Background() 102 globalStorage = &redisStateStorage{cl: cli, ctx: ctx} 103 } 104 return globalStorage 105 }