github.com/cozy/cozy-stack@v0.0.0-20240603063001-31110fa4cae1/model/job/redis_scheduler.go (about)

     1  package job
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"strconv"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/cozy/cozy-stack/model/permission"
    13  	"github.com/cozy/cozy-stack/pkg/consts"
    14  	"github.com/cozy/cozy-stack/pkg/couchdb"
    15  	"github.com/cozy/cozy-stack/pkg/couchdb/mango"
    16  	"github.com/cozy/cozy-stack/pkg/limits"
    17  	"github.com/cozy/cozy-stack/pkg/logger"
    18  	"github.com/cozy/cozy-stack/pkg/prefixer"
    19  	"github.com/cozy/cozy-stack/pkg/realtime"
    20  	"github.com/redis/go-redis/v9"
    21  )
    22  
    23  // TriggersKey is the the key of the sorted set in redis used for triggers
    24  // waiting to be activated
    25  const TriggersKey = "triggers"
    26  
    27  // SchedKey is the the key of the sorted set in redis used for triggers
    28  // currently being executed
    29  const SchedKey = "scheduling"
    30  
    31  // pollInterval is the time interval between 2 redis polling
    32  const pollInterval = 1 * time.Second
    33  
    34  // eventLoopSize is the number of goroutines handling @events and triggering
    35  // jobs.
    36  const eventLoopSize = 50
    37  
    38  // luaPoll returns the lua script used for polling triggers in redis.
    39  // If a trigger is in the scheduling key for more than 10 seconds, it is
    40  // an error and we can try again to schedule it.
    41  const luaPoll = `
    42  local w = KEYS[1] - 10
    43  local s = redis.call("ZRANGEBYSCORE", "` + SchedKey + `", 0, w, "WITHSCORES", "LIMIT", 0, 1)
    44  if #s > 0 then
    45    redis.call("ZADD", "` + SchedKey + `", KEYS[1], s[1])
    46    return s
    47  end
    48  local t = redis.call("ZRANGEBYSCORE", "` + TriggersKey + `", 0, KEYS[1], "WITHSCORES", "LIMIT", 0, 1)
    49  if #t > 0 then
    50    redis.call("ZREM", "` + TriggersKey + `", t[1])
    51    redis.call("ZADD", "` + SchedKey + `", t[2], t[1])
    52  end
    53  return t`
    54  
    55  // redisScheduler is a centralized scheduler of many triggers. It starts all of
    56  // them and schedules jobs accordingly.
    57  type redisScheduler struct {
    58  	broker  Broker
    59  	client  redis.UniversalClient
    60  	ctx     context.Context
    61  	thumb   *ThumbnailTrigger
    62  	share   *ShareGroupTrigger
    63  	closed  chan struct{}
    64  	stopped chan struct{}
    65  	log     *logger.Entry
    66  }
    67  
    68  // NewRedisScheduler creates a new scheduler that use redis to synchronize with
    69  // other cozy-stack processes to schedule jobs.
    70  func NewRedisScheduler(client redis.UniversalClient) Scheduler {
    71  	return &redisScheduler{
    72  		client:  client,
    73  		ctx:     context.Background(),
    74  		log:     logger.WithNamespace("scheduler-redis"),
    75  		stopped: make(chan struct{}),
    76  	}
    77  }
    78  
    79  func redisKey(t Trigger) string {
    80  	prefix := t.DBPrefix()
    81  	if cluster := t.DBCluster(); cluster > 0 {
    82  		prefix = fmt.Sprintf("%s%%%d", prefix, cluster)
    83  	}
    84  	return prefix + "/" + t.Infos().TID
    85  }
    86  
    87  func payloadKey(t Trigger) string {
    88  	return "payload-" + t.DBPrefix() + "/" + t.Infos().TID
    89  }
    90  
    91  func eventsKey(db prefixer.Prefixer) string {
    92  	return "events-" + db.DBPrefix()
    93  }
    94  
    95  // StartScheduler a goroutine that will fetch triggers in redis to schedule
    96  // their jobs.
    97  func (s *redisScheduler) StartScheduler(b Broker) error {
    98  	s.broker = b
    99  	s.closed = make(chan struct{})
   100  	s.startEventDispatcher()
   101  	s.thumb = NewThumbnailTrigger(s.broker)
   102  	go s.thumb.Schedule()
   103  	s.share = NewShareGroupTrigger(s.broker)
   104  	go s.share.Schedule()
   105  	go s.pollLoop()
   106  	return nil
   107  }
   108  
   109  func (s *redisScheduler) pollLoop() {
   110  	ticker := time.NewTicker(pollInterval)
   111  	for {
   112  		select {
   113  		case <-s.closed:
   114  			ticker.Stop()
   115  			s.stopped <- struct{}{}
   116  			return
   117  		case <-ticker.C:
   118  			now := time.Now().UTC().Unix()
   119  			if err := s.PollScheduler(now); err != nil {
   120  				s.log.Warnf("Failed to poll redis: %s", err)
   121  			}
   122  		}
   123  	}
   124  }
   125  
   126  func (s *redisScheduler) startEventDispatcher() {
   127  	eventsCh := make(chan *realtime.Event, 100)
   128  	go func() {
   129  		c := realtime.GetHub().SubscribeFirehose()
   130  		defer func() {
   131  			c.Close()
   132  			close(eventsCh)
   133  		}()
   134  		for {
   135  			select {
   136  			case <-s.closed:
   137  				return
   138  			case event := <-c.Channel:
   139  				eventsCh <- event
   140  			}
   141  		}
   142  	}()
   143  	for i := 0; i < eventLoopSize; i++ {
   144  		go s.eventLoop(eventsCh)
   145  	}
   146  }
   147  
   148  func (s *redisScheduler) eventLoop(eventsCh <-chan *realtime.Event) {
   149  	for event := range eventsCh {
   150  		key := eventsKey(event)
   151  		m, err := s.client.HGetAll(s.ctx, key).Result()
   152  		if err != nil {
   153  			s.log.Errorf("Could not fetch redis set %s: %s",
   154  				key, err.Error())
   155  			continue
   156  		}
   157  		for triggerID, arguments := range m {
   158  			found := false
   159  			for _, args := range strings.Split(arguments, " ") {
   160  				rule, err := permission.UnmarshalRuleString(args)
   161  				if err != nil {
   162  					s.log.Warnf("Coud not unmarshal rule %s: %s",
   163  						key, err.Error())
   164  					continue
   165  				}
   166  				if eventMatchRule(event, &rule) {
   167  					found = true
   168  					break
   169  				}
   170  			}
   171  			if !found {
   172  				continue
   173  			}
   174  			t, err := s.GetTrigger(event, triggerID)
   175  			if err != nil {
   176  				s.log.Warnf("Could not fetch @event trigger %s (%d) %s: %s",
   177  					event.Domain, event.Cluster, triggerID, err.Error())
   178  				continue
   179  			}
   180  			if t.Infos().WorkerType == "thumbnail" {
   181  				// Remove the legacy @event trigger for thumbnail, it is now hardcoded
   182  				_ = s.deleteTrigger(t)
   183  				continue
   184  			}
   185  			et := t.(*EventTrigger)
   186  			if et.Infos().Debounce != "" {
   187  				var d time.Duration
   188  				if d, err = time.ParseDuration(et.Infos().Debounce); err == nil {
   189  					timestamp := time.Now().Add(d)
   190  					s.client.ZAddNX(s.ctx, TriggersKey, redis.Z{
   191  						Score:  float64(timestamp.UTC().Unix()),
   192  						Member: redisKey(t),
   193  					})
   194  					continue
   195  				} else {
   196  					s.log.Warnf("Trigger %s %s has an invalid debounce: %s",
   197  						et.Infos().Domain, et.Infos().TID, et.Infos().Debounce)
   198  					continue
   199  				}
   200  			}
   201  			jobRequest, err := et.Infos().JobRequestWithEvent(event)
   202  			if err != nil {
   203  				s.log.Warnf("Could not encode realtime event %s %s: %s",
   204  					event.Domain, triggerID, err.Error())
   205  				continue
   206  			}
   207  			_, err = s.broker.PushJob(t, jobRequest)
   208  			if err != nil {
   209  				s.log.Warnf("Could not push job trigger by event %s %s: %s",
   210  					event.Domain, triggerID, err.Error())
   211  				continue
   212  			}
   213  		}
   214  	}
   215  }
   216  
   217  // fire is called when a webhook is fired.
   218  func (s *redisScheduler) fire(trigger Trigger, request *JobRequest) {
   219  	infos := trigger.Infos()
   220  	if infos.Debounce == "" {
   221  		if _, err := s.broker.PushJob(trigger, request); err != nil {
   222  			s.log.Warnf("Could not push job trigger by webhook %s %s: %s",
   223  				infos.Domain, infos.TID, err.Error())
   224  		}
   225  		return
   226  	}
   227  
   228  	d, err := time.ParseDuration(infos.Debounce)
   229  	if err != nil {
   230  		s.log.Warnf("Trigger %s %s has an invalid debounce: %s",
   231  			infos.Domain, infos.TID, infos.Debounce)
   232  	}
   233  	timestamp := time.Now().Add(d)
   234  	pipe := s.client.Pipeline()
   235  	switch trigger.CombineRequest() {
   236  	case appendPayload:
   237  		pipe.RPush(s.ctx, payloadKey(trigger), string(request.Payload))
   238  	case keepOriginalRequest:
   239  		pipe.SetNX(s.ctx, payloadKey(trigger), string(request.Payload), 30*24*time.Hour)
   240  	}
   241  	pipe.ZAddNX(s.ctx, TriggersKey, redis.Z{
   242  		Score:  float64(timestamp.UTC().Unix()),
   243  		Member: redisKey(trigger),
   244  	})
   245  	if _, err := pipe.Exec(s.ctx); err != nil {
   246  		s.log.Warnf("Cannot fire trigger because of redis error: %s", err)
   247  	}
   248  }
   249  
   250  // ShutdownScheduler shuts down the the scheduling of triggers
   251  func (s *redisScheduler) ShutdownScheduler(ctx context.Context) error {
   252  	if s.closed == nil {
   253  		return nil
   254  	}
   255  	fmt.Print("  shutting down redis scheduler...")
   256  	close(s.closed)
   257  	s.thumb.Unschedule()
   258  	s.share.Unschedule()
   259  	select {
   260  	case <-ctx.Done():
   261  		fmt.Println("failed: ", ctx.Err())
   262  		return ctx.Err()
   263  	case <-s.stopped:
   264  		fmt.Println("ok.")
   265  	}
   266  	return nil
   267  }
   268  
   269  // PollScheduler polls redis to see if there are some triggers ready.
   270  func (s *redisScheduler) PollScheduler(now int64) error {
   271  	keys := []string{strconv.FormatInt(now, 10)}
   272  	for {
   273  		res, err := s.client.Eval(s.ctx, luaPoll, keys).Result()
   274  		if err != nil || res == nil {
   275  			return err
   276  		}
   277  		results, ok := res.([]interface{})
   278  		if !ok {
   279  			return errors.New("Unexpected response from redis")
   280  		}
   281  		if len(results) < 2 {
   282  			return nil
   283  		}
   284  		parts := strings.SplitN(results[0].(string), "/", 2)
   285  		if len(parts) != 2 {
   286  			s.client.ZRem(s.ctx, SchedKey, results[0])
   287  			return fmt.Errorf("Invalid key %s", res)
   288  		}
   289  
   290  		triggerID := parts[1]
   291  		parts = strings.SplitN(parts[0], "%", 2)
   292  		prefix := parts[0]
   293  		var cluster int
   294  		if len(parts) > 1 {
   295  			cluster, _ = strconv.Atoi(parts[1])
   296  		}
   297  		t, err := s.GetTrigger(prefixer.NewPrefixer(cluster, "", prefix), triggerID)
   298  		if err != nil {
   299  			if errors.Is(err, ErrNotFoundTrigger) || errors.Is(err, ErrMalformedTrigger) {
   300  				s.client.ZRem(s.ctx, SchedKey, results[0])
   301  			}
   302  			return err
   303  		}
   304  		switch t := t.(type) {
   305  		case *EventTrigger, *WebhookTrigger: // Debounced
   306  			job := t.Infos().JobRequest()
   307  			job.Debounced = true
   308  			if err = s.client.ZRem(s.ctx, SchedKey, results[0]).Err(); err != nil {
   309  				return err
   310  			}
   311  			switch t.CombineRequest() {
   312  			case appendPayload:
   313  				pipe := s.client.Pipeline()
   314  				lrange := pipe.LRange(s.ctx, payloadKey(t), 0, -1)
   315  				pipe.Del(s.ctx, payloadKey(t))
   316  				if _, err := pipe.Exec(s.ctx); err == nil {
   317  					payloads := strings.Join(lrange.Val(), ",")
   318  					job.Payload = Payload(`{"payloads":[` + payloads + "]}")
   319  				}
   320  			case keepOriginalRequest:
   321  				pipe := s.client.Pipeline()
   322  				get := pipe.Get(s.ctx, payloadKey(t))
   323  				pipe.Del(s.ctx, payloadKey(t))
   324  				if _, err := pipe.Exec(s.ctx); err == nil {
   325  					job.Payload = Payload(get.Val())
   326  				}
   327  			}
   328  			if _, err = s.broker.PushJob(t, job); err != nil {
   329  				return err
   330  			}
   331  		case *AtTrigger:
   332  			job := t.Infos().JobRequest()
   333  			if _, err = s.broker.PushJob(t, job); err != nil {
   334  				if limits.IsLimitReachedOrExceeded(err) {
   335  					s.client.ZRem(s.ctx, SchedKey, results[0])
   336  				}
   337  				return err
   338  			}
   339  			if err = s.deleteTrigger(t); err != nil {
   340  				return err
   341  			}
   342  		case *CronTrigger:
   343  			job := t.Infos().JobRequest()
   344  			if _, err = s.broker.PushJob(t, job); err != nil {
   345  				// Remove the cron trigger from redis if it is invalid, as it
   346  				// may block other cron triggers
   347  				if errors.Is(err, ErrUnknownWorker) || limits.IsLimitReachedOrExceeded(err) {
   348  					s.client.ZRem(s.ctx, SchedKey, results[0])
   349  					continue
   350  				}
   351  				return err
   352  			}
   353  			score, err := strconv.ParseInt(results[1].(string), 10, 64)
   354  			var prev time.Time
   355  			if err != nil {
   356  				prev = time.Now()
   357  			} else {
   358  				prev = time.Unix(score, 0)
   359  			}
   360  			if err := s.addToRedis(t, prev); err != nil {
   361  				return err
   362  			}
   363  		default:
   364  			return errors.New("Not implemented yet")
   365  		}
   366  	}
   367  }
   368  
   369  // AddTrigger a trigger to the system, by persisting it and using redis for
   370  // scheduling its jobs
   371  func (s *redisScheduler) AddTrigger(t Trigger) error {
   372  	if err := createTrigger(t); err != nil {
   373  		return err
   374  	}
   375  	return s.addToRedis(t, time.Now())
   376  }
   377  
   378  func (s *redisScheduler) addToRedis(t Trigger, prev time.Time) error {
   379  	var timestamp time.Time
   380  	switch t := t.(type) {
   381  	case *EventTrigger:
   382  		hKey := eventsKey(t)
   383  		return s.client.HSet(s.ctx, hKey, t.ID(), t.Infos().Arguments).Err()
   384  	case *AtTrigger:
   385  		timestamp = t.at
   386  	case *CronTrigger:
   387  		timestamp = t.NextExecution(prev)
   388  		now := time.Now()
   389  		if timestamp.Before(now) {
   390  			timestamp = t.NextExecution(now)
   391  		}
   392  	case *WebhookTrigger, *ClientTrigger:
   393  		return nil
   394  	default:
   395  		return errors.New("Not implemented yet")
   396  	}
   397  	pipe := s.client.Pipeline()
   398  	err := pipe.ZAdd(s.ctx, TriggersKey, redis.Z{
   399  		Score:  float64(timestamp.UTC().Unix()),
   400  		Member: redisKey(t),
   401  	}).Err()
   402  	if err != nil {
   403  		return err
   404  	}
   405  	err = pipe.ZRem(s.ctx, SchedKey, redisKey(t)).Err()
   406  	if err != nil {
   407  		return err
   408  	}
   409  	_, err = pipe.Exec(s.ctx)
   410  	return err
   411  }
   412  
   413  // GetTrigger returns the trigger with the specified ID.
   414  func (s *redisScheduler) GetTrigger(db prefixer.Prefixer, id string) (Trigger, error) {
   415  	var infos TriggerInfos
   416  	if err := couchdb.GetDoc(db, consts.Triggers, id, &infos); err != nil {
   417  		if couchdb.IsNotFoundError(err) {
   418  			return nil, ErrNotFoundTrigger
   419  		}
   420  		return nil, err
   421  	}
   422  	t, err := fromTriggerInfos(&infos)
   423  	if err != nil {
   424  		return nil, err
   425  	}
   426  	if webhook, ok := t.(*WebhookTrigger); ok {
   427  		webhook.SetCallback(s)
   428  	}
   429  	return t, nil
   430  }
   431  
   432  // UpdateMessage changes the message for the given trigger.
   433  func (s *redisScheduler) UpdateMessage(db prefixer.Prefixer, trigger Trigger, message json.RawMessage) error {
   434  	infos := trigger.Infos()
   435  	infos.Message = Message(message)
   436  	return couchdb.UpdateDoc(db, infos)
   437  }
   438  
   439  // UpdateCron will change the frequency of execution for the given trigger.
   440  func (s *redisScheduler) UpdateCron(db prefixer.Prefixer, trigger Trigger, arguments string) error {
   441  	if trigger.Type() != "@cron" {
   442  		return ErrNotCronTrigger
   443  	}
   444  	infos := trigger.Infos()
   445  	infos.Arguments = arguments
   446  	updated, err := NewCronTrigger(infos)
   447  	if err != nil {
   448  		return err
   449  	}
   450  	if err := couchdb.UpdateDoc(db, infos); err != nil {
   451  		return err
   452  	}
   453  	timestamp := updated.NextExecution(time.Now())
   454  	pipe := s.client.Pipeline()
   455  	pipe.ZRem(s.ctx, TriggersKey, redisKey(updated))
   456  	pipe.ZRem(s.ctx, SchedKey, redisKey(updated))
   457  	pipe.ZAdd(s.ctx, TriggersKey, redis.Z{
   458  		Score:  float64(timestamp.UTC().Unix()),
   459  		Member: redisKey(updated),
   460  	})
   461  	_, err = pipe.Exec(s.ctx)
   462  	return err
   463  }
   464  
   465  // DeleteTrigger removes the trigger with the specified ID. The trigger is
   466  // unscheduled and remove from the storage.
   467  func (s *redisScheduler) DeleteTrigger(db prefixer.Prefixer, id string) error {
   468  	t, err := s.GetTrigger(db, id)
   469  	if err != nil {
   470  		return err
   471  	}
   472  	return s.deleteTrigger(t)
   473  }
   474  
   475  func (s *redisScheduler) deleteTrigger(t Trigger) error {
   476  	if err := couchdb.DeleteDoc(t, t.Infos()); err != nil {
   477  		return err
   478  	}
   479  	switch t.(type) {
   480  	case *EventTrigger:
   481  		return s.client.HDel(s.ctx, eventsKey(t), t.ID()).Err()
   482  	case *AtTrigger, *CronTrigger:
   483  		pipe := s.client.Pipeline()
   484  		pipe.ZRem(s.ctx, TriggersKey, redisKey(t))
   485  		pipe.ZRem(s.ctx, SchedKey, redisKey(t))
   486  		_, err := pipe.Exec(s.ctx)
   487  		return err
   488  	}
   489  	return nil
   490  }
   491  
   492  // GetAllTriggers returns all the triggers for a domain, from couch.
   493  func (s *redisScheduler) GetAllTriggers(db prefixer.Prefixer) ([]Trigger, error) {
   494  	var infos []*TriggerInfos
   495  	err := couchdb.ForeachDocs(db, consts.Triggers, func(_ string, data json.RawMessage) error {
   496  		var t *TriggerInfos
   497  		if err := json.Unmarshal(data, &t); err != nil {
   498  			return err
   499  		}
   500  		infos = append(infos, t)
   501  		return nil
   502  	})
   503  	if err != nil {
   504  		if couchdb.IsNoDatabaseError(err) {
   505  			return nil, nil
   506  		}
   507  		return nil, err
   508  	}
   509  	v := make([]Trigger, 0, len(infos))
   510  	for _, info := range infos {
   511  		t, err := fromTriggerInfos(info)
   512  		if err != nil {
   513  			return nil, err
   514  		}
   515  		v = append(v, t)
   516  	}
   517  	return v, nil
   518  }
   519  
   520  // HasEventTrigger returns true if the given trigger already exists. Only the
   521  // type (@event, @cron...), worker, and arguments (if not empty) are looked at.
   522  func (s *redisScheduler) HasTrigger(db prefixer.Prefixer, infos TriggerInfos) bool {
   523  	var candidates []*TriggerInfos
   524  	limit := 1000
   525  	if infos.Arguments == "" {
   526  		limit = 1
   527  	}
   528  	req := &couchdb.FindRequest{
   529  		UseIndex: "by-worker-and-type",
   530  		Selector: mango.And(
   531  			mango.Equal("worker", infos.WorkerType),
   532  			mango.Equal("type", infos.Type),
   533  		),
   534  		Limit: limit,
   535  	}
   536  	err := couchdb.FindDocs(db, consts.Triggers, req, &candidates)
   537  	if err != nil {
   538  		s.log.Errorf("Could not fetch triggers: %s", err)
   539  		return false
   540  	}
   541  	if infos.Arguments == "" && len(candidates) > 0 {
   542  		return true
   543  	}
   544  	for _, candidate := range candidates {
   545  		if infos.Arguments == candidate.Arguments {
   546  			return true
   547  		}
   548  	}
   549  	return false
   550  }
   551  
   552  // CleanRedis removes clean redis by removing the two sets holding the triggers
   553  // states.
   554  func (s *redisScheduler) CleanRedis() error {
   555  	return s.client.Del(s.ctx, TriggersKey, SchedKey).Err()
   556  }
   557  
   558  // RebuildRedis puts all the triggers in redis (idempotent)
   559  func (s *redisScheduler) RebuildRedis(db prefixer.Prefixer) error {
   560  	triggers, err := s.GetAllTriggers(db)
   561  	if err != nil {
   562  		joblog.Errorf("Error when rebuilding redis for domain %q: %s",
   563  			db.DomainName(), err)
   564  		return err
   565  	}
   566  	for _, t := range triggers {
   567  		if err = s.addToRedis(t, time.Now()); err != nil {
   568  			joblog.Errorf("Error when rebuilding redis for domain %q: %s (%v)",
   569  				db.DomainName(), err, t)
   570  			return err
   571  		}
   572  	}
   573  	joblog.Infof("Redis rebuilt for domain %q with %d triggers created",
   574  		db.DomainName(), len(triggers))
   575  	return nil
   576  }
   577  
   578  var _ Scheduler = &redisScheduler{}