github.com/jonasnick/go-ethereum@v0.7.12-0.20150216215225-22176f05d387/p2p/peer_test.go (about) 1 package p2p 2 3 import ( 4 "bytes" 5 "fmt" 6 "io/ioutil" 7 "net" 8 "reflect" 9 "sort" 10 "testing" 11 "time" 12 13 "github.com/jonasnick/go-ethereum/p2p/discover" 14 "github.com/jonasnick/go-ethereum/rlp" 15 ) 16 17 var discard = Protocol{ 18 Name: "discard", 19 Length: 1, 20 Run: func(p *Peer, rw MsgReadWriter) error { 21 for { 22 msg, err := rw.ReadMsg() 23 if err != nil { 24 return err 25 } 26 if err = msg.Discard(); err != nil { 27 return err 28 } 29 } 30 }, 31 } 32 33 func testPeer(noHandshake bool, protos []Protocol) (*frameRW, *Peer, <-chan DiscReason) { 34 conn1, conn2 := net.Pipe() 35 peer := newPeer(conn1, protos, "name", &discover.NodeID{}, &discover.NodeID{}) 36 peer.noHandshake = noHandshake 37 errc := make(chan DiscReason, 1) 38 go func() { errc <- peer.run() }() 39 return newFrameRW(conn2, msgWriteTimeout), peer, errc 40 } 41 42 func TestPeerProtoReadMsg(t *testing.T) { 43 defer testlog(t).detach() 44 45 done := make(chan struct{}) 46 proto := Protocol{ 47 Name: "a", 48 Length: 5, 49 Run: func(peer *Peer, rw MsgReadWriter) error { 50 if err := expectMsg(rw, 2, []uint{1}); err != nil { 51 t.Error(err) 52 } 53 if err := expectMsg(rw, 3, []uint{2}); err != nil { 54 t.Error(err) 55 } 56 if err := expectMsg(rw, 4, []uint{3}); err != nil { 57 t.Error(err) 58 } 59 close(done) 60 return nil 61 }, 62 } 63 64 rw, peer, errc := testPeer(true, []Protocol{proto}) 65 defer rw.Close() 66 peer.startSubprotocols([]Cap{proto.cap()}) 67 68 EncodeMsg(rw, baseProtocolLength+2, 1) 69 EncodeMsg(rw, baseProtocolLength+3, 2) 70 EncodeMsg(rw, baseProtocolLength+4, 3) 71 72 select { 73 case <-done: 74 case err := <-errc: 75 t.Errorf("peer returned: %v", err) 76 case <-time.After(2 * time.Second): 77 t.Errorf("receive timeout") 78 } 79 } 80 81 func TestPeerProtoReadLargeMsg(t *testing.T) { 82 defer testlog(t).detach() 83 84 msgsize := uint32(10 * 1024 * 1024) 85 done := make(chan struct{}) 86 proto := Protocol{ 87 Name: "a", 88 Length: 5, 89 Run: func(peer *Peer, rw MsgReadWriter) error { 90 msg, err := rw.ReadMsg() 91 if err != nil { 92 t.Errorf("read error: %v", err) 93 } 94 if msg.Size != msgsize+4 { 95 t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize) 96 } 97 msg.Discard() 98 close(done) 99 return nil 100 }, 101 } 102 103 rw, peer, errc := testPeer(true, []Protocol{proto}) 104 defer rw.Close() 105 peer.startSubprotocols([]Cap{proto.cap()}) 106 107 EncodeMsg(rw, 18, make([]byte, msgsize)) 108 select { 109 case <-done: 110 case err := <-errc: 111 t.Errorf("peer returned: %v", err) 112 case <-time.After(2 * time.Second): 113 t.Errorf("receive timeout") 114 } 115 } 116 117 func TestPeerProtoEncodeMsg(t *testing.T) { 118 defer testlog(t).detach() 119 120 proto := Protocol{ 121 Name: "a", 122 Length: 2, 123 Run: func(peer *Peer, rw MsgReadWriter) error { 124 if err := EncodeMsg(rw, 2); err == nil { 125 t.Error("expected error for out-of-range msg code, got nil") 126 } 127 if err := EncodeMsg(rw, 1, "foo", "bar"); err != nil { 128 t.Errorf("write error: %v", err) 129 } 130 return nil 131 }, 132 } 133 rw, peer, _ := testPeer(true, []Protocol{proto}) 134 defer rw.Close() 135 peer.startSubprotocols([]Cap{proto.cap()}) 136 137 if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil { 138 t.Error(err) 139 } 140 } 141 142 func TestPeerWriteForBroadcast(t *testing.T) { 143 defer testlog(t).detach() 144 145 rw, peer, peerErr := testPeer(true, []Protocol{discard}) 146 defer rw.Close() 147 peer.startSubprotocols([]Cap{discard.cap()}) 148 149 // test write errors 150 if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil { 151 t.Errorf("expected error for unknown protocol, got nil") 152 } 153 if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil { 154 t.Errorf("expected error for out-of-range msg code, got nil") 155 } else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode { 156 t.Errorf("wrong error for out-of-range msg code, got %#v", err) 157 } 158 159 // setup for reading the message on the other end 160 read := make(chan struct{}) 161 go func() { 162 if err := expectMsg(rw, 16, nil); err != nil { 163 t.Error() 164 } 165 close(read) 166 }() 167 168 // test successful write 169 if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil { 170 t.Errorf("expect no error for known protocol: %v", err) 171 } 172 select { 173 case <-read: 174 case err := <-peerErr: 175 t.Fatalf("peer stopped: %v", err) 176 } 177 } 178 179 func TestPeerPing(t *testing.T) { 180 defer testlog(t).detach() 181 182 rw, _, _ := testPeer(true, nil) 183 defer rw.Close() 184 if err := EncodeMsg(rw, pingMsg); err != nil { 185 t.Fatal(err) 186 } 187 if err := expectMsg(rw, pongMsg, nil); err != nil { 188 t.Error(err) 189 } 190 } 191 192 func TestPeerDisconnect(t *testing.T) { 193 defer testlog(t).detach() 194 195 rw, _, disc := testPeer(true, nil) 196 defer rw.Close() 197 if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil { 198 t.Fatal(err) 199 } 200 if err := expectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil { 201 t.Error(err) 202 } 203 rw.Close() // make test end faster 204 if reason := <-disc; reason != DiscRequested { 205 t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested) 206 } 207 } 208 209 func TestPeerHandshake(t *testing.T) { 210 defer testlog(t).detach() 211 212 // remote has two matching protocols: a and c 213 remote := NewPeer(randomID(), "", []Cap{{"a", 1}, {"b", 999}, {"c", 3}}) 214 remoteID := randomID() 215 remote.ourID = &remoteID 216 remote.ourName = "remote peer" 217 218 start := make(chan string) 219 stop := make(chan struct{}) 220 run := func(p *Peer, rw MsgReadWriter) error { 221 name := rw.(*proto).name 222 if name != "a" && name != "c" { 223 t.Errorf("protocol %q should not be started", name) 224 } else { 225 start <- name 226 } 227 <-stop 228 return nil 229 } 230 protocols := []Protocol{ 231 {Name: "a", Version: 1, Length: 1, Run: run}, 232 {Name: "b", Version: 2, Length: 1, Run: run}, 233 {Name: "c", Version: 3, Length: 1, Run: run}, 234 {Name: "d", Version: 4, Length: 1, Run: run}, 235 } 236 rw, p, disc := testPeer(false, protocols) 237 p.remoteID = remote.ourID 238 defer rw.Close() 239 240 // run the handshake 241 remoteProtocols := []Protocol{protocols[0], protocols[2]} 242 if err := writeProtocolHandshake(rw, "remote peer", remoteID, remoteProtocols); err != nil { 243 t.Fatalf("handshake write error: %v", err) 244 } 245 if err := readProtocolHandshake(remote, rw); err != nil { 246 t.Fatalf("handshake read error: %v", err) 247 } 248 249 // check that all protocols have been started 250 var started []string 251 for i := 0; i < 2; i++ { 252 select { 253 case name := <-start: 254 started = append(started, name) 255 case <-time.After(100 * time.Millisecond): 256 } 257 } 258 sort.Strings(started) 259 if !reflect.DeepEqual(started, []string{"a", "c"}) { 260 t.Errorf("wrong protocols started: %v", started) 261 } 262 263 // check that metadata has been set 264 if p.ID() != remoteID { 265 t.Errorf("peer has wrong node ID: got %v, want %v", p.ID(), remoteID) 266 } 267 if p.Name() != remote.ourName { 268 t.Errorf("peer has wrong node name: got %q, want %q", p.Name(), remote.ourName) 269 } 270 271 close(stop) 272 expectMsg(rw, discMsg, nil) 273 t.Logf("disc reason: %v", <-disc) 274 } 275 276 func TestNewPeer(t *testing.T) { 277 name := "nodename" 278 caps := []Cap{{"foo", 2}, {"bar", 3}} 279 id := randomID() 280 p := NewPeer(id, name, caps) 281 if p.ID() != id { 282 t.Errorf("ID mismatch: got %v, expected %v", p.ID(), id) 283 } 284 if p.Name() != name { 285 t.Errorf("Name mismatch: got %v, expected %v", p.Name(), name) 286 } 287 if !reflect.DeepEqual(p.Caps(), caps) { 288 t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps) 289 } 290 291 p.Disconnect(DiscAlreadyConnected) // Should not hang 292 } 293 294 // expectMsg reads a message from r and verifies that its 295 // code and encoded RLP content match the provided values. 296 // If content is nil, the payload is discarded and not verified. 297 func expectMsg(r MsgReader, code uint64, content interface{}) error { 298 msg, err := r.ReadMsg() 299 if err != nil { 300 return err 301 } 302 if msg.Code != code { 303 return fmt.Errorf("message code mismatch: got %d, expected %d", msg.Code, code) 304 } 305 if content == nil { 306 return msg.Discard() 307 } else { 308 contentEnc, err := rlp.EncodeToBytes(content) 309 if err != nil { 310 panic("content encode error: " + err.Error()) 311 } 312 // skip over list header in encoded value. this is temporary. 313 contentEncR := bytes.NewReader(contentEnc) 314 if k, _, err := rlp.NewStream(contentEncR).Kind(); k != rlp.List || err != nil { 315 panic("content must encode as RLP list") 316 } 317 contentEnc = contentEnc[len(contentEnc)-contentEncR.Len():] 318 319 actualContent, err := ioutil.ReadAll(msg.Payload) 320 if err != nil { 321 return err 322 } 323 if !bytes.Equal(actualContent, contentEnc) { 324 return fmt.Errorf("message payload mismatch:\ngot: %x\nwant: %x", actualContent, contentEnc) 325 } 326 } 327 return nil 328 }