github.com/cozy/cozy-stack@v0.0.0-20240603063001-31110fa4cae1/web/oidc/statestore.go (about)

     1  package oidc
     2  
     3  import (
     4  	"context"
     5  	"encoding/hex"
     6  	"encoding/json"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/cozy/cozy-stack/pkg/config/config"
    11  	"github.com/cozy/cozy-stack/pkg/crypto"
    12  	"github.com/cozy/cozy-stack/pkg/logger"
    13  	"github.com/redis/go-redis/v9"
    14  )
    15  
    16  const (
    17  	stateTTL = 15 * time.Minute
    18  	codeTTL  = 3 * time.Hour
    19  )
    20  
    21  type stateHolder struct {
    22  	id        string
    23  	expiresAt int64
    24  	Provider  ProviderOIDC
    25  	Instance  string
    26  	Redirect  string
    27  	Nonce     string
    28  	Confirm   string
    29  }
    30  
    31  type ProviderOIDC int
    32  
    33  const (
    34  	GenericProvider ProviderOIDC = iota
    35  	FranceConnectProvider
    36  )
    37  
    38  func newStateHolder(domain, redirect, confirm string, provider ProviderOIDC) *stateHolder {
    39  	id := hex.EncodeToString(crypto.GenerateRandomBytes(24))
    40  	nonce := hex.EncodeToString(crypto.GenerateRandomBytes(24))
    41  	return &stateHolder{
    42  		id:       id,
    43  		Provider: provider,
    44  		Instance: domain,
    45  		Redirect: redirect,
    46  		Confirm:  confirm,
    47  		Nonce:    nonce,
    48  	}
    49  }
    50  
    51  type stateStorage interface {
    52  	Add(*stateHolder) error
    53  	Find(id string) *stateHolder
    54  	CreateCode(sub string) string
    55  	GetSub(code string) string
    56  }
    57  
    58  type memStateStorage struct {
    59  	states map[string]*stateHolder
    60  	codes  map[string]string // delegated code -> sub
    61  }
    62  
    63  func (store memStateStorage) Add(state *stateHolder) error {
    64  	state.expiresAt = time.Now().UTC().Add(stateTTL).Unix()
    65  	store.states[state.id] = state
    66  	return nil
    67  }
    68  
    69  func (store memStateStorage) Find(id string) *stateHolder {
    70  	state, ok := store.states[id]
    71  	if !ok {
    72  		return nil
    73  	}
    74  	if state.expiresAt < time.Now().UTC().Unix() {
    75  		delete(store.states, id)
    76  		return nil
    77  	}
    78  	return state
    79  }
    80  
    81  func (store memStateStorage) CreateCode(sub string) string {
    82  	code := makeCode()
    83  	store.codes[code] = sub
    84  	return code
    85  }
    86  
    87  func (store memStateStorage) GetSub(code string) string {
    88  	return store.codes[code]
    89  }
    90  
    91  type subRedisInterface interface {
    92  	Get(ctx context.Context, key string) *redis.StringCmd
    93  	Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.StatusCmd
    94  }
    95  
    96  type redisStateStorage struct {
    97  	cl  subRedisInterface
    98  	ctx context.Context
    99  }
   100  
   101  func (store *redisStateStorage) Add(s *stateHolder) error {
   102  	serialized, err := json.Marshal(s)
   103  	if err != nil {
   104  		return err
   105  	}
   106  	return store.cl.Set(store.ctx, s.id, serialized, stateTTL).Err()
   107  }
   108  
   109  func (store *redisStateStorage) Find(id string) *stateHolder {
   110  	serialized, err := store.cl.Get(store.ctx, id).Bytes()
   111  	if err != nil {
   112  		return nil
   113  	}
   114  	var s stateHolder
   115  	err = json.Unmarshal(serialized, &s)
   116  	if err != nil {
   117  		logger.WithNamespace("redis-state").Errorf(
   118  			"Bad state in redis %s", string(serialized))
   119  		return nil
   120  	}
   121  	return &s
   122  }
   123  
   124  func (store *redisStateStorage) CreateCode(sub string) string {
   125  	code := makeCode()
   126  	store.cl.Set(store.ctx, code, sub, codeTTL)
   127  	return code
   128  }
   129  
   130  func (store *redisStateStorage) GetSub(code string) string {
   131  	return store.cl.Get(store.ctx, code).Val()
   132  }
   133  
   134  var globalStorage stateStorage
   135  var globalStorageMutex sync.Mutex
   136  
   137  func getStorage() stateStorage {
   138  	globalStorageMutex.Lock()
   139  	defer globalStorageMutex.Unlock()
   140  	if globalStorage != nil {
   141  		return globalStorage
   142  	}
   143  	cli := config.GetConfig().OauthStateStorage
   144  	if cli == nil {
   145  		globalStorage = &memStateStorage{
   146  			states: make(map[string]*stateHolder),
   147  			codes:  make(map[string]string),
   148  		}
   149  	} else {
   150  		ctx := context.Background()
   151  		globalStorage = &redisStateStorage{cl: cli, ctx: ctx}
   152  	}
   153  	return globalStorage
   154  }
   155  
   156  func makeCode() string {
   157  	return hex.EncodeToString(crypto.GenerateRandomBytes(12))
   158  }