github.com/status-im/status-go@v1.1.0/discovery/muxer_test.go (about)

     1  package discovery
     2  
     3  import (
     4  	"errors"
     5  	"sync"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/stretchr/testify/require"
    10  
    11  	"github.com/ethereum/go-ethereum/p2p/discv5"
    12  )
    13  
    14  func newRegistry() *registry {
    15  	return &registry{
    16  		storage: map[string][]int{},
    17  	}
    18  }
    19  
    20  type registry struct {
    21  	mu      sync.Mutex
    22  	storage map[string][]int
    23  }
    24  
    25  func (r *registry) Add(topic string, id int) {
    26  	r.mu.Lock()
    27  	defer r.mu.Unlock()
    28  	r.storage[topic] = append(r.storage[topic], id)
    29  }
    30  
    31  func (r *registry) Get(topic string) []int {
    32  	r.mu.Lock()
    33  	defer r.mu.Unlock()
    34  	return r.storage[topic]
    35  }
    36  
    37  type fake struct {
    38  	started  bool
    39  	err      error
    40  	id       int
    41  	registry *registry
    42  }
    43  
    44  func (f *fake) Start() error {
    45  	if f.err != nil {
    46  		return f.err
    47  	}
    48  	f.started = true
    49  	return nil
    50  }
    51  
    52  func (f *fake) Stop() error {
    53  	f.started = false
    54  	if f.err != nil {
    55  		return f.err
    56  	}
    57  	return nil
    58  }
    59  
    60  func (f *fake) Running() bool {
    61  	return f.started
    62  }
    63  
    64  func (f *fake) Register(topic string, stop chan struct{}) error {
    65  	if f.err != nil {
    66  		return f.err
    67  	}
    68  	f.registry.Add(topic, f.id)
    69  	return nil
    70  }
    71  
    72  func (f *fake) Discover(topic string, period <-chan time.Duration, found chan<- *discv5.Node, lookup chan<- bool) error {
    73  	if f.err != nil {
    74  		return f.err
    75  	}
    76  	for _, n := range f.registry.Get(topic) {
    77  		found <- discv5.NewNode(discv5.NodeID{byte(n)}, nil, 0, 0)
    78  	}
    79  	return nil
    80  }
    81  
    82  type testErrorCase struct {
    83  	desc   string
    84  	errors []error
    85  }
    86  
    87  func errorCases() []testErrorCase {
    88  	return []testErrorCase{
    89  		{desc: "SingleError", errors: []error{nil, errors.New("test")}},
    90  		{desc: "NoErrors", errors: []error{nil, nil}},
    91  		{desc: "AllErrors", errors: []error{errors.New("test"), errors.New("test")}},
    92  	}
    93  }
    94  
    95  func TestMuxerStart(t *testing.T) {
    96  	for _, tc := range errorCases() {
    97  		t.Run(tc.desc, func(t *testing.T) {
    98  			discoveries := make([]Discovery, len(tc.errors))
    99  			erred := false
   100  			for i, err := range tc.errors {
   101  				if err != nil {
   102  					erred = true
   103  				}
   104  				discoveries[i] = &fake{err: err}
   105  			}
   106  			muxer := NewMultiplexer(discoveries)
   107  			if erred {
   108  				require.Error(t, muxer.Start())
   109  			} else {
   110  				require.NoError(t, muxer.Start())
   111  			}
   112  			for _, d := range discoveries {
   113  				require.Equal(t, !erred, d.Running())
   114  			}
   115  		})
   116  	}
   117  }
   118  
   119  func TestMuxerStop(t *testing.T) {
   120  	for _, tc := range errorCases() {
   121  		t.Run(tc.desc, func(t *testing.T) {
   122  			discoveries := make([]Discovery, len(tc.errors))
   123  			erred := false
   124  			for i, err := range tc.errors {
   125  				if err != nil {
   126  					erred = true
   127  				}
   128  				discoveries[i] = &fake{started: true, err: err}
   129  			}
   130  			muxer := NewMultiplexer(discoveries)
   131  			if erred {
   132  				require.Error(t, muxer.Stop())
   133  			} else {
   134  				require.NoError(t, muxer.Stop())
   135  			}
   136  			for _, d := range discoveries {
   137  				require.False(t, d.Running())
   138  			}
   139  		})
   140  	}
   141  }
   142  
   143  func TestMuxerRunning(t *testing.T) {
   144  	for _, tc := range []struct {
   145  		desc    string
   146  		started []bool
   147  	}{
   148  		{desc: "FirstRunning", started: []bool{false, true}},
   149  		{desc: "SecondRunning", started: []bool{true, false}},
   150  		{desc: "AllRunning", started: []bool{true, true}},
   151  		{desc: "NoRunning", started: []bool{false, false}},
   152  	} {
   153  		t.Run(tc.desc, func(t *testing.T) {
   154  			discoveries := make([]Discovery, len(tc.started))
   155  			allstarted := false
   156  			for i, start := range tc.started {
   157  				allstarted = start || allstarted
   158  				discoveries[i] = &fake{started: start}
   159  			}
   160  			require.Equal(t, allstarted, NewMultiplexer(discoveries).Running())
   161  		})
   162  	}
   163  }
   164  
   165  func TestMuxerRegister(t *testing.T) {
   166  	for _, tc := range []struct {
   167  		desc   string
   168  		errors []error
   169  		topics []string
   170  	}{
   171  		{"NoErrors", []error{nil, nil, nil}, []string{"a"}},
   172  		{"MultipleTopics", []error{nil, nil, nil}, []string{"a", "b", "c"}},
   173  		{"SingleError", []error{nil, errors.New("test"), nil}, []string{"a"}},
   174  		{"AllErrors", []error{errors.New("test"), errors.New("test"), errors.New("test")}, []string{"a"}},
   175  	} {
   176  		t.Run(tc.desc, func(t *testing.T) {
   177  			reg := newRegistry()
   178  			discoveries := make([]Discovery, len(tc.errors))
   179  			erred := 0
   180  			for i := range discoveries {
   181  				if tc.errors[i] != nil {
   182  					erred++
   183  				}
   184  				discoveries[i] = &fake{id: i, err: tc.errors[i], registry: reg}
   185  			}
   186  			muxer := NewMultiplexer(discoveries)
   187  			for _, topic := range tc.topics {
   188  				if erred != 0 {
   189  					require.Error(t, muxer.Register(topic, nil))
   190  				} else {
   191  					require.NoError(t, muxer.Register(topic, nil))
   192  				}
   193  				require.Equal(t, len(discoveries)-erred, len(reg.Get(topic)))
   194  			}
   195  		})
   196  	}
   197  }
   198  
   199  func TestMuxerDiscovery(t *testing.T) {
   200  	for _, tc := range []struct {
   201  		desc   string
   202  		errors []error
   203  		topics []string
   204  		ids    [][]int
   205  	}{
   206  		{"EqualNoErrors", []error{nil, nil}, []string{"a"}, [][]int{{11, 22, 33}, {44, 55, 66}}},
   207  		{"MultiTopicsSingleSource", []error{nil, nil}, []string{"a", "b"}, [][]int{{11, 22, 33}, {}}},
   208  		{"SingleError", []error{nil, errors.New("test")}, []string{"a"}, [][]int{{11, 22, 33}, {44, 55, 66}}},
   209  		{"AllErrors", []error{errors.New("test"), errors.New("test")}, []string{"a"}, [][]int{{11, 22, 33}, {44, 55, 66}}},
   210  	} {
   211  		t.Run(tc.desc, func(t *testing.T) {
   212  			discoveries := make([]Discovery, len(tc.errors))
   213  			erred := false
   214  			expected := 0
   215  			for i := range discoveries {
   216  				if tc.errors[i] == nil {
   217  					expected += len(tc.ids[i])
   218  				} else {
   219  					erred = true
   220  				}
   221  				reg := newRegistry()
   222  				discoveries[i] = &fake{id: i, err: tc.errors[i], registry: reg}
   223  				for _, topic := range tc.topics {
   224  					for _, id := range tc.ids[i] {
   225  						reg.Add(topic, id)
   226  					}
   227  				}
   228  			}
   229  			muxer := NewMultiplexer(discoveries)
   230  			for _, topic := range tc.topics {
   231  				found := make(chan *discv5.Node, expected)
   232  				period := make(chan time.Duration)
   233  				close(period)
   234  				if erred {
   235  					// TODO test period channel
   236  					require.Error(t, muxer.Discover(topic, period, found, nil))
   237  				} else {
   238  					require.NoError(t, muxer.Discover(topic, period, found, nil))
   239  				}
   240  				close(found)
   241  				count := 0
   242  				for range found {
   243  					count++
   244  				}
   245  				require.Equal(t, expected, count)
   246  			}
   247  		})
   248  	}
   249  }