github.com/jeffallen/go-ethereum@v1.1.4-0.20150910155051-571d3236c49c/p2p/discover/udp.go (about)

     1  // Copyright 2015 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package discover
    18  
    19  import (
    20  	"bytes"
    21  	"container/list"
    22  	"crypto/ecdsa"
    23  	"errors"
    24  	"fmt"
    25  	"net"
    26  	"time"
    27  
    28  	"github.com/ethereum/go-ethereum/crypto"
    29  	"github.com/ethereum/go-ethereum/logger"
    30  	"github.com/ethereum/go-ethereum/logger/glog"
    31  	"github.com/ethereum/go-ethereum/p2p/nat"
    32  	"github.com/ethereum/go-ethereum/rlp"
    33  )
    34  
    35  const Version = 4
    36  
    37  // Errors
    38  var (
    39  	errPacketTooSmall   = errors.New("too small")
    40  	errBadHash          = errors.New("bad hash")
    41  	errExpired          = errors.New("expired")
    42  	errBadVersion       = errors.New("version mismatch")
    43  	errUnsolicitedReply = errors.New("unsolicited reply")
    44  	errUnknownNode      = errors.New("unknown node")
    45  	errTimeout          = errors.New("RPC timeout")
    46  	errClockWarp        = errors.New("reply deadline too far in the future")
    47  	errClosed           = errors.New("socket closed")
    48  )
    49  
    50  // Timeouts
    51  const (
    52  	respTimeout = 500 * time.Millisecond
    53  	sendTimeout = 500 * time.Millisecond
    54  	expiration  = 20 * time.Second
    55  
    56  	refreshInterval = 1 * time.Hour
    57  )
    58  
    59  // RPC packet types
    60  const (
    61  	pingPacket = iota + 1 // zero is 'reserved'
    62  	pongPacket
    63  	findnodePacket
    64  	neighborsPacket
    65  )
    66  
    67  // RPC request structures
    68  type (
    69  	ping struct {
    70  		Version    uint
    71  		From, To   rpcEndpoint
    72  		Expiration uint64
    73  	}
    74  
    75  	// pong is the reply to ping.
    76  	pong struct {
    77  		// This field should mirror the UDP envelope address
    78  		// of the ping packet, which provides a way to discover the
    79  		// the external address (after NAT).
    80  		To rpcEndpoint
    81  
    82  		ReplyTok   []byte // This contains the hash of the ping packet.
    83  		Expiration uint64 // Absolute timestamp at which the packet becomes invalid.
    84  	}
    85  
    86  	// findnode is a query for nodes close to the given target.
    87  	findnode struct {
    88  		Target     NodeID // doesn't need to be an actual public key
    89  		Expiration uint64
    90  	}
    91  
    92  	// reply to findnode
    93  	neighbors struct {
    94  		Nodes      []rpcNode
    95  		Expiration uint64
    96  	}
    97  
    98  	rpcNode struct {
    99  		IP  net.IP // len 4 for IPv4 or 16 for IPv6
   100  		UDP uint16 // for discovery protocol
   101  		TCP uint16 // for RLPx protocol
   102  		ID  NodeID
   103  	}
   104  
   105  	rpcEndpoint struct {
   106  		IP  net.IP // len 4 for IPv4 or 16 for IPv6
   107  		UDP uint16 // for discovery protocol
   108  		TCP uint16 // for RLPx protocol
   109  	}
   110  )
   111  
   112  func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint {
   113  	ip := addr.IP.To4()
   114  	if ip == nil {
   115  		ip = addr.IP.To16()
   116  	}
   117  	return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort}
   118  }
   119  
   120  func nodeFromRPC(rn rpcNode) (n *Node, valid bool) {
   121  	// TODO: don't accept localhost, LAN addresses from internet hosts
   122  	// TODO: check public key is on secp256k1 curve
   123  	if rn.IP.IsMulticast() || rn.IP.IsUnspecified() || rn.UDP == 0 {
   124  		return nil, false
   125  	}
   126  	return newNode(rn.ID, rn.IP, rn.UDP, rn.TCP), true
   127  }
   128  
   129  func nodeToRPC(n *Node) rpcNode {
   130  	return rpcNode{ID: n.ID, IP: n.IP, UDP: n.UDP, TCP: n.TCP}
   131  }
   132  
   133  type packet interface {
   134  	handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error
   135  }
   136  
   137  type conn interface {
   138  	ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error)
   139  	WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error)
   140  	Close() error
   141  	LocalAddr() net.Addr
   142  }
   143  
   144  // udp implements the RPC protocol.
   145  type udp struct {
   146  	conn        conn
   147  	priv        *ecdsa.PrivateKey
   148  	ourEndpoint rpcEndpoint
   149  
   150  	addpending chan *pending
   151  	gotreply   chan reply
   152  
   153  	closing chan struct{}
   154  	nat     nat.Interface
   155  
   156  	*Table
   157  }
   158  
   159  // pending represents a pending reply.
   160  //
   161  // some implementations of the protocol wish to send more than one
   162  // reply packet to findnode. in general, any neighbors packet cannot
   163  // be matched up with a specific findnode packet.
   164  //
   165  // our implementation handles this by storing a callback function for
   166  // each pending reply. incoming packets from a node are dispatched
   167  // to all the callback functions for that node.
   168  type pending struct {
   169  	// these fields must match in the reply.
   170  	from  NodeID
   171  	ptype byte
   172  
   173  	// time when the request must complete
   174  	deadline time.Time
   175  
   176  	// callback is called when a matching reply arrives. if it returns
   177  	// true, the callback is removed from the pending reply queue.
   178  	// if it returns false, the reply is considered incomplete and
   179  	// the callback will be invoked again for the next matching reply.
   180  	callback func(resp interface{}) (done bool)
   181  
   182  	// errc receives nil when the callback indicates completion or an
   183  	// error if no further reply is received within the timeout.
   184  	errc chan<- error
   185  }
   186  
   187  type reply struct {
   188  	from  NodeID
   189  	ptype byte
   190  	data  interface{}
   191  	// loop indicates whether there was
   192  	// a matching request by sending on this channel.
   193  	matched chan<- bool
   194  }
   195  
   196  // ListenUDP returns a new table that listens for UDP packets on laddr.
   197  func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface, nodeDBPath string) (*Table, error) {
   198  	addr, err := net.ResolveUDPAddr("udp", laddr)
   199  	if err != nil {
   200  		return nil, err
   201  	}
   202  	conn, err := net.ListenUDP("udp", addr)
   203  	if err != nil {
   204  		return nil, err
   205  	}
   206  	tab, _ := newUDP(priv, conn, natm, nodeDBPath)
   207  	glog.V(logger.Info).Infoln("Listening,", tab.self)
   208  	return tab, nil
   209  }
   210  
   211  func newUDP(priv *ecdsa.PrivateKey, c conn, natm nat.Interface, nodeDBPath string) (*Table, *udp) {
   212  	udp := &udp{
   213  		conn:       c,
   214  		priv:       priv,
   215  		closing:    make(chan struct{}),
   216  		gotreply:   make(chan reply),
   217  		addpending: make(chan *pending),
   218  	}
   219  	realaddr := c.LocalAddr().(*net.UDPAddr)
   220  	if natm != nil {
   221  		if !realaddr.IP.IsLoopback() {
   222  			go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery")
   223  		}
   224  		// TODO: react to external IP changes over time.
   225  		if ext, err := natm.ExternalIP(); err == nil {
   226  			realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port}
   227  		}
   228  	}
   229  	// TODO: separate TCP port
   230  	udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port))
   231  	udp.Table = newTable(udp, PubkeyID(&priv.PublicKey), realaddr, nodeDBPath)
   232  	go udp.loop()
   233  	go udp.readLoop()
   234  	return udp.Table, udp
   235  }
   236  
   237  func (t *udp) close() {
   238  	close(t.closing)
   239  	t.conn.Close()
   240  	// TODO: wait for the loops to end.
   241  }
   242  
   243  // ping sends a ping message to the given node and waits for a reply.
   244  func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error {
   245  	// TODO: maybe check for ReplyTo field in callback to measure RTT
   246  	errc := t.pending(toid, pongPacket, func(interface{}) bool { return true })
   247  	t.send(toaddr, pingPacket, ping{
   248  		Version:    Version,
   249  		From:       t.ourEndpoint,
   250  		To:         makeEndpoint(toaddr, 0), // TODO: maybe use known TCP port from DB
   251  		Expiration: uint64(time.Now().Add(expiration).Unix()),
   252  	})
   253  	return <-errc
   254  }
   255  
   256  func (t *udp) waitping(from NodeID) error {
   257  	return <-t.pending(from, pingPacket, func(interface{}) bool { return true })
   258  }
   259  
   260  // findnode sends a findnode request to the given node and waits until
   261  // the node has sent up to k neighbors.
   262  func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
   263  	nodes := make([]*Node, 0, bucketSize)
   264  	nreceived := 0
   265  	errc := t.pending(toid, neighborsPacket, func(r interface{}) bool {
   266  		reply := r.(*neighbors)
   267  		for _, rn := range reply.Nodes {
   268  			nreceived++
   269  			if n, valid := nodeFromRPC(rn); valid {
   270  				nodes = append(nodes, n)
   271  			}
   272  		}
   273  		return nreceived >= bucketSize
   274  	})
   275  	t.send(toaddr, findnodePacket, findnode{
   276  		Target:     target,
   277  		Expiration: uint64(time.Now().Add(expiration).Unix()),
   278  	})
   279  	err := <-errc
   280  	return nodes, err
   281  }
   282  
   283  // pending adds a reply callback to the pending reply queue.
   284  // see the documentation of type pending for a detailed explanation.
   285  func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-chan error {
   286  	ch := make(chan error, 1)
   287  	p := &pending{from: id, ptype: ptype, callback: callback, errc: ch}
   288  	select {
   289  	case t.addpending <- p:
   290  		// loop will handle it
   291  	case <-t.closing:
   292  		ch <- errClosed
   293  	}
   294  	return ch
   295  }
   296  
   297  func (t *udp) handleReply(from NodeID, ptype byte, req packet) bool {
   298  	matched := make(chan bool, 1)
   299  	select {
   300  	case t.gotreply <- reply{from, ptype, req, matched}:
   301  		// loop will handle it
   302  		return <-matched
   303  	case <-t.closing:
   304  		return false
   305  	}
   306  }
   307  
   308  // loop runs in its own goroutin. it keeps track of
   309  // the refresh timer and the pending reply queue.
   310  func (t *udp) loop() {
   311  	var (
   312  		plist       = list.New()
   313  		timeout     = time.NewTimer(0)
   314  		nextTimeout *pending // head of plist when timeout was last reset
   315  		refresh     = time.NewTicker(refreshInterval)
   316  	)
   317  	<-timeout.C // ignore first timeout
   318  	defer refresh.Stop()
   319  	defer timeout.Stop()
   320  
   321  	resetTimeout := func() {
   322  		if plist.Front() == nil || nextTimeout == plist.Front().Value {
   323  			return
   324  		}
   325  		// Start the timer so it fires when the next pending reply has expired.
   326  		now := time.Now()
   327  		for el := plist.Front(); el != nil; el = el.Next() {
   328  			nextTimeout = el.Value.(*pending)
   329  			if dist := nextTimeout.deadline.Sub(now); dist < 2*respTimeout {
   330  				timeout.Reset(dist)
   331  				return
   332  			}
   333  			// Remove pending replies whose deadline is too far in the
   334  			// future. These can occur if the system clock jumped
   335  			// backwards after the deadline was assigned.
   336  			nextTimeout.errc <- errClockWarp
   337  			plist.Remove(el)
   338  		}
   339  		nextTimeout = nil
   340  		timeout.Stop()
   341  	}
   342  
   343  	for {
   344  		resetTimeout()
   345  
   346  		select {
   347  		case <-refresh.C:
   348  			go t.refresh()
   349  
   350  		case <-t.closing:
   351  			for el := plist.Front(); el != nil; el = el.Next() {
   352  				el.Value.(*pending).errc <- errClosed
   353  			}
   354  			return
   355  
   356  		case p := <-t.addpending:
   357  			p.deadline = time.Now().Add(respTimeout)
   358  			plist.PushBack(p)
   359  
   360  		case r := <-t.gotreply:
   361  			var matched bool
   362  			for el := plist.Front(); el != nil; el = el.Next() {
   363  				p := el.Value.(*pending)
   364  				if p.from == r.from && p.ptype == r.ptype {
   365  					matched = true
   366  					// Remove the matcher if its callback indicates
   367  					// that all replies have been received. This is
   368  					// required for packet types that expect multiple
   369  					// reply packets.
   370  					if p.callback(r.data) {
   371  						p.errc <- nil
   372  						plist.Remove(el)
   373  					}
   374  				}
   375  			}
   376  			r.matched <- matched
   377  
   378  		case now := <-timeout.C:
   379  			nextTimeout = nil
   380  			// Notify and remove callbacks whose deadline is in the past.
   381  			for el := plist.Front(); el != nil; el = el.Next() {
   382  				p := el.Value.(*pending)
   383  				if now.After(p.deadline) || now.Equal(p.deadline) {
   384  					p.errc <- errTimeout
   385  					plist.Remove(el)
   386  				}
   387  			}
   388  		}
   389  	}
   390  }
   391  
   392  const (
   393  	macSize  = 256 / 8
   394  	sigSize  = 520 / 8
   395  	headSize = macSize + sigSize // space of packet frame data
   396  )
   397  
   398  var (
   399  	headSpace = make([]byte, headSize)
   400  
   401  	// Neighbors replies are sent across multiple packets to
   402  	// stay below the 1280 byte limit. We compute the maximum number
   403  	// of entries by stuffing a packet until it grows too large.
   404  	maxNeighbors int
   405  )
   406  
   407  func init() {
   408  	p := neighbors{Expiration: ^uint64(0)}
   409  	maxSizeNode := rpcNode{IP: make(net.IP, 16), UDP: ^uint16(0), TCP: ^uint16(0)}
   410  	for n := 0; ; n++ {
   411  		p.Nodes = append(p.Nodes, maxSizeNode)
   412  		size, _, err := rlp.EncodeToReader(p)
   413  		if err != nil {
   414  			// If this ever happens, it will be caught by the unit tests.
   415  			panic("cannot encode: " + err.Error())
   416  		}
   417  		if headSize+size+1 >= 1280 {
   418  			maxNeighbors = n
   419  			break
   420  		}
   421  	}
   422  }
   423  
   424  func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req interface{}) error {
   425  	packet, err := encodePacket(t.priv, ptype, req)
   426  	if err != nil {
   427  		return err
   428  	}
   429  	glog.V(logger.Detail).Infof(">>> %v %T\n", toaddr, req)
   430  	if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
   431  		glog.V(logger.Detail).Infoln("UDP send failed:", err)
   432  	}
   433  	return err
   434  }
   435  
   436  func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) ([]byte, error) {
   437  	b := new(bytes.Buffer)
   438  	b.Write(headSpace)
   439  	b.WriteByte(ptype)
   440  	if err := rlp.Encode(b, req); err != nil {
   441  		glog.V(logger.Error).Infoln("error encoding packet:", err)
   442  		return nil, err
   443  	}
   444  	packet := b.Bytes()
   445  	sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), priv)
   446  	if err != nil {
   447  		glog.V(logger.Error).Infoln("could not sign packet:", err)
   448  		return nil, err
   449  	}
   450  	copy(packet[macSize:], sig)
   451  	// add the hash to the front. Note: this doesn't protect the
   452  	// packet in any way. Our public key will be part of this hash in
   453  	// The future.
   454  	copy(packet, crypto.Sha3(packet[macSize:]))
   455  	return packet, nil
   456  }
   457  
   458  type tempError interface {
   459  	Temporary() bool
   460  }
   461  
   462  // readLoop runs in its own goroutine. it handles incoming UDP packets.
   463  func (t *udp) readLoop() {
   464  	defer t.conn.Close()
   465  	// Discovery packets are defined to be no larger than 1280 bytes.
   466  	// Packets larger than this size will be cut at the end and treated
   467  	// as invalid because their hash won't match.
   468  	buf := make([]byte, 1280)
   469  	for {
   470  		nbytes, from, err := t.conn.ReadFromUDP(buf)
   471  		if tempErr, ok := err.(tempError); ok && tempErr.Temporary() {
   472  			// Ignore temporary read errors.
   473  			glog.V(logger.Debug).Infof("Temporary read error: %v", err)
   474  			continue
   475  		} else if err != nil {
   476  			// Shut down the loop for permament errors.
   477  			glog.V(logger.Debug).Infof("Read error: %v", err)
   478  			return
   479  		}
   480  		t.handlePacket(from, buf[:nbytes])
   481  	}
   482  }
   483  
   484  func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error {
   485  	packet, fromID, hash, err := decodePacket(buf)
   486  	if err != nil {
   487  		glog.V(logger.Debug).Infof("Bad packet from %v: %v\n", from, err)
   488  		return err
   489  	}
   490  	status := "ok"
   491  	if err = packet.handle(t, from, fromID, hash); err != nil {
   492  		status = err.Error()
   493  	}
   494  	glog.V(logger.Detail).Infof("<<< %v %T: %s\n", from, packet, status)
   495  	return err
   496  }
   497  
   498  func decodePacket(buf []byte) (packet, NodeID, []byte, error) {
   499  	if len(buf) < headSize+1 {
   500  		return nil, NodeID{}, nil, errPacketTooSmall
   501  	}
   502  	hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:]
   503  	shouldhash := crypto.Sha3(buf[macSize:])
   504  	if !bytes.Equal(hash, shouldhash) {
   505  		return nil, NodeID{}, nil, errBadHash
   506  	}
   507  	fromID, err := recoverNodeID(crypto.Sha3(buf[headSize:]), sig)
   508  	if err != nil {
   509  		return nil, NodeID{}, hash, err
   510  	}
   511  	var req packet
   512  	switch ptype := sigdata[0]; ptype {
   513  	case pingPacket:
   514  		req = new(ping)
   515  	case pongPacket:
   516  		req = new(pong)
   517  	case findnodePacket:
   518  		req = new(findnode)
   519  	case neighborsPacket:
   520  		req = new(neighbors)
   521  	default:
   522  		return nil, fromID, hash, fmt.Errorf("unknown type: %d", ptype)
   523  	}
   524  	err = rlp.DecodeBytes(sigdata[1:], req)
   525  	return req, fromID, hash, err
   526  }
   527  
   528  func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
   529  	if expired(req.Expiration) {
   530  		return errExpired
   531  	}
   532  	if req.Version != Version {
   533  		return errBadVersion
   534  	}
   535  	t.send(from, pongPacket, pong{
   536  		To:         makeEndpoint(from, req.From.TCP),
   537  		ReplyTok:   mac,
   538  		Expiration: uint64(time.Now().Add(expiration).Unix()),
   539  	})
   540  	if !t.handleReply(fromID, pingPacket, req) {
   541  		// Note: we're ignoring the provided IP address right now
   542  		go t.bond(true, fromID, from, req.From.TCP)
   543  	}
   544  	return nil
   545  }
   546  
   547  func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
   548  	if expired(req.Expiration) {
   549  		return errExpired
   550  	}
   551  	if !t.handleReply(fromID, pongPacket, req) {
   552  		return errUnsolicitedReply
   553  	}
   554  	return nil
   555  }
   556  
   557  func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
   558  	if expired(req.Expiration) {
   559  		return errExpired
   560  	}
   561  	if t.db.node(fromID) == nil {
   562  		// No bond exists, we don't process the packet. This prevents
   563  		// an attack vector where the discovery protocol could be used
   564  		// to amplify traffic in a DDOS attack. A malicious actor
   565  		// would send a findnode request with the IP address and UDP
   566  		// port of the target as the source address. The recipient of
   567  		// the findnode packet would then send a neighbors packet
   568  		// (which is a much bigger packet than findnode) to the victim.
   569  		return errUnknownNode
   570  	}
   571  	target := crypto.Sha3Hash(req.Target[:])
   572  	t.mutex.Lock()
   573  	closest := t.closest(target, bucketSize).entries
   574  	t.mutex.Unlock()
   575  
   576  	p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
   577  	// Send neighbors in chunks with at most maxNeighbors per packet
   578  	// to stay below the 1280 byte limit.
   579  	for i, n := range closest {
   580  		p.Nodes = append(p.Nodes, nodeToRPC(n))
   581  		if len(p.Nodes) == maxNeighbors || i == len(closest)-1 {
   582  			t.send(from, neighborsPacket, p)
   583  			p.Nodes = p.Nodes[:0]
   584  		}
   585  	}
   586  	return nil
   587  }
   588  
   589  func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
   590  	if expired(req.Expiration) {
   591  		return errExpired
   592  	}
   593  	if !t.handleReply(fromID, neighborsPacket, req) {
   594  		return errUnsolicitedReply
   595  	}
   596  	return nil
   597  }
   598  
   599  func expired(ts uint64) bool {
   600  	return time.Unix(int64(ts), 0).Before(time.Now())
   601  }