github.com/vmware/transport-go@v1.3.4/bus/fabric_endpoint_test.go (about)

     1  // Copyright 2019-2020 VMware, Inc.
     2  // SPDX-License-Identifier: BSD-2-Clause
     3  
     4  package bus
     5  
     6  import (
     7  	"encoding/json"
     8  	"errors"
     9  	"github.com/google/uuid"
    10  	"github.com/stretchr/testify/assert"
    11  	"github.com/vmware/transport-go/model"
    12  	"github.com/vmware/transport-go/stompserver"
    13  	"sync"
    14  	"testing"
    15  )
    16  
    17  type MockStompServerMessage struct {
    18  	Destination string `json:"destination"`
    19  	Payload     []byte `json:"payload"`
    20  	conId       string
    21  }
    22  
    23  type MockStompServer struct {
    24  	started                           bool
    25  	sentMessages                      []MockStompServerMessage
    26  	subscribeHandlerFunction          stompserver.SubscribeHandlerFunction
    27  	connectionEventCallbacks          map[stompserver.StompSessionEventType]func(event *stompserver.ConnEvent)
    28  	unsubscribeHandlerFunction        stompserver.UnsubscribeHandlerFunction
    29  	applicationRequestHandlerFunction stompserver.ApplicationRequestHandlerFunction
    30  	wg                                *sync.WaitGroup
    31  }
    32  
    33  func (s *MockStompServer) Start() {
    34  	s.started = true
    35  }
    36  
    37  func (s *MockStompServer) Stop() {
    38  	s.started = false
    39  }
    40  
    41  func (s *MockStompServer) SendMessage(destination string, messageBody []byte) {
    42  	s.sentMessages = append(s.sentMessages,
    43  		MockStompServerMessage{Destination: destination, Payload: messageBody})
    44  
    45  	if s.wg != nil {
    46  		s.wg.Done()
    47  	}
    48  }
    49  
    50  func (s *MockStompServer) SendMessageToClient(conId string, destination string, messageBody []byte) {
    51  	s.sentMessages = append(s.sentMessages,
    52  		MockStompServerMessage{Destination: destination, Payload: messageBody, conId: conId})
    53  
    54  	if s.wg != nil {
    55  		s.wg.Done()
    56  	}
    57  }
    58  
    59  func (s *MockStompServer) OnUnsubscribeEvent(callback stompserver.UnsubscribeHandlerFunction) {
    60  	s.unsubscribeHandlerFunction = callback
    61  }
    62  
    63  func (s *MockStompServer) OnApplicationRequest(callback stompserver.ApplicationRequestHandlerFunction) {
    64  	s.applicationRequestHandlerFunction = callback
    65  }
    66  
    67  func (s *MockStompServer) OnSubscribeEvent(callback stompserver.SubscribeHandlerFunction) {
    68  	s.subscribeHandlerFunction = callback
    69  }
    70  
    71  func (s *MockStompServer) SetConnectionEventCallback(connEventType stompserver.StompSessionEventType, cb func(connEvent *stompserver.ConnEvent)) {
    72  	s.connectionEventCallbacks[connEventType] = cb
    73  	cb(&stompserver.ConnEvent{ConnId: "id"})
    74  }
    75  
    76  func newTestFabricEndpoint(bus EventBus, config EndpointConfig) (*fabricEndpoint, *MockStompServer) {
    77  
    78  	fe := newFabricEndpoint(bus, nil, config).(*fabricEndpoint)
    79  	ms := &MockStompServer{connectionEventCallbacks: make(map[stompserver.StompSessionEventType]func(event *stompserver.ConnEvent))}
    80  
    81  	fe.server = ms
    82  	fe.initHandlers()
    83  
    84  	return fe, ms
    85  }
    86  
    87  func TestFabricEndpoint_newFabricEndpoint(t *testing.T) {
    88  	fe, _ := newTestFabricEndpoint(nil, EndpointConfig{
    89  		TopicPrefix:      "/topic",
    90  		AppRequestPrefix: "/pub",
    91  		Heartbeat:        0,
    92  	})
    93  
    94  	assert.NotNil(t, fe)
    95  	assert.Equal(t, fe.config.TopicPrefix, "/topic/")
    96  	assert.Equal(t, fe.config.AppRequestPrefix, "/pub/")
    97  
    98  	fe, _ = newTestFabricEndpoint(nil, EndpointConfig{
    99  		TopicPrefix:      "/topic/",
   100  		AppRequestPrefix: "",
   101  		Heartbeat:        0,
   102  	})
   103  
   104  	assert.Equal(t, fe.config.TopicPrefix, "/topic/")
   105  	assert.Equal(t, fe.config.AppRequestPrefix, "")
   106  }
   107  
   108  func TestFabricEndpoint_StartAndStop(t *testing.T) {
   109  	fe, mockServer := newTestFabricEndpoint(nil, EndpointConfig{})
   110  	assert.Equal(t, mockServer.started, false)
   111  	fe.Start()
   112  	assert.Equal(t, mockServer.started, true)
   113  	fe.Stop()
   114  	assert.Equal(t, mockServer.started, false)
   115  }
   116  
   117  func TestFabricEndpoint_SubscribeEvent(t *testing.T) {
   118  
   119  	bus := newTestEventBus()
   120  	bus.GetChannelManager().CreateChannel(STOMP_SESSION_NOTIFY_CHANNEL) // used for internal channel protection test
   121  	fe, mockServer := newTestFabricEndpoint(bus,
   122  		EndpointConfig{TopicPrefix: "/topic", UserQueuePrefix: "/user/queue"})
   123  
   124  	bus.GetChannelManager().CreateChannel("test-service")
   125  
   126  	monitorWg := sync.WaitGroup{}
   127  	var monitorEvents []*MonitorEvent
   128  	bus.AddMonitorEventListener(func(monitorEvt *MonitorEvent) {
   129  		monitorEvents = append(monitorEvents, monitorEvt)
   130  		monitorWg.Done()
   131  	}, FabricEndpointSubscribeEvt)
   132  
   133  	// subscribe to invalid topic
   134  	mockServer.subscribeHandlerFunction("con1", "sub1", "/topic2/test-service", nil)
   135  	assert.Equal(t, len(fe.chanMappings), 0)
   136  
   137  	bus.SendResponseMessage("test-service", "test-message", nil)
   138  	assert.Equal(t, len(mockServer.sentMessages), 0)
   139  
   140  	// subscribe to valid channel
   141  	monitorWg.Add(1)
   142  	mockServer.subscribeHandlerFunction("con1", "sub1", "/topic/test-service", nil)
   143  	monitorWg.Wait()
   144  	assert.Equal(t, len(monitorEvents), 1)
   145  	assert.Equal(t, monitorEvents[0].EventType, FabricEndpointSubscribeEvt)
   146  	assert.Equal(t, monitorEvents[0].EntityName, "test-service")
   147  
   148  	assert.Equal(t, len(fe.chanMappings), 1)
   149  	assert.Equal(t, len(fe.chanMappings["test-service"].subs), 1)
   150  	assert.Equal(t, fe.chanMappings["test-service"].subs["con1#sub1"], true)
   151  
   152  	// subscribe again to the same channel
   153  	monitorWg.Add(1)
   154  	mockServer.subscribeHandlerFunction("con1", "sub2", "/topic/test-service", nil)
   155  	monitorWg.Wait()
   156  
   157  	assert.Equal(t, len(monitorEvents), 2)
   158  	assert.Equal(t, monitorEvents[1].EventType, FabricEndpointSubscribeEvt)
   159  	assert.Equal(t, monitorEvents[1].EntityName, "test-service")
   160  
   161  	assert.Equal(t, len(fe.chanMappings), 1)
   162  	assert.Equal(t, len(fe.chanMappings["test-service"].subs), 2)
   163  	assert.Equal(t, fe.chanMappings["test-service"].subs["con1#sub2"], true)
   164  
   165  	// subscribe to queue channel
   166  	monitorWg.Add(1)
   167  	mockServer.subscribeHandlerFunction("con1", "sub3", "/user/queue/test-service", nil)
   168  	monitorWg.Wait()
   169  	assert.Equal(t, len(monitorEvents), 3)
   170  	assert.Equal(t, monitorEvents[2].EventType, FabricEndpointSubscribeEvt)
   171  	assert.Equal(t, monitorEvents[2].EntityName, "test-service")
   172  
   173  	assert.Equal(t, len(fe.chanMappings), 1)
   174  	assert.Equal(t, len(fe.chanMappings["test-service"].subs), 3)
   175  	assert.Equal(t, fe.chanMappings["test-service"].subs["con1#sub3"], true)
   176  
   177  	// attempt to subscribe to a protected destination
   178  	mockServer.subscribeHandlerFunction("con1", "sub4", "/topic/"+STOMP_SESSION_NOTIFY_CHANNEL, nil)
   179  	_, chanMapCreated := fe.chanMappings[STOMP_SESSION_NOTIFY_CHANNEL]
   180  	assert.False(t, chanMapCreated)
   181  
   182  	mockServer.wg = &sync.WaitGroup{}
   183  	mockServer.wg.Add(1)
   184  
   185  	bus.SendResponseMessage("test-service", "test-message", nil)
   186  
   187  	mockServer.wg.Wait()
   188  
   189  	mockServer.wg.Add(1)
   190  	bus.SendResponseMessage("test-service", []byte{1, 2, 3}, nil)
   191  	mockServer.wg.Wait()
   192  
   193  	mockServer.wg.Add(1)
   194  	msg := MockStompServerMessage{Destination: "test", Payload: []byte("test-message")}
   195  	bus.SendResponseMessage("test-service", msg, nil)
   196  	mockServer.wg.Wait()
   197  
   198  	mockServer.wg.Add(1)
   199  	bus.SendErrorMessage("test-service", errors.New("test-error"), nil)
   200  	mockServer.wg.Wait()
   201  
   202  	assert.Equal(t, len(mockServer.sentMessages), 4)
   203  	assert.Equal(t, mockServer.sentMessages[0].Destination, "/topic/test-service")
   204  	assert.Equal(t, string(mockServer.sentMessages[0].Payload), "test-message")
   205  	assert.Equal(t, mockServer.sentMessages[1].Payload, []byte{1, 2, 3})
   206  
   207  	var sentMsg MockStompServerMessage
   208  	json.Unmarshal(mockServer.sentMessages[2].Payload, &sentMsg)
   209  	assert.Equal(t, msg, sentMsg)
   210  
   211  	assert.Equal(t, string(mockServer.sentMessages[3].Payload), "test-error")
   212  
   213  	mockServer.wg.Add(1)
   214  	bus.SendResponseMessage("test-service", model.Response{
   215  		BrokerDestination: &model.BrokerDestinationConfig{
   216  			Destination:  "/user/queue/test-service",
   217  			ConnectionId: "con1",
   218  		},
   219  		Payload: "test-private-message",
   220  	}, nil)
   221  
   222  	mockServer.wg.Wait()
   223  
   224  	assert.Equal(t, len(mockServer.sentMessages), 5)
   225  	assert.Equal(t, mockServer.sentMessages[4].Destination, "/user/queue/test-service")
   226  	var sentResponse model.Response
   227  	json.Unmarshal(mockServer.sentMessages[4].Payload, &sentResponse)
   228  	assert.Equal(t, sentResponse.Payload, "test-private-message")
   229  
   230  	mockServer.wg.Add(1)
   231  	bus.SendResponseMessage("test-service", &model.Response{
   232  		BrokerDestination: &model.BrokerDestinationConfig{
   233  			Destination:  "/user/queue/test-service",
   234  			ConnectionId: "con1",
   235  		},
   236  		Payload: "test-private-message-ptr",
   237  	}, nil)
   238  
   239  	mockServer.wg.Wait()
   240  
   241  	assert.Equal(t, len(mockServer.sentMessages), 6)
   242  	assert.Equal(t, mockServer.sentMessages[5].Destination, "/user/queue/test-service")
   243  	json.Unmarshal(mockServer.sentMessages[5].Payload, &sentResponse)
   244  	assert.Equal(t, sentResponse.Payload, "test-private-message-ptr")
   245  }
   246  
   247  func TestFabricEndpoint_UnsubscribeEvent(t *testing.T) {
   248  	bus := newTestEventBus()
   249  	fe, mockServer := newTestFabricEndpoint(bus, EndpointConfig{TopicPrefix: "/topic"})
   250  
   251  	bus.GetChannelManager().CreateChannel("test-service")
   252  
   253  	monitorWg := sync.WaitGroup{}
   254  	var monitorEvents []*MonitorEvent
   255  	bus.AddMonitorEventListener(func(monitorEvt *MonitorEvent) {
   256  		monitorEvents = append(monitorEvents, monitorEvt)
   257  		monitorWg.Done()
   258  	}, FabricEndpointUnsubscribeEvt)
   259  
   260  	// subscribe to valid channel
   261  	mockServer.subscribeHandlerFunction("con1", "sub1", "/topic/test-service", nil)
   262  	mockServer.subscribeHandlerFunction("con1", "sub2", "/topic/test-service", nil)
   263  
   264  	assert.Equal(t, len(fe.chanMappings), 1)
   265  	assert.Equal(t, len(fe.chanMappings["test-service"].subs), 2)
   266  
   267  	mockServer.wg = &sync.WaitGroup{}
   268  	mockServer.wg.Add(1)
   269  	bus.SendResponseMessage("test-service", "test-message", nil)
   270  	mockServer.wg.Wait()
   271  	assert.Equal(t, len(mockServer.sentMessages), 1)
   272  
   273  	mockServer.unsubscribeHandlerFunction("con1", "sub2", "/invalid-topic/test-service")
   274  	assert.Equal(t, len(fe.chanMappings), 1)
   275  	assert.Equal(t, len(fe.chanMappings["test-service"].subs), 2)
   276  
   277  	mockServer.unsubscribeHandlerFunction("invalid-con1", "sub2", "/topic/test-service")
   278  	assert.Equal(t, len(fe.chanMappings), 1)
   279  	assert.Equal(t, len(fe.chanMappings["test-service"].subs), 2)
   280  
   281  	monitorWg.Add(1)
   282  	mockServer.unsubscribeHandlerFunction("con1", "sub2", "/topic/test-service")
   283  	monitorWg.Wait()
   284  
   285  	assert.Equal(t, len(monitorEvents), 1)
   286  	assert.Equal(t, monitorEvents[0].EventType, FabricEndpointUnsubscribeEvt)
   287  	assert.Equal(t, monitorEvents[0].EntityName, "test-service")
   288  
   289  	assert.Equal(t, len(fe.chanMappings), 1)
   290  	assert.Equal(t, len(fe.chanMappings["test-service"].subs), 1)
   291  
   292  	mockServer.wg = &sync.WaitGroup{}
   293  	mockServer.wg.Add(1)
   294  	bus.SendResponseMessage("test-service", "test-message", nil)
   295  	mockServer.wg.Wait()
   296  	assert.Equal(t, len(mockServer.sentMessages), 2)
   297  
   298  	monitorWg.Add(1)
   299  	mockServer.unsubscribeHandlerFunction("con1", "sub1", "/topic/test-service")
   300  	monitorWg.Wait()
   301  
   302  	assert.Equal(t, len(monitorEvents), 2)
   303  	assert.Equal(t, monitorEvents[1].EventType, FabricEndpointUnsubscribeEvt)
   304  	assert.Equal(t, monitorEvents[1].EntityName, "test-service")
   305  
   306  	assert.Equal(t, len(fe.chanMappings), 0)
   307  	bus.SendResponseMessage("test-service", "test-message", nil)
   308  
   309  	// subscribe to non-existing channel
   310  	mockServer.subscribeHandlerFunction("con3", "sub1", "/topic/non-existing-channel", nil)
   311  	assert.Equal(t, len(fe.chanMappings), 1)
   312  	assert.Equal(t, len(fe.chanMappings["non-existing-channel"].subs), 1)
   313  	assert.Equal(t, fe.chanMappings["non-existing-channel"].autoCreated, true)
   314  	assert.True(t, bus.GetChannelManager().CheckChannelExists("non-existing-channel"))
   315  
   316  	monitorWg.Add(1)
   317  	mockServer.unsubscribeHandlerFunction("con3", "sub1", "/topic/non-existing-channel")
   318  	monitorWg.Wait()
   319  
   320  	assert.Equal(t, len(monitorEvents), 3)
   321  	assert.Equal(t, monitorEvents[2].EventType, FabricEndpointUnsubscribeEvt)
   322  	assert.Equal(t, monitorEvents[2].EntityName, "non-existing-channel")
   323  
   324  	assert.Equal(t, len(fe.chanMappings), 0)
   325  	assert.False(t, bus.GetChannelManager().CheckChannelExists("non-existing-channel"))
   326  }
   327  
   328  func TestFabricEndpoint_BridgeMessage(t *testing.T) {
   329  	bus := newTestEventBus()
   330  	_, mockServer := newTestFabricEndpoint(bus, EndpointConfig{TopicPrefix: "/topic", AppRequestPrefix: "/pub",
   331  		AppRequestQueuePrefix: "/pub/queue", UserQueuePrefix: "/user/queue"})
   332  
   333  	bus.GetChannelManager().CreateChannel("request-channel")
   334  	mh, _ := bus.ListenRequestStream("request-channel")
   335  	assert.NotNil(t, mh)
   336  
   337  	wg := sync.WaitGroup{}
   338  
   339  	var messages []*model.Message
   340  
   341  	mh.Handle(func(message *model.Message) {
   342  		messages = append(messages, message)
   343  		wg.Done()
   344  	}, func(e error) {
   345  		assert.Fail(t, "unexpected error")
   346  	})
   347  
   348  	id1 := uuid.New()
   349  	req1, _ := json.Marshal(model.Request{
   350  		Request: "test-request",
   351  		Payload: "test-rq",
   352  		Id:      &id1,
   353  	})
   354  
   355  	wg.Add(1)
   356  
   357  	mockServer.applicationRequestHandlerFunction("/pub/request-channel", req1, "con1")
   358  
   359  	mockServer.applicationRequestHandlerFunction("/pub2/request-channel", req1, "con1")
   360  	mockServer.applicationRequestHandlerFunction("/pub/request-channel-2", req1, "con1")
   361  
   362  	mockServer.applicationRequestHandlerFunction("/pub/request-channel", []byte("invalid-request-json"), "con1")
   363  
   364  	id2 := uuid.New()
   365  	req2, _ := json.Marshal(model.Request{
   366  		Request: "test-request2",
   367  		Payload: "test-rq2",
   368  		Id:      &id2,
   369  	})
   370  
   371  	wg.Wait()
   372  
   373  	wg.Add(1)
   374  	mockServer.applicationRequestHandlerFunction("/pub/queue/request-channel", req2, "con2")
   375  	wg.Wait()
   376  
   377  	assert.Equal(t, len(messages), 2)
   378  
   379  	receivedReq := messages[0].Payload.(*model.Request)
   380  
   381  	assert.Equal(t, receivedReq.Request, "test-request")
   382  	assert.Equal(t, receivedReq.Payload, "test-rq")
   383  	assert.Equal(t, *receivedReq.Id, id1)
   384  	assert.Nil(t, receivedReq.BrokerDestination)
   385  
   386  	receivedReq2 := messages[1].Payload.(*model.Request)
   387  
   388  	assert.Equal(t, receivedReq2.Request, "test-request2")
   389  	assert.Equal(t, receivedReq2.Payload, "test-rq2")
   390  	assert.Equal(t, *receivedReq2.Id, id2)
   391  	assert.Equal(t, receivedReq2.BrokerDestination.ConnectionId, "con2")
   392  	assert.Equal(t, receivedReq2.BrokerDestination.Destination, "/user/queue/request-channel")
   393  }