github.com/shrimpyuk/bor@v0.2.15-0.20220224151350-fb4ec6020bae/cmd/devp2p/internal/ethtest/helpers.go (about)

     1  // Copyright 2020 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package ethtest
    18  
    19  import (
    20  	"fmt"
    21  	"net"
    22  	"reflect"
    23  	"strings"
    24  	"time"
    25  
    26  	"github.com/davecgh/go-spew/spew"
    27  	"github.com/ethereum/go-ethereum/common"
    28  	"github.com/ethereum/go-ethereum/core/types"
    29  	"github.com/ethereum/go-ethereum/crypto"
    30  	"github.com/ethereum/go-ethereum/eth/protocols/eth"
    31  	"github.com/ethereum/go-ethereum/internal/utesting"
    32  	"github.com/ethereum/go-ethereum/p2p"
    33  	"github.com/ethereum/go-ethereum/p2p/rlpx"
    34  )
    35  
    36  var (
    37  	pretty = spew.ConfigState{
    38  		Indent:                  "  ",
    39  		DisableCapacities:       true,
    40  		DisablePointerAddresses: true,
    41  		SortKeys:                true,
    42  	}
    43  	timeout = 20 * time.Second
    44  )
    45  
    46  // Is_66 checks if the node supports the eth66 protocol version,
    47  // and if not, exists the test suite
    48  func (s *Suite) Is_66(t *utesting.T) {
    49  	conn, err := s.dial66()
    50  	if err != nil {
    51  		t.Fatalf("dial failed: %v", err)
    52  	}
    53  	if err := conn.handshake(); err != nil {
    54  		t.Fatalf("handshake failed: %v", err)
    55  	}
    56  	if conn.negotiatedProtoVersion < 66 {
    57  		t.Fail()
    58  	}
    59  }
    60  
    61  // dial attempts to dial the given node and perform a handshake,
    62  // returning the created Conn if successful.
    63  func (s *Suite) dial() (*Conn, error) {
    64  	// dial
    65  	fd, err := net.Dial("tcp", fmt.Sprintf("%v:%d", s.Dest.IP(), s.Dest.TCP()))
    66  	if err != nil {
    67  		return nil, err
    68  	}
    69  	conn := Conn{Conn: rlpx.NewConn(fd, s.Dest.Pubkey())}
    70  	// do encHandshake
    71  	conn.ourKey, _ = crypto.GenerateKey()
    72  	_, err = conn.Handshake(conn.ourKey)
    73  	if err != nil {
    74  		conn.Close()
    75  		return nil, err
    76  	}
    77  	// set default p2p capabilities
    78  	conn.caps = []p2p.Cap{
    79  		{Name: "eth", Version: 64},
    80  		{Name: "eth", Version: 65},
    81  	}
    82  	conn.ourHighestProtoVersion = 65
    83  	return &conn, nil
    84  }
    85  
    86  // dial66 attempts to dial the given node and perform a handshake,
    87  // returning the created Conn with additional eth66 capabilities if
    88  // successful
    89  func (s *Suite) dial66() (*Conn, error) {
    90  	conn, err := s.dial()
    91  	if err != nil {
    92  		return nil, fmt.Errorf("dial failed: %v", err)
    93  	}
    94  	conn.caps = append(conn.caps, p2p.Cap{Name: "eth", Version: 66})
    95  	conn.ourHighestProtoVersion = 66
    96  	return conn, nil
    97  }
    98  
    99  // peer performs both the protocol handshake and the status message
   100  // exchange with the node in order to peer with it.
   101  func (c *Conn) peer(chain *Chain, status *Status) error {
   102  	if err := c.handshake(); err != nil {
   103  		return fmt.Errorf("handshake failed: %v", err)
   104  	}
   105  	if _, err := c.statusExchange(chain, status); err != nil {
   106  		return fmt.Errorf("status exchange failed: %v", err)
   107  	}
   108  	return nil
   109  }
   110  
   111  // handshake performs a protocol handshake with the node.
   112  func (c *Conn) handshake() error {
   113  	defer c.SetDeadline(time.Time{})
   114  	c.SetDeadline(time.Now().Add(10 * time.Second))
   115  	// write hello to client
   116  	pub0 := crypto.FromECDSAPub(&c.ourKey.PublicKey)[1:]
   117  	ourHandshake := &Hello{
   118  		Version: 5,
   119  		Caps:    c.caps,
   120  		ID:      pub0,
   121  	}
   122  	if err := c.Write(ourHandshake); err != nil {
   123  		return fmt.Errorf("write to connection failed: %v", err)
   124  	}
   125  	// read hello from client
   126  	switch msg := c.Read().(type) {
   127  	case *Hello:
   128  		// set snappy if version is at least 5
   129  		if msg.Version >= 5 {
   130  			c.SetSnappy(true)
   131  		}
   132  		c.negotiateEthProtocol(msg.Caps)
   133  		if c.negotiatedProtoVersion == 0 {
   134  			return fmt.Errorf("unexpected eth protocol version")
   135  		}
   136  		return nil
   137  	default:
   138  		return fmt.Errorf("bad handshake: %#v", msg)
   139  	}
   140  }
   141  
   142  // negotiateEthProtocol sets the Conn's eth protocol version to highest
   143  // advertised capability from peer.
   144  func (c *Conn) negotiateEthProtocol(caps []p2p.Cap) {
   145  	var highestEthVersion uint
   146  	for _, capability := range caps {
   147  		if capability.Name != "eth" {
   148  			continue
   149  		}
   150  		if capability.Version > highestEthVersion && capability.Version <= c.ourHighestProtoVersion {
   151  			highestEthVersion = capability.Version
   152  		}
   153  	}
   154  	c.negotiatedProtoVersion = highestEthVersion
   155  }
   156  
   157  // statusExchange performs a `Status` message exchange with the given node.
   158  func (c *Conn) statusExchange(chain *Chain, status *Status) (Message, error) {
   159  	defer c.SetDeadline(time.Time{})
   160  	c.SetDeadline(time.Now().Add(20 * time.Second))
   161  
   162  	// read status message from client
   163  	var message Message
   164  loop:
   165  	for {
   166  		switch msg := c.Read().(type) {
   167  		case *Status:
   168  			if have, want := msg.Head, chain.blocks[chain.Len()-1].Hash(); have != want {
   169  				return nil, fmt.Errorf("wrong head block in status, want:  %#x (block %d) have %#x",
   170  					want, chain.blocks[chain.Len()-1].NumberU64(), have)
   171  			}
   172  			if have, want := msg.TD.Cmp(chain.TD()), 0; have != want {
   173  				return nil, fmt.Errorf("wrong TD in status: have %v want %v", have, want)
   174  			}
   175  			if have, want := msg.ForkID, chain.ForkID(); !reflect.DeepEqual(have, want) {
   176  				return nil, fmt.Errorf("wrong fork ID in status: have %v, want %v", have, want)
   177  			}
   178  			if have, want := msg.ProtocolVersion, c.ourHighestProtoVersion; have != uint32(want) {
   179  				return nil, fmt.Errorf("wrong protocol version: have %v, want %v", have, want)
   180  			}
   181  			message = msg
   182  			break loop
   183  		case *Disconnect:
   184  			return nil, fmt.Errorf("disconnect received: %v", msg.Reason)
   185  		case *Ping:
   186  			c.Write(&Pong{}) // TODO (renaynay): in the future, this should be an error
   187  			// (PINGs should not be a response upon fresh connection)
   188  		default:
   189  			return nil, fmt.Errorf("bad status message: %s", pretty.Sdump(msg))
   190  		}
   191  	}
   192  	// make sure eth protocol version is set for negotiation
   193  	if c.negotiatedProtoVersion == 0 {
   194  		return nil, fmt.Errorf("eth protocol version must be set in Conn")
   195  	}
   196  	if status == nil {
   197  		// default status message
   198  		status = &Status{
   199  			ProtocolVersion: uint32(c.negotiatedProtoVersion),
   200  			NetworkID:       chain.chainConfig.ChainID.Uint64(),
   201  			TD:              chain.TD(),
   202  			Head:            chain.blocks[chain.Len()-1].Hash(),
   203  			Genesis:         chain.blocks[0].Hash(),
   204  			ForkID:          chain.ForkID(),
   205  		}
   206  	}
   207  	if err := c.Write(status); err != nil {
   208  		return nil, fmt.Errorf("write to connection failed: %v", err)
   209  	}
   210  	return message, nil
   211  }
   212  
   213  // createSendAndRecvConns creates two connections, one for sending messages to the
   214  // node, and one for receiving messages from the node.
   215  func (s *Suite) createSendAndRecvConns(isEth66 bool) (*Conn, *Conn, error) {
   216  	var (
   217  		sendConn *Conn
   218  		recvConn *Conn
   219  		err      error
   220  	)
   221  	if isEth66 {
   222  		sendConn, err = s.dial66()
   223  		if err != nil {
   224  			return nil, nil, fmt.Errorf("dial failed: %v", err)
   225  		}
   226  		recvConn, err = s.dial66()
   227  		if err != nil {
   228  			sendConn.Close()
   229  			return nil, nil, fmt.Errorf("dial failed: %v", err)
   230  		}
   231  	} else {
   232  		sendConn, err = s.dial()
   233  		if err != nil {
   234  			return nil, nil, fmt.Errorf("dial failed: %v", err)
   235  		}
   236  		recvConn, err = s.dial()
   237  		if err != nil {
   238  			sendConn.Close()
   239  			return nil, nil, fmt.Errorf("dial failed: %v", err)
   240  		}
   241  	}
   242  	return sendConn, recvConn, nil
   243  }
   244  
   245  func (c *Conn) readAndServe(chain *Chain, timeout time.Duration) Message {
   246  	if c.negotiatedProtoVersion == 66 {
   247  		_, msg := c.readAndServe66(chain, timeout)
   248  		return msg
   249  	}
   250  	return c.readAndServe65(chain, timeout)
   251  }
   252  
   253  // readAndServe serves GetBlockHeaders requests while waiting
   254  // on another message from the node.
   255  func (c *Conn) readAndServe65(chain *Chain, timeout time.Duration) Message {
   256  	start := time.Now()
   257  	for time.Since(start) < timeout {
   258  		c.SetReadDeadline(time.Now().Add(5 * time.Second))
   259  		switch msg := c.Read().(type) {
   260  		case *Ping:
   261  			c.Write(&Pong{})
   262  		case *GetBlockHeaders:
   263  			req := *msg
   264  			headers, err := chain.GetHeaders(req)
   265  			if err != nil {
   266  				return errorf("could not get headers for inbound header request: %v", err)
   267  			}
   268  			if err := c.Write(headers); err != nil {
   269  				return errorf("could not write to connection: %v", err)
   270  			}
   271  		default:
   272  			return msg
   273  		}
   274  	}
   275  	return errorf("no message received within %v", timeout)
   276  }
   277  
   278  // readAndServe66 serves eth66 GetBlockHeaders requests while waiting
   279  // on another message from the node.
   280  func (c *Conn) readAndServe66(chain *Chain, timeout time.Duration) (uint64, Message) {
   281  	start := time.Now()
   282  	for time.Since(start) < timeout {
   283  		c.SetReadDeadline(time.Now().Add(10 * time.Second))
   284  
   285  		reqID, msg := c.Read66()
   286  
   287  		switch msg := msg.(type) {
   288  		case *Ping:
   289  			c.Write(&Pong{})
   290  		case GetBlockHeaders:
   291  			headers, err := chain.GetHeaders(msg)
   292  			if err != nil {
   293  				return 0, errorf("could not get headers for inbound header request: %v", err)
   294  			}
   295  			resp := &eth.BlockHeadersPacket66{
   296  				RequestId:          reqID,
   297  				BlockHeadersPacket: eth.BlockHeadersPacket(headers),
   298  			}
   299  			if err := c.Write66(resp, BlockHeaders{}.Code()); err != nil {
   300  				return 0, errorf("could not write to connection: %v", err)
   301  			}
   302  		default:
   303  			return reqID, msg
   304  		}
   305  	}
   306  	return 0, errorf("no message received within %v", timeout)
   307  }
   308  
   309  // headersRequest executes the given `GetBlockHeaders` request.
   310  func (c *Conn) headersRequest(request *GetBlockHeaders, chain *Chain, isEth66 bool, reqID uint64) (BlockHeaders, error) {
   311  	defer c.SetReadDeadline(time.Time{})
   312  	c.SetReadDeadline(time.Now().Add(20 * time.Second))
   313  	// if on eth66 connection, perform eth66 GetBlockHeaders request
   314  	if isEth66 {
   315  		return getBlockHeaders66(chain, c, request, reqID)
   316  	}
   317  	if err := c.Write(request); err != nil {
   318  		return nil, err
   319  	}
   320  	switch msg := c.readAndServe(chain, timeout).(type) {
   321  	case *BlockHeaders:
   322  		return *msg, nil
   323  	default:
   324  		return nil, fmt.Errorf("invalid message: %s", pretty.Sdump(msg))
   325  	}
   326  }
   327  
   328  // getBlockHeaders66 executes the given `GetBlockHeaders` request over the eth66 protocol.
   329  func getBlockHeaders66(chain *Chain, conn *Conn, request *GetBlockHeaders, id uint64) (BlockHeaders, error) {
   330  	// write request
   331  	packet := eth.GetBlockHeadersPacket(*request)
   332  	req := &eth.GetBlockHeadersPacket66{
   333  		RequestId:             id,
   334  		GetBlockHeadersPacket: &packet,
   335  	}
   336  	if err := conn.Write66(req, GetBlockHeaders{}.Code()); err != nil {
   337  		return nil, fmt.Errorf("could not write to connection: %v", err)
   338  	}
   339  	// wait for response
   340  	msg := conn.waitForResponse(chain, timeout, req.RequestId)
   341  	headers, ok := msg.(BlockHeaders)
   342  	if !ok {
   343  		return nil, fmt.Errorf("unexpected message received: %s", pretty.Sdump(msg))
   344  	}
   345  	return headers, nil
   346  }
   347  
   348  // headersMatch returns whether the received headers match the given request
   349  func headersMatch(expected BlockHeaders, headers BlockHeaders) bool {
   350  	return reflect.DeepEqual(expected, headers)
   351  }
   352  
   353  // waitForResponse reads from the connection until a response with the expected
   354  // request ID is received.
   355  func (c *Conn) waitForResponse(chain *Chain, timeout time.Duration, requestID uint64) Message {
   356  	for {
   357  		id, msg := c.readAndServe66(chain, timeout)
   358  		if id == requestID {
   359  			return msg
   360  		}
   361  	}
   362  }
   363  
   364  // sendNextBlock broadcasts the next block in the chain and waits
   365  // for the node to propagate the block and import it into its chain.
   366  func (s *Suite) sendNextBlock(isEth66 bool) error {
   367  	// set up sending and receiving connections
   368  	sendConn, recvConn, err := s.createSendAndRecvConns(isEth66)
   369  	if err != nil {
   370  		return err
   371  	}
   372  	defer sendConn.Close()
   373  	defer recvConn.Close()
   374  	if err = sendConn.peer(s.chain, nil); err != nil {
   375  		return fmt.Errorf("peering failed: %v", err)
   376  	}
   377  	if err = recvConn.peer(s.chain, nil); err != nil {
   378  		return fmt.Errorf("peering failed: %v", err)
   379  	}
   380  	// create new block announcement
   381  	nextBlock := s.fullChain.blocks[s.chain.Len()]
   382  	blockAnnouncement := &NewBlock{
   383  		Block: nextBlock,
   384  		TD:    s.fullChain.TotalDifficultyAt(s.chain.Len()),
   385  	}
   386  	// send announcement and wait for node to request the header
   387  	if err = s.testAnnounce(sendConn, recvConn, blockAnnouncement); err != nil {
   388  		return fmt.Errorf("failed to announce block: %v", err)
   389  	}
   390  	// wait for client to update its chain
   391  	if err = s.waitForBlockImport(recvConn, nextBlock, isEth66); err != nil {
   392  		return fmt.Errorf("failed to receive confirmation of block import: %v", err)
   393  	}
   394  	// update test suite chain
   395  	s.chain.blocks = append(s.chain.blocks, nextBlock)
   396  	return nil
   397  }
   398  
   399  // testAnnounce writes a block announcement to the node and waits for the node
   400  // to propagate it.
   401  func (s *Suite) testAnnounce(sendConn, receiveConn *Conn, blockAnnouncement *NewBlock) error {
   402  	if err := sendConn.Write(blockAnnouncement); err != nil {
   403  		return fmt.Errorf("could not write to connection: %v", err)
   404  	}
   405  	return s.waitAnnounce(receiveConn, blockAnnouncement)
   406  }
   407  
   408  // waitAnnounce waits for a NewBlock or NewBlockHashes announcement from the node.
   409  func (s *Suite) waitAnnounce(conn *Conn, blockAnnouncement *NewBlock) error {
   410  	for {
   411  		switch msg := conn.readAndServe(s.chain, timeout).(type) {
   412  		case *NewBlock:
   413  			if !reflect.DeepEqual(blockAnnouncement.Block.Header(), msg.Block.Header()) {
   414  				return fmt.Errorf("wrong header in block announcement: \nexpected %v "+
   415  					"\ngot %v", blockAnnouncement.Block.Header(), msg.Block.Header())
   416  			}
   417  			if !reflect.DeepEqual(blockAnnouncement.TD, msg.TD) {
   418  				return fmt.Errorf("wrong TD in announcement: expected %v, got %v", blockAnnouncement.TD, msg.TD)
   419  			}
   420  			return nil
   421  		case *NewBlockHashes:
   422  			hashes := *msg
   423  			if blockAnnouncement.Block.Hash() != hashes[0].Hash {
   424  				return fmt.Errorf("wrong block hash in announcement: expected %v, got %v", blockAnnouncement.Block.Hash(), hashes[0].Hash)
   425  			}
   426  			return nil
   427  		case *NewPooledTransactionHashes:
   428  			// ignore tx announcements from previous tests
   429  			continue
   430  		default:
   431  			return fmt.Errorf("unexpected: %s", pretty.Sdump(msg))
   432  		}
   433  	}
   434  }
   435  
   436  func (s *Suite) waitForBlockImport(conn *Conn, block *types.Block, isEth66 bool) error {
   437  	defer conn.SetReadDeadline(time.Time{})
   438  	conn.SetReadDeadline(time.Now().Add(20 * time.Second))
   439  	// create request
   440  	req := &GetBlockHeaders{
   441  		Origin: eth.HashOrNumber{
   442  			Hash: block.Hash(),
   443  		},
   444  		Amount: 1,
   445  	}
   446  	// loop until BlockHeaders response contains desired block, confirming the
   447  	// node imported the block
   448  	for {
   449  		var (
   450  			headers BlockHeaders
   451  			err     error
   452  		)
   453  		if isEth66 {
   454  			requestID := uint64(54)
   455  			headers, err = conn.headersRequest(req, s.chain, eth66, requestID)
   456  		} else {
   457  			headers, err = conn.headersRequest(req, s.chain, eth65, 0)
   458  		}
   459  		if err != nil {
   460  			return fmt.Errorf("GetBlockHeader request failed: %v", err)
   461  		}
   462  		// if headers response is empty, node hasn't imported block yet, try again
   463  		if len(headers) == 0 {
   464  			time.Sleep(100 * time.Millisecond)
   465  			continue
   466  		}
   467  		if !reflect.DeepEqual(block.Header(), headers[0]) {
   468  			return fmt.Errorf("wrong header returned: wanted %v, got %v", block.Header(), headers[0])
   469  		}
   470  		return nil
   471  	}
   472  }
   473  
   474  func (s *Suite) oldAnnounce(isEth66 bool) error {
   475  	sendConn, receiveConn, err := s.createSendAndRecvConns(isEth66)
   476  	if err != nil {
   477  		return err
   478  	}
   479  	defer sendConn.Close()
   480  	defer receiveConn.Close()
   481  	if err := sendConn.peer(s.chain, nil); err != nil {
   482  		return fmt.Errorf("peering failed: %v", err)
   483  	}
   484  	if err := receiveConn.peer(s.chain, nil); err != nil {
   485  		return fmt.Errorf("peering failed: %v", err)
   486  	}
   487  	// create old block announcement
   488  	oldBlockAnnounce := &NewBlock{
   489  		Block: s.chain.blocks[len(s.chain.blocks)/2],
   490  		TD:    s.chain.blocks[len(s.chain.blocks)/2].Difficulty(),
   491  	}
   492  	if err := sendConn.Write(oldBlockAnnounce); err != nil {
   493  		return fmt.Errorf("could not write to connection: %v", err)
   494  	}
   495  	// wait to see if the announcement is propagated
   496  	switch msg := receiveConn.readAndServe(s.chain, time.Second*8).(type) {
   497  	case *NewBlock:
   498  		block := *msg
   499  		if block.Block.Hash() == oldBlockAnnounce.Block.Hash() {
   500  			return fmt.Errorf("unexpected: block propagated: %s", pretty.Sdump(msg))
   501  		}
   502  	case *NewBlockHashes:
   503  		hashes := *msg
   504  		for _, hash := range hashes {
   505  			if hash.Hash == oldBlockAnnounce.Block.Hash() {
   506  				return fmt.Errorf("unexpected: block announced: %s", pretty.Sdump(msg))
   507  			}
   508  		}
   509  	case *Error:
   510  		errMsg := *msg
   511  		// check to make sure error is timeout (propagation didn't come through == test successful)
   512  		if !strings.Contains(errMsg.String(), "timeout") {
   513  			return fmt.Errorf("unexpected error: %v", pretty.Sdump(msg))
   514  		}
   515  	default:
   516  		return fmt.Errorf("unexpected: %s", pretty.Sdump(msg))
   517  	}
   518  	return nil
   519  }
   520  
   521  func (s *Suite) maliciousHandshakes(t *utesting.T, isEth66 bool) error {
   522  	var (
   523  		conn *Conn
   524  		err  error
   525  	)
   526  	if isEth66 {
   527  		conn, err = s.dial66()
   528  		if err != nil {
   529  			return fmt.Errorf("dial failed: %v", err)
   530  		}
   531  	} else {
   532  		conn, err = s.dial()
   533  		if err != nil {
   534  			return fmt.Errorf("dial failed: %v", err)
   535  		}
   536  	}
   537  	defer conn.Close()
   538  	// write hello to client
   539  	pub0 := crypto.FromECDSAPub(&conn.ourKey.PublicKey)[1:]
   540  	handshakes := []*Hello{
   541  		{
   542  			Version: 5,
   543  			Caps: []p2p.Cap{
   544  				{Name: largeString(2), Version: 64},
   545  			},
   546  			ID: pub0,
   547  		},
   548  		{
   549  			Version: 5,
   550  			Caps: []p2p.Cap{
   551  				{Name: "eth", Version: 64},
   552  				{Name: "eth", Version: 65},
   553  			},
   554  			ID: append(pub0, byte(0)),
   555  		},
   556  		{
   557  			Version: 5,
   558  			Caps: []p2p.Cap{
   559  				{Name: "eth", Version: 64},
   560  				{Name: "eth", Version: 65},
   561  			},
   562  			ID: append(pub0, pub0...),
   563  		},
   564  		{
   565  			Version: 5,
   566  			Caps: []p2p.Cap{
   567  				{Name: "eth", Version: 64},
   568  				{Name: "eth", Version: 65},
   569  			},
   570  			ID: largeBuffer(2),
   571  		},
   572  		{
   573  			Version: 5,
   574  			Caps: []p2p.Cap{
   575  				{Name: largeString(2), Version: 64},
   576  			},
   577  			ID: largeBuffer(2),
   578  		},
   579  	}
   580  	for i, handshake := range handshakes {
   581  		t.Logf("Testing malicious handshake %v\n", i)
   582  		if err := conn.Write(handshake); err != nil {
   583  			return fmt.Errorf("could not write to connection: %v", err)
   584  		}
   585  		// check that the peer disconnected
   586  		for i := 0; i < 2; i++ {
   587  			switch msg := conn.readAndServe(s.chain, 20*time.Second).(type) {
   588  			case *Disconnect:
   589  			case *Error:
   590  			case *Hello:
   591  				// Discard one hello as Hello's are sent concurrently
   592  				continue
   593  			default:
   594  				return fmt.Errorf("unexpected: %s", pretty.Sdump(msg))
   595  			}
   596  		}
   597  		// dial for the next round
   598  		if isEth66 {
   599  			conn, err = s.dial66()
   600  			if err != nil {
   601  				return fmt.Errorf("dial failed: %v", err)
   602  			}
   603  		} else {
   604  			conn, err = s.dial()
   605  			if err != nil {
   606  				return fmt.Errorf("dial failed: %v", err)
   607  			}
   608  		}
   609  	}
   610  	return nil
   611  }
   612  
   613  func (s *Suite) maliciousStatus(conn *Conn) error {
   614  	if err := conn.handshake(); err != nil {
   615  		return fmt.Errorf("handshake failed: %v", err)
   616  	}
   617  	status := &Status{
   618  		ProtocolVersion: uint32(conn.negotiatedProtoVersion),
   619  		NetworkID:       s.chain.chainConfig.ChainID.Uint64(),
   620  		TD:              largeNumber(2),
   621  		Head:            s.chain.blocks[s.chain.Len()-1].Hash(),
   622  		Genesis:         s.chain.blocks[0].Hash(),
   623  		ForkID:          s.chain.ForkID(),
   624  	}
   625  	// get status
   626  	msg, err := conn.statusExchange(s.chain, status)
   627  	if err != nil {
   628  		return fmt.Errorf("status exchange failed: %v", err)
   629  	}
   630  	switch msg := msg.(type) {
   631  	case *Status:
   632  	default:
   633  		return fmt.Errorf("expected status, got: %#v ", msg)
   634  	}
   635  	// wait for disconnect
   636  	switch msg := conn.readAndServe(s.chain, timeout).(type) {
   637  	case *Disconnect:
   638  		return nil
   639  	case *Error:
   640  		return nil
   641  	default:
   642  		return fmt.Errorf("expected disconnect, got: %s", pretty.Sdump(msg))
   643  	}
   644  }
   645  
   646  func (s *Suite) hashAnnounce(isEth66 bool) error {
   647  	// create connections
   648  	sendConn, recvConn, err := s.createSendAndRecvConns(isEth66)
   649  	if err != nil {
   650  		return fmt.Errorf("failed to create connections: %v", err)
   651  	}
   652  	defer sendConn.Close()
   653  	defer recvConn.Close()
   654  	if err := sendConn.peer(s.chain, nil); err != nil {
   655  		return fmt.Errorf("peering failed: %v", err)
   656  	}
   657  	if err := recvConn.peer(s.chain, nil); err != nil {
   658  		return fmt.Errorf("peering failed: %v", err)
   659  	}
   660  	// create NewBlockHashes announcement
   661  	type anno struct {
   662  		Hash   common.Hash // Hash of one particular block being announced
   663  		Number uint64      // Number of one particular block being announced
   664  	}
   665  	nextBlock := s.fullChain.blocks[s.chain.Len()]
   666  	announcement := anno{Hash: nextBlock.Hash(), Number: nextBlock.Number().Uint64()}
   667  	newBlockHash := &NewBlockHashes{announcement}
   668  	if err := sendConn.Write(newBlockHash); err != nil {
   669  		return fmt.Errorf("failed to write to connection: %v", err)
   670  	}
   671  	// Announcement sent, now wait for a header request
   672  	var (
   673  		id             uint64
   674  		msg            Message
   675  		blockHeaderReq GetBlockHeaders
   676  	)
   677  	if isEth66 {
   678  		id, msg = sendConn.Read66()
   679  		switch msg := msg.(type) {
   680  		case GetBlockHeaders:
   681  			blockHeaderReq = msg
   682  		default:
   683  			return fmt.Errorf("unexpected %s", pretty.Sdump(msg))
   684  		}
   685  		if blockHeaderReq.Amount != 1 {
   686  			return fmt.Errorf("unexpected number of block headers requested: %v", blockHeaderReq.Amount)
   687  		}
   688  		if blockHeaderReq.Origin.Hash != announcement.Hash {
   689  			return fmt.Errorf("unexpected block header requested. Announced:\n %v\n Remote request:\n%v",
   690  				pretty.Sdump(announcement),
   691  				pretty.Sdump(blockHeaderReq))
   692  		}
   693  		if err := sendConn.Write66(&eth.BlockHeadersPacket66{
   694  			RequestId: id,
   695  			BlockHeadersPacket: eth.BlockHeadersPacket{
   696  				nextBlock.Header(),
   697  			},
   698  		}, BlockHeaders{}.Code()); err != nil {
   699  			return fmt.Errorf("failed to write to connection: %v", err)
   700  		}
   701  	} else {
   702  		msg = sendConn.Read()
   703  		switch msg := msg.(type) {
   704  		case *GetBlockHeaders:
   705  			blockHeaderReq = *msg
   706  		default:
   707  			return fmt.Errorf("unexpected %s", pretty.Sdump(msg))
   708  		}
   709  		if blockHeaderReq.Amount != 1 {
   710  			return fmt.Errorf("unexpected number of block headers requested: %v", blockHeaderReq.Amount)
   711  		}
   712  		if blockHeaderReq.Origin.Hash != announcement.Hash {
   713  			return fmt.Errorf("unexpected block header requested. Announced:\n %v\n Remote request:\n%v",
   714  				pretty.Sdump(announcement),
   715  				pretty.Sdump(blockHeaderReq))
   716  		}
   717  		if err := sendConn.Write(&BlockHeaders{nextBlock.Header()}); err != nil {
   718  			return fmt.Errorf("failed to write to connection: %v", err)
   719  		}
   720  	}
   721  	// wait for block announcement
   722  	msg = recvConn.readAndServe(s.chain, timeout)
   723  	switch msg := msg.(type) {
   724  	case *NewBlockHashes:
   725  		hashes := *msg
   726  		if len(hashes) != 1 {
   727  			return fmt.Errorf("unexpected new block hash announcement: wanted 1 announcement, got %d", len(hashes))
   728  		}
   729  		if nextBlock.Hash() != hashes[0].Hash {
   730  			return fmt.Errorf("unexpected block hash announcement, wanted %v, got %v", nextBlock.Hash(),
   731  				hashes[0].Hash)
   732  		}
   733  	case *NewBlock:
   734  		// node should only propagate NewBlock without having requested the body if the body is empty
   735  		nextBlockBody := nextBlock.Body()
   736  		if len(nextBlockBody.Transactions) != 0 || len(nextBlockBody.Uncles) != 0 {
   737  			return fmt.Errorf("unexpected non-empty new block propagated: %s", pretty.Sdump(msg))
   738  		}
   739  		if msg.Block.Hash() != nextBlock.Hash() {
   740  			return fmt.Errorf("mismatched hash of propagated new block: wanted %v, got %v",
   741  				nextBlock.Hash(), msg.Block.Hash())
   742  		}
   743  		// check to make sure header matches header that was sent to the node
   744  		if !reflect.DeepEqual(nextBlock.Header(), msg.Block.Header()) {
   745  			return fmt.Errorf("incorrect header received: wanted %v, got %v", nextBlock.Header(), msg.Block.Header())
   746  		}
   747  	default:
   748  		return fmt.Errorf("unexpected: %s", pretty.Sdump(msg))
   749  	}
   750  	// confirm node imported block
   751  	if err := s.waitForBlockImport(recvConn, nextBlock, isEth66); err != nil {
   752  		return fmt.Errorf("error waiting for node to import new block: %v", err)
   753  	}
   754  	// update the chain
   755  	s.chain.blocks = append(s.chain.blocks, nextBlock)
   756  	return nil
   757  }