go-ml.dev/pkg/base@v0.0.0-20200610162856-60c38abac71b/fu/mask.go (about)

     1  package fu
     2  
     3  import (
     4  	"runtime"
     5  	"sync"
     6  )
     7  
     8  type AtomicMask_ struct {
     9  	width      int
    10  	value      uint64
    11  	mu         sync.Mutex
    12  	cond       sync.Cond
    13  	extendable bool
    14  	probe      func(int) bool
    15  }
    16  
    17  func AtomicMask(width int) *AtomicMask_ {
    18  	if width > 64 {
    19  		width = 64
    20  	}
    21  	a := &AtomicMask_{width: width}
    22  	a.cond.L = &a.mu
    23  	return a
    24  }
    25  
    26  func ExtendableAtomicMask(canextend func(int) bool) *AtomicMask_ {
    27  	a := &AtomicMask_{width: 0}
    28  	a.cond.L = &a.mu
    29  	a.extendable = true
    30  	a.probe = canextend
    31  	return a
    32  }
    33  
    34  var numcpu = runtime.NumCPU()
    35  
    36  func (a *AtomicMask_) Lock() int {
    37  	n := -1
    38  	a.mu.Lock()
    39  l:
    40  	for n == -1 {
    41  		x := ^uint64(0) >> (64 - a.width)
    42  		if x != 0 && a.value&x == 0 {
    43  			n = 0
    44  			for a.value&(uint64(1)<<n) != 0 {
    45  				n++
    46  			}
    47  			a.value |= (uint64(1) << n)
    48  			break l
    49  		}
    50  
    51  		if a.extendable {
    52  			if a.width < 64 && a.probe(a.width) {
    53  				n = a.width
    54  				a.value |= (uint64(1) << n)
    55  				a.width++
    56  				break l
    57  			}
    58  			a.extendable = false
    59  		}
    60  
    61  		a.cond.Wait()
    62  	}
    63  	a.mu.Unlock()
    64  	return n
    65  }
    66  
    67  func (a *AtomicMask_) Unlock(i int) {
    68  	a.mu.Lock()
    69  	if a.value&(uint64(1)<<i) != 0 {
    70  		a.value = a.value &^ (uint64(1) << i)
    71  		a.mu.Unlock()
    72  		a.cond.Broadcast()
    73  		return
    74  	}
    75  	a.mu.Unlock() // ?
    76  	panic("opps")
    77  }
    78  
    79  func (a *AtomicMask_) FinCallForAll(f func(no int)) {
    80  	a.mu.Lock()
    81  	a.extendable = false
    82  	mask := ^uint64(0) >> (64 - a.width)
    83  	for mask != 0 {
    84  		for i := 0; i < a.width; i++ {
    85  			x := uint64(1) << i
    86  			if mask&x != 0 && a.value&x == 0 {
    87  				mask &= ^x
    88  				a.value |= x
    89  				f(i)
    90  			}
    91  		}
    92  		if mask != 0 {
    93  			a.cond.Wait()
    94  		}
    95  	}
    96  	a.mu.Unlock()
    97  }