github.com/lirm/aeron-go@v0.0.0-20230415210743-920325491dc4/systests/counter_test.go (about) 1 // Copyright 2022 Steven Stern 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package systests 16 17 import ( 18 "github.com/lirm/aeron-go/aeron" 19 "github.com/lirm/aeron-go/aeron/atomic" 20 "github.com/lirm/aeron-go/aeron/counters" 21 "github.com/lirm/aeron-go/aeron/testdata" 22 "github.com/lirm/aeron-go/systests/driver" 23 "github.com/stretchr/testify/mock" 24 "github.com/stretchr/testify/suite" 25 "testing" 26 "time" 27 ) 28 29 const ( 30 counterTypeId = 1101 31 counterLabel = "counter label" 32 ) 33 34 type CounterTestSuite struct { 35 suite.Suite 36 mediaDriver *driver.MediaDriver 37 clientA *aeron.Aeron 38 clientB *aeron.Aeron 39 keyBuffer *atomic.Buffer 40 labelBuffer *atomic.Buffer 41 } 42 43 func (s *CounterTestSuite) SetupTest() { 44 mediaDriver, err := driver.StartMediaDriver() 45 s.Require().NoError(err, "Couldn't start Media Driver") 46 s.mediaDriver = mediaDriver 47 48 clientA, errA := aeron.Connect(aeron.NewContext().AeronDir(s.mediaDriver.TempDir)) 49 clientB, errB := aeron.Connect(aeron.NewContext().AeronDir(s.mediaDriver.TempDir)) 50 if errA != nil || errB != nil { 51 // Testify does not run TearDownTest if SetupTest fails. We have to manually stop Media Driver. 52 s.mediaDriver.StopMediaDriver() 53 s.Require().NoError(errA, "aeron couldn't connect") 54 s.Require().NoError(errB, "aeron couldn't connect") 55 } 56 s.clientA = clientA 57 s.clientB = clientB 58 59 s.keyBuffer = atomic.MakeBuffer(make([]byte, 8)) 60 s.labelBuffer = atomic.MakeBuffer([]byte(counterLabel)) 61 } 62 63 func (s *CounterTestSuite) TearDownTest() { 64 s.clientA.Close() 65 s.clientB.Close() 66 s.mediaDriver.StopMediaDriver() 67 } 68 69 func (s *CounterTestSuite) TestShouldBeAbleToAddCounter() { 70 chanA := make(chan int32, 10) 71 availableCounterHandlerClientA := testdata.NewMockAvailableCounterHandler(s.T()) 72 s.clientA.AddAvailableCounterHandler(availableCounterHandlerClientA) 73 availableCounterHandlerClientA.On("Handle", 74 mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { 75 // Java has a verify(call, timeout). Go does not. This allows us to wait for the right call before asserting 76 // that the right call happened. As redundant as it is, this is probably the easiest way to test this 77 // functionality. 78 chanA <- args.Get(2).(int32) 79 }) 80 81 chanB := make(chan int32, 10) 82 availableCounterHandlerClientB := testdata.NewMockAvailableCounterHandler(s.T()) 83 s.clientB.AddAvailableCounterHandler(availableCounterHandlerClientB) 84 availableCounterHandlerClientB.On("Handle", 85 mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { 86 chanB <- args.Get(2).(int32) 87 }) 88 89 regId, err := s.clientA.AddCounter( 90 counterTypeId, 91 s.keyBuffer, 92 0, 93 s.keyBuffer.Capacity(), 94 s.labelBuffer, 95 0, 96 int32(len(counterLabel))) 97 s.Require().NoError(err) 98 var counter *aeron.Counter 99 for counter == nil && err == nil { 100 counter, err = s.clientA.FindCounter(regId) 101 } 102 s.Require().NoError(err) 103 104 s.Assert().False(counter.IsClosed()) 105 s.Assert().Equal(counter.RegistrationId(), s.clientA.CounterReader().GetCounterRegistrationId(counter.Id())) 106 s.Assert().Equal(s.clientA.ClientID(), s.clientA.CounterReader().GetCounterOwnerId(counter.Id())) 107 108 GetCounterFromChanOrFail(counter.Id(), chanA, s.Suite) 109 GetCounterFromChanOrFail(counter.Id(), chanB, s.Suite) 110 availableCounterHandlerClientA.AssertCalled(s.T(), "Handle", 111 mock.Anything, counter.RegistrationId(), counter.Id()) 112 availableCounterHandlerClientB.AssertCalled(s.T(), "Handle", 113 mock.Anything, counter.RegistrationId(), counter.Id()) 114 } 115 116 func GetCounterFromChanOrFail(counterId int32, ch chan int32, s suite.Suite) { 117 for { 118 select { 119 case id := <-ch: 120 if id == counterId { 121 return 122 } 123 case <-time.After(5 * time.Second): 124 s.FailNow("Timed out waiting for avaiable counter handler call") 125 } 126 } 127 } 128 129 func (s *CounterTestSuite) TestShouldBeAbleToAddReadableCounterWithinHandler() { 130 ch, handler := s.createReadableCounterHandler() 131 s.clientB.AddAvailableCounterHandler(handler) 132 133 regId, err := s.clientA.AddCounter( 134 counterTypeId, 135 s.keyBuffer, 136 0, 137 s.keyBuffer.Capacity(), 138 s.labelBuffer, 139 0, 140 int32(len(counterLabel))) 141 s.Require().NoError(err) 142 var counter *aeron.Counter 143 for counter == nil && err == nil { 144 counter, err = s.clientA.FindCounter(regId) 145 } 146 s.Require().NoError(err) 147 148 var readableCounter *counters.ReadableCounter 149 select { 150 case readableCounter = <-ch: 151 case <-time.After(5 * time.Second): 152 s.Fail("Timed out waiting for a ReadableCounter") 153 } 154 s.Assert().Equal(counters.RecordAllocated, readableCounter.State()) 155 s.Assert().Equal(counter.Id(), readableCounter.CounterId) 156 s.Assert().Equal(counter.RegistrationId(), readableCounter.RegistrationId) 157 s.Assert().Equal(counterLabel, readableCounter.Label()) 158 } 159 160 func (s *CounterTestSuite) TestShouldBeAbleToAddReadableCounterWithinHandlerWithAddCounterByLabel() { 161 ch, handler := s.createReadableCounterHandler() 162 s.clientB.AddAvailableCounterHandler(handler) 163 164 regId, err := s.clientA.AddCounterByLabel(counterTypeId, counterLabel) 165 s.Require().NoError(err) 166 var counter *aeron.Counter 167 for counter == nil && err == nil { 168 counter, err = s.clientA.FindCounter(regId) 169 } 170 s.Require().NoError(err) 171 172 var readableCounter *counters.ReadableCounter 173 select { 174 case readableCounter = <-ch: 175 case <-time.After(5 * time.Second): 176 s.Fail("Timed out waiting for a ReadableCounter") 177 } 178 s.Assert().Equal(counters.RecordAllocated, readableCounter.State()) 179 s.Assert().Equal(counter.Id(), readableCounter.CounterId) 180 s.Assert().Equal(counter.RegistrationId(), readableCounter.RegistrationId) 181 s.Assert().Equal(counterLabel, readableCounter.Label()) 182 } 183 184 func (s *CounterTestSuite) TestShouldBeAbleToAddReadableCounterAndGetCounterReads() { 185 ch, handler := s.createReadableCounterHandler() 186 s.clientB.AddAvailableCounterHandler(handler) 187 188 regId, err := s.clientA.AddCounter( 189 counterTypeId, 190 s.keyBuffer, 191 0, 192 s.keyBuffer.Capacity(), 193 s.labelBuffer, 194 0, 195 int32(len(counterLabel))) 196 s.Require().NoError(err) 197 var counter *aeron.Counter 198 for counter == nil && err == nil { 199 counter, err = s.clientA.FindCounter(regId) 200 } 201 s.Require().NoError(err) 202 203 var readableCounter *counters.ReadableCounter 204 select { 205 case readableCounter = <-ch: 206 case <-time.After(5 * time.Second): 207 s.Fail("Timed out waiting for a ReadableCounter") 208 } 209 s.Assert().Equal(counters.RecordAllocated, readableCounter.State()) 210 s.Assert().Equal(counter.Id(), readableCounter.CounterId) 211 s.Assert().Equal(counter.RegistrationId(), readableCounter.RegistrationId) 212 s.Assert().Equal(counterLabel, readableCounter.Label()) 213 214 s.Assert().Equal(counter.Counter().Get(), readableCounter.Get()) 215 counter.Counter().Set(42) 216 s.Assert().Equal(counter.Counter().Get(), readableCounter.Get()) 217 } 218 219 func (s *CounterTestSuite) TestShouldCloseReadableCounterOnUnavailableCounter() { 220 ch, readableHandler := s.createReadableCounterHandler() 221 s.clientB.AddAvailableCounterHandler(readableHandler) 222 223 regId, err := s.clientA.AddCounter( 224 counterTypeId, 225 s.keyBuffer, 226 0, 227 s.keyBuffer.Capacity(), 228 s.labelBuffer, 229 0, 230 int32(len(counterLabel))) 231 s.Require().NoError(err) 232 var counter *aeron.Counter 233 for counter == nil && err == nil { 234 counter, err = s.clientA.FindCounter(regId) 235 } 236 s.Require().NoError(err) 237 238 var readableCounter *counters.ReadableCounter 239 select { 240 case readableCounter = <-ch: 241 case <-time.After(5 * time.Second): 242 s.Fail("Timed out waiting for a ReadableCounter") 243 } 244 unavailableHandler := s.createUnavailableCounterHandler(readableCounter) 245 s.clientB.AddUnavailableCounterHandler(unavailableHandler) 246 247 s.Require().False(readableCounter.IsClosed()) 248 s.Require().Equal(counters.RecordAllocated, readableCounter.State()) 249 250 counter.Close() 251 252 start := time.Now() 253 for !readableCounter.IsClosed() { 254 if time.Now().After(start.Add(5 * time.Second)) { 255 s.Fail("Timed out waiting for ReadableCounter to close") 256 } 257 time.Sleep(10 * time.Millisecond) 258 } 259 } 260 261 func (s *CounterTestSuite) TestShouldGetUnavailableCounterWhenOwningClientIsClosed() { 262 ch, readableHandler := s.createReadableCounterHandler() 263 s.clientB.AddAvailableCounterHandler(readableHandler) 264 265 regId, err := s.clientA.AddCounter( 266 counterTypeId, 267 s.keyBuffer, 268 0, 269 s.keyBuffer.Capacity(), 270 s.labelBuffer, 271 0, 272 int32(len(counterLabel))) 273 s.Require().NoError(err) 274 var counter *aeron.Counter 275 for counter == nil && err == nil { 276 counter, err = s.clientA.FindCounter(regId) 277 } 278 s.Require().NoError(err) 279 280 var readableCounter *counters.ReadableCounter 281 select { 282 case readableCounter = <-ch: 283 case <-time.After(5 * time.Second): 284 s.FailNow("Timed out waiting for a ReadableCounter") 285 } 286 unavailableHandler := s.createUnavailableCounterHandler(readableCounter) 287 s.clientB.AddUnavailableCounterHandler(unavailableHandler) 288 289 s.Require().False(readableCounter.IsClosed()) 290 s.Require().Equal(counters.RecordAllocated, readableCounter.State()) 291 292 s.clientA.Close() 293 294 start := time.Now() 295 for !readableCounter.IsClosed() { 296 if time.Now().After(start.Add(5 * time.Second)) { 297 s.FailNow("Timed out waiting for ReadableCounter to close") 298 return 299 } 300 time.Sleep(10 * time.Millisecond) 301 } 302 } 303 304 type ReadableCounterHandler struct { 305 suite *suite.Suite 306 ch chan *counters.ReadableCounter 307 } 308 309 func (r ReadableCounterHandler) Handle(reader *counters.Reader, registrationId int64, counterId int32) { 310 if counterTypeId == reader.GetCounterTypeId(counterId) { 311 counter, err := counters.NewReadableRegisteredCounter(reader, registrationId, counterId) 312 r.suite.Require().NoError(err) 313 r.ch <- counter 314 } 315 } 316 317 func (s *CounterTestSuite) createReadableCounterHandler() (chan *counters.ReadableCounter, aeron.AvailableCounterHandler) { 318 ch := make(chan *counters.ReadableCounter, 1) 319 handler := ReadableCounterHandler{suite: &s.Suite, ch: ch} 320 return ch, handler 321 } 322 323 type UnavailableCounterHandler struct { 324 suite *suite.Suite 325 counter *counters.ReadableCounter 326 } 327 328 func (r UnavailableCounterHandler) Handle(_ *counters.Reader, registrationId int64, _ int32) { 329 if r.counter.RegistrationId == registrationId { 330 r.counter.Close() 331 } 332 } 333 334 func (s *CounterTestSuite) createUnavailableCounterHandler(counter *counters.ReadableCounter) aeron.UnavailableCounterHandler { 335 return UnavailableCounterHandler{suite: &s.Suite, counter: counter} 336 } 337 338 func TestCounter(t *testing.T) { 339 suite.Run(t, new(CounterTestSuite)) 340 }