github.com/ngicks/gokugen@v0.0.5/task_storage/single_node.go (about)

     1  package taskstorage
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/ngicks/gokugen"
    11  	"github.com/ngicks/gokugen/common"
    12  	"github.com/ngicks/type-param-common/set"
    13  )
    14  
    15  var (
    16  	ErrMiddlewareOrder   = errors.New("invalid middleware order")
    17  	ErrNonexistentWorkId = errors.New("nonexistent work id")
    18  )
    19  
    20  // ExternalStateChangeErr is used to tell that error is caused by
    21  // external repository manipulations.
    22  type ExternalStateChangeErr struct {
    23  	id    string
    24  	state TaskState
    25  }
    26  
    27  // Error implements Error interface.
    28  func (e ExternalStateChangeErr) Error() string {
    29  	return fmt.Sprintf("The state is changed externally: id = %s, state = %s", e.id, e.state)
    30  }
    31  
    32  type WorkFn = gokugen.WorkFn
    33  type WorkFnWParam = gokugen.WorkFnWParam
    34  
    35  // WorkRegistry is used to retrieve work function by workId.
    36  // impl/work_registry.ParamUnmarshaller will be good enough for almost all users.
    37  type WorkRegistry interface {
    38  	Load(key string) (value WorkFnWParam, ok bool)
    39  }
    40  
    41  // SingleNodeTaskStorage provides ability to store task information to,
    42  // and restore from persistent data storage.
    43  type SingleNodeTaskStorage struct {
    44  	repo           Repository
    45  	failedIds      *SyncStateStore
    46  	shouldRestore  func(TaskInfo) bool
    47  	workRegistry   WorkRegistry
    48  	taskMap        *TaskMap
    49  	mu             sync.Mutex
    50  	getNow         common.GetNower // this field can be swapped out in test codes.
    51  	lastSynced     time.Time
    52  	knownIdForTime set.Set[string]
    53  	syncCtxWrapper func(gokugen.SchedulerContext) gokugen.SchedulerContext
    54  }
    55  
    56  // NewSingleNodeTaskStorage creates new SingleNodeTaskStorage instance.
    57  //
    58  // repo is Repository, interface to manipulate persistent data storage.
    59  //
    60  // shouldRestore is used in Sync, to decide if task should be restored and re-scheduled in internal scheduler.
    61  // (e.g. ignore tasks if they are too old and overdue.)
    62  //
    63  // workRegistry is used to retrieve work function associated to workId.
    64  // User must register functions to registry beforehand.
    65  //
    66  // syncCtxWrapper is used in Sync. Sync tries to schedule newly craeted context.
    67  // this context will be wrapped with syncCtxWrapper if non nil.
    68  func NewSingleNodeTaskStorage(
    69  	repo Repository,
    70  	shouldRestore func(TaskInfo) bool,
    71  	workRegistry WorkRegistry,
    72  	syncCtxWrapper func(gokugen.SchedulerContext) gokugen.SchedulerContext,
    73  ) *SingleNodeTaskStorage {
    74  	return &SingleNodeTaskStorage{
    75  		repo:           repo,
    76  		shouldRestore:  shouldRestore,
    77  		failedIds:      NewSyncStateStore(),
    78  		workRegistry:   workRegistry,
    79  		taskMap:        NewTaskMap(),
    80  		getNow:         common.GetNowImpl{},
    81  		knownIdForTime: set.Set[string]{},
    82  		syncCtxWrapper: syncCtxWrapper,
    83  	}
    84  }
    85  
    86  func (ts *SingleNodeTaskStorage) paramLoad(handler gokugen.ScheduleHandlerFn) gokugen.ScheduleHandlerFn {
    87  	return func(ctx gokugen.SchedulerContext) (gokugen.Task, error) {
    88  		taskId, err := gokugen.GetTaskId(ctx)
    89  		if err != nil {
    90  			return nil, err
    91  		}
    92  		loadable := gokugen.WrapContext(ctx,
    93  			gokugen.WithParamLoader(
    94  				func() (any, error) {
    95  					info, err := ts.repo.GetById(taskId)
    96  					if err != nil {
    97  						return nil, err
    98  					}
    99  					return info.Param, nil
   100  				},
   101  			),
   102  		)
   103  		return handler(loadable)
   104  	}
   105  }
   106  
   107  func (ts *SingleNodeTaskStorage) storeTask(handler gokugen.ScheduleHandlerFn) gokugen.ScheduleHandlerFn {
   108  	return func(ctx gokugen.SchedulerContext) (task gokugen.Task, err error) {
   109  		param, err := gokugen.GetParam(ctx)
   110  		if err != nil {
   111  			return
   112  		}
   113  		workId, err := gokugen.GetWorkId(ctx)
   114  		if err != nil {
   115  			return
   116  		}
   117  		scheduledTime := ctx.ScheduledTime()
   118  		workWithParam, ok := ts.workRegistry.Load(workId)
   119  		if !ok {
   120  			err = fmt.Errorf("%w: unknown work id = %s", ErrNonexistentWorkId, workId)
   121  			return
   122  		}
   123  
   124  		taskId, err := gokugen.GetTaskId(ctx)
   125  		if err != nil && !errors.Is(err, gokugen.ErrValueNotFound) {
   126  			return
   127  		}
   128  
   129  		hadTaskId := true
   130  		if taskId == "" {
   131  			// ctx does not contain task id.
   132  			// needs to create new entry in repository.
   133  			hadTaskId = false
   134  			taskId, err = ts.repo.Insert(TaskInfo{
   135  				WorkId:        workId,
   136  				Param:         param,
   137  				ScheduledTime: scheduledTime,
   138  				State:         Initialized,
   139  			})
   140  			if err != nil {
   141  				return
   142  			}
   143  		}
   144  
   145  		var newCtx gokugen.SchedulerContext = ctx
   146  		if !hadTaskId {
   147  			newCtx = gokugen.WrapContext(
   148  				ctx,
   149  				gokugen.WithTaskId(taskId),
   150  			)
   151  		}
   152  		if workSet := ctx.Work(); workSet == nil {
   153  			newCtx = gokugen.WrapContext(
   154  				newCtx,
   155  				gokugen.WithWorkFnWrapper(
   156  					func(self gokugen.SchedulerContext, _ WorkFn) WorkFn {
   157  						return func(taskCtx context.Context, scheduled time.Time) (any, error) {
   158  							param, err := gokugen.GetParam(self)
   159  							if err != nil {
   160  								return nil, err
   161  							}
   162  							ret, err := workWithParam(taskCtx, scheduled, param)
   163  							markDoneTask(err, ts, taskId)
   164  							return ret, err
   165  						}
   166  					},
   167  				),
   168  			)
   169  		}
   170  
   171  		task, err = handler(newCtx)
   172  		if err != nil {
   173  			return
   174  		}
   175  		task = wrapCancel(ts.repo, ts.failedIds, ts.taskMap, taskId, task)
   176  		ts.taskMap.LoadOrStore(taskId, task)
   177  		return
   178  	}
   179  }
   180  
   181  func markDoneTask(result error, ts *SingleNodeTaskStorage, taskId string) {
   182  	ts.taskMap.Delete(taskId)
   183  	if errors.Is(result, ErrOtherNodeWorkingOnTheTask) {
   184  		return
   185  	} else if result != nil {
   186  		_, err := ts.repo.MarkAsFailed(taskId)
   187  		if err != nil {
   188  			ts.failedIds.Put(taskId, Failed)
   189  		}
   190  	} else {
   191  		_, err := ts.repo.MarkAsDone(taskId)
   192  		if err != nil {
   193  			ts.failedIds.Put(taskId, Done)
   194  		}
   195  	}
   196  }
   197  
   198  // Middleware returns gokugen.MiddlewareFunc's. Order must be maintained.
   199  // Though these middleware(s), task context info is stored in external persistent data storage.
   200  //
   201  // If freeParam is true, param free up functionality is enabled.
   202  // It let those middlewares to forget param until needed.
   203  // Setting freeParam true adds one middleware
   204  // that loads up param from repository right before work execution.
   205  func (ts *SingleNodeTaskStorage) Middleware(freeParam bool) []gokugen.MiddlewareFunc {
   206  	if freeParam {
   207  		return []gokugen.MiddlewareFunc{ts.storeTask, ts.paramLoad}
   208  	}
   209  	return []gokugen.MiddlewareFunc{ts.storeTask}
   210  }
   211  
   212  // Sync syncs itnernal state with an external data storage.
   213  // Normally TaskStorage does it reversely through middlewares, mirroring internal state to the external data storage.
   214  // But after rebooting the system, or repository is changed externally, Sync is needed to fetch back external data.
   215  func (ts *SingleNodeTaskStorage) Sync(
   216  	schedule func(ctx gokugen.SchedulerContext) (gokugen.Task, error),
   217  ) (rescheduled map[string]gokugen.Task, schedulingErr map[string]error, err error) {
   218  	ts.mu.Lock()
   219  	defer ts.mu.Unlock()
   220  
   221  	syncedAt := ts.lastSynced
   222  	fetchedIds, err := ts.repo.GetUpdatedSince(ts.lastSynced)
   223  	if err != nil {
   224  		return
   225  	}
   226  
   227  	rescheduled = make(map[string]gokugen.Task)
   228  	schedulingErr = make(map[string]error)
   229  
   230  	for _, fetched := range fetchedIds {
   231  		if fetched.LastModified.After(syncedAt) {
   232  			// Latest time of fetched tasks is next last-synced time.
   233  			// We do want to set it correctly to avoid doubly syncing same entry.
   234  			//
   235  			// And also, GetUpdatedSince implemention may limit the number of fetched entries.
   236  			// There could still be non-synced tasks that is modified at same time as syncedAt.
   237  			syncedAt = fetched.LastModified
   238  			// Also clear knownId for this time.
   239  			// knowId storage is needed to avoid doubly syncing.
   240  			// Racy clients may add or modify tasks for same second after this agent fetched lastly.
   241  			//
   242  			// Some of database support only one second presicion.
   243  			// (e.g. strftime('%s') of sqlite3. you can also use `strftime('%s','now') || substr(strftime('%f','now'),4)`
   244  			//   to enable milli second precision. But there could still be race conditions.)
   245  			ts.knownIdForTime.Clear()
   246  		}
   247  
   248  		if ts.knownIdForTime.Has(fetched.Id) {
   249  			continue
   250  		} else {
   251  			ts.knownIdForTime.Add(fetched.Id)
   252  		}
   253  
   254  		task, err := ts.sync(schedule, fetched)
   255  		if err != nil {
   256  			schedulingErr[fetched.Id] = err
   257  		} else if task != nil {
   258  			rescheduled[fetched.Id] = task
   259  		}
   260  	}
   261  
   262  	if syncedAt.After(ts.lastSynced) {
   263  		ts.lastSynced = syncedAt
   264  	}
   265  	return
   266  }
   267  
   268  func (ts *SingleNodeTaskStorage) sync(
   269  	schedule func(ctx gokugen.SchedulerContext) (gokugen.Task, error),
   270  	fetched TaskInfo,
   271  ) (task gokugen.Task, schedulingErr error) {
   272  
   273  	_, ok := ts.workRegistry.Load(fetched.WorkId)
   274  	if !ok {
   275  		return nil, fmt.Errorf("%w: unknown work id = %s", ErrNonexistentWorkId, fetched.WorkId)
   276  	}
   277  
   278  	switch fetched.State {
   279  	case Working, Done, Cancelled, Failed:
   280  		if inTaskMap, loaded := ts.taskMap.LoadAndDelete(fetched.Id); loaded {
   281  			inTaskMap.CancelWithReason(ExternalStateChangeErr{
   282  				id:    fetched.Id,
   283  				state: fetched.State,
   284  			})
   285  		}
   286  		return
   287  	default:
   288  		if ts.taskMap.Has(fetched.Id) {
   289  			return
   290  		}
   291  	}
   292  
   293  	if !ts.shouldRestore(fetched) {
   294  		return
   295  	}
   296  
   297  	param := fetched.Param
   298  	var ctx gokugen.SchedulerContext = gokugen.BuildContext(
   299  		fetched.ScheduledTime,
   300  		nil,
   301  		make(map[any]any),
   302  		gokugen.WithTaskId(fetched.Id),
   303  		gokugen.WithWorkId(fetched.WorkId),
   304  		gokugen.WithParam(param),
   305  	)
   306  
   307  	if ts.syncCtxWrapper != nil {
   308  		ctx = ts.syncCtxWrapper(ctx)
   309  	}
   310  
   311  	task, err := schedule(ctx)
   312  	if err != nil {
   313  		return nil, err
   314  	}
   315  	ts.taskMap.LoadOrStore(fetched.Id, task)
   316  	return
   317  }
   318  
   319  // RetryMarking retries to mark of failed marking.
   320  func (s *SingleNodeTaskStorage) RetryMarking() (allRemoved bool) {
   321  	for _, set := range s.failedIds.GetAll() {
   322  		if !s.failedIds.Remove(set.Key) {
   323  			// race condition.
   324  			continue
   325  		}
   326  		var err error
   327  		switch set.Value {
   328  		case Done:
   329  			_, err = s.repo.MarkAsDone(set.Key)
   330  		case Cancelled:
   331  			_, err = s.repo.MarkAsCancelled(set.Key)
   332  		case Failed:
   333  			_, err = s.repo.MarkAsFailed(set.Key)
   334  		}
   335  
   336  		if err != nil {
   337  			s.failedIds.Put(set.Key, set.Value)
   338  		}
   339  	}
   340  	return s.failedIds.Len() == 0
   341  }
   342  
   343  func wrapCancel(repo Repository, failedIds *SyncStateStore, taskMap *TaskMap, id string, t gokugen.Task) gokugen.Task {
   344  	return &taskWrapper{
   345  		Task: t,
   346  		cancel: func(baseCanceller func(err error) (cancelled bool)) func(err error) (cancelled bool) {
   347  			return func(err error) (cancelled bool) {
   348  				cancelled = baseCanceller(err)
   349  				taskMap.Delete(id)
   350  				if _, ok := err.(ExternalStateChangeErr); ok {
   351  					// no marking is needed since it is already changed by external source!
   352  					return
   353  				}
   354  				if cancelled && !t.IsDone() {
   355  					// if it's done, we dont want to call heavy marking method.
   356  					_, err := repo.MarkAsCancelled(id)
   357  					if err != nil {
   358  						failedIds.Put(id, Cancelled)
   359  					}
   360  				}
   361  				return
   362  			}
   363  		},
   364  	}
   365  }