github.com/richardwilkes/toolbox@v1.121.0/taskqueue/taskqueue.go (about)

     1  // Copyright (c) 2016-2024 by Richard A. Wilkes. All rights reserved.
     2  //
     3  // This Source Code Form is subject to the terms of the Mozilla Public
     4  // License, version 2.0. If a copy of the MPL was not distributed with
     5  // this file, You can obtain one at http://mozilla.org/MPL/2.0/.
     6  //
     7  // This Source Code Form is "Incompatible With Secondary Licenses", as
     8  // defined by the Mozilla Public License, version 2.0.
     9  
    10  // Package taskqueue provides a simple asynchronous task queue.
    11  package taskqueue
    12  
    13  import (
    14  	"runtime"
    15  
    16  	"github.com/richardwilkes/toolbox/errs"
    17  )
    18  
    19  // Logger provides a way to log panics caused by workers in a queue.
    20  type Logger func(v ...any)
    21  
    22  // Task defines a unit of work.
    23  type Task func()
    24  
    25  // Option defines an option for the queue.
    26  type Option func(*Queue)
    27  
    28  // Queue holds the queue information.
    29  type Queue struct {
    30  	in              chan Task
    31  	done            chan bool
    32  	recoveryHandler errs.RecoveryHandler
    33  	depth           int
    34  	workers         int
    35  }
    36  
    37  // RecoveryHandler sets the recovery handler to use for tasks that panic. Defaults to none, which silently ignores the
    38  // panic.
    39  func RecoveryHandler(recoveryHandler errs.RecoveryHandler) Option {
    40  	return func(q *Queue) { q.recoveryHandler = recoveryHandler }
    41  }
    42  
    43  // Depth sets the depth of the queue. Calls to Submit() will block when this number of tasks are already pending
    44  // execution. Pass in a negative number to use an unbounded queue. Defaults to unbounded.
    45  func Depth(depth int) Option {
    46  	return func(q *Queue) { q.depth = depth }
    47  }
    48  
    49  // Workers sets the number of workers that will simultaneously process tasks. If this is set to 1, tasks submitted to
    50  // the queue will be executed serially. Defaults to one plus the number of CPUs.
    51  func Workers(workers int) Option {
    52  	return func(q *Queue) { q.workers = workers }
    53  }
    54  
    55  // New creates a queue which executes the tasks submitted to it.
    56  func New(options ...Option) *Queue {
    57  	numCPU := runtime.NumCPU()
    58  	q := &Queue{
    59  		in:    make(chan Task, numCPU*2),
    60  		done:  make(chan bool),
    61  		depth: -1,
    62  	}
    63  	for _, option := range options {
    64  		option(q)
    65  	}
    66  	if q.workers < 1 {
    67  		q.workers = 1 + numCPU
    68  	}
    69  	go q.process()
    70  	return q
    71  }
    72  
    73  // Submit a task to be run.
    74  func (q *Queue) Submit(task Task) {
    75  	q.in <- task
    76  }
    77  
    78  // Shutdown the queue. Does not return until all pending tasks have completed.
    79  func (q *Queue) Shutdown() {
    80  	close(q.in)
    81  	<-q.done
    82  }
    83  
    84  func (q *Queue) process() {
    85  	var received, processed uint64
    86  
    87  	// Setup backlog
    88  	var backlog []Task
    89  	if q.depth > 0 {
    90  		backlog = make([]Task, 0, q.depth)
    91  	}
    92  
    93  	// Setup workers
    94  	ready := make(chan bool, q.workers)
    95  	tasks := make(chan Task, q.workers)
    96  	for i := 0; i < q.workers; i++ {
    97  		go q.work(tasks, ready)
    98  	}
    99  
   100  	// Main processing loop
   101  outer:
   102  	for {
   103  	inner:
   104  		select {
   105  		case task := <-q.in:
   106  			if task == nil {
   107  				break outer
   108  			}
   109  			received++
   110  			if len(backlog) == 0 {
   111  				select {
   112  				case tasks <- task:
   113  					break inner
   114  				default:
   115  				}
   116  			}
   117  			if q.depth < 0 || len(backlog) < q.depth {
   118  				backlog = append(backlog, task)
   119  			} else {
   120  				<-ready
   121  				processed++
   122  				tasks <- backlog[0]
   123  				copy(backlog, backlog[1:])
   124  				backlog[len(backlog)-1] = task
   125  			}
   126  		case <-ready:
   127  			processed++
   128  			if len(backlog) > 0 {
   129  				tasks <- backlog[0]
   130  				copy(backlog, backlog[1:])
   131  				backlog[len(backlog)-1] = nil
   132  				backlog = backlog[:len(backlog)-1]
   133  			}
   134  		}
   135  	}
   136  
   137  	// Finish any remaining tasks
   138  	for _, task := range backlog {
   139  	drain:
   140  		for {
   141  			select {
   142  			case tasks <- task:
   143  				break drain
   144  			case <-ready:
   145  				processed++
   146  			}
   147  		}
   148  	}
   149  	for received != processed {
   150  		<-ready
   151  		processed++
   152  	}
   153  	close(tasks)
   154  	q.done <- true
   155  }
   156  
   157  func (q *Queue) work(tasks <-chan Task, ready chan<- bool) {
   158  	for task := range tasks {
   159  		q.runTask(task)
   160  		ready <- true
   161  	}
   162  }
   163  
   164  func (q *Queue) runTask(task Task) {
   165  	defer errs.Recovery(q.recoveryHandler)
   166  	task()
   167  }