github.com/xmidt-org/webpa-common@v1.11.9/service/consul/registrar_test.go (about) 1 package consul 2 3 import ( 4 "errors" 5 "testing" 6 "time" 7 8 "github.com/hashicorp/consul/api" 9 "github.com/stretchr/testify/assert" 10 "github.com/stretchr/testify/mock" 11 "github.com/stretchr/testify/require" 12 "github.com/xmidt-org/webpa-common/logging" 13 ) 14 15 func TestDefaultTickerFactory(t *testing.T) { 16 var ( 17 assert = assert.New(t) 18 require = require.New(t) 19 ) 20 21 assert.Panics(func() { 22 defaultTickerFactory(-123123) 23 }) 24 25 ticker, stop := defaultTickerFactory(20 * time.Second) 26 assert.NotNil(ticker) 27 require.NotNil(stop) 28 stop() 29 } 30 31 func testNewRegistrarNoChecks(t *testing.T) { 32 defer resetTickerFactory() 33 34 var ( 35 require = require.New(t) 36 37 logger = logging.NewTestLogger(nil, t) 38 client = new(mockClient) 39 ttlUpdater = new(mockTTLUpdater) 40 tickerFactory = prepareMockTickerFactory() 41 42 registration = &api.AgentServiceRegistration{ 43 ID: "service1", 44 Address: "somehost.com", 45 Port: 1111, 46 } 47 ) 48 49 client.On("Register", 50 mock.MatchedBy(func(r *api.AgentServiceRegistration) bool { 51 return r.ID == "service1" 52 }), 53 ).Return(error(nil)).Once() 54 55 client.On("Deregister", 56 mock.MatchedBy(func(r *api.AgentServiceRegistration) bool { 57 return r.ID == "service1" 58 }), 59 ).Return(error(nil)).Once() 60 61 r, err := NewRegistrar(client, ttlUpdater, registration, logger) 62 require.NoError(err) 63 require.NotNil(r) 64 65 r.Register() 66 r.Deregister() 67 68 client.AssertExpectations(t) 69 ttlUpdater.AssertExpectations(t) 70 tickerFactory.AssertExpectations(t) 71 } 72 73 func testNewRegistrarNoTTL(t *testing.T) { 74 defer resetTickerFactory() 75 76 var ( 77 require = require.New(t) 78 79 logger = logging.NewTestLogger(nil, t) 80 client = new(mockClient) 81 ttlUpdater = new(mockTTLUpdater) 82 tickerFactory = prepareMockTickerFactory() 83 84 registration = &api.AgentServiceRegistration{ 85 ID: "service1", 86 Address: "somehost.com", 87 Port: 1111, 88 Check: &api.AgentServiceCheck{ 89 CheckID: "check1", 90 HTTP: "https://foobar.com/foo", 91 }, 92 Checks: []*api.AgentServiceCheck{ 93 { 94 CheckID: "check2", 95 HTTP: "https://foobar.com/moo", 96 }, 97 }, 98 } 99 ) 100 101 client.On("Register", 102 mock.MatchedBy(func(r *api.AgentServiceRegistration) bool { 103 return r.ID == "service1" 104 }), 105 ).Return(error(nil)).Once() 106 107 client.On("Deregister", 108 mock.MatchedBy(func(r *api.AgentServiceRegistration) bool { 109 return r.ID == "service1" 110 }), 111 ).Return(error(nil)).Once() 112 113 r, err := NewRegistrar(client, ttlUpdater, registration, logger) 114 require.NoError(err) 115 require.NotNil(r) 116 117 r.Register() 118 r.Deregister() 119 120 client.AssertExpectations(t) 121 ttlUpdater.AssertExpectations(t) 122 tickerFactory.AssertExpectations(t) 123 } 124 125 func testNewRegistrarCheckMalformedTTL(t *testing.T) { 126 defer resetTickerFactory() 127 128 var ( 129 assert = assert.New(t) 130 131 logger = logging.NewTestLogger(nil, t) 132 client = new(mockClient) 133 ttlUpdater = new(mockTTLUpdater) 134 tickerFactory = prepareMockTickerFactory() 135 136 registration = &api.AgentServiceRegistration{ 137 ID: "service1", 138 Address: "somehost.com", 139 Port: 1111, 140 Check: &api.AgentServiceCheck{ 141 CheckID: "check1", 142 TTL: "this is not valid", 143 }, 144 } 145 ) 146 147 r, err := NewRegistrar(client, ttlUpdater, registration, logger) 148 assert.Error(err) 149 assert.Nil(r) 150 151 client.AssertExpectations(t) 152 ttlUpdater.AssertExpectations(t) 153 tickerFactory.AssertExpectations(t) 154 } 155 156 func testNewRegistrarCheckTTLTooSmall(t *testing.T) { 157 defer resetTickerFactory() 158 159 var ( 160 assert = assert.New(t) 161 162 logger = logging.NewTestLogger(nil, t) 163 client = new(mockClient) 164 ttlUpdater = new(mockTTLUpdater) 165 tickerFactory = prepareMockTickerFactory() 166 167 registration = &api.AgentServiceRegistration{ 168 ID: "service1", 169 Address: "somehost.com", 170 Port: 1111, 171 Check: &api.AgentServiceCheck{ 172 CheckID: "check1", 173 TTL: "1ns", 174 }, 175 } 176 ) 177 178 r, err := NewRegistrar(client, ttlUpdater, registration, logger) 179 assert.Error(err) 180 assert.Nil(r) 181 182 client.AssertExpectations(t) 183 ttlUpdater.AssertExpectations(t) 184 tickerFactory.AssertExpectations(t) 185 } 186 187 func testNewRegistrarChecksMalformedTTL(t *testing.T) { 188 defer resetTickerFactory() 189 190 var ( 191 assert = assert.New(t) 192 193 logger = logging.NewTestLogger(nil, t) 194 client = new(mockClient) 195 ttlUpdater = new(mockTTLUpdater) 196 tickerFactory = prepareMockTickerFactory() 197 198 registration = &api.AgentServiceRegistration{ 199 ID: "service1", 200 Address: "somehost.com", 201 Port: 1111, 202 Checks: []*api.AgentServiceCheck{ 203 { 204 CheckID: "check1", 205 TTL: "this is not valid", 206 }, 207 }, 208 } 209 ) 210 211 r, err := NewRegistrar(client, ttlUpdater, registration, logger) 212 assert.Error(err) 213 assert.Nil(r) 214 215 client.AssertExpectations(t) 216 ttlUpdater.AssertExpectations(t) 217 tickerFactory.AssertExpectations(t) 218 } 219 220 func testNewRegistrarChecksTTLTooSmall(t *testing.T) { 221 defer resetTickerFactory() 222 223 var ( 224 assert = assert.New(t) 225 226 logger = logging.NewTestLogger(nil, t) 227 client = new(mockClient) 228 ttlUpdater = new(mockTTLUpdater) 229 tickerFactory = prepareMockTickerFactory() 230 231 registration = &api.AgentServiceRegistration{ 232 ID: "service1", 233 Address: "somehost.com", 234 Port: 1111, 235 Checks: []*api.AgentServiceCheck{ 236 { 237 CheckID: "check1", 238 TTL: "1ns", 239 }, 240 }, 241 } 242 ) 243 244 r, err := NewRegistrar(client, ttlUpdater, registration, logger) 245 assert.Error(err) 246 assert.Nil(r) 247 248 client.AssertExpectations(t) 249 ttlUpdater.AssertExpectations(t) 250 tickerFactory.AssertExpectations(t) 251 } 252 253 func testNewRegistrarTTL(t *testing.T) { 254 defer resetTickerFactory() 255 256 var ( 257 assert = assert.New(t) 258 require = require.New(t) 259 260 logger = logging.NewTestLogger(nil, t) 261 client = new(mockClient) 262 ttlUpdater = new(mockTTLUpdater) 263 tickerFactory = prepareMockTickerFactory() 264 265 timer1 = make(chan time.Time, 1) 266 timer1Ack = make(chan struct{}, 1) 267 timer1AckRun = func(mock.Arguments) { timer1Ack <- struct{}{} } 268 update1Done = make(chan struct{}) 269 stop1 = func() { 270 close(update1Done) 271 } 272 273 timer2 = make(chan time.Time, 1) 274 timer2Ack = make(chan struct{}, 1) 275 timer2AckRun = func(mock.Arguments) { timer2Ack <- struct{}{} } 276 update2Done = make(chan struct{}) 277 stop2 = func() { 278 close(update2Done) 279 } 280 281 registration = &api.AgentServiceRegistration{ 282 ID: "service1", 283 Address: "somehost.com", 284 Port: 1111, 285 Check: &api.AgentServiceCheck{ 286 CheckID: "check1", 287 TTL: "15s", 288 }, 289 Checks: []*api.AgentServiceCheck{ 290 { 291 CheckID: "check2", 292 TTL: "30s", 293 }, 294 }, 295 } 296 ) 297 298 ttlUpdater.On("UpdateTTL", "check1", mock.MatchedBy(func(v string) bool { return len(v) > 0 }), "pass").Return(error(nil)).Once().Run(timer1AckRun) 299 ttlUpdater.On("UpdateTTL", "check1", mock.MatchedBy(func(v string) bool { return len(v) > 0 }), "pass").Return(errors.New("expected check1 error")).Once().Run(timer1AckRun) 300 ttlUpdater.On("UpdateTTL", "check1", mock.MatchedBy(func(v string) bool { return len(v) > 0 }), "pass").Return(error(nil)).Once().Run(timer1AckRun) 301 ttlUpdater.On("UpdateTTL", "check1", mock.MatchedBy(func(v string) bool { return len(v) > 0 }), "fail").Return(error(nil)).Once() 302 303 ttlUpdater.On("UpdateTTL", "check2", mock.MatchedBy(func(v string) bool { return len(v) > 0 }), "pass").Return(error(nil)).Once().Run(timer2AckRun) 304 ttlUpdater.On("UpdateTTL", "check2", mock.MatchedBy(func(v string) bool { return len(v) > 0 }), "pass").Return(errors.New("expected check2 error")).Once().Run(timer2AckRun) 305 ttlUpdater.On("UpdateTTL", "check2", mock.MatchedBy(func(v string) bool { return len(v) > 0 }), "pass").Return(error(nil)).Once().Run(timer2AckRun) 306 ttlUpdater.On("UpdateTTL", "check2", mock.MatchedBy(func(v string) bool { return len(v) > 0 }), "fail").Return(errors.New("expected check2 fail error")).Once() 307 308 tickerFactory.On("NewTicker", (15*time.Second)/2).Return((<-chan time.Time)(timer1), stop1) 309 tickerFactory.On("NewTicker", (30*time.Second)/2).Return((<-chan time.Time)(timer2), stop2) 310 311 client.On("Register", 312 mock.MatchedBy(func(r *api.AgentServiceRegistration) bool { 313 return r.ID == "service1" 314 }), 315 ).Return(error(nil)).Once() 316 317 client.On("Deregister", 318 mock.MatchedBy(func(r *api.AgentServiceRegistration) bool { 319 return r.ID == "service1" 320 }), 321 ).Return(error(nil)).Once() 322 323 r, err := NewRegistrar(client, ttlUpdater, registration, logger) 324 require.NoError(err) 325 require.NotNil(r) 326 327 r.Register() 328 r.Register() // idempotent 329 330 // simulate some updates 331 now := time.Now() 332 333 // we have 3 pass updates expected for each TTL check above 334 for repeat := 0; repeat < 3; repeat++ { 335 timer1 <- now 336 select { 337 case <-timer1Ack: 338 // passing 339 case <-time.After(2 * time.Second): 340 require.Fail("Time event was not processed") 341 } 342 343 timer2 <- now 344 select { 345 case <-timer2Ack: 346 // passing 347 case <-time.After(2 * time.Second): 348 require.Fail("Time event was not processed") 349 } 350 } 351 352 r.Deregister() 353 r.Deregister() // idempotent 354 355 select { 356 case <-update1Done: 357 // passing 358 case <-time.After(2 * time.Second): 359 assert.Fail("TTL update goroutine did not fail the TTL") 360 } 361 362 select { 363 case <-update2Done: 364 // passing 365 case <-time.After(2 * time.Second): 366 assert.Fail("TTL update goroutine did not fail the TTL") 367 } 368 369 client.AssertExpectations(t) 370 ttlUpdater.AssertExpectations(t) 371 tickerFactory.AssertExpectations(t) 372 } 373 374 func TestNewRegistrar(t *testing.T) { 375 t.Run("NoChecks", testNewRegistrarNoChecks) 376 t.Run("NoTTL", testNewRegistrarNoTTL) 377 378 t.Run("Check", func(t *testing.T) { 379 t.Run("MalformedTTL", testNewRegistrarCheckMalformedTTL) 380 t.Run("TTLTooSmall", testNewRegistrarCheckTTLTooSmall) 381 }) 382 383 t.Run("Checks", func(t *testing.T) { 384 t.Run("MalformedTTL", testNewRegistrarChecksMalformedTTL) 385 t.Run("TTLTooSmall", testNewRegistrarChecksTTLTooSmall) 386 }) 387 388 t.Run("TTL", testNewRegistrarTTL) 389 }