github.com/jonasnick/go-ethereum@v0.7.12-0.20150216215225-22176f05d387/p2p/peer_test.go (about)

     1  package p2p
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io/ioutil"
     7  	"net"
     8  	"reflect"
     9  	"sort"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/jonasnick/go-ethereum/p2p/discover"
    14  	"github.com/jonasnick/go-ethereum/rlp"
    15  )
    16  
    17  var discard = Protocol{
    18  	Name:   "discard",
    19  	Length: 1,
    20  	Run: func(p *Peer, rw MsgReadWriter) error {
    21  		for {
    22  			msg, err := rw.ReadMsg()
    23  			if err != nil {
    24  				return err
    25  			}
    26  			if err = msg.Discard(); err != nil {
    27  				return err
    28  			}
    29  		}
    30  	},
    31  }
    32  
    33  func testPeer(noHandshake bool, protos []Protocol) (*frameRW, *Peer, <-chan DiscReason) {
    34  	conn1, conn2 := net.Pipe()
    35  	peer := newPeer(conn1, protos, "name", &discover.NodeID{}, &discover.NodeID{})
    36  	peer.noHandshake = noHandshake
    37  	errc := make(chan DiscReason, 1)
    38  	go func() { errc <- peer.run() }()
    39  	return newFrameRW(conn2, msgWriteTimeout), peer, errc
    40  }
    41  
    42  func TestPeerProtoReadMsg(t *testing.T) {
    43  	defer testlog(t).detach()
    44  
    45  	done := make(chan struct{})
    46  	proto := Protocol{
    47  		Name:   "a",
    48  		Length: 5,
    49  		Run: func(peer *Peer, rw MsgReadWriter) error {
    50  			if err := expectMsg(rw, 2, []uint{1}); err != nil {
    51  				t.Error(err)
    52  			}
    53  			if err := expectMsg(rw, 3, []uint{2}); err != nil {
    54  				t.Error(err)
    55  			}
    56  			if err := expectMsg(rw, 4, []uint{3}); err != nil {
    57  				t.Error(err)
    58  			}
    59  			close(done)
    60  			return nil
    61  		},
    62  	}
    63  
    64  	rw, peer, errc := testPeer(true, []Protocol{proto})
    65  	defer rw.Close()
    66  	peer.startSubprotocols([]Cap{proto.cap()})
    67  
    68  	EncodeMsg(rw, baseProtocolLength+2, 1)
    69  	EncodeMsg(rw, baseProtocolLength+3, 2)
    70  	EncodeMsg(rw, baseProtocolLength+4, 3)
    71  
    72  	select {
    73  	case <-done:
    74  	case err := <-errc:
    75  		t.Errorf("peer returned: %v", err)
    76  	case <-time.After(2 * time.Second):
    77  		t.Errorf("receive timeout")
    78  	}
    79  }
    80  
    81  func TestPeerProtoReadLargeMsg(t *testing.T) {
    82  	defer testlog(t).detach()
    83  
    84  	msgsize := uint32(10 * 1024 * 1024)
    85  	done := make(chan struct{})
    86  	proto := Protocol{
    87  		Name:   "a",
    88  		Length: 5,
    89  		Run: func(peer *Peer, rw MsgReadWriter) error {
    90  			msg, err := rw.ReadMsg()
    91  			if err != nil {
    92  				t.Errorf("read error: %v", err)
    93  			}
    94  			if msg.Size != msgsize+4 {
    95  				t.Errorf("incorrect msg.Size, got %d, expected %d", msg.Size, msgsize)
    96  			}
    97  			msg.Discard()
    98  			close(done)
    99  			return nil
   100  		},
   101  	}
   102  
   103  	rw, peer, errc := testPeer(true, []Protocol{proto})
   104  	defer rw.Close()
   105  	peer.startSubprotocols([]Cap{proto.cap()})
   106  
   107  	EncodeMsg(rw, 18, make([]byte, msgsize))
   108  	select {
   109  	case <-done:
   110  	case err := <-errc:
   111  		t.Errorf("peer returned: %v", err)
   112  	case <-time.After(2 * time.Second):
   113  		t.Errorf("receive timeout")
   114  	}
   115  }
   116  
   117  func TestPeerProtoEncodeMsg(t *testing.T) {
   118  	defer testlog(t).detach()
   119  
   120  	proto := Protocol{
   121  		Name:   "a",
   122  		Length: 2,
   123  		Run: func(peer *Peer, rw MsgReadWriter) error {
   124  			if err := EncodeMsg(rw, 2); err == nil {
   125  				t.Error("expected error for out-of-range msg code, got nil")
   126  			}
   127  			if err := EncodeMsg(rw, 1, "foo", "bar"); err != nil {
   128  				t.Errorf("write error: %v", err)
   129  			}
   130  			return nil
   131  		},
   132  	}
   133  	rw, peer, _ := testPeer(true, []Protocol{proto})
   134  	defer rw.Close()
   135  	peer.startSubprotocols([]Cap{proto.cap()})
   136  
   137  	if err := expectMsg(rw, 17, []string{"foo", "bar"}); err != nil {
   138  		t.Error(err)
   139  	}
   140  }
   141  
   142  func TestPeerWriteForBroadcast(t *testing.T) {
   143  	defer testlog(t).detach()
   144  
   145  	rw, peer, peerErr := testPeer(true, []Protocol{discard})
   146  	defer rw.Close()
   147  	peer.startSubprotocols([]Cap{discard.cap()})
   148  
   149  	// test write errors
   150  	if err := peer.writeProtoMsg("b", NewMsg(3)); err == nil {
   151  		t.Errorf("expected error for unknown protocol, got nil")
   152  	}
   153  	if err := peer.writeProtoMsg("discard", NewMsg(8)); err == nil {
   154  		t.Errorf("expected error for out-of-range msg code, got nil")
   155  	} else if perr, ok := err.(*peerError); !ok || perr.Code != errInvalidMsgCode {
   156  		t.Errorf("wrong error for out-of-range msg code, got %#v", err)
   157  	}
   158  
   159  	// setup for reading the message on the other end
   160  	read := make(chan struct{})
   161  	go func() {
   162  		if err := expectMsg(rw, 16, nil); err != nil {
   163  			t.Error()
   164  		}
   165  		close(read)
   166  	}()
   167  
   168  	// test successful write
   169  	if err := peer.writeProtoMsg("discard", NewMsg(0)); err != nil {
   170  		t.Errorf("expect no error for known protocol: %v", err)
   171  	}
   172  	select {
   173  	case <-read:
   174  	case err := <-peerErr:
   175  		t.Fatalf("peer stopped: %v", err)
   176  	}
   177  }
   178  
   179  func TestPeerPing(t *testing.T) {
   180  	defer testlog(t).detach()
   181  
   182  	rw, _, _ := testPeer(true, nil)
   183  	defer rw.Close()
   184  	if err := EncodeMsg(rw, pingMsg); err != nil {
   185  		t.Fatal(err)
   186  	}
   187  	if err := expectMsg(rw, pongMsg, nil); err != nil {
   188  		t.Error(err)
   189  	}
   190  }
   191  
   192  func TestPeerDisconnect(t *testing.T) {
   193  	defer testlog(t).detach()
   194  
   195  	rw, _, disc := testPeer(true, nil)
   196  	defer rw.Close()
   197  	if err := EncodeMsg(rw, discMsg, DiscQuitting); err != nil {
   198  		t.Fatal(err)
   199  	}
   200  	if err := expectMsg(rw, discMsg, []interface{}{DiscRequested}); err != nil {
   201  		t.Error(err)
   202  	}
   203  	rw.Close() // make test end faster
   204  	if reason := <-disc; reason != DiscRequested {
   205  		t.Errorf("run returned wrong reason: got %v, want %v", reason, DiscRequested)
   206  	}
   207  }
   208  
   209  func TestPeerHandshake(t *testing.T) {
   210  	defer testlog(t).detach()
   211  
   212  	// remote has two matching protocols: a and c
   213  	remote := NewPeer(randomID(), "", []Cap{{"a", 1}, {"b", 999}, {"c", 3}})
   214  	remoteID := randomID()
   215  	remote.ourID = &remoteID
   216  	remote.ourName = "remote peer"
   217  
   218  	start := make(chan string)
   219  	stop := make(chan struct{})
   220  	run := func(p *Peer, rw MsgReadWriter) error {
   221  		name := rw.(*proto).name
   222  		if name != "a" && name != "c" {
   223  			t.Errorf("protocol %q should not be started", name)
   224  		} else {
   225  			start <- name
   226  		}
   227  		<-stop
   228  		return nil
   229  	}
   230  	protocols := []Protocol{
   231  		{Name: "a", Version: 1, Length: 1, Run: run},
   232  		{Name: "b", Version: 2, Length: 1, Run: run},
   233  		{Name: "c", Version: 3, Length: 1, Run: run},
   234  		{Name: "d", Version: 4, Length: 1, Run: run},
   235  	}
   236  	rw, p, disc := testPeer(false, protocols)
   237  	p.remoteID = remote.ourID
   238  	defer rw.Close()
   239  
   240  	// run the handshake
   241  	remoteProtocols := []Protocol{protocols[0], protocols[2]}
   242  	if err := writeProtocolHandshake(rw, "remote peer", remoteID, remoteProtocols); err != nil {
   243  		t.Fatalf("handshake write error: %v", err)
   244  	}
   245  	if err := readProtocolHandshake(remote, rw); err != nil {
   246  		t.Fatalf("handshake read error: %v", err)
   247  	}
   248  
   249  	// check that all protocols have been started
   250  	var started []string
   251  	for i := 0; i < 2; i++ {
   252  		select {
   253  		case name := <-start:
   254  			started = append(started, name)
   255  		case <-time.After(100 * time.Millisecond):
   256  		}
   257  	}
   258  	sort.Strings(started)
   259  	if !reflect.DeepEqual(started, []string{"a", "c"}) {
   260  		t.Errorf("wrong protocols started: %v", started)
   261  	}
   262  
   263  	// check that metadata has been set
   264  	if p.ID() != remoteID {
   265  		t.Errorf("peer has wrong node ID: got %v, want %v", p.ID(), remoteID)
   266  	}
   267  	if p.Name() != remote.ourName {
   268  		t.Errorf("peer has wrong node name: got %q, want %q", p.Name(), remote.ourName)
   269  	}
   270  
   271  	close(stop)
   272  	expectMsg(rw, discMsg, nil)
   273  	t.Logf("disc reason: %v", <-disc)
   274  }
   275  
   276  func TestNewPeer(t *testing.T) {
   277  	name := "nodename"
   278  	caps := []Cap{{"foo", 2}, {"bar", 3}}
   279  	id := randomID()
   280  	p := NewPeer(id, name, caps)
   281  	if p.ID() != id {
   282  		t.Errorf("ID mismatch: got %v, expected %v", p.ID(), id)
   283  	}
   284  	if p.Name() != name {
   285  		t.Errorf("Name mismatch: got %v, expected %v", p.Name(), name)
   286  	}
   287  	if !reflect.DeepEqual(p.Caps(), caps) {
   288  		t.Errorf("Caps mismatch: got %v, expected %v", p.Caps(), caps)
   289  	}
   290  
   291  	p.Disconnect(DiscAlreadyConnected) // Should not hang
   292  }
   293  
   294  // expectMsg reads a message from r and verifies that its
   295  // code and encoded RLP content match the provided values.
   296  // If content is nil, the payload is discarded and not verified.
   297  func expectMsg(r MsgReader, code uint64, content interface{}) error {
   298  	msg, err := r.ReadMsg()
   299  	if err != nil {
   300  		return err
   301  	}
   302  	if msg.Code != code {
   303  		return fmt.Errorf("message code mismatch: got %d, expected %d", msg.Code, code)
   304  	}
   305  	if content == nil {
   306  		return msg.Discard()
   307  	} else {
   308  		contentEnc, err := rlp.EncodeToBytes(content)
   309  		if err != nil {
   310  			panic("content encode error: " + err.Error())
   311  		}
   312  		// skip over list header in encoded value. this is temporary.
   313  		contentEncR := bytes.NewReader(contentEnc)
   314  		if k, _, err := rlp.NewStream(contentEncR).Kind(); k != rlp.List || err != nil {
   315  			panic("content must encode as RLP list")
   316  		}
   317  		contentEnc = contentEnc[len(contentEnc)-contentEncR.Len():]
   318  
   319  		actualContent, err := ioutil.ReadAll(msg.Payload)
   320  		if err != nil {
   321  			return err
   322  		}
   323  		if !bytes.Equal(actualContent, contentEnc) {
   324  			return fmt.Errorf("message payload mismatch:\ngot:  %x\nwant: %x", actualContent, contentEnc)
   325  		}
   326  	}
   327  	return nil
   328  }