github.com/cozy/cozy-stack@v0.0.0-20240603063001-31110fa4cae1/model/session/registration_logins.go (about)

     1  package session
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/cozy/cozy-stack/model/instance/lifecycle"
    10  	"github.com/cozy/cozy-stack/pkg/config/config"
    11  	"github.com/cozy/cozy-stack/pkg/consts"
    12  	"github.com/cozy/cozy-stack/pkg/couchdb"
    13  	"github.com/cozy/cozy-stack/pkg/logger"
    14  	"github.com/cozy/cozy-stack/pkg/prefixer"
    15  	"github.com/cozy/cozy-stack/pkg/utils"
    16  )
    17  
    18  var log = logger.WithNamespace("sessions")
    19  
    20  const redisRegistationKey = "registration-logins"
    21  
    22  type registrationEntry struct {
    23  	Domain       string
    24  	ClientID     string
    25  	LoginEntryID string
    26  	Expire       time.Time
    27  }
    28  
    29  func (r *registrationEntry) Key() string {
    30  	return r.Domain + "|" + r.ClientID
    31  }
    32  
    33  var (
    34  	registrationExpirationDuration = 5 * time.Minute
    35  
    36  	registrationsMap     map[string]registrationEntry
    37  	registrationsMapLock sync.Mutex
    38  )
    39  
    40  // SweepLoginRegistrations starts the login registration process.
    41  //
    42  // This process involving a queue of registration login entries is necessary to
    43  // distinguish "normal" logins from logins to give right to an OAuth
    44  // application.
    45  //
    46  // Since we cannot really distinguish between them other than trusting the
    47  // user, we send a notification to the user by following this process:
    48  //   - if we identify a login for a device registration — by looking at the
    49  //     redirection address — we push an entry onto the queue
    50  //   - if we do not receive the activation of the device by the user in 5
    51  //     minutes, we send a notification for a "normal" login
    52  //   - otherwise we send a notification for the activation of a new device.
    53  func SweepLoginRegistrations() utils.Shutdowner {
    54  	closed := make(chan struct{})
    55  	go func() {
    56  		waitDuration := registrationExpirationDuration / 2
    57  		for {
    58  			select {
    59  			case <-time.After(waitDuration):
    60  				var err error
    61  				waitDuration, err = sweepRegistrations()
    62  				if err != nil {
    63  					log.Errorf("Could not sweep registration queue: %s", err)
    64  				}
    65  				if waitDuration <= 0 {
    66  					waitDuration = registrationExpirationDuration
    67  				}
    68  			case <-closed:
    69  				return
    70  			}
    71  		}
    72  	}()
    73  	return &sweeper{closed}
    74  }
    75  
    76  type sweeper struct {
    77  	closed chan struct{}
    78  }
    79  
    80  func (s *sweeper) Shutdown(ctx context.Context) error {
    81  	select {
    82  	case s.closed <- struct{}{}:
    83  	case <-ctx.Done():
    84  	}
    85  	return nil
    86  }
    87  
    88  // PushLoginRegistration pushes a new login into the registration queue.
    89  func PushLoginRegistration(db prefixer.Prefixer, login *LoginEntry, clientID string) error {
    90  	entry := registrationEntry{
    91  		Domain:       db.DomainName(),
    92  		ClientID:     clientID,
    93  		LoginEntryID: login.ID(),
    94  		Expire:       time.Now(),
    95  	}
    96  
    97  	if cli := config.GetConfig().SessionStorage; cli != nil {
    98  		b, err := json.Marshal(entry)
    99  		if err != nil {
   100  			return err
   101  		}
   102  		ctx := context.Background()
   103  		return cli.HSet(ctx, redisRegistationKey, entry.Key(), b).Err()
   104  	}
   105  
   106  	registrationsMapLock.Lock()
   107  	if registrationsMap == nil {
   108  		registrationsMap = make(map[string]registrationEntry)
   109  	}
   110  	registrationsMap[entry.Key()] = entry
   111  	registrationsMapLock.Unlock()
   112  	return nil
   113  }
   114  
   115  // RemoveLoginRegistration removes a login from the registration map.
   116  func RemoveLoginRegistration(domain, clientID string) error {
   117  	var entryPtr *registrationEntry
   118  	key := domain + "|" + clientID
   119  	if cli := config.GetConfig().SessionStorage; cli != nil {
   120  		ctx := context.Background()
   121  		b, err := cli.HGet(ctx, redisRegistationKey, key).Result()
   122  		if err != nil {
   123  			return err
   124  		}
   125  		var entry registrationEntry
   126  		if err = json.Unmarshal([]byte(b), &entry); err != nil {
   127  			return err
   128  		}
   129  		if err = cli.HDel(ctx, redisRegistationKey, key).Err(); err != nil {
   130  			return err
   131  		}
   132  		entryPtr = &entry
   133  	} else {
   134  		registrationsMapLock.Lock()
   135  		entry, ok := registrationsMap[key]
   136  		if ok {
   137  			delete(registrationsMap, key)
   138  			entryPtr = &entry
   139  		}
   140  		registrationsMapLock.Unlock()
   141  	}
   142  	if entryPtr != nil {
   143  		_ = sendRegistrationNotification(entryPtr, true)
   144  	}
   145  	return nil
   146  }
   147  
   148  func sweepRegistrations() (waitDuration time.Duration, err error) {
   149  	var expiredLogins []registrationEntry
   150  
   151  	now := time.Now()
   152  	if cli := config.GetConfig().SessionStorage; cli != nil {
   153  		ctx := context.Background()
   154  		var vals map[string]string
   155  		vals, err = cli.HGetAll(ctx, redisRegistationKey).Result()
   156  		if err != nil {
   157  			return
   158  		}
   159  
   160  		var deletedKeys []string
   161  		for key, data := range vals {
   162  			var entry registrationEntry
   163  			if err = json.Unmarshal([]byte(data), &entry); err != nil {
   164  				deletedKeys = append(deletedKeys, key)
   165  				continue
   166  			}
   167  			diff := entry.Expire.Sub(now)
   168  			if diff < -24*time.Hour {
   169  				// skip too old entries
   170  				deletedKeys = append(deletedKeys, entry.Key())
   171  			} else if diff <= 10*time.Second {
   172  				expiredLogins = append(expiredLogins, entry)
   173  				deletedKeys = append(deletedKeys, entry.Key())
   174  			} else if waitDuration == 0 || waitDuration > diff {
   175  				waitDuration = diff
   176  			}
   177  		}
   178  
   179  		if len(deletedKeys) > 0 {
   180  			err = cli.HDel(ctx, redisRegistationKey, deletedKeys...).Err()
   181  		}
   182  	} else {
   183  		registrationsMapLock.Lock()
   184  
   185  		var deletedKeys []string
   186  		for _, entry := range registrationsMap {
   187  			diff := entry.Expire.Sub(now)
   188  			if diff < -24*time.Hour {
   189  				// skip too old entries
   190  				deletedKeys = append(deletedKeys, entry.Key())
   191  			} else if diff <= 10*time.Second {
   192  				expiredLogins = append(expiredLogins, entry)
   193  				deletedKeys = append(deletedKeys, entry.Key())
   194  			} else if waitDuration == 0 || waitDuration > diff {
   195  				waitDuration = diff
   196  			}
   197  		}
   198  
   199  		for _, key := range deletedKeys {
   200  			delete(registrationsMap, key)
   201  		}
   202  
   203  		registrationsMapLock.Unlock()
   204  	}
   205  
   206  	if len(expiredLogins) > 0 {
   207  		sendExpiredRegistrationNotifications(expiredLogins)
   208  	}
   209  
   210  	return
   211  }
   212  
   213  func sendRegistrationNotification(entry *registrationEntry, registrationNotification bool) error {
   214  	i, err := lifecycle.GetInstance(entry.Domain)
   215  	if err != nil {
   216  		return err
   217  	}
   218  
   219  	var clientID string
   220  	if registrationNotification {
   221  		clientID = entry.ClientID
   222  	}
   223  	if clientID != "" {
   224  		return SendNewRegistrationNotification(i, clientID)
   225  	}
   226  
   227  	var login LoginEntry
   228  	err = couchdb.GetDoc(i, consts.SessionsLogins, entry.LoginEntryID, &login)
   229  	if err != nil {
   230  		return err
   231  	}
   232  	return sendLoginNotification(i, &login)
   233  }
   234  
   235  func sendExpiredRegistrationNotifications(entries []registrationEntry) {
   236  	for _, entry := range entries {
   237  		_ = sendRegistrationNotification(&entry, false)
   238  	}
   239  }