github.com/projecteru2/core@v0.0.0-20240321043226-06bcc1c23f58/cluster/calcium/service_test.go (about)

     1  package calcium
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/projecteru2/core/discovery/helium"
    11  	storemocks "github.com/projecteru2/core/store/mocks"
    12  
    13  	"github.com/stretchr/testify/assert"
    14  	"github.com/stretchr/testify/mock"
    15  )
    16  
    17  func TestServiceStatusStream(t *testing.T) {
    18  	c := NewTestCluster()
    19  	c.config.Bind = ":5001"
    20  	c.config.GRPCConfig.ServiceHeartbeatInterval = 100 * time.Millisecond
    21  	c.config.GRPCConfig.ServiceDiscoveryPushInterval = 10 * time.Second
    22  	store := c.store.(*storemocks.Store)
    23  
    24  	var unregistered bool
    25  	unregister := func() { unregistered = true }
    26  	expiry := make(<-chan struct{})
    27  	store.On("RegisterService", mock.Anything, mock.Anything, mock.Anything).Return(expiry, unregister, nil).Once()
    28  
    29  	ctx, cancel := context.WithCancel(context.Background())
    30  	defer cancel()
    31  	unregisterService, err := c.RegisterService(ctx)
    32  	assert.NoError(t, err)
    33  
    34  	unregisterService()
    35  	assert.True(t, unregistered)
    36  }
    37  
    38  func TestServiceStatusStreamWithMultipleRegisteringAsExpired(t *testing.T) {
    39  	c := NewTestCluster()
    40  	c.config.Bind = ":5001"
    41  	c.config.GRPCConfig.ServiceHeartbeatInterval = 100 * time.Millisecond
    42  	c.config.GRPCConfig.ServiceDiscoveryPushInterval = 10 * time.Second
    43  	store := c.store.(*storemocks.Store)
    44  
    45  	raw := make(chan struct{})
    46  	var expiry <-chan struct{} = raw
    47  	store.On("RegisterService", mock.Anything, mock.Anything, mock.Anything).Return(expiry, func() {}, nil).Once()
    48  	// Once the original one expired, the new calling's expiry must also be a brand new <-chan.
    49  	store.On("RegisterService", mock.Anything, mock.Anything, mock.Anything).Return(make(<-chan struct{}), func() {}, nil).Once()
    50  
    51  	ctx, cancel := context.WithCancel(context.Background())
    52  	defer cancel()
    53  	_, err := c.RegisterService(ctx)
    54  	assert.NoError(t, err)
    55  
    56  	// Triggers the original one expired.
    57  	close(raw)
    58  	// Waiting for the second calling of store.RegisterService.
    59  	time.Sleep(time.Millisecond)
    60  	store.AssertExpectations(t)
    61  }
    62  
    63  func TestRegisterServiceFailed(t *testing.T) {
    64  	c := NewTestCluster()
    65  	c.config.Bind = ":5001"
    66  	c.config.GRPCConfig.ServiceHeartbeatInterval = 100 * time.Millisecond
    67  	c.config.GRPCConfig.ServiceDiscoveryPushInterval = 10 * time.Second
    68  	store := c.store.(*storemocks.Store)
    69  
    70  	experr := fmt.Errorf("error")
    71  	store.On("RegisterService", mock.Anything, mock.Anything, mock.Anything).Return(make(<-chan struct{}), func() {}, experr).Once()
    72  
    73  	ctx, cancel := context.WithCancel(context.Background())
    74  	defer cancel()
    75  
    76  	_, err := c.RegisterService(ctx)
    77  	assert.EqualError(t, err, "error")
    78  }
    79  
    80  func TestWatchServiceStatus(t *testing.T) {
    81  	c := NewTestCluster()
    82  	c.config.GRPCConfig.ServiceDiscoveryPushInterval = 500 * time.Millisecond
    83  	store := c.store.(*storemocks.Store)
    84  	store.On("ServiceStatusStream", mock.AnythingOfType("*context.emptyCtx")).Return(
    85  		func(_ context.Context) chan []string {
    86  			ch := make(chan []string)
    87  			go func() {
    88  				ticker := time.NewTicker(50 * time.Millisecond)
    89  				cnt := 0
    90  				for range ticker.C {
    91  					if cnt == 2 {
    92  						break
    93  					}
    94  					ch <- []string{fmt.Sprintf("127.0.0.1:500%d", cnt)}
    95  					cnt++
    96  				}
    97  			}()
    98  			return ch
    99  		}, nil,
   100  	)
   101  	c.watcher = helium.New(context.TODO(), c.config.GRPCConfig, c.store)
   102  
   103  	ch, err := c.WatchServiceStatus(context.Background())
   104  	assert.NoError(t, err)
   105  	ch2, err := c.WatchServiceStatus(context.Background())
   106  	assert.NoError(t, err)
   107  	wg := sync.WaitGroup{}
   108  	wg.Add(2)
   109  	go func() {
   110  		defer wg.Done()
   111  		assert.Equal(t, (<-ch).Addresses, []string{"127.0.0.1:5000"})
   112  		assert.Equal(t, (<-ch).Addresses, []string{"127.0.0.1:5001"})
   113  		assert.Equal(t, (<-ch).Addresses, []string{"127.0.0.1:5001"})
   114  	}()
   115  	go func() {
   116  		defer wg.Done()
   117  		assert.Equal(t, (<-ch2).Addresses, []string{"127.0.0.1:5000"})
   118  		assert.Equal(t, (<-ch2).Addresses, []string{"127.0.0.1:5001"})
   119  		assert.Equal(t, (<-ch2).Addresses, []string{"127.0.0.1:5001"})
   120  	}()
   121  	wg.Wait()
   122  }