github.com/lirm/aeron-go@v0.0.0-20230415210743-920325491dc4/systests/counter_test.go (about)

     1  // Copyright 2022 Steven Stern
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package systests
    16  
    17  import (
    18  	"github.com/lirm/aeron-go/aeron"
    19  	"github.com/lirm/aeron-go/aeron/atomic"
    20  	"github.com/lirm/aeron-go/aeron/counters"
    21  	"github.com/lirm/aeron-go/aeron/testdata"
    22  	"github.com/lirm/aeron-go/systests/driver"
    23  	"github.com/stretchr/testify/mock"
    24  	"github.com/stretchr/testify/suite"
    25  	"testing"
    26  	"time"
    27  )
    28  
    29  const (
    30  	counterTypeId = 1101
    31  	counterLabel  = "counter label"
    32  )
    33  
    34  type CounterTestSuite struct {
    35  	suite.Suite
    36  	mediaDriver *driver.MediaDriver
    37  	clientA     *aeron.Aeron
    38  	clientB     *aeron.Aeron
    39  	keyBuffer   *atomic.Buffer
    40  	labelBuffer *atomic.Buffer
    41  }
    42  
    43  func (s *CounterTestSuite) SetupTest() {
    44  	mediaDriver, err := driver.StartMediaDriver()
    45  	s.Require().NoError(err, "Couldn't start Media Driver")
    46  	s.mediaDriver = mediaDriver
    47  
    48  	clientA, errA := aeron.Connect(aeron.NewContext().AeronDir(s.mediaDriver.TempDir))
    49  	clientB, errB := aeron.Connect(aeron.NewContext().AeronDir(s.mediaDriver.TempDir))
    50  	if errA != nil || errB != nil {
    51  		// Testify does not run TearDownTest if SetupTest fails.  We have to manually stop Media Driver.
    52  		s.mediaDriver.StopMediaDriver()
    53  		s.Require().NoError(errA, "aeron couldn't connect")
    54  		s.Require().NoError(errB, "aeron couldn't connect")
    55  	}
    56  	s.clientA = clientA
    57  	s.clientB = clientB
    58  
    59  	s.keyBuffer = atomic.MakeBuffer(make([]byte, 8))
    60  	s.labelBuffer = atomic.MakeBuffer([]byte(counterLabel))
    61  }
    62  
    63  func (s *CounterTestSuite) TearDownTest() {
    64  	s.clientA.Close()
    65  	s.clientB.Close()
    66  	s.mediaDriver.StopMediaDriver()
    67  }
    68  
    69  func (s *CounterTestSuite) TestShouldBeAbleToAddCounter() {
    70  	chanA := make(chan int32, 10)
    71  	availableCounterHandlerClientA := testdata.NewMockAvailableCounterHandler(s.T())
    72  	s.clientA.AddAvailableCounterHandler(availableCounterHandlerClientA)
    73  	availableCounterHandlerClientA.On("Handle",
    74  		mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
    75  		// Java has a verify(call, timeout).  Go does not.  This allows us to wait for the right call before asserting
    76  		// that the right call happened.  As redundant as it is, this is probably the easiest way to test this
    77  		// functionality.
    78  		chanA <- args.Get(2).(int32)
    79  	})
    80  
    81  	chanB := make(chan int32, 10)
    82  	availableCounterHandlerClientB := testdata.NewMockAvailableCounterHandler(s.T())
    83  	s.clientB.AddAvailableCounterHandler(availableCounterHandlerClientB)
    84  	availableCounterHandlerClientB.On("Handle",
    85  		mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
    86  		chanB <- args.Get(2).(int32)
    87  	})
    88  
    89  	regId, err := s.clientA.AddCounter(
    90  		counterTypeId,
    91  		s.keyBuffer,
    92  		0,
    93  		s.keyBuffer.Capacity(),
    94  		s.labelBuffer,
    95  		0,
    96  		int32(len(counterLabel)))
    97  	s.Require().NoError(err)
    98  	var counter *aeron.Counter
    99  	for counter == nil && err == nil {
   100  		counter, err = s.clientA.FindCounter(regId)
   101  	}
   102  	s.Require().NoError(err)
   103  
   104  	s.Assert().False(counter.IsClosed())
   105  	s.Assert().Equal(counter.RegistrationId(), s.clientA.CounterReader().GetCounterRegistrationId(counter.Id()))
   106  	s.Assert().Equal(s.clientA.ClientID(), s.clientA.CounterReader().GetCounterOwnerId(counter.Id()))
   107  
   108  	GetCounterFromChanOrFail(counter.Id(), chanA, s.Suite)
   109  	GetCounterFromChanOrFail(counter.Id(), chanB, s.Suite)
   110  	availableCounterHandlerClientA.AssertCalled(s.T(), "Handle",
   111  		mock.Anything, counter.RegistrationId(), counter.Id())
   112  	availableCounterHandlerClientB.AssertCalled(s.T(), "Handle",
   113  		mock.Anything, counter.RegistrationId(), counter.Id())
   114  }
   115  
   116  func GetCounterFromChanOrFail(counterId int32, ch chan int32, s suite.Suite) {
   117  	for {
   118  		select {
   119  		case id := <-ch:
   120  			if id == counterId {
   121  				return
   122  			}
   123  		case <-time.After(5 * time.Second):
   124  			s.FailNow("Timed out waiting for avaiable counter handler call")
   125  		}
   126  	}
   127  }
   128  
   129  func (s *CounterTestSuite) TestShouldBeAbleToAddReadableCounterWithinHandler() {
   130  	ch, handler := s.createReadableCounterHandler()
   131  	s.clientB.AddAvailableCounterHandler(handler)
   132  
   133  	regId, err := s.clientA.AddCounter(
   134  		counterTypeId,
   135  		s.keyBuffer,
   136  		0,
   137  		s.keyBuffer.Capacity(),
   138  		s.labelBuffer,
   139  		0,
   140  		int32(len(counterLabel)))
   141  	s.Require().NoError(err)
   142  	var counter *aeron.Counter
   143  	for counter == nil && err == nil {
   144  		counter, err = s.clientA.FindCounter(regId)
   145  	}
   146  	s.Require().NoError(err)
   147  
   148  	var readableCounter *counters.ReadableCounter
   149  	select {
   150  	case readableCounter = <-ch:
   151  	case <-time.After(5 * time.Second):
   152  		s.Fail("Timed out waiting for a ReadableCounter")
   153  	}
   154  	s.Assert().Equal(counters.RecordAllocated, readableCounter.State())
   155  	s.Assert().Equal(counter.Id(), readableCounter.CounterId)
   156  	s.Assert().Equal(counter.RegistrationId(), readableCounter.RegistrationId)
   157  	s.Assert().Equal(counterLabel, readableCounter.Label())
   158  }
   159  
   160  func (s *CounterTestSuite) TestShouldBeAbleToAddReadableCounterWithinHandlerWithAddCounterByLabel() {
   161  	ch, handler := s.createReadableCounterHandler()
   162  	s.clientB.AddAvailableCounterHandler(handler)
   163  
   164  	regId, err := s.clientA.AddCounterByLabel(counterTypeId, counterLabel)
   165  	s.Require().NoError(err)
   166  	var counter *aeron.Counter
   167  	for counter == nil && err == nil {
   168  		counter, err = s.clientA.FindCounter(regId)
   169  	}
   170  	s.Require().NoError(err)
   171  
   172  	var readableCounter *counters.ReadableCounter
   173  	select {
   174  	case readableCounter = <-ch:
   175  	case <-time.After(5 * time.Second):
   176  		s.Fail("Timed out waiting for a ReadableCounter")
   177  	}
   178  	s.Assert().Equal(counters.RecordAllocated, readableCounter.State())
   179  	s.Assert().Equal(counter.Id(), readableCounter.CounterId)
   180  	s.Assert().Equal(counter.RegistrationId(), readableCounter.RegistrationId)
   181  	s.Assert().Equal(counterLabel, readableCounter.Label())
   182  }
   183  
   184  func (s *CounterTestSuite) TestShouldBeAbleToAddReadableCounterAndGetCounterReads() {
   185  	ch, handler := s.createReadableCounterHandler()
   186  	s.clientB.AddAvailableCounterHandler(handler)
   187  
   188  	regId, err := s.clientA.AddCounter(
   189  		counterTypeId,
   190  		s.keyBuffer,
   191  		0,
   192  		s.keyBuffer.Capacity(),
   193  		s.labelBuffer,
   194  		0,
   195  		int32(len(counterLabel)))
   196  	s.Require().NoError(err)
   197  	var counter *aeron.Counter
   198  	for counter == nil && err == nil {
   199  		counter, err = s.clientA.FindCounter(regId)
   200  	}
   201  	s.Require().NoError(err)
   202  
   203  	var readableCounter *counters.ReadableCounter
   204  	select {
   205  	case readableCounter = <-ch:
   206  	case <-time.After(5 * time.Second):
   207  		s.Fail("Timed out waiting for a ReadableCounter")
   208  	}
   209  	s.Assert().Equal(counters.RecordAllocated, readableCounter.State())
   210  	s.Assert().Equal(counter.Id(), readableCounter.CounterId)
   211  	s.Assert().Equal(counter.RegistrationId(), readableCounter.RegistrationId)
   212  	s.Assert().Equal(counterLabel, readableCounter.Label())
   213  
   214  	s.Assert().Equal(counter.Counter().Get(), readableCounter.Get())
   215  	counter.Counter().Set(42)
   216  	s.Assert().Equal(counter.Counter().Get(), readableCounter.Get())
   217  }
   218  
   219  func (s *CounterTestSuite) TestShouldCloseReadableCounterOnUnavailableCounter() {
   220  	ch, readableHandler := s.createReadableCounterHandler()
   221  	s.clientB.AddAvailableCounterHandler(readableHandler)
   222  
   223  	regId, err := s.clientA.AddCounter(
   224  		counterTypeId,
   225  		s.keyBuffer,
   226  		0,
   227  		s.keyBuffer.Capacity(),
   228  		s.labelBuffer,
   229  		0,
   230  		int32(len(counterLabel)))
   231  	s.Require().NoError(err)
   232  	var counter *aeron.Counter
   233  	for counter == nil && err == nil {
   234  		counter, err = s.clientA.FindCounter(regId)
   235  	}
   236  	s.Require().NoError(err)
   237  
   238  	var readableCounter *counters.ReadableCounter
   239  	select {
   240  	case readableCounter = <-ch:
   241  	case <-time.After(5 * time.Second):
   242  		s.Fail("Timed out waiting for a ReadableCounter")
   243  	}
   244  	unavailableHandler := s.createUnavailableCounterHandler(readableCounter)
   245  	s.clientB.AddUnavailableCounterHandler(unavailableHandler)
   246  
   247  	s.Require().False(readableCounter.IsClosed())
   248  	s.Require().Equal(counters.RecordAllocated, readableCounter.State())
   249  
   250  	counter.Close()
   251  
   252  	start := time.Now()
   253  	for !readableCounter.IsClosed() {
   254  		if time.Now().After(start.Add(5 * time.Second)) {
   255  			s.Fail("Timed out waiting for ReadableCounter to close")
   256  		}
   257  		time.Sleep(10 * time.Millisecond)
   258  	}
   259  }
   260  
   261  func (s *CounterTestSuite) TestShouldGetUnavailableCounterWhenOwningClientIsClosed() {
   262  	ch, readableHandler := s.createReadableCounterHandler()
   263  	s.clientB.AddAvailableCounterHandler(readableHandler)
   264  
   265  	regId, err := s.clientA.AddCounter(
   266  		counterTypeId,
   267  		s.keyBuffer,
   268  		0,
   269  		s.keyBuffer.Capacity(),
   270  		s.labelBuffer,
   271  		0,
   272  		int32(len(counterLabel)))
   273  	s.Require().NoError(err)
   274  	var counter *aeron.Counter
   275  	for counter == nil && err == nil {
   276  		counter, err = s.clientA.FindCounter(regId)
   277  	}
   278  	s.Require().NoError(err)
   279  
   280  	var readableCounter *counters.ReadableCounter
   281  	select {
   282  	case readableCounter = <-ch:
   283  	case <-time.After(5 * time.Second):
   284  		s.FailNow("Timed out waiting for a ReadableCounter")
   285  	}
   286  	unavailableHandler := s.createUnavailableCounterHandler(readableCounter)
   287  	s.clientB.AddUnavailableCounterHandler(unavailableHandler)
   288  
   289  	s.Require().False(readableCounter.IsClosed())
   290  	s.Require().Equal(counters.RecordAllocated, readableCounter.State())
   291  
   292  	s.clientA.Close()
   293  
   294  	start := time.Now()
   295  	for !readableCounter.IsClosed() {
   296  		if time.Now().After(start.Add(5 * time.Second)) {
   297  			s.FailNow("Timed out waiting for ReadableCounter to close")
   298  			return
   299  		}
   300  		time.Sleep(10 * time.Millisecond)
   301  	}
   302  }
   303  
   304  type ReadableCounterHandler struct {
   305  	suite *suite.Suite
   306  	ch    chan *counters.ReadableCounter
   307  }
   308  
   309  func (r ReadableCounterHandler) Handle(reader *counters.Reader, registrationId int64, counterId int32) {
   310  	if counterTypeId == reader.GetCounterTypeId(counterId) {
   311  		counter, err := counters.NewReadableRegisteredCounter(reader, registrationId, counterId)
   312  		r.suite.Require().NoError(err)
   313  		r.ch <- counter
   314  	}
   315  }
   316  
   317  func (s *CounterTestSuite) createReadableCounterHandler() (chan *counters.ReadableCounter, aeron.AvailableCounterHandler) {
   318  	ch := make(chan *counters.ReadableCounter, 1)
   319  	handler := ReadableCounterHandler{suite: &s.Suite, ch: ch}
   320  	return ch, handler
   321  }
   322  
   323  type UnavailableCounterHandler struct {
   324  	suite   *suite.Suite
   325  	counter *counters.ReadableCounter
   326  }
   327  
   328  func (r UnavailableCounterHandler) Handle(_ *counters.Reader, registrationId int64, _ int32) {
   329  	if r.counter.RegistrationId == registrationId {
   330  		r.counter.Close()
   331  	}
   332  }
   333  
   334  func (s *CounterTestSuite) createUnavailableCounterHandler(counter *counters.ReadableCounter) aeron.UnavailableCounterHandler {
   335  	return UnavailableCounterHandler{suite: &s.Suite, counter: counter}
   336  }
   337  
   338  func TestCounter(t *testing.T) {
   339  	suite.Run(t, new(CounterTestSuite))
   340  }