go.mondoo.com/cnquery@v0.0.0-20231005093811-59568235f6ea/mql/internal/waitgroup.go (about)

     1  // Copyright (c) Mondoo, Inc.
     2  // SPDX-License-Identifier: BUSL-1.1
     3  
     4  package internal
     5  
     6  import (
     7  	"fmt"
     8  	"sync"
     9  )
    10  
    11  // WaitGroup is a synchronization primitive that allows waiting
    12  // for for a collection of goroutines similar to sync.WaitGroup
    13  // It differs in the following ways:
    14  //   - Add takes in a workID instead of an increment. This workID is
    15  //     passed to Done to finish it. This allows calling Done on
    16  //     the same workID twice, making the second one a noop.
    17  //   - There is a way to unblock all goroutines blocked on the
    18  //     waitgroup without a normal completion. This is done through
    19  //     Decommission
    20  type WaitGroup struct {
    21  	cond           *sync.Cond
    22  	activeWorkIDs  map[string]struct{}
    23  	seenWorkIDs    map[string]struct{}
    24  	decommissioned bool
    25  
    26  	numAdded int
    27  	numDone  int
    28  }
    29  
    30  // NewWaitGroup returns a new WaitGroup
    31  func NewWaitGroup() *WaitGroup {
    32  	mutex := &sync.Mutex{}
    33  
    34  	return &WaitGroup{
    35  		cond:           sync.NewCond(mutex),
    36  		activeWorkIDs:  make(map[string]struct{}),
    37  		seenWorkIDs:    make(map[string]struct{}),
    38  		decommissioned: false,
    39  	}
    40  }
    41  
    42  // Done removes the given workID from the set of active work IDs. If the workID
    43  // is not part of the active set, the call is a noop. Once removed, that workID
    44  // can be reused.
    45  // Passing a workID that was never added is an invalid operation and will cause a panic
    46  func (w *WaitGroup) Done(workID string) {
    47  	w.cond.L.Lock()
    48  	defer w.cond.L.Unlock()
    49  	if _, ok := w.seenWorkIDs[workID]; !ok {
    50  		// You are not allowed to complete an ID that has never been added
    51  		panic(fmt.Sprintf("workID %q not found", workID))
    52  	}
    53  	if _, ok := w.activeWorkIDs[workID]; ok {
    54  		delete(w.activeWorkIDs, workID)
    55  		w.numDone++
    56  	}
    57  	if len(w.activeWorkIDs) == 0 {
    58  		w.cond.Broadcast()
    59  	}
    60  }
    61  
    62  // Add adds the workID to the set of active workIDs. Providing a workID
    63  // that is already active is an invalid operation and will cause
    64  // a panic. You must first Done it before reusing it.
    65  func (w *WaitGroup) Add(workID string) {
    66  	w.cond.L.Lock()
    67  	defer w.cond.L.Unlock()
    68  	if _, ok := w.activeWorkIDs[workID]; ok {
    69  		// You are not allowed to add the same thing to the waitgroup
    70  		// multiple times without Doneing it
    71  		panic(fmt.Sprintf("duplicate codeID %q", workID))
    72  	}
    73  	w.seenWorkIDs[workID] = struct{}{}
    74  	w.activeWorkIDs[workID] = struct{}{}
    75  	w.numAdded++
    76  }
    77  
    78  // Wait blocks the caller until there are either no more active
    79  // workIDs in the wait group, or the wait group is decommissioned
    80  func (w *WaitGroup) Wait() {
    81  	w.cond.L.Lock()
    82  	defer w.cond.L.Unlock()
    83  
    84  	for {
    85  		if w.decommissioned || len(w.activeWorkIDs) == 0 {
    86  			return
    87  		} else {
    88  			w.cond.Wait()
    89  		}
    90  	}
    91  }
    92  
    93  // Decommission notifies all blocked goroutines that the waitgroup
    94  // is in a Done state, regardless of if there are still any active
    95  // workIDs
    96  func (w *WaitGroup) Decommission() []string {
    97  	w.cond.L.Lock()
    98  	defer w.cond.L.Unlock()
    99  	w.cond.Broadcast()
   100  	w.decommissioned = true
   101  	stillActivate := make([]string, len(w.activeWorkIDs))
   102  	i := 0
   103  	for w := range w.activeWorkIDs {
   104  		stillActivate[i] = w
   105  		i++
   106  	}
   107  	return stillActivate
   108  }
   109  
   110  func (w *WaitGroup) IsDecommissioned() bool {
   111  	w.cond.L.Lock()
   112  	defer w.cond.L.Unlock()
   113  	return w.decommissioned
   114  }
   115  
   116  type WaitGroupStats struct {
   117  	NumAdded  int
   118  	NumActive int
   119  	NumDone   int
   120  }
   121  
   122  func (w *WaitGroup) Stats() WaitGroupStats {
   123  	w.cond.L.Lock()
   124  	defer w.cond.L.Unlock()
   125  	return WaitGroupStats{
   126  		NumAdded:  w.numAdded,
   127  		NumActive: len(w.activeWorkIDs),
   128  		NumDone:   w.numDone,
   129  	}
   130  }