go.temporal.io/server@v1.23.0/common/locks/condition_variable_test.go (about) 1 // The MIT License 2 // 3 // Copyright (c) 2020 Temporal Technologies Inc. All rights reserved. 4 // 5 // Copyright (c) 2020 Uber Technologies, Inc. 6 // 7 // Permission is hereby granted, free of charge, to any person obtaining a copy 8 // of this software and associated documentation files (the "Software"), to deal 9 // in the Software without restriction, including without limitation the rights 10 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 // copies of the Software, and to permit persons to whom the Software is 12 // furnished to do so, subject to the following conditions: 13 // 14 // The above copyright notice and this permission notice shall be included in 15 // all copies or substantial portions of the Software. 16 // 17 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 23 // THE SOFTWARE. 24 25 package locks 26 27 import ( 28 "math/rand" 29 "sync" 30 "testing" 31 32 "github.com/stretchr/testify/require" 33 "github.com/stretchr/testify/suite" 34 ) 35 36 type ( 37 conditionVariableSuite struct { 38 *require.Assertions 39 suite.Suite 40 41 lock sync.Locker 42 cv *ConditionVariableImpl 43 } 44 ) 45 46 func TestConditionVariableSuite(t *testing.T) { 47 s := new(conditionVariableSuite) 48 suite.Run(t, s) 49 } 50 51 func (s *conditionVariableSuite) SetupSuite() { 52 s.Assertions = require.New(s.T()) 53 } 54 55 func (s *conditionVariableSuite) TearDownSuite() { 56 } 57 58 func (s *conditionVariableSuite) SetupTest() { 59 s.lock = &sync.Mutex{} 60 s.cv = NewConditionVariable(s.lock) 61 } 62 63 func (s *conditionVariableSuite) TearDownTest() { 64 65 } 66 67 func (s *conditionVariableSuite) TestChannelSize_New() { 68 s.testChannelSize(s.cv.channel) 69 } 70 71 func (s *conditionVariableSuite) TestChannelSize_Broadcast() { 72 s.cv.Broadcast() 73 s.testChannelSize(s.cv.channel) 74 } 75 76 func (s *conditionVariableSuite) testChannelSize( 77 channel chan struct{}, 78 ) { 79 // assert channel size == 1 80 select { 81 case channel <- struct{}{}: 82 // noop 83 default: 84 s.Fail("conditional variable size should be 1") 85 } 86 87 select { 88 case channel <- struct{}{}: 89 s.Fail("conditional variable size should be 1") 90 default: 91 // noop 92 } 93 } 94 95 func (s *conditionVariableSuite) TestSignal() { 96 signalWaitGroup := sync.WaitGroup{} 97 signalWaitGroup.Add(1) 98 99 waitGroup := sync.WaitGroup{} 100 waitGroup.Add(1) 101 102 waitFn := func() { 103 defer waitGroup.Done() 104 105 s.lock.Lock() 106 defer s.lock.Unlock() 107 108 signalWaitGroup.Done() 109 s.cv.Wait(nil) 110 } 111 go waitFn() 112 113 signalWaitGroup.Wait() 114 s.lock.Lock() 115 func() {}() 116 s.lock.Unlock() 117 s.cv.Signal() 118 waitGroup.Wait() 119 } 120 121 func (s *conditionVariableSuite) TestInterrupt() { 122 interruptWaitGroup := sync.WaitGroup{} 123 interruptWaitGroup.Add(1) 124 125 waitGroup := sync.WaitGroup{} 126 waitGroup.Add(1) 127 interruptChan := make(chan struct{}) 128 129 waitFn := func() { 130 defer waitGroup.Done() 131 132 s.lock.Lock() 133 defer s.lock.Unlock() 134 135 interruptWaitGroup.Done() 136 s.cv.Wait(interruptChan) 137 } 138 go waitFn() 139 140 interruptWaitGroup.Wait() 141 s.lock.Lock() 142 func() {}() 143 s.lock.Unlock() 144 interruptChan <- struct{}{} 145 waitGroup.Wait() 146 } 147 148 func (s *conditionVariableSuite) TestBroadcast() { 149 waitThreads := 256 150 151 broadcastWaitGroup := sync.WaitGroup{} 152 broadcastWaitGroup.Add(waitThreads) 153 154 waitGroup := sync.WaitGroup{} 155 waitGroup.Add(waitThreads) 156 157 waitFn := func() { 158 defer waitGroup.Done() 159 160 s.lock.Lock() 161 defer s.lock.Unlock() 162 163 broadcastWaitGroup.Done() 164 s.cv.Wait(nil) 165 } 166 for i := 0; i < waitThreads; i++ { 167 go waitFn() 168 } 169 170 broadcastWaitGroup.Wait() 171 s.lock.Lock() 172 func() {}() 173 s.lock.Unlock() 174 s.cv.Broadcast() 175 waitGroup.Wait() 176 } 177 178 func (s *conditionVariableSuite) TestCase_ProducerConsumer() { 179 signalRatio := 0.8 180 numProducer := 256 181 numConsumer := 256 182 totalToken := numProducer * numConsumer * 10 183 tokenPerProducer := totalToken / numProducer 184 tokenPerConsumer := totalToken / numConsumer 185 186 lock := &sync.Mutex{} 187 tokens := 0 188 notifyProducerCV := NewConditionVariable(lock) 189 notifyConsumerCV := NewConditionVariable(lock) 190 191 waitGroup := sync.WaitGroup{} 192 waitGroup.Add(numProducer + numConsumer) 193 194 produceFn := func() { 195 defer waitGroup.Done() 196 remainingToken := tokenPerProducer 197 198 lock.Lock() 199 defer lock.Unlock() 200 201 for remainingToken > 0 { 202 for tokens > 0 { 203 randSignalBroadcast(notifyConsumerCV, signalRatio) 204 notifyProducerCV.Wait(nil) 205 } 206 207 produce := rand.Intn(remainingToken + 1) 208 tokens += produce 209 remainingToken -= produce 210 } 211 randSignalBroadcast(notifyConsumerCV, signalRatio) 212 } 213 214 consumerFn := func() { 215 defer waitGroup.Done() 216 remainingToken := 0 217 218 lock.Lock() 219 defer lock.Unlock() 220 221 for remainingToken < tokenPerConsumer { 222 for tokens == 0 { 223 randSignalBroadcast(notifyProducerCV, signalRatio) 224 notifyConsumerCV.Wait(nil) 225 } 226 227 consume := min(tokens, tokenPerConsumer-remainingToken) 228 tokens -= consume 229 remainingToken += consume 230 } 231 randSignalBroadcast(notifyProducerCV, signalRatio) 232 if tokens > 0 { 233 randSignalBroadcast(notifyConsumerCV, signalRatio) 234 } 235 } 236 237 for i := 0; i < numConsumer; i++ { 238 go consumerFn() 239 } 240 for i := 0; i < numProducer; i++ { 241 go produceFn() 242 } 243 244 waitGroup.Wait() 245 } 246 247 func randSignalBroadcast( 248 cv ConditionVariable, 249 signalRatio float64, 250 ) { 251 if rand.Float64() <= signalRatio { 252 cv.Signal() 253 } else { 254 cv.Broadcast() 255 } 256 } 257 258 func min(left int, right int) int { 259 if left < right { 260 return left 261 } else if left > right { 262 return right 263 } else { 264 return left 265 } 266 }