github.com/vmware/transport-go@v1.3.4/service/service_registry_test.go (about)

     1  // Copyright 2019-2020 VMware, Inc.
     2  // SPDX-License-Identifier: BSD-2-Clause
     3  
     4  package service
     5  
     6  import (
     7  	"errors"
     8  	"github.com/google/uuid"
     9  	"github.com/stretchr/testify/assert"
    10  	"github.com/vmware/transport-go/bus"
    11  	"github.com/vmware/transport-go/model"
    12  	"net/http"
    13  	"sync"
    14  	"testing"
    15  )
    16  
    17  func newTestServiceRegistry() *serviceRegistry {
    18  	eventBus := bus.NewEventBusInstance()
    19  	return newServiceRegistry(eventBus).(*serviceRegistry)
    20  }
    21  
    22  func newTestServiceLifecycleManager(sr ServiceRegistry) ServiceLifecycleManager {
    23  	return newServiceLifecycleManager(sr)
    24  }
    25  
    26  type mockFabricService struct {
    27  	processedRequests []*model.Request
    28  	core              FabricServiceCore
    29  	wg                sync.WaitGroup
    30  }
    31  
    32  func (fs *mockFabricService) HandleServiceRequest(request *model.Request, core FabricServiceCore) {
    33  	fs.processedRequests = append(fs.processedRequests, request)
    34  	fs.core = core
    35  	fs.wg.Done()
    36  }
    37  
    38  type mockLifecycleHookEnabledService struct {
    39  	initChan chan bool
    40  	core     FabricServiceCore
    41  	shutdown bool
    42  }
    43  
    44  func (s *mockLifecycleHookEnabledService) HandleServiceRequest(request *model.Request, core FabricServiceCore) {
    45  }
    46  
    47  func (s *mockLifecycleHookEnabledService) OnServiceReady() chan bool {
    48  	s.initChan = make(chan bool, 1)
    49  	s.initChan <- true
    50  	return s.initChan
    51  }
    52  
    53  func (s *mockLifecycleHookEnabledService) OnServerShutdown() {
    54  	s.shutdown = true
    55  }
    56  
    57  func (s *mockLifecycleHookEnabledService) GetRESTBridgeConfig() []*RESTBridgeConfig {
    58  	return []*RESTBridgeConfig{
    59  		{
    60  			ServiceChannel: "another-test-channel",
    61  			Uri:            "/rest/test",
    62  			Method:         http.MethodGet,
    63  			AllowHead:      true,
    64  			AllowOptions:   true,
    65  			FabricRequestBuilder: func(w http.ResponseWriter, r *http.Request) model.Request {
    66  				return model.Request{
    67  					Id:      &uuid.UUID{},
    68  					Payload: "test",
    69  				}
    70  			},
    71  		},
    72  	}
    73  }
    74  
    75  type mockInitializableService struct {
    76  	initialized bool
    77  	core        FabricServiceCore
    78  	initError   error
    79  }
    80  
    81  func (fs *mockInitializableService) Init(core FabricServiceCore) error {
    82  	fs.core = core
    83  	fs.initialized = true
    84  	return fs.initError
    85  }
    86  
    87  func (fs *mockInitializableService) HandleServiceRequest(request *model.Request, core FabricServiceCore) {
    88  }
    89  
    90  func TestGetServiceRegistry(t *testing.T) {
    91  	sr := GetServiceRegistry()
    92  	sr2 := GetServiceRegistry()
    93  	assert.NotNil(t, sr)
    94  	assert.Equal(t, sr, sr2)
    95  }
    96  
    97  func TestServiceRegistry_RegisterService(t *testing.T) {
    98  	registry := newTestServiceRegistry()
    99  	mockService := &mockFabricService{}
   100  
   101  	assert.Nil(t, registry.RegisterService(mockService, "test-channel"))
   102  	assert.True(t, registry.bus.GetChannelManager().CheckChannelExists("test-channel"))
   103  
   104  	id := uuid.New()
   105  	req := model.Request{
   106  		Id:      &id,
   107  		Request: "test-request",
   108  		Payload: "request-payload",
   109  	}
   110  
   111  	mockService.wg.Add(1)
   112  	registry.bus.SendRequestMessage("test-channel", req, nil)
   113  	mockService.wg.Wait()
   114  
   115  	assert.Equal(t, len(mockService.processedRequests), 1)
   116  	assert.Equal(t, *mockService.processedRequests[0], req)
   117  	assert.NotNil(t, mockService.core)
   118  
   119  	registry.bus.SendRequestMessage("test-channel", "invalid-request", nil)
   120  	registry.bus.SendRequestMessage("test-channel", nil, nil)
   121  	registry.bus.SendResponseMessage("test-channel", req, nil)
   122  	registry.bus.SendErrorMessage("test-channel", errors.New("test-error"), nil)
   123  
   124  	mockService.wg.Add(1)
   125  	registry.bus.SendRequestMessage("test-channel", &req, nil)
   126  	mockService.wg.Wait()
   127  
   128  	assert.Equal(t, len(mockService.processedRequests), 2)
   129  	assert.Equal(t, mockService.processedRequests[1], &req)
   130  	assert.NotNil(t, mockService.core)
   131  
   132  	mockService.wg.Add(1)
   133  	uuid := uuid.New()
   134  	registry.bus.SendRequestMessage("test-channel", model.Request{
   135  		Request: "test-request-2",
   136  		Payload: "request-payload",
   137  	}, &uuid)
   138  	mockService.wg.Wait()
   139  
   140  	assert.Equal(t, len(mockService.processedRequests), 3)
   141  	assert.Equal(t, mockService.processedRequests[2].Id, &uuid)
   142  
   143  	assert.EqualError(t, registry.RegisterService(&mockFabricService{}, "test-channel"),
   144  		"unable to register service: service channel name is already used: test-channel")
   145  
   146  	assert.EqualError(t, registry.RegisterService(nil, "test-channel2"),
   147  		"unable to register service: nil service")
   148  
   149  	assert.False(t, registry.bus.GetChannelManager().CheckChannelExists("test-channel2"))
   150  }
   151  
   152  func TestServiceRegistry_RegisterInitializableService(t *testing.T) {
   153  	registry := newTestServiceRegistry()
   154  	mockService := &mockInitializableService{}
   155  	assert.Nil(t, registry.RegisterService(mockService, "test-channel"))
   156  
   157  	assert.True(t, mockService.initialized)
   158  	assert.NotNil(t, mockService.core)
   159  
   160  	assert.EqualError(t,
   161  		registry.RegisterService(&mockInitializableService{initError: errors.New("init-error")}, "test-channel2"),
   162  		"init-error")
   163  }
   164  
   165  func TestServiceRegistry_UnregisterService(t *testing.T) {
   166  	registry := newTestServiceRegistry()
   167  	mockService := &mockFabricService{}
   168  
   169  	assert.Nil(t, registry.RegisterService(mockService, "test-channel"))
   170  	assert.True(t, registry.bus.GetChannelManager().CheckChannelExists("test-channel"))
   171  
   172  	id := uuid.New()
   173  	req := model.Request{
   174  		Id:      &id,
   175  		Request: "test-request",
   176  		Payload: "request-payload",
   177  	}
   178  
   179  	assert.Nil(t, registry.UnregisterService("test-channel"))
   180  	registry.bus.SendRequestMessage("test-channel", req, nil)
   181  
   182  	assert.Equal(t, len(mockService.processedRequests), 0)
   183  	assert.EqualError(t, registry.UnregisterService("test-channel"),
   184  		"unable to unregister service: no service is registered for channel \"test-channel\"")
   185  }
   186  
   187  func TestServiceRegistry_SetGlobalRestServiceBaseHost(t *testing.T) {
   188  	registry := newTestServiceRegistry()
   189  	registry.SetGlobalRestServiceBaseHost("localhost:9999")
   190  	assert.Equal(t, "localhost:9999",
   191  		registry.services[restServiceChannel].service.(*restService).baseHost)
   192  }
   193  
   194  func TestServiceRegistry_GetAllServiceChannels(t *testing.T) {
   195  	registry := newTestServiceRegistry()
   196  	mockService := &mockFabricService{}
   197  
   198  	registry.RegisterService(mockService, "test-channel")
   199  	chans := registry.GetAllServiceChannels()
   200  
   201  	assert.Len(t, chans, 1)
   202  	assert.EqualValues(t, "test-channel", chans[0])
   203  }
   204  
   205  func TestServiceRegistry_RegisterService_LifecycleHookEnabled(t *testing.T) {
   206  	svc := &mockLifecycleHookEnabledService{}
   207  	registry := newTestServiceRegistry()
   208  	registry.RegisterService(svc, "another-test-channel")
   209  
   210  	assert.True(t, <-svc.OnServiceReady())
   211  
   212  	svc.OnServerShutdown()
   213  	assert.True(t, svc.shutdown)
   214  
   215  	restBridgeConfig := svc.GetRESTBridgeConfig()
   216  	assert.NotNil(t, restBridgeConfig)
   217  }