github.com/cozy/cozy-stack@v0.0.0-20240603063001-31110fa4cae1/web/oidc/statestore.go (about) 1 package oidc 2 3 import ( 4 "context" 5 "encoding/hex" 6 "encoding/json" 7 "sync" 8 "time" 9 10 "github.com/cozy/cozy-stack/pkg/config/config" 11 "github.com/cozy/cozy-stack/pkg/crypto" 12 "github.com/cozy/cozy-stack/pkg/logger" 13 "github.com/redis/go-redis/v9" 14 ) 15 16 const ( 17 stateTTL = 15 * time.Minute 18 codeTTL = 3 * time.Hour 19 ) 20 21 type stateHolder struct { 22 id string 23 expiresAt int64 24 Provider ProviderOIDC 25 Instance string 26 Redirect string 27 Nonce string 28 Confirm string 29 } 30 31 type ProviderOIDC int 32 33 const ( 34 GenericProvider ProviderOIDC = iota 35 FranceConnectProvider 36 ) 37 38 func newStateHolder(domain, redirect, confirm string, provider ProviderOIDC) *stateHolder { 39 id := hex.EncodeToString(crypto.GenerateRandomBytes(24)) 40 nonce := hex.EncodeToString(crypto.GenerateRandomBytes(24)) 41 return &stateHolder{ 42 id: id, 43 Provider: provider, 44 Instance: domain, 45 Redirect: redirect, 46 Confirm: confirm, 47 Nonce: nonce, 48 } 49 } 50 51 type stateStorage interface { 52 Add(*stateHolder) error 53 Find(id string) *stateHolder 54 CreateCode(sub string) string 55 GetSub(code string) string 56 } 57 58 type memStateStorage struct { 59 states map[string]*stateHolder 60 codes map[string]string // delegated code -> sub 61 } 62 63 func (store memStateStorage) Add(state *stateHolder) error { 64 state.expiresAt = time.Now().UTC().Add(stateTTL).Unix() 65 store.states[state.id] = state 66 return nil 67 } 68 69 func (store memStateStorage) Find(id string) *stateHolder { 70 state, ok := store.states[id] 71 if !ok { 72 return nil 73 } 74 if state.expiresAt < time.Now().UTC().Unix() { 75 delete(store.states, id) 76 return nil 77 } 78 return state 79 } 80 81 func (store memStateStorage) CreateCode(sub string) string { 82 code := makeCode() 83 store.codes[code] = sub 84 return code 85 } 86 87 func (store memStateStorage) GetSub(code string) string { 88 return store.codes[code] 89 } 90 91 type subRedisInterface interface { 92 Get(ctx context.Context, key string) *redis.StringCmd 93 Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.StatusCmd 94 } 95 96 type redisStateStorage struct { 97 cl subRedisInterface 98 ctx context.Context 99 } 100 101 func (store *redisStateStorage) Add(s *stateHolder) error { 102 serialized, err := json.Marshal(s) 103 if err != nil { 104 return err 105 } 106 return store.cl.Set(store.ctx, s.id, serialized, stateTTL).Err() 107 } 108 109 func (store *redisStateStorage) Find(id string) *stateHolder { 110 serialized, err := store.cl.Get(store.ctx, id).Bytes() 111 if err != nil { 112 return nil 113 } 114 var s stateHolder 115 err = json.Unmarshal(serialized, &s) 116 if err != nil { 117 logger.WithNamespace("redis-state").Errorf( 118 "Bad state in redis %s", string(serialized)) 119 return nil 120 } 121 return &s 122 } 123 124 func (store *redisStateStorage) CreateCode(sub string) string { 125 code := makeCode() 126 store.cl.Set(store.ctx, code, sub, codeTTL) 127 return code 128 } 129 130 func (store *redisStateStorage) GetSub(code string) string { 131 return store.cl.Get(store.ctx, code).Val() 132 } 133 134 var globalStorage stateStorage 135 var globalStorageMutex sync.Mutex 136 137 func getStorage() stateStorage { 138 globalStorageMutex.Lock() 139 defer globalStorageMutex.Unlock() 140 if globalStorage != nil { 141 return globalStorage 142 } 143 cli := config.GetConfig().OauthStateStorage 144 if cli == nil { 145 globalStorage = &memStateStorage{ 146 states: make(map[string]*stateHolder), 147 codes: make(map[string]string), 148 } 149 } else { 150 ctx := context.Background() 151 globalStorage = &redisStateStorage{cl: cli, ctx: ctx} 152 } 153 return globalStorage 154 } 155 156 func makeCode() string { 157 return hex.EncodeToString(crypto.GenerateRandomBytes(12)) 158 }