github.com/core-coin/go-core/v2@v2.1.9/cmd/devp2p/internal/v5test/framework.go (about)

     1  // Copyright 2020 by the Authors
     2  // This file is part of go-core.
     3  //
     4  // go-core is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // go-core is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU General Public License
    15  // along with go-core. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package v5test
    18  
    19  import (
    20  	"bytes"
    21  	crand "crypto/rand"
    22  	"encoding/binary"
    23  	"fmt"
    24  	"net"
    25  	"time"
    26  
    27  	"github.com/core-coin/go-core/v2/common/mclock"
    28  	"github.com/core-coin/go-core/v2/crypto"
    29  	"github.com/core-coin/go-core/v2/p2p/discover/v5wire"
    30  	"github.com/core-coin/go-core/v2/p2p/enode"
    31  	"github.com/core-coin/go-core/v2/p2p/enr"
    32  )
    33  
    34  // readError represents an error during packet reading.
    35  // This exists to facilitate type-switching on the result of conn.read.
    36  type readError struct {
    37  	err error
    38  }
    39  
    40  func (p *readError) Kind() byte          { return 99 }
    41  func (p *readError) Name() string        { return fmt.Sprintf("error: %v", p.err) }
    42  func (p *readError) Error() string       { return p.err.Error() }
    43  func (p *readError) Unwrap() error       { return p.err }
    44  func (p *readError) RequestID() []byte   { return nil }
    45  func (p *readError) SetRequestID([]byte) {}
    46  
    47  // readErrorf creates a readError with the given text.
    48  func readErrorf(format string, args ...interface{}) *readError {
    49  	return &readError{fmt.Errorf(format, args...)}
    50  }
    51  
    52  // This is the response timeout used in tests.
    53  const waitTime = 300 * time.Millisecond
    54  
    55  // conn is a connection to the node under test.
    56  type conn struct {
    57  	localNode  *enode.LocalNode
    58  	localKey   *crypto.PrivateKey
    59  	remote     *enode.Node
    60  	remoteAddr *net.UDPAddr
    61  	listeners  []net.PacketConn
    62  
    63  	log           logger
    64  	codec         *v5wire.Codec
    65  	lastRequest   v5wire.Packet
    66  	lastChallenge *v5wire.Whoareyou
    67  	idCounter     uint32
    68  }
    69  
    70  type logger interface {
    71  	Logf(string, ...interface{})
    72  }
    73  
    74  // newConn sets up a connection to the given node.
    75  func newConn(dest *enode.Node, log logger) *conn {
    76  	key, err := crypto.GenerateKey(crand.Reader)
    77  	if err != nil {
    78  		panic(err)
    79  	}
    80  	db, err := enode.OpenDB("")
    81  	if err != nil {
    82  		panic(err)
    83  	}
    84  	ln := enode.NewLocalNode(db, key)
    85  
    86  	return &conn{
    87  		localKey:   key,
    88  		localNode:  ln,
    89  		remote:     dest,
    90  		remoteAddr: &net.UDPAddr{IP: dest.IP(), Port: dest.UDP()},
    91  		codec:      v5wire.NewCodec(ln, key, mclock.System{}),
    92  		log:        log,
    93  	}
    94  }
    95  
    96  func (tc *conn) setEndpoint(c net.PacketConn) {
    97  	tc.localNode.SetStaticIP(laddr(c).IP)
    98  	tc.localNode.SetFallbackUDP(laddr(c).Port)
    99  }
   100  
   101  func (tc *conn) listen(ip string) net.PacketConn {
   102  	l, err := net.ListenPacket("udp", fmt.Sprintf("%v:0", ip))
   103  	if err != nil {
   104  		panic(err)
   105  	}
   106  	tc.listeners = append(tc.listeners, l)
   107  	return l
   108  }
   109  
   110  // close shuts down all listeners and the local node.
   111  func (tc *conn) close() {
   112  	for _, l := range tc.listeners {
   113  		l.Close()
   114  	}
   115  	tc.localNode.Database().Close()
   116  }
   117  
   118  // nextReqID creates a request id.
   119  func (tc *conn) nextReqID() []byte {
   120  	id := make([]byte, 4)
   121  	tc.idCounter++
   122  	binary.BigEndian.PutUint32(id, tc.idCounter)
   123  	return id
   124  }
   125  
   126  // reqresp performs a request/response interaction on the given connection.
   127  // The request is retried if a handshake is requested.
   128  func (tc *conn) reqresp(c net.PacketConn, req v5wire.Packet) v5wire.Packet {
   129  	reqnonce := tc.write(c, req, nil)
   130  	switch resp := tc.read(c).(type) {
   131  	case *v5wire.Whoareyou:
   132  		if resp.Nonce != reqnonce {
   133  			return readErrorf("wrong nonce %x in WHOAREYOU (want %x)", resp.Nonce[:], reqnonce[:])
   134  		}
   135  		resp.Node = tc.remote
   136  		tc.write(c, req, resp)
   137  		return tc.read(c)
   138  	default:
   139  		return resp
   140  	}
   141  }
   142  
   143  // findnode sends a FINDNODE request and waits for its responses.
   144  func (tc *conn) findnode(c net.PacketConn, dists []uint) ([]*enode.Node, error) {
   145  	var (
   146  		findnode = &v5wire.Findnode{ReqID: tc.nextReqID(), Distances: dists}
   147  		reqnonce = tc.write(c, findnode, nil)
   148  		first    = true
   149  		total    uint8
   150  		results  []*enode.Node
   151  	)
   152  	for n := 1; n > 0; {
   153  		switch resp := tc.read(c).(type) {
   154  		case *v5wire.Whoareyou:
   155  			// Handle handshake.
   156  			if resp.Nonce == reqnonce {
   157  				resp.Node = tc.remote
   158  				tc.write(c, findnode, resp)
   159  			} else {
   160  				return nil, fmt.Errorf("unexpected WHOAREYOU (nonce %x), waiting for NODES", resp.Nonce[:])
   161  			}
   162  		case *v5wire.Ping:
   163  			// Handle ping from remote.
   164  			tc.write(c, &v5wire.Pong{
   165  				ReqID:  resp.ReqID,
   166  				ENRSeq: tc.localNode.Seq(),
   167  			}, nil)
   168  		case *v5wire.Nodes:
   169  			// Got NODES! Check request ID.
   170  			if !bytes.Equal(resp.ReqID, findnode.ReqID) {
   171  				return nil, fmt.Errorf("NODES response has wrong request id %x", resp.ReqID)
   172  			}
   173  			// Check total count. It should be greater than one
   174  			// and needs to be the same across all responses.
   175  			if first {
   176  				if resp.Total == 0 || resp.Total > 6 {
   177  					return nil, fmt.Errorf("invalid NODES response 'total' %d (not in (0,7))", resp.Total)
   178  				}
   179  				total = resp.Total
   180  				n = int(total) - 1
   181  				first = false
   182  			} else {
   183  				n--
   184  				if resp.Total != total {
   185  					return nil, fmt.Errorf("invalid NODES response 'total' %d (!= %d)", resp.Total, total)
   186  				}
   187  			}
   188  			// Check nodes.
   189  			nodes, err := checkRecords(resp.Nodes)
   190  			if err != nil {
   191  				return nil, fmt.Errorf("invalid node in NODES response: %v", err)
   192  			}
   193  			results = append(results, nodes...)
   194  		default:
   195  			return nil, fmt.Errorf("expected NODES, got %v", resp)
   196  		}
   197  	}
   198  	return results, nil
   199  }
   200  
   201  // write sends a packet on the given connection.
   202  func (tc *conn) write(c net.PacketConn, p v5wire.Packet, challenge *v5wire.Whoareyou) v5wire.Nonce {
   203  	packet, nonce, err := tc.codec.Encode(tc.remote.ID(), tc.remoteAddr.String(), p, challenge)
   204  	if err != nil {
   205  		panic(fmt.Errorf("can't encode %v packet: %v", p.Name(), err))
   206  	}
   207  	if _, err := c.WriteTo(packet, tc.remoteAddr); err != nil {
   208  		tc.logf("Can't send %s: %v", p.Name(), err)
   209  	} else {
   210  		tc.logf(">> %s", p.Name())
   211  	}
   212  	return nonce
   213  }
   214  
   215  // read waits for an incoming packet on the given connection.
   216  func (tc *conn) read(c net.PacketConn) v5wire.Packet {
   217  	buf := make([]byte, 1280)
   218  	if err := c.SetReadDeadline(time.Now().Add(waitTime)); err != nil {
   219  		return &readError{err}
   220  	}
   221  	n, fromAddr, err := c.ReadFrom(buf)
   222  	if err != nil {
   223  		return &readError{err}
   224  	}
   225  	_, _, p, err := tc.codec.Decode(buf[:n], fromAddr.String())
   226  	if err != nil {
   227  		return &readError{err}
   228  	}
   229  	tc.logf("<< %s", p.Name())
   230  	return p
   231  }
   232  
   233  // logf prints to the test log.
   234  func (tc *conn) logf(format string, args ...interface{}) {
   235  	if tc.log != nil {
   236  		tc.log.Logf("(%s) %s", tc.localNode.ID().TerminalString(), fmt.Sprintf(format, args...))
   237  	}
   238  }
   239  
   240  func laddr(c net.PacketConn) *net.UDPAddr {
   241  	return c.LocalAddr().(*net.UDPAddr)
   242  }
   243  
   244  func checkRecords(records []*enr.Record) ([]*enode.Node, error) {
   245  	nodes := make([]*enode.Node, len(records))
   246  	for i := range records {
   247  		n, err := enode.New(enode.ValidSchemes, records[i])
   248  		if err != nil {
   249  			return nil, err
   250  		}
   251  		nodes[i] = n
   252  	}
   253  	return nodes, nil
   254  }
   255  
   256  func containsUint(ints []uint, x uint) bool {
   257  	for i := range ints {
   258  		if ints[i] == x {
   259  			return true
   260  		}
   261  	}
   262  	return false
   263  }