github.com/insolar/vanilla@v0.0.0-20201023172447-248fdf805322/synckit/versioned_signal.go (about)

     1  // Copyright 2020 Insolar Network Ltd.
     2  // All rights reserved.
     3  // This material is licensed under the Insolar License version 1.0,
     4  // available at https://github.com/insolar/assured-ledger/blob/master/LICENSE.md.
     5  
     6  package synckit
     7  
     8  import (
     9  	"sync"
    10  	"sync/atomic"
    11  	"unsafe"
    12  )
    13  
    14  func NewVersionedSignal() VersionedSignal {
    15  	return VersionedSignal{}
    16  }
    17  
    18  type VersionedSignal struct {
    19  	signalVersion *SignalVersion // atomic
    20  }
    21  
    22  func (p *VersionedSignal) _signalVersion() *unsafe.Pointer {
    23  	return (*unsafe.Pointer)((unsafe.Pointer)(&p.signalVersion))
    24  }
    25  
    26  func (p *VersionedSignal) NextBroadcast() {
    27  	sv := (*SignalVersion)(atomic.SwapPointer(p._signalVersion(), nil))
    28  	sv.signal()
    29  }
    30  
    31  func (p *VersionedSignal) BroadcastAndMark() *SignalVersion {
    32  	nsv := newSignalVersion()
    33  	sv := (*SignalVersion)(atomic.SwapPointer(p._signalVersion(), (unsafe.Pointer)(nsv)))
    34  	sv.signal()
    35  	return nsv
    36  }
    37  
    38  func (p *VersionedSignal) Mark() *SignalVersion {
    39  	var nsv *SignalVersion
    40  	for {
    41  		sv := (*SignalVersion)(atomic.LoadPointer(p._signalVersion()))
    42  		switch {
    43  		case sv != nil:
    44  			return sv
    45  		case nsv == nil: // avoid repetitive new
    46  			nsv = newSignalVersion()
    47  		}
    48  		if atomic.CompareAndSwapPointer(p._signalVersion(), nil, (unsafe.Pointer)(nsv)) {
    49  			return nsv
    50  		}
    51  	}
    52  }
    53  
    54  func NewNeverSignal() *SignalVersion {
    55  	return newSignalVersion()
    56  }
    57  
    58  func newSignalVersion() *SignalVersion {
    59  	sv := SignalVersion{}
    60  	sv.wg.Add(1)
    61  	return &sv
    62  }
    63  
    64  type signalChannel = chan struct{}
    65  
    66  type SignalVersion struct {
    67  	next *SignalVersion
    68  	wg   sync.WaitGroup // is cheaper than channel and doesn't need additional heap allocation
    69  	c    *signalChannel // atomic
    70  }
    71  
    72  func (p *SignalVersion) _signalChannel() *unsafe.Pointer {
    73  	return (*unsafe.Pointer)((unsafe.Pointer)(&p.c))
    74  }
    75  
    76  func (p *SignalVersion) getSignalChannel() *signalChannel {
    77  	return (*signalChannel)(atomic.LoadPointer(p._signalChannel()))
    78  }
    79  
    80  func (p *SignalVersion) signal() {
    81  	if p == nil {
    82  		return
    83  	}
    84  	p.next.signal() // older signals must fire first
    85  
    86  	var closedSignal *SignalChannel // explicit type decl to avoid passing of something wrong into unsafe.Pointer conversion
    87  	closedSignal = &closedChan
    88  
    89  	atomic.CompareAndSwapPointer(p._signalChannel(), nil, (unsafe.Pointer)(closedSignal))
    90  	p.wg.Done()
    91  }
    92  
    93  func (p *SignalVersion) Wait() {
    94  	if p == nil {
    95  		return
    96  	}
    97  
    98  	p.wg.Wait()
    99  }
   100  
   101  func (p *SignalVersion) ChannelIf(choice bool, def <-chan struct{}) <-chan struct{} {
   102  	if choice {
   103  		return p.Channel()
   104  	}
   105  	return def
   106  }
   107  
   108  func (p *SignalVersion) Channel() <-chan struct{} {
   109  	if p == nil {
   110  		return ClosedChannel()
   111  	}
   112  
   113  	var wcp *signalChannel
   114  	for {
   115  		switch sc := p.getSignalChannel(); {
   116  		case sc != nil:
   117  			return *sc
   118  		case wcp == nil:
   119  			wcp = new(signalChannel)
   120  			*wcp = make(signalChannel)
   121  		}
   122  
   123  		if atomic.CompareAndSwapPointer(p._signalChannel(), nil, (unsafe.Pointer)(wcp)) {
   124  			go p.waitClose(wcp)
   125  			return *wcp
   126  		}
   127  	}
   128  }
   129  
   130  func (p *SignalVersion) waitClose(wcp *signalChannel) {
   131  	p.wg.Wait()
   132  	close(*wcp)
   133  }
   134  
   135  func (p *SignalVersion) HasSignal() bool {
   136  	if p == nil {
   137  		return true
   138  	}
   139  
   140  	sc := p.getSignalChannel()
   141  	if sc == nil {
   142  		return false
   143  	}
   144  	select {
   145  	case <-*sc:
   146  		return true
   147  	default:
   148  		return false
   149  	}
   150  }