github.com/mailru/activerecord@v1.12.2/pkg/iproto/syncutil/taskrunner.go (about)

     1  package syncutil
     2  
     3  import (
     4  	"fmt"
     5  	"log"
     6  	"sync"
     7  
     8  	"golang.org/x/net/context"
     9  )
    10  
    11  var ErrTaskPanic = fmt.Errorf("task panic occurred")
    12  
    13  // TaskRunner runs only one task. Caller can
    14  // subscribe to current task or in case if no task
    15  // is running initiate a new one via Do method.
    16  //
    17  // Check MAILX-1585 for details.
    18  type TaskRunner struct {
    19  	mu sync.RWMutex
    20  
    21  	rcvrs []chan error
    22  
    23  	cancel func()
    24  }
    25  
    26  // Do returns channel from which the result of the current task will
    27  // be returned.
    28  //
    29  // In case if task is not running, it creates one.
    30  func (t *TaskRunner) Do(ctx context.Context, task func(context.Context) error) <-chan error {
    31  	result := make(chan error, 1)
    32  
    33  	t.mu.Lock()
    34  	defer t.mu.Unlock()
    35  
    36  	if t.rcvrs == nil {
    37  		t.initTask(ctx, task)
    38  	}
    39  
    40  	t.rcvrs = append(t.rcvrs, result)
    41  
    42  	return result
    43  }
    44  
    45  func (t *TaskRunner) Cancel() {
    46  	t.mu.RLock()
    47  	defer t.mu.RUnlock()
    48  
    49  	if t.cancel != nil {
    50  		t.cancel()
    51  	}
    52  }
    53  
    54  func (t *TaskRunner) initTask(ctx context.Context, task func(context.Context) error) {
    55  	ctx, cancel := context.WithCancel(ctx)
    56  	t.cancel = cancel
    57  
    58  	go func() {
    59  		defer func() {
    60  			t.makeRecover(recover())
    61  			cancel()
    62  		}()
    63  
    64  		err := task(ctx)
    65  
    66  		t.broadcastErr(err)
    67  	}()
    68  }
    69  
    70  func (t *TaskRunner) broadcastErr(err error) {
    71  	t.mu.Lock()
    72  	rcvrs := t.rcvrs
    73  	t.rcvrs = nil
    74  	t.mu.Unlock()
    75  
    76  	if rcvrs == nil {
    77  		return
    78  	}
    79  
    80  	for _, subscriber := range rcvrs {
    81  		subscriber <- err
    82  	}
    83  }
    84  
    85  func (t *TaskRunner) makeRecover(rec interface{}) {
    86  	if rec != nil {
    87  		log.Printf("[internal_error] panic occurred in TaskRunner: %v", rec)
    88  		t.broadcastErr(ErrTaskPanic)
    89  	}
    90  }