github.com/timandy/routine@v1.1.4-0.20240507073150-e4a3e1fe2ba5/future_task.go (about) 1 package routine 2 3 import ( 4 "fmt" 5 "sync" 6 "sync/atomic" 7 "time" 8 ) 9 10 type taskState = int32 11 12 const ( 13 taskStateNew taskState = iota 14 taskStateRunning 15 taskStateCompleted 16 taskStateCanceled 17 taskStateFailed 18 ) 19 20 type futureTask[TResult any] struct { 21 await sync.WaitGroup 22 state taskState 23 callable FutureCallable[TResult] 24 result TResult 25 error RuntimeError 26 } 27 28 func (task *futureTask[TResult]) IsDone() bool { 29 state := atomic.LoadInt32(&task.state) 30 return state == taskStateCompleted || state == taskStateCanceled || state == taskStateFailed 31 } 32 33 func (task *futureTask[TResult]) IsCanceled() bool { 34 return atomic.LoadInt32(&task.state) == taskStateCanceled 35 } 36 37 func (task *futureTask[TResult]) IsFailed() bool { 38 return atomic.LoadInt32(&task.state) == taskStateFailed 39 } 40 41 func (task *futureTask[TResult]) Complete(result TResult) { 42 if atomic.CompareAndSwapInt32(&task.state, taskStateNew, taskStateCompleted) || 43 atomic.CompareAndSwapInt32(&task.state, taskStateRunning, taskStateCompleted) { 44 task.result = result 45 task.await.Done() 46 } 47 } 48 49 func (task *futureTask[TResult]) Cancel() { 50 if atomic.CompareAndSwapInt32(&task.state, taskStateNew, taskStateCanceled) || 51 atomic.CompareAndSwapInt32(&task.state, taskStateRunning, taskStateCanceled) { 52 task.error = NewRuntimeError("Task was canceled.") 53 task.await.Done() 54 } 55 } 56 57 func (task *futureTask[TResult]) Fail(error any) { 58 if atomic.CompareAndSwapInt32(&task.state, taskStateNew, taskStateFailed) || 59 atomic.CompareAndSwapInt32(&task.state, taskStateRunning, taskStateFailed) { 60 runtimeErr, isRuntimeErr := error.(RuntimeError) 61 if !isRuntimeErr { 62 runtimeErr = NewRuntimeError(error) 63 } 64 task.error = runtimeErr 65 task.await.Done() 66 } 67 } 68 69 func (task *futureTask[TResult]) Get() TResult { 70 task.await.Wait() 71 if atomic.LoadInt32(&task.state) == taskStateCompleted { 72 return task.result 73 } 74 panic(task.error) 75 } 76 77 func (task *futureTask[TResult]) GetWithTimeout(timeout time.Duration) TResult { 78 waitChan := make(chan struct{}) 79 go func() { 80 task.await.Wait() 81 close(waitChan) 82 }() 83 timer := time.NewTimer(timeout) 84 defer timer.Stop() 85 select { 86 case <-waitChan: 87 if atomic.LoadInt32(&task.state) == taskStateCompleted { 88 return task.result 89 } 90 panic(task.error) 91 case <-timer.C: 92 task.timeout(timeout) 93 task.await.Wait() 94 if atomic.LoadInt32(&task.state) == taskStateCompleted { 95 return task.result 96 } 97 panic(task.error) 98 } 99 } 100 101 func (task *futureTask[TResult]) Run() { 102 if atomic.CompareAndSwapInt32(&task.state, taskStateNew, taskStateRunning) { 103 defer func() { 104 if cause := recover(); cause != nil { 105 task.Fail(cause) 106 } 107 }() 108 result := task.callable(task) 109 task.Complete(result) 110 } 111 } 112 113 func (task *futureTask[TResult]) timeout(timeout time.Duration) { 114 if atomic.CompareAndSwapInt32(&task.state, taskStateNew, taskStateCanceled) || 115 atomic.CompareAndSwapInt32(&task.state, taskStateRunning, taskStateCanceled) { 116 task.error = NewRuntimeError(fmt.Sprintf("Task execution timeout after %v.", timeout)) 117 task.await.Done() 118 } 119 }