github.com/xmidt-org/webpa-common@v1.11.9/service/consul/registrar_test.go (about)

     1  package consul
     2  
     3  import (
     4  	"errors"
     5  	"testing"
     6  	"time"
     7  
     8  	"github.com/hashicorp/consul/api"
     9  	"github.com/stretchr/testify/assert"
    10  	"github.com/stretchr/testify/mock"
    11  	"github.com/stretchr/testify/require"
    12  	"github.com/xmidt-org/webpa-common/logging"
    13  )
    14  
    15  func TestDefaultTickerFactory(t *testing.T) {
    16  	var (
    17  		assert  = assert.New(t)
    18  		require = require.New(t)
    19  	)
    20  
    21  	assert.Panics(func() {
    22  		defaultTickerFactory(-123123)
    23  	})
    24  
    25  	ticker, stop := defaultTickerFactory(20 * time.Second)
    26  	assert.NotNil(ticker)
    27  	require.NotNil(stop)
    28  	stop()
    29  }
    30  
    31  func testNewRegistrarNoChecks(t *testing.T) {
    32  	defer resetTickerFactory()
    33  
    34  	var (
    35  		require = require.New(t)
    36  
    37  		logger        = logging.NewTestLogger(nil, t)
    38  		client        = new(mockClient)
    39  		ttlUpdater    = new(mockTTLUpdater)
    40  		tickerFactory = prepareMockTickerFactory()
    41  
    42  		registration = &api.AgentServiceRegistration{
    43  			ID:      "service1",
    44  			Address: "somehost.com",
    45  			Port:    1111,
    46  		}
    47  	)
    48  
    49  	client.On("Register",
    50  		mock.MatchedBy(func(r *api.AgentServiceRegistration) bool {
    51  			return r.ID == "service1"
    52  		}),
    53  	).Return(error(nil)).Once()
    54  
    55  	client.On("Deregister",
    56  		mock.MatchedBy(func(r *api.AgentServiceRegistration) bool {
    57  			return r.ID == "service1"
    58  		}),
    59  	).Return(error(nil)).Once()
    60  
    61  	r, err := NewRegistrar(client, ttlUpdater, registration, logger)
    62  	require.NoError(err)
    63  	require.NotNil(r)
    64  
    65  	r.Register()
    66  	r.Deregister()
    67  
    68  	client.AssertExpectations(t)
    69  	ttlUpdater.AssertExpectations(t)
    70  	tickerFactory.AssertExpectations(t)
    71  }
    72  
    73  func testNewRegistrarNoTTL(t *testing.T) {
    74  	defer resetTickerFactory()
    75  
    76  	var (
    77  		require = require.New(t)
    78  
    79  		logger        = logging.NewTestLogger(nil, t)
    80  		client        = new(mockClient)
    81  		ttlUpdater    = new(mockTTLUpdater)
    82  		tickerFactory = prepareMockTickerFactory()
    83  
    84  		registration = &api.AgentServiceRegistration{
    85  			ID:      "service1",
    86  			Address: "somehost.com",
    87  			Port:    1111,
    88  			Check: &api.AgentServiceCheck{
    89  				CheckID: "check1",
    90  				HTTP:    "https://foobar.com/foo",
    91  			},
    92  			Checks: []*api.AgentServiceCheck{
    93  				{
    94  					CheckID: "check2",
    95  					HTTP:    "https://foobar.com/moo",
    96  				},
    97  			},
    98  		}
    99  	)
   100  
   101  	client.On("Register",
   102  		mock.MatchedBy(func(r *api.AgentServiceRegistration) bool {
   103  			return r.ID == "service1"
   104  		}),
   105  	).Return(error(nil)).Once()
   106  
   107  	client.On("Deregister",
   108  		mock.MatchedBy(func(r *api.AgentServiceRegistration) bool {
   109  			return r.ID == "service1"
   110  		}),
   111  	).Return(error(nil)).Once()
   112  
   113  	r, err := NewRegistrar(client, ttlUpdater, registration, logger)
   114  	require.NoError(err)
   115  	require.NotNil(r)
   116  
   117  	r.Register()
   118  	r.Deregister()
   119  
   120  	client.AssertExpectations(t)
   121  	ttlUpdater.AssertExpectations(t)
   122  	tickerFactory.AssertExpectations(t)
   123  }
   124  
   125  func testNewRegistrarCheckMalformedTTL(t *testing.T) {
   126  	defer resetTickerFactory()
   127  
   128  	var (
   129  		assert = assert.New(t)
   130  
   131  		logger        = logging.NewTestLogger(nil, t)
   132  		client        = new(mockClient)
   133  		ttlUpdater    = new(mockTTLUpdater)
   134  		tickerFactory = prepareMockTickerFactory()
   135  
   136  		registration = &api.AgentServiceRegistration{
   137  			ID:      "service1",
   138  			Address: "somehost.com",
   139  			Port:    1111,
   140  			Check: &api.AgentServiceCheck{
   141  				CheckID: "check1",
   142  				TTL:     "this is not valid",
   143  			},
   144  		}
   145  	)
   146  
   147  	r, err := NewRegistrar(client, ttlUpdater, registration, logger)
   148  	assert.Error(err)
   149  	assert.Nil(r)
   150  
   151  	client.AssertExpectations(t)
   152  	ttlUpdater.AssertExpectations(t)
   153  	tickerFactory.AssertExpectations(t)
   154  }
   155  
   156  func testNewRegistrarCheckTTLTooSmall(t *testing.T) {
   157  	defer resetTickerFactory()
   158  
   159  	var (
   160  		assert = assert.New(t)
   161  
   162  		logger        = logging.NewTestLogger(nil, t)
   163  		client        = new(mockClient)
   164  		ttlUpdater    = new(mockTTLUpdater)
   165  		tickerFactory = prepareMockTickerFactory()
   166  
   167  		registration = &api.AgentServiceRegistration{
   168  			ID:      "service1",
   169  			Address: "somehost.com",
   170  			Port:    1111,
   171  			Check: &api.AgentServiceCheck{
   172  				CheckID: "check1",
   173  				TTL:     "1ns",
   174  			},
   175  		}
   176  	)
   177  
   178  	r, err := NewRegistrar(client, ttlUpdater, registration, logger)
   179  	assert.Error(err)
   180  	assert.Nil(r)
   181  
   182  	client.AssertExpectations(t)
   183  	ttlUpdater.AssertExpectations(t)
   184  	tickerFactory.AssertExpectations(t)
   185  }
   186  
   187  func testNewRegistrarChecksMalformedTTL(t *testing.T) {
   188  	defer resetTickerFactory()
   189  
   190  	var (
   191  		assert = assert.New(t)
   192  
   193  		logger        = logging.NewTestLogger(nil, t)
   194  		client        = new(mockClient)
   195  		ttlUpdater    = new(mockTTLUpdater)
   196  		tickerFactory = prepareMockTickerFactory()
   197  
   198  		registration = &api.AgentServiceRegistration{
   199  			ID:      "service1",
   200  			Address: "somehost.com",
   201  			Port:    1111,
   202  			Checks: []*api.AgentServiceCheck{
   203  				{
   204  					CheckID: "check1",
   205  					TTL:     "this is not valid",
   206  				},
   207  			},
   208  		}
   209  	)
   210  
   211  	r, err := NewRegistrar(client, ttlUpdater, registration, logger)
   212  	assert.Error(err)
   213  	assert.Nil(r)
   214  
   215  	client.AssertExpectations(t)
   216  	ttlUpdater.AssertExpectations(t)
   217  	tickerFactory.AssertExpectations(t)
   218  }
   219  
   220  func testNewRegistrarChecksTTLTooSmall(t *testing.T) {
   221  	defer resetTickerFactory()
   222  
   223  	var (
   224  		assert = assert.New(t)
   225  
   226  		logger        = logging.NewTestLogger(nil, t)
   227  		client        = new(mockClient)
   228  		ttlUpdater    = new(mockTTLUpdater)
   229  		tickerFactory = prepareMockTickerFactory()
   230  
   231  		registration = &api.AgentServiceRegistration{
   232  			ID:      "service1",
   233  			Address: "somehost.com",
   234  			Port:    1111,
   235  			Checks: []*api.AgentServiceCheck{
   236  				{
   237  					CheckID: "check1",
   238  					TTL:     "1ns",
   239  				},
   240  			},
   241  		}
   242  	)
   243  
   244  	r, err := NewRegistrar(client, ttlUpdater, registration, logger)
   245  	assert.Error(err)
   246  	assert.Nil(r)
   247  
   248  	client.AssertExpectations(t)
   249  	ttlUpdater.AssertExpectations(t)
   250  	tickerFactory.AssertExpectations(t)
   251  }
   252  
   253  func testNewRegistrarTTL(t *testing.T) {
   254  	defer resetTickerFactory()
   255  
   256  	var (
   257  		assert  = assert.New(t)
   258  		require = require.New(t)
   259  
   260  		logger        = logging.NewTestLogger(nil, t)
   261  		client        = new(mockClient)
   262  		ttlUpdater    = new(mockTTLUpdater)
   263  		tickerFactory = prepareMockTickerFactory()
   264  
   265  		timer1       = make(chan time.Time, 1)
   266  		timer1Ack    = make(chan struct{}, 1)
   267  		timer1AckRun = func(mock.Arguments) { timer1Ack <- struct{}{} }
   268  		update1Done  = make(chan struct{})
   269  		stop1        = func() {
   270  			close(update1Done)
   271  		}
   272  
   273  		timer2       = make(chan time.Time, 1)
   274  		timer2Ack    = make(chan struct{}, 1)
   275  		timer2AckRun = func(mock.Arguments) { timer2Ack <- struct{}{} }
   276  		update2Done  = make(chan struct{})
   277  		stop2        = func() {
   278  			close(update2Done)
   279  		}
   280  
   281  		registration = &api.AgentServiceRegistration{
   282  			ID:      "service1",
   283  			Address: "somehost.com",
   284  			Port:    1111,
   285  			Check: &api.AgentServiceCheck{
   286  				CheckID: "check1",
   287  				TTL:     "15s",
   288  			},
   289  			Checks: []*api.AgentServiceCheck{
   290  				{
   291  					CheckID: "check2",
   292  					TTL:     "30s",
   293  				},
   294  			},
   295  		}
   296  	)
   297  
   298  	ttlUpdater.On("UpdateTTL", "check1", mock.MatchedBy(func(v string) bool { return len(v) > 0 }), "pass").Return(error(nil)).Once().Run(timer1AckRun)
   299  	ttlUpdater.On("UpdateTTL", "check1", mock.MatchedBy(func(v string) bool { return len(v) > 0 }), "pass").Return(errors.New("expected check1 error")).Once().Run(timer1AckRun)
   300  	ttlUpdater.On("UpdateTTL", "check1", mock.MatchedBy(func(v string) bool { return len(v) > 0 }), "pass").Return(error(nil)).Once().Run(timer1AckRun)
   301  	ttlUpdater.On("UpdateTTL", "check1", mock.MatchedBy(func(v string) bool { return len(v) > 0 }), "fail").Return(error(nil)).Once()
   302  
   303  	ttlUpdater.On("UpdateTTL", "check2", mock.MatchedBy(func(v string) bool { return len(v) > 0 }), "pass").Return(error(nil)).Once().Run(timer2AckRun)
   304  	ttlUpdater.On("UpdateTTL", "check2", mock.MatchedBy(func(v string) bool { return len(v) > 0 }), "pass").Return(errors.New("expected check2 error")).Once().Run(timer2AckRun)
   305  	ttlUpdater.On("UpdateTTL", "check2", mock.MatchedBy(func(v string) bool { return len(v) > 0 }), "pass").Return(error(nil)).Once().Run(timer2AckRun)
   306  	ttlUpdater.On("UpdateTTL", "check2", mock.MatchedBy(func(v string) bool { return len(v) > 0 }), "fail").Return(errors.New("expected check2 fail error")).Once()
   307  
   308  	tickerFactory.On("NewTicker", (15*time.Second)/2).Return((<-chan time.Time)(timer1), stop1)
   309  	tickerFactory.On("NewTicker", (30*time.Second)/2).Return((<-chan time.Time)(timer2), stop2)
   310  
   311  	client.On("Register",
   312  		mock.MatchedBy(func(r *api.AgentServiceRegistration) bool {
   313  			return r.ID == "service1"
   314  		}),
   315  	).Return(error(nil)).Once()
   316  
   317  	client.On("Deregister",
   318  		mock.MatchedBy(func(r *api.AgentServiceRegistration) bool {
   319  			return r.ID == "service1"
   320  		}),
   321  	).Return(error(nil)).Once()
   322  
   323  	r, err := NewRegistrar(client, ttlUpdater, registration, logger)
   324  	require.NoError(err)
   325  	require.NotNil(r)
   326  
   327  	r.Register()
   328  	r.Register() // idempotent
   329  
   330  	// simulate some updates
   331  	now := time.Now()
   332  
   333  	// we have 3 pass updates expected for each TTL check above
   334  	for repeat := 0; repeat < 3; repeat++ {
   335  		timer1 <- now
   336  		select {
   337  		case <-timer1Ack:
   338  			// passing
   339  		case <-time.After(2 * time.Second):
   340  			require.Fail("Time event was not processed")
   341  		}
   342  
   343  		timer2 <- now
   344  		select {
   345  		case <-timer2Ack:
   346  			// passing
   347  		case <-time.After(2 * time.Second):
   348  			require.Fail("Time event was not processed")
   349  		}
   350  	}
   351  
   352  	r.Deregister()
   353  	r.Deregister() // idempotent
   354  
   355  	select {
   356  	case <-update1Done:
   357  		// passing
   358  	case <-time.After(2 * time.Second):
   359  		assert.Fail("TTL update goroutine did not fail the TTL")
   360  	}
   361  
   362  	select {
   363  	case <-update2Done:
   364  		// passing
   365  	case <-time.After(2 * time.Second):
   366  		assert.Fail("TTL update goroutine did not fail the TTL")
   367  	}
   368  
   369  	client.AssertExpectations(t)
   370  	ttlUpdater.AssertExpectations(t)
   371  	tickerFactory.AssertExpectations(t)
   372  }
   373  
   374  func TestNewRegistrar(t *testing.T) {
   375  	t.Run("NoChecks", testNewRegistrarNoChecks)
   376  	t.Run("NoTTL", testNewRegistrarNoTTL)
   377  
   378  	t.Run("Check", func(t *testing.T) {
   379  		t.Run("MalformedTTL", testNewRegistrarCheckMalformedTTL)
   380  		t.Run("TTLTooSmall", testNewRegistrarCheckTTLTooSmall)
   381  	})
   382  
   383  	t.Run("Checks", func(t *testing.T) {
   384  		t.Run("MalformedTTL", testNewRegistrarChecksMalformedTTL)
   385  		t.Run("TTLTooSmall", testNewRegistrarChecksTTLTooSmall)
   386  	})
   387  
   388  	t.Run("TTL", testNewRegistrarTTL)
   389  }