github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/client/serviceregistration/nsd/nsd_test.go (about)

     1  package nsd
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"sync"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/hashicorp/go-hclog"
    12  	"github.com/hashicorp/nomad/client/serviceregistration"
    13  	"github.com/hashicorp/nomad/nomad/structs"
    14  	"github.com/shoenig/test"
    15  	"github.com/stretchr/testify/assert"
    16  	"github.com/stretchr/testify/require"
    17  )
    18  
    19  type mockCheckWatcher struct {
    20  	lock sync.Mutex
    21  
    22  	watchCalls   int
    23  	unWatchCalls int
    24  }
    25  
    26  func (cw *mockCheckWatcher) Run(_ context.Context) {
    27  	// Run runs async; just assume it ran
    28  }
    29  
    30  func (cw *mockCheckWatcher) Watch(_, _, _ string, _ *structs.ServiceCheck, _ serviceregistration.WorkloadRestarter) {
    31  	cw.lock.Lock()
    32  	defer cw.lock.Unlock()
    33  	cw.watchCalls++
    34  }
    35  
    36  func (cw *mockCheckWatcher) Unwatch(_ string) {
    37  	cw.lock.Lock()
    38  	defer cw.lock.Unlock()
    39  	cw.unWatchCalls++
    40  }
    41  
    42  func (cw *mockCheckWatcher) assert(t *testing.T, watchCalls, unWatchCalls int) {
    43  	cw.lock.Lock()
    44  	defer cw.lock.Unlock()
    45  	test.Eq(t, watchCalls, cw.watchCalls, test.Sprintf("expected %d Watch() calls but got %d", watchCalls, cw.watchCalls))
    46  	test.Eq(t, unWatchCalls, cw.unWatchCalls, test.Sprintf("expected %d Unwatch() calls but got %d", unWatchCalls, cw.unWatchCalls))
    47  }
    48  
    49  func TestServiceRegistrationHandler_RegisterWorkload(t *testing.T) {
    50  	testCases := []struct {
    51  		name                 string
    52  		inputCfg             *ServiceRegistrationHandlerCfg
    53  		inputWorkload        *serviceregistration.WorkloadServices
    54  		expectedRPCs         map[string]int
    55  		expectedError        error
    56  		expWatch, expUnWatch int
    57  	}{
    58  		{
    59  			name: "registration disabled",
    60  			inputCfg: &ServiceRegistrationHandlerCfg{
    61  				Enabled:      false,
    62  				CheckWatcher: new(mockCheckWatcher),
    63  			},
    64  			inputWorkload: mockWorkload(),
    65  			expectedRPCs:  map[string]int{},
    66  			expectedError: errors.New(`service registration provider "nomad" not enabled`),
    67  			expWatch:      0,
    68  			expUnWatch:    0,
    69  		},
    70  		{
    71  			name: "registration enabled",
    72  			inputCfg: &ServiceRegistrationHandlerCfg{
    73  				Enabled:      true,
    74  				CheckWatcher: new(mockCheckWatcher),
    75  			},
    76  			inputWorkload: mockWorkload(),
    77  			expectedRPCs:  map[string]int{structs.ServiceRegistrationUpsertRPCMethod: 1},
    78  			expectedError: nil,
    79  			expWatch:      1,
    80  			expUnWatch:    0,
    81  		},
    82  	}
    83  
    84  	// Create a logger we can use for all tests.
    85  	log := hclog.NewNullLogger()
    86  
    87  	for _, tc := range testCases {
    88  		t.Run(tc.name, func(t *testing.T) {
    89  
    90  			// Add the mock RPC functionality.
    91  			mockRPC := mockRPC{callCounts: map[string]int{}}
    92  			tc.inputCfg.RPCFn = mockRPC.RPC
    93  
    94  			// Create the handler and run the tests.
    95  			h := NewServiceRegistrationHandler(log, tc.inputCfg)
    96  
    97  			actualErr := h.RegisterWorkload(tc.inputWorkload)
    98  			require.Equal(t, tc.expectedError, actualErr)
    99  			require.Equal(t, tc.expectedRPCs, mockRPC.calls())
   100  			tc.inputCfg.CheckWatcher.(*mockCheckWatcher).assert(t, tc.expWatch, tc.expUnWatch)
   101  		})
   102  	}
   103  }
   104  
   105  func TestServiceRegistrationHandler_RemoveWorkload(t *testing.T) {
   106  	testCases := []struct {
   107  		name                 string
   108  		inputCfg             *ServiceRegistrationHandlerCfg
   109  		inputWorkload        *serviceregistration.WorkloadServices
   110  		expectedRPCs         map[string]int
   111  		expectedError        error
   112  		expWatch, expUnWatch int
   113  	}{
   114  		{
   115  			name: "registration disabled multiple services",
   116  			inputCfg: &ServiceRegistrationHandlerCfg{
   117  				Enabled:      false,
   118  				CheckWatcher: new(mockCheckWatcher),
   119  			},
   120  			inputWorkload: mockWorkload(),
   121  			expectedRPCs:  map[string]int{structs.ServiceRegistrationDeleteByIDRPCMethod: 2},
   122  			expectedError: nil,
   123  			expWatch:      0,
   124  			expUnWatch:    2, // RemoveWorkload works regardless if provider is enabled
   125  		},
   126  		{
   127  			name: "registration enabled multiple services",
   128  			inputCfg: &ServiceRegistrationHandlerCfg{
   129  				Enabled:      true,
   130  				CheckWatcher: new(mockCheckWatcher),
   131  			},
   132  			inputWorkload: mockWorkload(),
   133  			expectedRPCs:  map[string]int{structs.ServiceRegistrationDeleteByIDRPCMethod: 2},
   134  			expectedError: nil,
   135  			expWatch:      0,
   136  			expUnWatch:    2,
   137  		},
   138  	}
   139  
   140  	// Create a logger we can use for all tests.
   141  	log := hclog.NewNullLogger()
   142  
   143  	for _, tc := range testCases {
   144  		t.Run(tc.name, func(t *testing.T) {
   145  
   146  			// Add the mock RPC functionality.
   147  			mockRPC := mockRPC{callCounts: map[string]int{}}
   148  			tc.inputCfg.RPCFn = mockRPC.RPC
   149  
   150  			// Create the handler and run the tests.
   151  			h := NewServiceRegistrationHandler(log, tc.inputCfg)
   152  
   153  			h.RemoveWorkload(tc.inputWorkload)
   154  
   155  			require.Eventually(t, func() bool {
   156  				return assert.Equal(t, tc.expectedRPCs, mockRPC.calls())
   157  			}, 100*time.Millisecond, 10*time.Millisecond)
   158  			tc.inputCfg.CheckWatcher.(*mockCheckWatcher).assert(t, tc.expWatch, tc.expUnWatch)
   159  		})
   160  	}
   161  }
   162  
   163  func TestServiceRegistrationHandler_UpdateWorkload(t *testing.T) {
   164  	testCases := []struct {
   165  		name                 string
   166  		inputCfg             *ServiceRegistrationHandlerCfg
   167  		inputOldWorkload     *serviceregistration.WorkloadServices
   168  		inputNewWorkload     *serviceregistration.WorkloadServices
   169  		expectedRPCs         map[string]int
   170  		expectedError        error
   171  		expWatch, expUnWatch int
   172  	}{
   173  		{
   174  			name: "delete and upsert",
   175  			inputCfg: &ServiceRegistrationHandlerCfg{
   176  				Enabled:      true,
   177  				CheckWatcher: new(mockCheckWatcher),
   178  			},
   179  			inputOldWorkload: mockWorkload(),
   180  			inputNewWorkload: &serviceregistration.WorkloadServices{
   181  				AllocInfo: structs.AllocInfo{
   182  					AllocID: "98ea220b-7ebe-4662-6d74-9868e797717c",
   183  					Task:    "redis",
   184  					Group:   "cache",
   185  					JobID:   "example",
   186  				},
   187  				Canary:            false,
   188  				ProviderNamespace: "default",
   189  				Services: []*structs.Service{
   190  					{
   191  						Name:        "changed-redis-db",
   192  						AddressMode: structs.AddressModeHost,
   193  						PortLabel:   "db",
   194  						Checks: []*structs.ServiceCheck{
   195  							{
   196  								Name:         "changed-check-redis-db",
   197  								CheckRestart: &structs.CheckRestart{Limit: 1},
   198  							},
   199  						},
   200  					},
   201  					{
   202  						Name:        "changed-redis-http",
   203  						AddressMode: structs.AddressModeHost,
   204  						PortLabel:   "http",
   205  						// No check restart this time
   206  					},
   207  				},
   208  				Ports: []structs.AllocatedPortMapping{
   209  					{
   210  						Label:  "db",
   211  						HostIP: "10.10.13.2",
   212  						Value:  23098,
   213  					},
   214  					{
   215  						Label:  "http",
   216  						HostIP: "10.10.13.2",
   217  						Value:  24098,
   218  					},
   219  				},
   220  			},
   221  			expectedRPCs: map[string]int{
   222  				structs.ServiceRegistrationUpsertRPCMethod:     1,
   223  				structs.ServiceRegistrationDeleteByIDRPCMethod: 2,
   224  			},
   225  			expectedError: nil,
   226  			expWatch:      1,
   227  			expUnWatch:    2,
   228  		},
   229  		{
   230  			name: "upsert only",
   231  			inputCfg: &ServiceRegistrationHandlerCfg{
   232  				Enabled:      true,
   233  				CheckWatcher: new(mockCheckWatcher),
   234  			},
   235  			inputOldWorkload: mockWorkload(),
   236  			inputNewWorkload: &serviceregistration.WorkloadServices{
   237  				AllocInfo: structs.AllocInfo{
   238  					AllocID: "98ea220b-7ebe-4662-6d74-9868e797717c",
   239  					Task:    "redis",
   240  					Group:   "cache",
   241  					JobID:   "example",
   242  				},
   243  				Canary:            false,
   244  				ProviderNamespace: "default",
   245  				Services: []*structs.Service{
   246  					{
   247  						Name:        "redis-db",
   248  						AddressMode: structs.AddressModeHost,
   249  						PortLabel:   "db",
   250  						Tags:        []string{"foo"},
   251  						Checks: []*structs.ServiceCheck{
   252  							{
   253  								Name:         "redis-db-check-1",
   254  								CheckRestart: &structs.CheckRestart{Limit: 1},
   255  							},
   256  							{
   257  								Name: "redis-db-check-2",
   258  								// No check restart on this one
   259  							},
   260  						},
   261  					},
   262  					{
   263  						Name:        "redis-http",
   264  						AddressMode: structs.AddressModeHost,
   265  						PortLabel:   "http",
   266  						Tags:        []string{"bar"},
   267  						Checks: []*structs.ServiceCheck{
   268  							{
   269  								Name:         "redis-http-check-1",
   270  								CheckRestart: &structs.CheckRestart{Limit: 1},
   271  							},
   272  							{
   273  								Name:         "redis-http-check-2",
   274  								CheckRestart: &structs.CheckRestart{Limit: 1},
   275  							},
   276  						},
   277  					},
   278  				},
   279  				Ports: []structs.AllocatedPortMapping{
   280  					{
   281  						Label:  "db",
   282  						HostIP: "10.10.13.2",
   283  						Value:  23098,
   284  					},
   285  					{
   286  						Label:  "http",
   287  						HostIP: "10.10.13.2",
   288  						Value:  24098,
   289  					},
   290  				},
   291  			},
   292  			expectedRPCs: map[string]int{
   293  				structs.ServiceRegistrationUpsertRPCMethod: 1,
   294  			},
   295  			expectedError: nil,
   296  			expWatch:      3,
   297  			expUnWatch:    0,
   298  		},
   299  	}
   300  
   301  	// Create a logger we can use for all tests.
   302  	log := hclog.NewNullLogger()
   303  
   304  	for _, tc := range testCases {
   305  		t.Run(tc.name, func(t *testing.T) {
   306  
   307  			// Add the mock RPC functionality.
   308  			mockRPC := mockRPC{callCounts: map[string]int{}}
   309  			tc.inputCfg.RPCFn = mockRPC.RPC
   310  
   311  			// Create the handler and run the tests.
   312  			h := NewServiceRegistrationHandler(log, tc.inputCfg)
   313  
   314  			require.Equal(t, tc.expectedError, h.UpdateWorkload(tc.inputOldWorkload, tc.inputNewWorkload))
   315  
   316  			require.Eventually(t, func() bool {
   317  				return assert.Equal(t, tc.expectedRPCs, mockRPC.calls())
   318  			}, 100*time.Millisecond, 10*time.Millisecond)
   319  			tc.inputCfg.CheckWatcher.(*mockCheckWatcher).assert(t, tc.expWatch, tc.expUnWatch)
   320  		})
   321  	}
   322  
   323  }
   324  
   325  func TestServiceRegistrationHandler_dedupUpdatedWorkload(t *testing.T) {
   326  	testCases := []struct {
   327  		inputOldWorkload  *serviceregistration.WorkloadServices
   328  		inputNewWorkload  *serviceregistration.WorkloadServices
   329  		expectedOldOutput *serviceregistration.WorkloadServices
   330  		expectedNewOutput *serviceregistration.WorkloadServices
   331  		name              string
   332  	}{
   333  		{
   334  			inputOldWorkload: mockWorkload(),
   335  			inputNewWorkload: &serviceregistration.WorkloadServices{
   336  				AllocInfo: structs.AllocInfo{
   337  					AllocID: "98ea220b-7ebe-4662-6d74-9868e797717c",
   338  					Task:    "redis",
   339  					Group:   "cache",
   340  					JobID:   "example",
   341  				},
   342  				Canary:            false,
   343  				ProviderNamespace: "default",
   344  				Services: []*structs.Service{
   345  					{
   346  						Name:        "changed-redis-db",
   347  						AddressMode: structs.AddressModeHost,
   348  						PortLabel:   "db",
   349  					},
   350  					{
   351  						Name:        "changed-redis-http",
   352  						AddressMode: structs.AddressModeHost,
   353  						PortLabel:   "http",
   354  					},
   355  				},
   356  				Ports: []structs.AllocatedPortMapping{
   357  					{
   358  						Label:  "db",
   359  						HostIP: "10.10.13.2",
   360  						Value:  23098,
   361  					},
   362  					{
   363  						Label:  "http",
   364  						HostIP: "10.10.13.2",
   365  						Value:  24098,
   366  					},
   367  				},
   368  			},
   369  			expectedOldOutput: mockWorkload(),
   370  			expectedNewOutput: &serviceregistration.WorkloadServices{
   371  				AllocInfo: structs.AllocInfo{
   372  					AllocID: "98ea220b-7ebe-4662-6d74-9868e797717c",
   373  					Task:    "redis",
   374  					Group:   "cache",
   375  					JobID:   "example",
   376  				},
   377  				Canary:            false,
   378  				ProviderNamespace: "default",
   379  				Services: []*structs.Service{
   380  					{
   381  						Name:        "changed-redis-db",
   382  						AddressMode: structs.AddressModeHost,
   383  						PortLabel:   "db",
   384  					},
   385  					{
   386  						Name:        "changed-redis-http",
   387  						AddressMode: structs.AddressModeHost,
   388  						PortLabel:   "http",
   389  					},
   390  				},
   391  				Ports: []structs.AllocatedPortMapping{
   392  					{
   393  						Label:  "db",
   394  						HostIP: "10.10.13.2",
   395  						Value:  23098,
   396  					},
   397  					{
   398  						Label:  "http",
   399  						HostIP: "10.10.13.2",
   400  						Value:  24098,
   401  					},
   402  				},
   403  			},
   404  			name: "service names changed",
   405  		},
   406  		{
   407  			inputOldWorkload: mockWorkload(),
   408  			inputNewWorkload: &serviceregistration.WorkloadServices{
   409  				AllocInfo: structs.AllocInfo{
   410  					AllocID: "98ea220b-7ebe-4662-6d74-9868e797717c",
   411  					Task:    "redis",
   412  					Group:   "cache",
   413  					JobID:   "example",
   414  				},
   415  				Canary:            false,
   416  				ProviderNamespace: "default",
   417  				Services: []*structs.Service{
   418  					{
   419  						Name:        "redis-db",
   420  						AddressMode: structs.AddressModeHost,
   421  						PortLabel:   "db",
   422  						Tags:        []string{"foo"},
   423  					},
   424  					{
   425  						Name:        "redis-http",
   426  						AddressMode: structs.AddressModeHost,
   427  						PortLabel:   "http",
   428  						Tags:        []string{"bar"},
   429  					},
   430  				},
   431  				Ports: []structs.AllocatedPortMapping{
   432  					{
   433  						Label:  "db",
   434  						HostIP: "10.10.13.2",
   435  						Value:  23098,
   436  					},
   437  					{
   438  						Label:  "http",
   439  						HostIP: "10.10.13.2",
   440  						Value:  24098,
   441  					},
   442  				},
   443  			},
   444  			expectedOldOutput: &serviceregistration.WorkloadServices{
   445  				AllocInfo: structs.AllocInfo{
   446  					AllocID: "98ea220b-7ebe-4662-6d74-9868e797717c",
   447  					Task:    "redis",
   448  					Group:   "cache",
   449  					JobID:   "example",
   450  				},
   451  				Canary:            false,
   452  				ProviderNamespace: "default",
   453  				Services:          []*structs.Service{},
   454  				Ports: []structs.AllocatedPortMapping{
   455  					{
   456  						Label:  "db",
   457  						HostIP: "10.10.13.2",
   458  						Value:  23098,
   459  					},
   460  					{
   461  						Label:  "http",
   462  						HostIP: "10.10.13.2",
   463  						Value:  24098,
   464  					},
   465  				},
   466  			},
   467  			expectedNewOutput: &serviceregistration.WorkloadServices{
   468  				AllocInfo: structs.AllocInfo{
   469  					AllocID: "98ea220b-7ebe-4662-6d74-9868e797717c",
   470  					Task:    "redis",
   471  					Group:   "cache",
   472  					JobID:   "example",
   473  				},
   474  				Canary:            false,
   475  				ProviderNamespace: "default",
   476  				Services: []*structs.Service{
   477  					{
   478  						Name:        "redis-db",
   479  						AddressMode: structs.AddressModeHost,
   480  						PortLabel:   "db",
   481  						Tags:        []string{"foo"},
   482  					},
   483  					{
   484  						Name:        "redis-http",
   485  						AddressMode: structs.AddressModeHost,
   486  						PortLabel:   "http",
   487  						Tags:        []string{"bar"},
   488  					},
   489  				},
   490  				Ports: []structs.AllocatedPortMapping{
   491  					{
   492  						Label:  "db",
   493  						HostIP: "10.10.13.2",
   494  						Value:  23098,
   495  					},
   496  					{
   497  						Label:  "http",
   498  						HostIP: "10.10.13.2",
   499  						Value:  24098,
   500  					},
   501  				},
   502  			},
   503  			name: "tags updated",
   504  		},
   505  		{
   506  			inputOldWorkload: mockWorkload(),
   507  			inputNewWorkload: &serviceregistration.WorkloadServices{
   508  				AllocInfo: structs.AllocInfo{
   509  					AllocID: "98ea220b-7ebe-4662-6d74-9868e797717c",
   510  					Task:    "redis",
   511  					Group:   "cache",
   512  					JobID:   "example",
   513  				},
   514  				Canary:            false,
   515  				ProviderNamespace: "default",
   516  				Services: []*structs.Service{
   517  					{
   518  						Name:        "redis-db",
   519  						AddressMode: structs.AddressModeHost,
   520  						PortLabel:   "dbs",
   521  					},
   522  					{
   523  						Name:        "redis-http",
   524  						AddressMode: structs.AddressModeHost,
   525  						PortLabel:   "https",
   526  					},
   527  				},
   528  				Ports: []structs.AllocatedPortMapping{
   529  					{
   530  						Label:  "dbs",
   531  						HostIP: "10.10.13.2",
   532  						Value:  23098,
   533  					},
   534  					{
   535  						Label:  "https",
   536  						HostIP: "10.10.13.2",
   537  						Value:  24098,
   538  					},
   539  				},
   540  			},
   541  			expectedOldOutput: mockWorkload(),
   542  			expectedNewOutput: &serviceregistration.WorkloadServices{
   543  				AllocInfo: structs.AllocInfo{
   544  					AllocID: "98ea220b-7ebe-4662-6d74-9868e797717c",
   545  					Task:    "redis",
   546  					Group:   "cache",
   547  					JobID:   "example",
   548  				},
   549  				Canary:            false,
   550  				ProviderNamespace: "default",
   551  				Services: []*structs.Service{
   552  					{
   553  						Name:        "redis-db",
   554  						AddressMode: structs.AddressModeHost,
   555  						PortLabel:   "dbs",
   556  					},
   557  					{
   558  						Name:        "redis-http",
   559  						AddressMode: structs.AddressModeHost,
   560  						PortLabel:   "https",
   561  					},
   562  				},
   563  				Ports: []structs.AllocatedPortMapping{
   564  					{
   565  						Label:  "dbs",
   566  						HostIP: "10.10.13.2",
   567  						Value:  23098,
   568  					},
   569  					{
   570  						Label:  "https",
   571  						HostIP: "10.10.13.2",
   572  						Value:  24098,
   573  					},
   574  				},
   575  			},
   576  			name: "canary tags updated",
   577  		},
   578  	}
   579  
   580  	s := &ServiceRegistrationHandler{}
   581  
   582  	for _, tc := range testCases {
   583  		t.Run(tc.name, func(t *testing.T) {
   584  			actualOld, actualNew := s.dedupUpdatedWorkload(tc.inputOldWorkload, tc.inputNewWorkload)
   585  			require.ElementsMatch(t, tc.expectedOldOutput.Services, actualOld.Services)
   586  			require.ElementsMatch(t, tc.expectedNewOutput.Services, actualNew.Services)
   587  		})
   588  	}
   589  }
   590  
   591  func mockWorkload() *serviceregistration.WorkloadServices {
   592  	return &serviceregistration.WorkloadServices{
   593  		AllocInfo: structs.AllocInfo{
   594  			AllocID: "98ea220b-7ebe-4662-6d74-9868e797717c",
   595  			Task:    "redis",
   596  			Group:   "cache",
   597  			JobID:   "example",
   598  		},
   599  		Canary:            false,
   600  		ProviderNamespace: "default",
   601  		Services: []*structs.Service{
   602  			{
   603  				Name:        "redis-db",
   604  				AddressMode: structs.AddressModeHost,
   605  				PortLabel:   "db",
   606  			},
   607  			{
   608  				Name:        "redis-http",
   609  				AddressMode: structs.AddressModeHost,
   610  				PortLabel:   "http",
   611  				Checks: []*structs.ServiceCheck{
   612  					{
   613  						Name:     "check1",
   614  						Type:     "http",
   615  						Interval: 5 * time.Second,
   616  						Timeout:  1 * time.Second,
   617  						CheckRestart: &structs.CheckRestart{
   618  							Limit: 1,
   619  							Grace: 1,
   620  						},
   621  					},
   622  				},
   623  			},
   624  		},
   625  		Ports: []structs.AllocatedPortMapping{
   626  			{
   627  				Label:  "db",
   628  				HostIP: "10.10.13.2",
   629  				Value:  23098,
   630  			},
   631  			{
   632  				Label:  "http",
   633  				HostIP: "10.10.13.2",
   634  				Value:  24098,
   635  			},
   636  		},
   637  	}
   638  }
   639  
   640  // mockRPC mocks and tracks RPC calls made for testing.
   641  type mockRPC struct {
   642  
   643  	// callCounts tracks how many times each RPC method has been called. The
   644  	// lock should be used to access this.
   645  	callCounts map[string]int
   646  	l          sync.RWMutex
   647  }
   648  
   649  // calls returns the mapping counting the number of calls made to each RPC
   650  // method.
   651  func (mr *mockRPC) calls() map[string]int {
   652  	mr.l.RLock()
   653  	defer mr.l.RUnlock()
   654  	return mr.callCounts
   655  }
   656  
   657  // RPC mocks the server RPCs, acting as though any request succeeds.
   658  func (mr *mockRPC) RPC(method string, _, _ interface{}) error {
   659  	switch method {
   660  	case structs.ServiceRegistrationUpsertRPCMethod, structs.ServiceRegistrationDeleteByIDRPCMethod:
   661  		mr.l.Lock()
   662  		mr.callCounts[method]++
   663  		mr.l.Unlock()
   664  		return nil
   665  	default:
   666  		return fmt.Errorf("unexpected RPC method: %v", method)
   667  	}
   668  }