github.com/timandy/routine@v1.1.4-0.20240507073150-e4a3e1fe2ba5/future_task.go (about)

     1  package routine
     2  
     3  import (
     4  	"fmt"
     5  	"sync"
     6  	"sync/atomic"
     7  	"time"
     8  )
     9  
    10  type taskState = int32
    11  
    12  const (
    13  	taskStateNew taskState = iota
    14  	taskStateRunning
    15  	taskStateCompleted
    16  	taskStateCanceled
    17  	taskStateFailed
    18  )
    19  
    20  type futureTask[TResult any] struct {
    21  	await    sync.WaitGroup
    22  	state    taskState
    23  	callable FutureCallable[TResult]
    24  	result   TResult
    25  	error    RuntimeError
    26  }
    27  
    28  func (task *futureTask[TResult]) IsDone() bool {
    29  	state := atomic.LoadInt32(&task.state)
    30  	return state == taskStateCompleted || state == taskStateCanceled || state == taskStateFailed
    31  }
    32  
    33  func (task *futureTask[TResult]) IsCanceled() bool {
    34  	return atomic.LoadInt32(&task.state) == taskStateCanceled
    35  }
    36  
    37  func (task *futureTask[TResult]) IsFailed() bool {
    38  	return atomic.LoadInt32(&task.state) == taskStateFailed
    39  }
    40  
    41  func (task *futureTask[TResult]) Complete(result TResult) {
    42  	if atomic.CompareAndSwapInt32(&task.state, taskStateNew, taskStateCompleted) ||
    43  		atomic.CompareAndSwapInt32(&task.state, taskStateRunning, taskStateCompleted) {
    44  		task.result = result
    45  		task.await.Done()
    46  	}
    47  }
    48  
    49  func (task *futureTask[TResult]) Cancel() {
    50  	if atomic.CompareAndSwapInt32(&task.state, taskStateNew, taskStateCanceled) ||
    51  		atomic.CompareAndSwapInt32(&task.state, taskStateRunning, taskStateCanceled) {
    52  		task.error = NewRuntimeError("Task was canceled.")
    53  		task.await.Done()
    54  	}
    55  }
    56  
    57  func (task *futureTask[TResult]) Fail(error any) {
    58  	if atomic.CompareAndSwapInt32(&task.state, taskStateNew, taskStateFailed) ||
    59  		atomic.CompareAndSwapInt32(&task.state, taskStateRunning, taskStateFailed) {
    60  		runtimeErr, isRuntimeErr := error.(RuntimeError)
    61  		if !isRuntimeErr {
    62  			runtimeErr = NewRuntimeError(error)
    63  		}
    64  		task.error = runtimeErr
    65  		task.await.Done()
    66  	}
    67  }
    68  
    69  func (task *futureTask[TResult]) Get() TResult {
    70  	task.await.Wait()
    71  	if atomic.LoadInt32(&task.state) == taskStateCompleted {
    72  		return task.result
    73  	}
    74  	panic(task.error)
    75  }
    76  
    77  func (task *futureTask[TResult]) GetWithTimeout(timeout time.Duration) TResult {
    78  	waitChan := make(chan struct{})
    79  	go func() {
    80  		task.await.Wait()
    81  		close(waitChan)
    82  	}()
    83  	timer := time.NewTimer(timeout)
    84  	defer timer.Stop()
    85  	select {
    86  	case <-waitChan:
    87  		if atomic.LoadInt32(&task.state) == taskStateCompleted {
    88  			return task.result
    89  		}
    90  		panic(task.error)
    91  	case <-timer.C:
    92  		task.timeout(timeout)
    93  		task.await.Wait()
    94  		if atomic.LoadInt32(&task.state) == taskStateCompleted {
    95  			return task.result
    96  		}
    97  		panic(task.error)
    98  	}
    99  }
   100  
   101  func (task *futureTask[TResult]) Run() {
   102  	if atomic.CompareAndSwapInt32(&task.state, taskStateNew, taskStateRunning) {
   103  		defer func() {
   104  			if cause := recover(); cause != nil {
   105  				task.Fail(cause)
   106  			}
   107  		}()
   108  		result := task.callable(task)
   109  		task.Complete(result)
   110  	}
   111  }
   112  
   113  func (task *futureTask[TResult]) timeout(timeout time.Duration) {
   114  	if atomic.CompareAndSwapInt32(&task.state, taskStateNew, taskStateCanceled) ||
   115  		atomic.CompareAndSwapInt32(&task.state, taskStateRunning, taskStateCanceled) {
   116  		task.error = NewRuntimeError(fmt.Sprintf("Task execution timeout after %v.", timeout))
   117  		task.await.Done()
   118  	}
   119  }