github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/network/p2p/protocols/protocol_test.go (about) 1 package protocols 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "testing" 8 "time" 9 10 "github.com/neatlab/neatio/network/p2p" 11 "github.com/neatlab/neatio/network/p2p/discover" 12 "github.com/neatlab/neatio/network/p2p/simulations/adapters" 13 p2ptest "github.com/neatlab/neatio/network/p2p/testing" 14 ) 15 16 type hs0 struct { 17 C uint 18 } 19 20 type kill struct { 21 C discover.NodeID 22 } 23 24 type drop struct { 25 } 26 27 type protoHandshake struct { 28 Version uint 29 NetworkID string 30 } 31 32 func checkProtoHandshake(testVersion uint, testNetworkID string) func(interface{}) error { 33 return func(rhs interface{}) error { 34 remote := rhs.(*protoHandshake) 35 if remote.NetworkID != testNetworkID { 36 return fmt.Errorf("%s (!= %s)", remote.NetworkID, testNetworkID) 37 } 38 39 if remote.Version != testVersion { 40 return fmt.Errorf("%d (!= %d)", remote.Version, testVersion) 41 } 42 return nil 43 } 44 } 45 46 func newProtocol(pp *p2ptest.TestPeerPool) func(*p2p.Peer, p2p.MsgReadWriter) error { 47 spec := &Spec{ 48 Name: "test", 49 Version: 42, 50 MaxMsgSize: 10 * 1024, 51 Messages: []interface{}{ 52 protoHandshake{}, 53 hs0{}, 54 kill{}, 55 drop{}, 56 }, 57 } 58 return func(p *p2p.Peer, rw p2p.MsgReadWriter) error { 59 peer := NewPeer(p, rw, spec) 60 61 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 62 defer cancel() 63 phs := &protoHandshake{42, "420"} 64 hsCheck := checkProtoHandshake(phs.Version, phs.NetworkID) 65 _, err := peer.Handshake(ctx, phs, hsCheck) 66 if err != nil { 67 return err 68 } 69 70 lhs := &hs0{42} 71 72 hs, err := peer.Handshake(ctx, lhs, nil) 73 if err != nil { 74 return err 75 } 76 77 if rmhs := hs.(*hs0); rmhs.C > lhs.C { 78 return fmt.Errorf("handshake mismatch remote %v > local %v", rmhs.C, lhs.C) 79 } 80 81 handle := func(msg interface{}) error { 82 switch msg := msg.(type) { 83 84 case *protoHandshake: 85 return errors.New("duplicate handshake") 86 87 case *hs0: 88 rhs := msg 89 if rhs.C > lhs.C { 90 return fmt.Errorf("handshake mismatch remote %v > local %v", rhs.C, lhs.C) 91 } 92 lhs.C += rhs.C 93 return peer.Send(lhs) 94 95 case *kill: 96 97 id := msg.C 98 pp.Get(id).Drop(errors.New("killed")) 99 return nil 100 101 case *drop: 102 103 return errors.New("dropped") 104 105 default: 106 return fmt.Errorf("unknown message type: %T", msg) 107 } 108 } 109 110 pp.Add(peer) 111 defer pp.Remove(peer) 112 return peer.Run(handle) 113 } 114 } 115 116 func protocolTester(t *testing.T, pp *p2ptest.TestPeerPool) *p2ptest.ProtocolTester { 117 conf := adapters.RandomNodeConfig() 118 return p2ptest.NewProtocolTester(t, conf.ID, 2, newProtocol(pp)) 119 } 120 121 func protoHandshakeExchange(id discover.NodeID, proto *protoHandshake) []p2ptest.Exchange { 122 123 return []p2ptest.Exchange{ 124 { 125 Expects: []p2ptest.Expect{ 126 { 127 Code: 0, 128 Msg: &protoHandshake{42, "420"}, 129 Peer: id, 130 }, 131 }, 132 }, 133 { 134 Triggers: []p2ptest.Trigger{ 135 { 136 Code: 0, 137 Msg: proto, 138 Peer: id, 139 }, 140 }, 141 }, 142 } 143 } 144 145 func runProtoHandshake(t *testing.T, proto *protoHandshake, errs ...error) { 146 pp := p2ptest.NewTestPeerPool() 147 s := protocolTester(t, pp) 148 149 id := s.IDs[0] 150 if err := s.TestExchanges(protoHandshakeExchange(id, proto)...); err != nil { 151 t.Fatal(err) 152 } 153 var disconnects []*p2ptest.Disconnect 154 for i, err := range errs { 155 disconnects = append(disconnects, &p2ptest.Disconnect{Peer: s.IDs[i], Error: err}) 156 } 157 if err := s.TestDisconnected(disconnects...); err != nil { 158 t.Fatal(err) 159 } 160 } 161 162 func TestProtoHandshakeVersionMismatch(t *testing.T) { 163 runProtoHandshake(t, &protoHandshake{41, "420"}, errorf(ErrHandshake, errorf(ErrHandler, "(msg code 0): 41 (!= 42)").Error())) 164 } 165 166 func TestProtoHandshakeNetworkIDMismatch(t *testing.T) { 167 runProtoHandshake(t, &protoHandshake{42, "421"}, errorf(ErrHandshake, errorf(ErrHandler, "(msg code 0): 421 (!= 420)").Error())) 168 } 169 170 func TestProtoHandshakeSuccess(t *testing.T) { 171 runProtoHandshake(t, &protoHandshake{42, "420"}) 172 } 173 174 func moduleHandshakeExchange(id discover.NodeID, resp uint) []p2ptest.Exchange { 175 176 return []p2ptest.Exchange{ 177 { 178 Expects: []p2ptest.Expect{ 179 { 180 Code: 1, 181 Msg: &hs0{42}, 182 Peer: id, 183 }, 184 }, 185 }, 186 { 187 Triggers: []p2ptest.Trigger{ 188 { 189 Code: 1, 190 Msg: &hs0{resp}, 191 Peer: id, 192 }, 193 }, 194 }, 195 } 196 } 197 198 func runModuleHandshake(t *testing.T, resp uint, errs ...error) { 199 pp := p2ptest.NewTestPeerPool() 200 s := protocolTester(t, pp) 201 id := s.IDs[0] 202 if err := s.TestExchanges(protoHandshakeExchange(id, &protoHandshake{42, "420"})...); err != nil { 203 t.Fatal(err) 204 } 205 if err := s.TestExchanges(moduleHandshakeExchange(id, resp)...); err != nil { 206 t.Fatal(err) 207 } 208 var disconnects []*p2ptest.Disconnect 209 for i, err := range errs { 210 disconnects = append(disconnects, &p2ptest.Disconnect{Peer: s.IDs[i], Error: err}) 211 } 212 if err := s.TestDisconnected(disconnects...); err != nil { 213 t.Fatal(err) 214 } 215 } 216 217 func TestModuleHandshakeError(t *testing.T) { 218 runModuleHandshake(t, 43, fmt.Errorf("handshake mismatch remote 43 > local 42")) 219 } 220 221 func TestModuleHandshakeSuccess(t *testing.T) { 222 runModuleHandshake(t, 42) 223 } 224 225 func testMultiPeerSetup(a, b discover.NodeID) []p2ptest.Exchange { 226 227 return []p2ptest.Exchange{ 228 { 229 Label: "primary handshake", 230 Expects: []p2ptest.Expect{ 231 { 232 Code: 0, 233 Msg: &protoHandshake{42, "420"}, 234 Peer: a, 235 }, 236 { 237 Code: 0, 238 Msg: &protoHandshake{42, "420"}, 239 Peer: b, 240 }, 241 }, 242 }, 243 { 244 Label: "module handshake", 245 Triggers: []p2ptest.Trigger{ 246 { 247 Code: 0, 248 Msg: &protoHandshake{42, "420"}, 249 Peer: a, 250 }, 251 { 252 Code: 0, 253 Msg: &protoHandshake{42, "420"}, 254 Peer: b, 255 }, 256 }, 257 Expects: []p2ptest.Expect{ 258 { 259 Code: 1, 260 Msg: &hs0{42}, 261 Peer: a, 262 }, 263 { 264 Code: 1, 265 Msg: &hs0{42}, 266 Peer: b, 267 }, 268 }, 269 }, 270 271 {Label: "alternative module handshake", Triggers: []p2ptest.Trigger{{Code: 1, Msg: &hs0{41}, Peer: a}, 272 {Code: 1, Msg: &hs0{41}, Peer: b}}}, 273 {Label: "repeated module handshake", Triggers: []p2ptest.Trigger{{Code: 1, Msg: &hs0{1}, Peer: a}}}, 274 {Label: "receiving repeated module handshake", Expects: []p2ptest.Expect{{Code: 1, Msg: &hs0{43}, Peer: a}}}} 275 } 276 277 func runMultiplePeers(t *testing.T, peer int, errs ...error) { 278 pp := p2ptest.NewTestPeerPool() 279 s := protocolTester(t, pp) 280 281 if err := s.TestExchanges(testMultiPeerSetup(s.IDs[0], s.IDs[1])...); err != nil { 282 t.Fatal(err) 283 } 284 285 tick := time.NewTicker(10 * time.Millisecond) 286 timeout := time.NewTimer(1 * time.Second) 287 WAIT: 288 for { 289 select { 290 case <-tick.C: 291 if pp.Has(s.IDs[0]) { 292 break WAIT 293 } 294 case <-timeout.C: 295 t.Fatal("timeout") 296 } 297 } 298 if !pp.Has(s.IDs[1]) { 299 t.Fatalf("missing peer test-1: %v (%v)", pp, s.IDs) 300 } 301 302 err := s.TestExchanges(p2ptest.Exchange{ 303 Triggers: []p2ptest.Trigger{ 304 { 305 Code: 2, 306 Msg: &kill{s.IDs[peer]}, 307 Peer: s.IDs[0], 308 }, 309 }, 310 }) 311 312 if err != nil { 313 t.Fatal(err) 314 } 315 316 err = s.TestExchanges(p2ptest.Exchange{ 317 Triggers: []p2ptest.Trigger{ 318 { 319 Code: 3, 320 Msg: &drop{}, 321 Peer: s.IDs[(peer+1)%2], 322 }, 323 }, 324 }) 325 326 if err != nil { 327 t.Fatal(err) 328 } 329 330 var disconnects []*p2ptest.Disconnect 331 for i, err := range errs { 332 disconnects = append(disconnects, &p2ptest.Disconnect{Peer: s.IDs[i], Error: err}) 333 } 334 if err := s.TestDisconnected(disconnects...); err != nil { 335 t.Fatal(err) 336 } 337 338 if pp.Has(s.IDs[peer]) { 339 t.Fatalf("peer test-%v not dropped: %v (%v)", peer, pp, s.IDs) 340 } 341 342 } 343 344 func TestMultiplePeersDropSelf(t *testing.T) { 345 runMultiplePeers(t, 0, 346 fmt.Errorf("subprotocol error"), 347 fmt.Errorf("Message handler error: (msg code 3): dropped"), 348 ) 349 } 350 351 func TestMultiplePeersDropOther(t *testing.T) { 352 runMultiplePeers(t, 1, 353 fmt.Errorf("Message handler error: (msg code 3): dropped"), 354 fmt.Errorf("subprotocol error"), 355 ) 356 }