github.com/status-im/status-go@v1.1.0/discovery/muxer_test.go (about) 1 package discovery 2 3 import ( 4 "errors" 5 "sync" 6 "testing" 7 "time" 8 9 "github.com/stretchr/testify/require" 10 11 "github.com/ethereum/go-ethereum/p2p/discv5" 12 ) 13 14 func newRegistry() *registry { 15 return ®istry{ 16 storage: map[string][]int{}, 17 } 18 } 19 20 type registry struct { 21 mu sync.Mutex 22 storage map[string][]int 23 } 24 25 func (r *registry) Add(topic string, id int) { 26 r.mu.Lock() 27 defer r.mu.Unlock() 28 r.storage[topic] = append(r.storage[topic], id) 29 } 30 31 func (r *registry) Get(topic string) []int { 32 r.mu.Lock() 33 defer r.mu.Unlock() 34 return r.storage[topic] 35 } 36 37 type fake struct { 38 started bool 39 err error 40 id int 41 registry *registry 42 } 43 44 func (f *fake) Start() error { 45 if f.err != nil { 46 return f.err 47 } 48 f.started = true 49 return nil 50 } 51 52 func (f *fake) Stop() error { 53 f.started = false 54 if f.err != nil { 55 return f.err 56 } 57 return nil 58 } 59 60 func (f *fake) Running() bool { 61 return f.started 62 } 63 64 func (f *fake) Register(topic string, stop chan struct{}) error { 65 if f.err != nil { 66 return f.err 67 } 68 f.registry.Add(topic, f.id) 69 return nil 70 } 71 72 func (f *fake) Discover(topic string, period <-chan time.Duration, found chan<- *discv5.Node, lookup chan<- bool) error { 73 if f.err != nil { 74 return f.err 75 } 76 for _, n := range f.registry.Get(topic) { 77 found <- discv5.NewNode(discv5.NodeID{byte(n)}, nil, 0, 0) 78 } 79 return nil 80 } 81 82 type testErrorCase struct { 83 desc string 84 errors []error 85 } 86 87 func errorCases() []testErrorCase { 88 return []testErrorCase{ 89 {desc: "SingleError", errors: []error{nil, errors.New("test")}}, 90 {desc: "NoErrors", errors: []error{nil, nil}}, 91 {desc: "AllErrors", errors: []error{errors.New("test"), errors.New("test")}}, 92 } 93 } 94 95 func TestMuxerStart(t *testing.T) { 96 for _, tc := range errorCases() { 97 t.Run(tc.desc, func(t *testing.T) { 98 discoveries := make([]Discovery, len(tc.errors)) 99 erred := false 100 for i, err := range tc.errors { 101 if err != nil { 102 erred = true 103 } 104 discoveries[i] = &fake{err: err} 105 } 106 muxer := NewMultiplexer(discoveries) 107 if erred { 108 require.Error(t, muxer.Start()) 109 } else { 110 require.NoError(t, muxer.Start()) 111 } 112 for _, d := range discoveries { 113 require.Equal(t, !erred, d.Running()) 114 } 115 }) 116 } 117 } 118 119 func TestMuxerStop(t *testing.T) { 120 for _, tc := range errorCases() { 121 t.Run(tc.desc, func(t *testing.T) { 122 discoveries := make([]Discovery, len(tc.errors)) 123 erred := false 124 for i, err := range tc.errors { 125 if err != nil { 126 erred = true 127 } 128 discoveries[i] = &fake{started: true, err: err} 129 } 130 muxer := NewMultiplexer(discoveries) 131 if erred { 132 require.Error(t, muxer.Stop()) 133 } else { 134 require.NoError(t, muxer.Stop()) 135 } 136 for _, d := range discoveries { 137 require.False(t, d.Running()) 138 } 139 }) 140 } 141 } 142 143 func TestMuxerRunning(t *testing.T) { 144 for _, tc := range []struct { 145 desc string 146 started []bool 147 }{ 148 {desc: "FirstRunning", started: []bool{false, true}}, 149 {desc: "SecondRunning", started: []bool{true, false}}, 150 {desc: "AllRunning", started: []bool{true, true}}, 151 {desc: "NoRunning", started: []bool{false, false}}, 152 } { 153 t.Run(tc.desc, func(t *testing.T) { 154 discoveries := make([]Discovery, len(tc.started)) 155 allstarted := false 156 for i, start := range tc.started { 157 allstarted = start || allstarted 158 discoveries[i] = &fake{started: start} 159 } 160 require.Equal(t, allstarted, NewMultiplexer(discoveries).Running()) 161 }) 162 } 163 } 164 165 func TestMuxerRegister(t *testing.T) { 166 for _, tc := range []struct { 167 desc string 168 errors []error 169 topics []string 170 }{ 171 {"NoErrors", []error{nil, nil, nil}, []string{"a"}}, 172 {"MultipleTopics", []error{nil, nil, nil}, []string{"a", "b", "c"}}, 173 {"SingleError", []error{nil, errors.New("test"), nil}, []string{"a"}}, 174 {"AllErrors", []error{errors.New("test"), errors.New("test"), errors.New("test")}, []string{"a"}}, 175 } { 176 t.Run(tc.desc, func(t *testing.T) { 177 reg := newRegistry() 178 discoveries := make([]Discovery, len(tc.errors)) 179 erred := 0 180 for i := range discoveries { 181 if tc.errors[i] != nil { 182 erred++ 183 } 184 discoveries[i] = &fake{id: i, err: tc.errors[i], registry: reg} 185 } 186 muxer := NewMultiplexer(discoveries) 187 for _, topic := range tc.topics { 188 if erred != 0 { 189 require.Error(t, muxer.Register(topic, nil)) 190 } else { 191 require.NoError(t, muxer.Register(topic, nil)) 192 } 193 require.Equal(t, len(discoveries)-erred, len(reg.Get(topic))) 194 } 195 }) 196 } 197 } 198 199 func TestMuxerDiscovery(t *testing.T) { 200 for _, tc := range []struct { 201 desc string 202 errors []error 203 topics []string 204 ids [][]int 205 }{ 206 {"EqualNoErrors", []error{nil, nil}, []string{"a"}, [][]int{{11, 22, 33}, {44, 55, 66}}}, 207 {"MultiTopicsSingleSource", []error{nil, nil}, []string{"a", "b"}, [][]int{{11, 22, 33}, {}}}, 208 {"SingleError", []error{nil, errors.New("test")}, []string{"a"}, [][]int{{11, 22, 33}, {44, 55, 66}}}, 209 {"AllErrors", []error{errors.New("test"), errors.New("test")}, []string{"a"}, [][]int{{11, 22, 33}, {44, 55, 66}}}, 210 } { 211 t.Run(tc.desc, func(t *testing.T) { 212 discoveries := make([]Discovery, len(tc.errors)) 213 erred := false 214 expected := 0 215 for i := range discoveries { 216 if tc.errors[i] == nil { 217 expected += len(tc.ids[i]) 218 } else { 219 erred = true 220 } 221 reg := newRegistry() 222 discoveries[i] = &fake{id: i, err: tc.errors[i], registry: reg} 223 for _, topic := range tc.topics { 224 for _, id := range tc.ids[i] { 225 reg.Add(topic, id) 226 } 227 } 228 } 229 muxer := NewMultiplexer(discoveries) 230 for _, topic := range tc.topics { 231 found := make(chan *discv5.Node, expected) 232 period := make(chan time.Duration) 233 close(period) 234 if erred { 235 // TODO test period channel 236 require.Error(t, muxer.Discover(topic, period, found, nil)) 237 } else { 238 require.NoError(t, muxer.Discover(topic, period, found, nil)) 239 } 240 close(found) 241 count := 0 242 for range found { 243 count++ 244 } 245 require.Equal(t, expected, count) 246 } 247 }) 248 } 249 }