github.com/status-im/status-go@v1.1.0/services/wallet/async/async.go (about)

     1  package async
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  	"time"
     7  
     8  	"github.com/ethereum/go-ethereum/log"
     9  )
    10  
    11  type Command func(context.Context) error
    12  
    13  type Commander interface {
    14  	Command(inteval ...time.Duration) Command
    15  }
    16  
    17  type Runner interface {
    18  	Run(context.Context) error
    19  }
    20  
    21  // SingleShotCommand runs once.
    22  type SingleShotCommand struct {
    23  	Interval time.Duration
    24  	Init     func(context.Context) error
    25  	Runable  func(context.Context) error
    26  }
    27  
    28  func (c SingleShotCommand) Run(ctx context.Context) error {
    29  	timer := time.NewTimer(c.Interval)
    30  	if c.Init != nil {
    31  		err := c.Init(ctx)
    32  		if err != nil {
    33  			return err
    34  		}
    35  	}
    36  	for {
    37  		select {
    38  		case <-ctx.Done():
    39  			return ctx.Err()
    40  		case <-timer.C:
    41  			_ = c.Runable(ctx)
    42  		}
    43  	}
    44  }
    45  
    46  // FiniteCommand terminates when error is nil.
    47  type FiniteCommand struct {
    48  	Interval time.Duration
    49  	Runable  func(context.Context) error
    50  }
    51  
    52  func (c FiniteCommand) Run(ctx context.Context) error {
    53  	err := c.Runable(ctx)
    54  	if err == nil {
    55  		return nil
    56  	}
    57  	ticker := time.NewTicker(c.Interval)
    58  	for {
    59  		select {
    60  		case <-ctx.Done():
    61  			return ctx.Err()
    62  		case <-ticker.C:
    63  			err := c.Runable(ctx)
    64  			if err == nil {
    65  				return nil
    66  			}
    67  		}
    68  	}
    69  }
    70  
    71  // InfiniteCommand runs until context is closed.
    72  type InfiniteCommand struct {
    73  	Interval time.Duration
    74  	Runable  func(context.Context) error
    75  }
    76  
    77  func (c InfiniteCommand) Run(ctx context.Context) error {
    78  	_ = c.Runable(ctx)
    79  	ticker := time.NewTicker(c.Interval)
    80  	for {
    81  		select {
    82  		case <-ctx.Done():
    83  			return ctx.Err()
    84  		case <-ticker.C:
    85  			_ = c.Runable(ctx)
    86  		}
    87  	}
    88  }
    89  
    90  func NewGroup(parent context.Context) *Group {
    91  	ctx, cancel := context.WithCancel(parent)
    92  	return &Group{
    93  		ctx:    ctx,
    94  		cancel: cancel,
    95  	}
    96  }
    97  
    98  type Group struct {
    99  	ctx    context.Context
   100  	cancel func()
   101  	wg     sync.WaitGroup
   102  }
   103  
   104  func (g *Group) Add(cmd Command) {
   105  	g.wg.Add(1)
   106  	go func() {
   107  		_ = cmd(g.ctx)
   108  		g.wg.Done()
   109  	}()
   110  }
   111  
   112  func (g *Group) Stop() {
   113  	g.cancel()
   114  }
   115  
   116  func (g *Group) Wait() {
   117  	g.wg.Wait()
   118  }
   119  
   120  func (g *Group) WaitAsync() <-chan struct{} {
   121  	ch := make(chan struct{})
   122  	go func() {
   123  		g.Wait()
   124  		close(ch)
   125  	}()
   126  	return ch
   127  }
   128  
   129  func NewAtomicGroup(parent context.Context) *AtomicGroup {
   130  	ctx, cancel := context.WithCancel(parent)
   131  	ag := &AtomicGroup{ctx: ctx, cancel: cancel}
   132  	ag.done = ag.onFinish
   133  	return ag
   134  }
   135  
   136  // AtomicGroup terminates as soon as first goroutine terminates with error.
   137  type AtomicGroup struct {
   138  	ctx    context.Context
   139  	cancel func()
   140  	done   func()
   141  	wg     sync.WaitGroup
   142  
   143  	mu    sync.Mutex
   144  	error error
   145  }
   146  
   147  type AtomicGroupKey string
   148  
   149  func (d *AtomicGroup) SetName(name string) {
   150  	d.ctx = context.WithValue(d.ctx, AtomicGroupKey("name"), name)
   151  }
   152  
   153  func (d *AtomicGroup) Name() string {
   154  	val := d.ctx.Value(AtomicGroupKey("name"))
   155  	if val != nil {
   156  		return val.(string)
   157  	}
   158  	return ""
   159  }
   160  
   161  // Go spawns function in a goroutine and stores results or errors.
   162  func (d *AtomicGroup) Add(cmd Command) {
   163  	d.wg.Add(1)
   164  	go func() {
   165  		defer d.done()
   166  		err := cmd(d.ctx)
   167  		d.mu.Lock()
   168  		defer d.mu.Unlock()
   169  		if err != nil {
   170  			// do not overwrite original error by context errors
   171  			if d.error != nil {
   172  				log.Info("async.Command failed", "error", err, "d.error", d.error, "group", d.Name())
   173  				return
   174  			}
   175  			d.error = err
   176  			d.cancel()
   177  			return
   178  		}
   179  	}()
   180  }
   181  
   182  // Wait for all downloaders to finish.
   183  func (d *AtomicGroup) Wait() {
   184  	d.wg.Wait()
   185  	if d.Error() == nil {
   186  		d.mu.Lock()
   187  		defer d.mu.Unlock()
   188  		d.cancel()
   189  	}
   190  }
   191  
   192  func (d *AtomicGroup) WaitAsync() <-chan struct{} {
   193  	ch := make(chan struct{})
   194  	go func() {
   195  		d.Wait()
   196  		close(ch)
   197  	}()
   198  	return ch
   199  }
   200  
   201  // Error stores an error that was reported by any of the downloader. Should be called after Wait.
   202  func (d *AtomicGroup) Error() error {
   203  	d.mu.Lock()
   204  	defer d.mu.Unlock()
   205  	return d.error
   206  }
   207  
   208  func (d *AtomicGroup) Stop() {
   209  	d.cancel()
   210  }
   211  
   212  func (d *AtomicGroup) onFinish() {
   213  	d.wg.Done()
   214  }
   215  
   216  func NewQueuedAtomicGroup(parent context.Context, limit uint32) *QueuedAtomicGroup {
   217  	qag := &QueuedAtomicGroup{NewAtomicGroup(parent), limit, 0, []Command{}, sync.Mutex{}}
   218  	baseDoneFunc := qag.done // save original done function
   219  	qag.AtomicGroup.done = func() {
   220  		baseDoneFunc()
   221  		qag.onFinish()
   222  	}
   223  	return qag
   224  }
   225  
   226  type QueuedAtomicGroup struct {
   227  	*AtomicGroup
   228  	limit       uint32
   229  	count       uint32
   230  	pendingCmds []Command
   231  	mu          sync.Mutex
   232  }
   233  
   234  func (d *QueuedAtomicGroup) Add(cmd Command) {
   235  
   236  	d.mu.Lock()
   237  	if d.limit > 0 && d.count >= d.limit {
   238  		d.pendingCmds = append(d.pendingCmds, cmd)
   239  		d.mu.Unlock()
   240  		return
   241  	}
   242  
   243  	d.mu.Unlock()
   244  	d.run(cmd)
   245  }
   246  
   247  func (d *QueuedAtomicGroup) run(cmd Command) {
   248  	d.mu.Lock()
   249  	d.count++
   250  	d.mu.Unlock()
   251  	d.AtomicGroup.Add(cmd)
   252  }
   253  
   254  func (d *QueuedAtomicGroup) onFinish() {
   255  	d.mu.Lock()
   256  	d.count--
   257  
   258  	if d.count < d.limit && len(d.pendingCmds) > 0 {
   259  		cmd := d.pendingCmds[0]
   260  		d.pendingCmds = d.pendingCmds[1:]
   261  		d.mu.Unlock()
   262  		d.run(cmd)
   263  		return
   264  	}
   265  
   266  	d.mu.Unlock()
   267  }
   268  
   269  func NewErrorCounter(maxErrors int, msg string) *ErrorCounter {
   270  	return &ErrorCounter{maxErrors: maxErrors, msg: msg}
   271  }
   272  
   273  type ErrorCounter struct {
   274  	cnt       int
   275  	maxErrors int
   276  	err       error
   277  	msg       string
   278  }
   279  
   280  // Returns false in case of counter overflow
   281  func (ec *ErrorCounter) SetError(err error) bool {
   282  	log.Debug("ErrorCounter setError", "msg", ec.msg, "err", err, "cnt", ec.cnt)
   283  
   284  	ec.cnt++
   285  
   286  	// do not overwrite the first error
   287  	if ec.err == nil {
   288  		ec.err = err
   289  	}
   290  
   291  	if ec.cnt >= ec.maxErrors {
   292  		log.Error("ErrorCounter overflow", "msg", ec.msg)
   293  		return false
   294  	}
   295  
   296  	return true
   297  }
   298  
   299  func (ec *ErrorCounter) Error() error {
   300  	return ec.err
   301  }
   302  
   303  func (ec *ErrorCounter) MaxErrors() int {
   304  	return ec.maxErrors
   305  }
   306  
   307  type FiniteCommandWithErrorCounter struct {
   308  	FiniteCommand
   309  	*ErrorCounter
   310  }
   311  
   312  func (c FiniteCommandWithErrorCounter) Run(ctx context.Context) error {
   313  	f := func(ctx context.Context) (quit bool, err error) {
   314  		err = c.Runable(ctx)
   315  		if err == nil {
   316  			return true, err
   317  		}
   318  
   319  		if c.ErrorCounter.SetError(err) {
   320  			return false, err
   321  		}
   322  		return true, err
   323  	}
   324  
   325  	quit, err := f(ctx)
   326  	if quit {
   327  		return err
   328  	}
   329  
   330  	ticker := time.NewTicker(c.Interval)
   331  	for {
   332  		select {
   333  		case <-ctx.Done():
   334  			return ctx.Err()
   335  		case <-ticker.C:
   336  			quit, err := f(ctx)
   337  			if quit {
   338  				return err
   339  			}
   340  		}
   341  	}
   342  }