github.com/artisanhe/tools@v1.0.1-0.20210607022958-19a8fef2eb04/runner/runner.go (about)

     1  package runner
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"os"
     7  	"os/signal"
     8  	"strings"
     9  	"sync"
    10  	"sync/atomic"
    11  )
    12  
    13  func NewRunner(ctx context.Context, name string, numWorkers int) *Runner {
    14  	if numWorkers < 1 {
    15  		numWorkers = 1
    16  	}
    17  	runner := &Runner{
    18  		Name:       name,
    19  		ctx:        ctx,
    20  		numWorkers: numWorkers,
    21  	}
    22  	return runner
    23  }
    24  
    25  type Runner struct {
    26  	Name       string
    27  	ctx        context.Context
    28  	numWorkers int
    29  	jobs       []*Job
    30  }
    31  
    32  func (r *Runner) Add(task Task, desc ...string) {
    33  	r.jobs = append(r.jobs, &Job{
    34  		Task: task,
    35  		Name: fmt.Sprintf("[%s] %s", r.Name, strings.Join(desc, ";")),
    36  	})
    37  }
    38  
    39  func (r *Runner) Commit() (err error) {
    40  	total := int64(len(r.jobs))
    41  	if total == 0 {
    42  		return nil
    43  	}
    44  
    45  	errLocker := sync.RWMutex{}
    46  	setErr := func(e error) {
    47  		errLocker.Lock()
    48  		defer errLocker.Unlock()
    49  		if err == nil {
    50  			err = e
    51  		}
    52  	}
    53  
    54  	queue := make(chan *Job, total)
    55  	defer close(queue)
    56  
    57  	for i := int64(0); i < total; i++ {
    58  		queue <- r.jobs[i]
    59  	}
    60  	r.jobs = nil
    61  
    62  	numberWorks := int64(r.numWorkers)
    63  	if numberWorks > total {
    64  		numberWorks = total
    65  	}
    66  
    67  	goroutines := sync.WaitGroup{}
    68  
    69  	for i := int64(0); i < numberWorks; i++ {
    70  		goroutines.Add(1)
    71  		go func() {
    72  			defer goroutines.Done()
    73  			for job := range queue {
    74  				atomic.AddInt64(&total, -1)
    75  				e := dispatch(r.ctx, job)
    76  				if e != nil {
    77  					setErr(e)
    78  					break
    79  				}
    80  				// when task down
    81  				if atomic.LoadInt64(&total) == 0 {
    82  					break
    83  				}
    84  			}
    85  		}()
    86  	}
    87  
    88  	goroutines.Wait()
    89  	return
    90  }
    91  
    92  func dispatch(ctx context.Context, job *Job) (err error) {
    93  	defer func() {
    94  		if r := recover(); r != nil {
    95  			err = fmt.Errorf("%v", r)
    96  		}
    97  	}()
    98  	interrupt := make(chan os.Signal)
    99  	signal.Notify(interrupt, os.Interrupt)
   100  	defer signal.Stop(interrupt)
   101  
   102  	select {
   103  	case <-interrupt:
   104  		return &Error{
   105  			Name: job.Name,
   106  			Type: ErrTypeInterrupt,
   107  		}
   108  	case <-ctx.Done():
   109  		return &Error{
   110  			Name: job.Name,
   111  			Type: ErrTypeTimeout,
   112  		}
   113  	default:
   114  		return job.Task(ctx)
   115  	}
   116  }
   117  
   118  type Task func(ctx context.Context) error
   119  
   120  type Job struct {
   121  	Name string
   122  	Task Task
   123  }