github.com/vmware/transport-go@v1.3.4/bus/transaction_test.go (about)

     1  // Copyright 2019-2020 VMware, Inc.
     2  // SPDX-License-Identifier: BSD-2-Clause
     3  
     4  package bus
     5  
     6  import (
     7  	"errors"
     8  	"github.com/stretchr/testify/assert"
     9  	"github.com/vmware/transport-go/model"
    10  	"sync"
    11  	"sync/atomic"
    12  	"testing"
    13  )
    14  
    15  func TestBusTransaction_OnCompleteSync(t *testing.T) {
    16  
    17  	bus := newTestEventBus()
    18  
    19  	bus.GetChannelManager().CreateChannel("test-channel")
    20  
    21  	var channelReqMessage *model.Message
    22  	var requestCounter = 0
    23  
    24  	wg := sync.WaitGroup{}
    25  
    26  	mh, _ := bus.ListenRequestStream("test-channel")
    27  	mh.Handle(func(message *model.Message) {
    28  		requestCounter++
    29  		channelReqMessage = message
    30  		wg.Done()
    31  	}, func(e error) {
    32  		assert.Fail(t, "unexpected error")
    33  	})
    34  
    35  	tr := newBusTransaction(bus, syncTransaction)
    36  
    37  	bus.GetStoreManager().CreateStore("testStore")
    38  	assert.Nil(t, tr.WaitForStoreReady("testStore"))
    39  	assert.Nil(t, tr.SendRequest("test-channel", "sample-request"))
    40  
    41  	var completeCounter int64
    42  
    43  	tr.OnComplete(func(responses []*model.Message) {
    44  		atomic.AddInt64(&completeCounter, 1)
    45  		wg.Done()
    46  	})
    47  
    48  	tr.OnError(func(e error) {
    49  		assert.Fail(t, "unexpected error")
    50  	})
    51  
    52  	tr.OnComplete(func(responses []*model.Message) {
    53  		atomic.AddInt64(&completeCounter, 1)
    54  		assert.Equal(t, len(responses), 2)
    55  		assert.Equal(t, responses[1].Channel, "test-channel")
    56  		assert.Equal(t, responses[1].Payload, "sample-response")
    57  		wg.Done()
    58  	})
    59  
    60  	assert.Equal(t, requestCounter, 0)
    61  
    62  	wg.Add(1)
    63  
    64  	assert.Nil(t, tr.Commit())
    65  
    66  	go bus.GetStoreManager().CreateStore("testStore").Initialize()
    67  
    68  	wg.Wait()
    69  
    70  	assert.Equal(t, requestCounter, 1)
    71  	assert.NotNil(t, channelReqMessage)
    72  
    73  	assert.Equal(t, channelReqMessage.Payload, "sample-request")
    74  
    75  	for i := 0; i < 50; i++ {
    76  		bus.SendResponseMessage("test-channel", "general-message", nil)
    77  	}
    78  
    79  	assert.Equal(t, completeCounter, int64(0))
    80  
    81  	wg.Add(2)
    82  	bus.SendResponseMessage("test-channel", "sample-response", channelReqMessage.DestinationId)
    83  
    84  	wg.Wait()
    85  
    86  	assert.Equal(t, tr.(*busTransaction).state, completedState)
    87  
    88  	assert.Equal(t, completeCounter, int64(2))
    89  
    90  	bus.SendResponseMessage("test-channel", "sample-response2", channelReqMessage.DestinationId)
    91  	assert.Equal(t, completeCounter, int64(2))
    92  }
    93  
    94  func TestBusTransaction_OnCompleteErrorHandling(t *testing.T) {
    95  
    96  	bus := newTestEventBus()
    97  
    98  	tr := newBusTransaction(bus, syncTransaction)
    99  
   100  	assert.EqualError(t, tr.Commit(), "cannot commit empty transaction")
   101  
   102  	assert.Equal(t, tr.(*busTransaction).state, uncommittedState)
   103  
   104  	bus.GetStoreManager().CreateStore("testStore")
   105  	assert.Nil(t, tr.WaitForStoreReady("testStore"))
   106  
   107  	assert.EqualError(t, tr.WaitForStoreReady("invalid-store"), "cannot find store 'invalid-store'")
   108  
   109  	tr.Commit()
   110  
   111  	assert.EqualError(t, tr.OnComplete(func(responses []*model.Message) {}), "transaction has already been committed")
   112  
   113  	assert.Equal(t, tr.(*busTransaction).state, committedState)
   114  	assert.EqualError(t, tr.Commit(), "transaction has already been committed")
   115  
   116  	assert.EqualError(t, tr.WaitForStoreReady("test"), "transaction has already been committed")
   117  	assert.EqualError(t, tr.SendRequest("test", "test"), "transaction has already been committed")
   118  }
   119  
   120  func TestBusTransaction_OnErrorSync(t *testing.T) {
   121  
   122  	bus := newTestEventBus()
   123  
   124  	tr := newBusTransaction(bus, syncTransaction)
   125  
   126  	bus.GetStoreManager().CreateStore("testStore")
   127  	assert.Nil(t, tr.WaitForStoreReady("testStore"))
   128  
   129  	bus.GetChannelManager().CreateChannel("test-channel")
   130  
   131  	var channelReqMessage *model.Message
   132  	var requestCounter = 0
   133  
   134  	wg := sync.WaitGroup{}
   135  
   136  	mh, _ := bus.ListenRequestStream("test-channel")
   137  	mh.Handle(func(message *model.Message) {
   138  		requestCounter++
   139  		channelReqMessage = message
   140  		wg.Done()
   141  	}, func(e error) {
   142  	})
   143  
   144  	tr.SendRequest("test-channel", "sample-request")
   145  	tr.SendRequest("test-channel", "sample-request")
   146  	tr.SendRequest("test-channel", "sample-request")
   147  
   148  	tr.OnComplete(func(responses []*model.Message) {
   149  		assert.Fail(t, "invalid state")
   150  	})
   151  
   152  	var errorHandlerCount int64 = 0
   153  	tr.OnError(func(e error) {
   154  		atomic.AddInt64(&errorHandlerCount, 1)
   155  		wg.Done()
   156  	})
   157  
   158  	tr.OnError(func(e error) {
   159  		atomic.AddInt64(&errorHandlerCount, 1)
   160  		assert.EqualError(t, e, "test-error")
   161  		wg.Done()
   162  	})
   163  
   164  	tr.Commit()
   165  
   166  	assert.Equal(t, tr.(*busTransaction).state, committedState)
   167  
   168  	wg.Add(1)
   169  
   170  	bus.GetStoreManager().GetStore("testStore").Initialize()
   171  
   172  	wg.Wait()
   173  
   174  	assert.Equal(t, requestCounter, 1)
   175  	assert.NotNil(t, channelReqMessage)
   176  
   177  	wg.Add(2)
   178  	bus.SendErrorMessage("test-channel", errors.New("test-error"), channelReqMessage.DestinationId)
   179  
   180  	wg.Wait()
   181  
   182  	assert.Equal(t, tr.(*busTransaction).state, abortedState)
   183  
   184  	assert.Equal(t, requestCounter, 1)
   185  	assert.Equal(t, errorHandlerCount, int64(2))
   186  
   187  	assert.EqualError(t, tr.Commit(), "transaction has already been committed")
   188  }
   189  
   190  func TestBusTransaction_OnCompleteAsync(t *testing.T) {
   191  
   192  	bus := newTestEventBus()
   193  
   194  	bus.GetChannelManager().CreateChannel("test-channel")
   195  
   196  	var channelReqMessage *model.Message
   197  	var requestCounter = 0
   198  
   199  	wg := sync.WaitGroup{}
   200  
   201  	mh, _ := bus.ListenRequestStream("test-channel")
   202  	mh.Handle(func(message *model.Message) {
   203  		requestCounter++
   204  		channelReqMessage = message
   205  		wg.Done()
   206  	}, func(e error) {
   207  		assert.Fail(t, "unexpected error")
   208  	})
   209  
   210  	tr := newBusTransaction(bus, asyncTransaction)
   211  
   212  	bus.GetStoreManager().CreateStore("testStore")
   213  	assert.Nil(t, tr.WaitForStoreReady("testStore"))
   214  	assert.Nil(t, tr.WaitForStoreReady("testStore"))
   215  	bus.GetStoreManager().CreateStore("testStore2")
   216  	assert.Nil(t, tr.WaitForStoreReady("testStore2"))
   217  	bus.GetStoreManager().CreateStore("testStore3")
   218  	assert.Nil(t, tr.WaitForStoreReady("testStore3"))
   219  	assert.Nil(t, tr.SendRequest("test-channel", "sample-request"))
   220  
   221  	var completeCounter int64
   222  
   223  	tr.OnComplete(func(responses []*model.Message) {
   224  		atomic.AddInt64(&completeCounter, 1)
   225  		wg.Done()
   226  	})
   227  
   228  	tr.OnComplete(func(responses []*model.Message) {
   229  		atomic.AddInt64(&completeCounter, 1)
   230  		assert.Equal(t, len(responses), 5)
   231  		assert.Equal(t, responses[4].Channel, "test-channel")
   232  		assert.Equal(t, responses[4].Payload, "sample-response")
   233  		wg.Done()
   234  	})
   235  
   236  	wg.Add(1)
   237  	assert.Nil(t, tr.Commit())
   238  	wg.Wait()
   239  
   240  	assert.NotNil(t, bus.GetStoreManager().GetStore("testStore"))
   241  	assert.NotNil(t, bus.GetStoreManager().GetStore("testStore2"))
   242  	assert.NotNil(t, bus.GetStoreManager().GetStore("testStore3"))
   243  	assert.Equal(t, requestCounter, 1)
   244  	assert.NotNil(t, channelReqMessage)
   245  	assert.Equal(t, channelReqMessage.Payload, "sample-request")
   246  
   247  	for i := 0; i < 20; i++ {
   248  		bus.SendResponseMessage("test-channel", "general-message", nil)
   249  	}
   250  
   251  	assert.Equal(t, completeCounter, int64(0))
   252  
   253  	wg.Add(2)
   254  
   255  	bus.SendResponseMessage("test-channel", "sample-response", channelReqMessage.DestinationId)
   256  	bus.GetStoreManager().GetStore("testStore").Initialize()
   257  	bus.GetStoreManager().GetStore("testStore2").Initialize()
   258  	bus.GetStoreManager().GetStore("testStore3").Initialize()
   259  
   260  	wg.Wait()
   261  
   262  	assert.Equal(t, completeCounter, int64(2))
   263  }
   264  
   265  func TestBusTransaction_OnErrorAsync(t *testing.T) {
   266  
   267  	bus := newTestEventBus()
   268  
   269  	tr := newBusTransaction(bus, asyncTransaction)
   270  
   271  	bus.GetChannelManager().CreateChannel("test-channel")
   272  	bus.GetChannelManager().CreateChannel("test-channel2")
   273  
   274  	var channelReqMessage, channelReqMessage2 *model.Message
   275  
   276  	wg := sync.WaitGroup{}
   277  
   278  	mh, _ := bus.ListenRequestStream("test-channel")
   279  	mh.Handle(func(message *model.Message) {
   280  		channelReqMessage = message
   281  		wg.Done()
   282  	}, func(e error) {
   283  	})
   284  
   285  	mh2, _ := bus.ListenRequestStream("test-channel2")
   286  	mh2.Handle(func(message *model.Message) {
   287  		channelReqMessage2 = message
   288  		wg.Done()
   289  	}, func(e error) {
   290  	})
   291  
   292  	tr.OnComplete(func(responses []*model.Message) {
   293  		assert.Fail(t, "invalid state")
   294  	})
   295  
   296  	var errorHandlerCount int64 = 0
   297  	tr.OnError(func(e error) {
   298  		atomic.AddInt64(&errorHandlerCount, 1)
   299  		assert.EqualError(t, e, "test-error")
   300  		wg.Done()
   301  	})
   302  
   303  	tr.SendRequest("test-channel", "sample-request")
   304  	tr.SendRequest("test-channel2", "sample-request2")
   305  
   306  	wg.Add(2)
   307  	tr.Commit()
   308  	wg.Wait()
   309  
   310  	wg.Add(1)
   311  	bus.SendErrorMessage("test-channel2", errors.New("test-error"), channelReqMessage2.DestinationId)
   312  
   313  	wg.Wait()
   314  
   315  	assert.Equal(t, errorHandlerCount, int64(1))
   316  
   317  	for i := 0; i < 50; i++ {
   318  		bus.SendErrorMessage("test-channel", errors.New("test-error-2"), channelReqMessage.DestinationId)
   319  	}
   320  
   321  	assert.Equal(t, errorHandlerCount, int64(1))
   322  }