github.com/cozy/cozy-stack@v0.0.0-20240603063001-31110fa4cae1/web/accounts/statestore.go (about)

     1  package accounts
     2  
     3  import (
     4  	"context"
     5  	"encoding/hex"
     6  	"encoding/json"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/cozy/cozy-stack/pkg/config/config"
    11  	"github.com/cozy/cozy-stack/pkg/crypto"
    12  	"github.com/cozy/cozy-stack/pkg/logger"
    13  	"github.com/redis/go-redis/v9"
    14  )
    15  
    16  const stateTTL = 15 * time.Minute
    17  
    18  type stateHolder struct {
    19  	InstanceDomain string
    20  	AccountType    string
    21  	ClientState    string
    22  	Nonce          string
    23  	Slug           string
    24  	ExpiresAt      int64
    25  	WebviewFlow    bool
    26  }
    27  
    28  type stateStorage interface {
    29  	Add(*stateHolder) (string, error)
    30  	Find(ref string) *stateHolder
    31  }
    32  
    33  type memStateStorage map[string]*stateHolder
    34  
    35  func (store memStateStorage) Add(state *stateHolder) (string, error) {
    36  	state.ExpiresAt = time.Now().UTC().Add(stateTTL).Unix()
    37  	ref := hex.EncodeToString(crypto.GenerateRandomBytes(16))
    38  	store[ref] = state
    39  	return ref, nil
    40  }
    41  
    42  func (store memStateStorage) Find(ref string) *stateHolder {
    43  	state, ok := store[ref]
    44  	if !ok {
    45  		return nil
    46  	}
    47  	if state.ExpiresAt < time.Now().UTC().Unix() {
    48  		delete(store, ref)
    49  		return nil
    50  	}
    51  	return state
    52  }
    53  
    54  type subRedisInterface interface {
    55  	Get(ctx context.Context, key string) *redis.StringCmd
    56  	Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *redis.StatusCmd
    57  }
    58  
    59  type redisStateStorage struct {
    60  	cl  subRedisInterface
    61  	ctx context.Context
    62  }
    63  
    64  func (store *redisStateStorage) Add(s *stateHolder) (string, error) {
    65  	ref := hex.EncodeToString(crypto.GenerateRandomBytes(16))
    66  	bb, err := json.Marshal(s)
    67  	if err != nil {
    68  		return "", err
    69  	}
    70  	return ref, store.cl.Set(store.ctx, ref, bb, stateTTL).Err()
    71  }
    72  
    73  func (store *redisStateStorage) Find(ref string) *stateHolder {
    74  	bb, err := store.cl.Get(store.ctx, ref).Bytes()
    75  	if err != nil {
    76  		return nil
    77  	}
    78  	var s stateHolder
    79  	err = json.Unmarshal(bb, &s)
    80  	if err != nil {
    81  		logger.WithNamespace("redis-state").Errorf(
    82  			"bad state in redis %s", string(bb))
    83  		return nil
    84  	}
    85  	return &s
    86  }
    87  
    88  var globalStorage stateStorage
    89  var globalStorageMutex sync.Mutex
    90  
    91  func getStorage() stateStorage {
    92  	globalStorageMutex.Lock()
    93  	defer globalStorageMutex.Unlock()
    94  	if globalStorage != nil {
    95  		return globalStorage
    96  	}
    97  	cli := config.GetConfig().OauthStateStorage
    98  	if cli == nil {
    99  		globalStorage = &memStateStorage{}
   100  	} else {
   101  		ctx := context.Background()
   102  		globalStorage = &redisStateStorage{cl: cli, ctx: ctx}
   103  	}
   104  	return globalStorage
   105  }