github.com/jonasnick/go-ethereum@v0.7.12-0.20150216215225-22176f05d387/p2p/discover/udp.go (about)

     1  package discover
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/ecdsa"
     6  	"errors"
     7  	"fmt"
     8  	"net"
     9  	"time"
    10  
    11  	"github.com/jonasnick/go-ethereum/crypto"
    12  	"github.com/jonasnick/go-ethereum/logger"
    13  	"github.com/jonasnick/go-ethereum/p2p/nat"
    14  	"github.com/jonasnick/go-ethereum/rlp"
    15  )
    16  
    17  var log = logger.NewLogger("P2P Discovery")
    18  
    19  // Errors
    20  var (
    21  	errPacketTooSmall = errors.New("too small")
    22  	errBadHash        = errors.New("bad hash")
    23  	errExpired        = errors.New("expired")
    24  	errTimeout        = errors.New("RPC timeout")
    25  	errClosed         = errors.New("socket closed")
    26  )
    27  
    28  // Timeouts
    29  const (
    30  	respTimeout = 300 * time.Millisecond
    31  	sendTimeout = 300 * time.Millisecond
    32  	expiration  = 20 * time.Second
    33  
    34  	refreshInterval = 1 * time.Hour
    35  )
    36  
    37  // RPC packet types
    38  const (
    39  	pingPacket = iota + 1 // zero is 'reserved'
    40  	pongPacket
    41  	findnodePacket
    42  	neighborsPacket
    43  )
    44  
    45  // RPC request structures
    46  type (
    47  	ping struct {
    48  		IP         string // our IP
    49  		Port       uint16 // our port
    50  		Expiration uint64
    51  	}
    52  
    53  	// reply to Ping
    54  	pong struct {
    55  		ReplyTok   []byte
    56  		Expiration uint64
    57  	}
    58  
    59  	findnode struct {
    60  		// Id to look up. The responding node will send back nodes
    61  		// closest to the target.
    62  		Target     NodeID
    63  		Expiration uint64
    64  	}
    65  
    66  	// reply to findnode
    67  	neighbors struct {
    68  		Nodes      []*Node
    69  		Expiration uint64
    70  	}
    71  )
    72  
    73  type rpcNode struct {
    74  	IP   string
    75  	Port uint16
    76  	ID   NodeID
    77  }
    78  
    79  // udp implements the RPC protocol.
    80  type udp struct {
    81  	conn       *net.UDPConn
    82  	priv       *ecdsa.PrivateKey
    83  	addpending chan *pending
    84  	replies    chan reply
    85  	closing    chan struct{}
    86  	nat        nat.Interface
    87  
    88  	*Table
    89  }
    90  
    91  // pending represents a pending reply.
    92  //
    93  // some implementations of the protocol wish to send more than one
    94  // reply packet to findnode. in general, any neighbors packet cannot
    95  // be matched up with a specific findnode packet.
    96  //
    97  // our implementation handles this by storing a callback function for
    98  // each pending reply. incoming packets from a node are dispatched
    99  // to all the callback functions for that node.
   100  type pending struct {
   101  	// these fields must match in the reply.
   102  	from  NodeID
   103  	ptype byte
   104  
   105  	// time when the request must complete
   106  	deadline time.Time
   107  
   108  	// callback is called when a matching reply arrives. if it returns
   109  	// true, the callback is removed from the pending reply queue.
   110  	// if it returns false, the reply is considered incomplete and
   111  	// the callback will be invoked again for the next matching reply.
   112  	callback func(resp interface{}) (done bool)
   113  
   114  	// errc receives nil when the callback indicates completion or an
   115  	// error if no further reply is received within the timeout.
   116  	errc chan<- error
   117  }
   118  
   119  type reply struct {
   120  	from  NodeID
   121  	ptype byte
   122  	data  interface{}
   123  }
   124  
   125  // ListenUDP returns a new table that listens for UDP packets on laddr.
   126  func ListenUDP(priv *ecdsa.PrivateKey, laddr string, natm nat.Interface) (*Table, error) {
   127  	addr, err := net.ResolveUDPAddr("udp", laddr)
   128  	if err != nil {
   129  		return nil, err
   130  	}
   131  	conn, err := net.ListenUDP("udp", addr)
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  	udp := &udp{
   136  		conn:       conn,
   137  		priv:       priv,
   138  		closing:    make(chan struct{}),
   139  		addpending: make(chan *pending),
   140  		replies:    make(chan reply),
   141  	}
   142  
   143  	realaddr := conn.LocalAddr().(*net.UDPAddr)
   144  	if natm != nil {
   145  		if !realaddr.IP.IsLoopback() {
   146  			go nat.Map(natm, udp.closing, "udp", realaddr.Port, realaddr.Port, "ethereum discovery")
   147  		}
   148  		// TODO: react to external IP changes over time.
   149  		if ext, err := natm.ExternalIP(); err == nil {
   150  			realaddr = &net.UDPAddr{IP: ext, Port: realaddr.Port}
   151  		}
   152  	}
   153  	udp.Table = newTable(udp, PubkeyID(&priv.PublicKey), realaddr)
   154  
   155  	go udp.loop()
   156  	go udp.readLoop()
   157  	log.Infoln("Listening, ", udp.self)
   158  	return udp.Table, nil
   159  }
   160  
   161  func (t *udp) close() {
   162  	close(t.closing)
   163  	t.conn.Close()
   164  	// TODO: wait for the loops to end.
   165  }
   166  
   167  // ping sends a ping message to the given node and waits for a reply.
   168  func (t *udp) ping(e *Node) error {
   169  	// TODO: maybe check for ReplyTo field in callback to measure RTT
   170  	errc := t.pending(e.ID, pongPacket, func(interface{}) bool { return true })
   171  	t.send(e, pingPacket, ping{
   172  		IP:         t.self.IP.String(),
   173  		Port:       uint16(t.self.TCPPort),
   174  		Expiration: uint64(time.Now().Add(expiration).Unix()),
   175  	})
   176  	return <-errc
   177  }
   178  
   179  // findnode sends a findnode request to the given node and waits until
   180  // the node has sent up to k neighbors.
   181  func (t *udp) findnode(to *Node, target NodeID) ([]*Node, error) {
   182  	nodes := make([]*Node, 0, bucketSize)
   183  	nreceived := 0
   184  	errc := t.pending(to.ID, neighborsPacket, func(r interface{}) bool {
   185  		reply := r.(*neighbors)
   186  		for _, n := range reply.Nodes {
   187  			nreceived++
   188  			if n.isValid() {
   189  				nodes = append(nodes, n)
   190  			}
   191  		}
   192  		return nreceived >= bucketSize
   193  	})
   194  
   195  	t.send(to, findnodePacket, findnode{
   196  		Target:     target,
   197  		Expiration: uint64(time.Now().Add(expiration).Unix()),
   198  	})
   199  	err := <-errc
   200  	return nodes, err
   201  }
   202  
   203  // pending adds a reply callback to the pending reply queue.
   204  // see the documentation of type pending for a detailed explanation.
   205  func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-chan error {
   206  	ch := make(chan error, 1)
   207  	p := &pending{from: id, ptype: ptype, callback: callback, errc: ch}
   208  	select {
   209  	case t.addpending <- p:
   210  		// loop will handle it
   211  	case <-t.closing:
   212  		ch <- errClosed
   213  	}
   214  	return ch
   215  }
   216  
   217  // loop runs in its own goroutin. it keeps track of
   218  // the refresh timer and the pending reply queue.
   219  func (t *udp) loop() {
   220  	var (
   221  		pending      []*pending
   222  		nextDeadline time.Time
   223  		timeout      = time.NewTimer(0)
   224  		refresh      = time.NewTicker(refreshInterval)
   225  	)
   226  	<-timeout.C // ignore first timeout
   227  	defer refresh.Stop()
   228  	defer timeout.Stop()
   229  
   230  	rearmTimeout := func() {
   231  		if len(pending) == 0 || nextDeadline == pending[0].deadline {
   232  			return
   233  		}
   234  		nextDeadline = pending[0].deadline
   235  		timeout.Reset(nextDeadline.Sub(time.Now()))
   236  	}
   237  
   238  	for {
   239  		select {
   240  		case <-refresh.C:
   241  			go t.refresh()
   242  
   243  		case <-t.closing:
   244  			for _, p := range pending {
   245  				p.errc <- errClosed
   246  			}
   247  			return
   248  
   249  		case p := <-t.addpending:
   250  			p.deadline = time.Now().Add(respTimeout)
   251  			pending = append(pending, p)
   252  			rearmTimeout()
   253  
   254  		case reply := <-t.replies:
   255  			// run matching callbacks, remove if they return false.
   256  			for i, p := range pending {
   257  				if reply.from == p.from && reply.ptype == p.ptype && p.callback(reply.data) {
   258  					p.errc <- nil
   259  					copy(pending[i:], pending[i+1:])
   260  					pending = pending[:len(pending)-1]
   261  					i--
   262  				}
   263  			}
   264  			rearmTimeout()
   265  
   266  		case now := <-timeout.C:
   267  			// notify and remove callbacks whose deadline is in the past.
   268  			i := 0
   269  			for ; i < len(pending) && now.After(pending[i].deadline); i++ {
   270  				pending[i].errc <- errTimeout
   271  			}
   272  			if i > 0 {
   273  				copy(pending, pending[i:])
   274  				pending = pending[:len(pending)-i]
   275  			}
   276  			rearmTimeout()
   277  		}
   278  	}
   279  }
   280  
   281  const (
   282  	macSize  = 256 / 8
   283  	sigSize  = 520 / 8
   284  	headSize = macSize + sigSize // space of packet frame data
   285  )
   286  
   287  var headSpace = make([]byte, headSize)
   288  
   289  func (t *udp) send(to *Node, ptype byte, req interface{}) error {
   290  	b := new(bytes.Buffer)
   291  	b.Write(headSpace)
   292  	b.WriteByte(ptype)
   293  	if err := rlp.Encode(b, req); err != nil {
   294  		log.Errorln("error encoding packet:", err)
   295  		return err
   296  	}
   297  
   298  	packet := b.Bytes()
   299  	sig, err := crypto.Sign(crypto.Sha3(packet[headSize:]), t.priv)
   300  	if err != nil {
   301  		log.Errorln("could not sign packet:", err)
   302  		return err
   303  	}
   304  	copy(packet[macSize:], sig)
   305  	// add the hash to the front. Note: this doesn't protect the
   306  	// packet in any way. Our public key will be part of this hash in
   307  	// the future.
   308  	copy(packet, crypto.Sha3(packet[macSize:]))
   309  
   310  	toaddr := &net.UDPAddr{IP: to.IP, Port: to.DiscPort}
   311  	log.DebugDetailf(">>> %v %T %v\n", toaddr, req, req)
   312  	if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
   313  		log.DebugDetailln("UDP send failed:", err)
   314  	}
   315  	return err
   316  }
   317  
   318  // readLoop runs in its own goroutine. it handles incoming UDP packets.
   319  func (t *udp) readLoop() {
   320  	defer t.conn.Close()
   321  	buf := make([]byte, 4096) // TODO: good buffer size
   322  	for {
   323  		nbytes, from, err := t.conn.ReadFromUDP(buf)
   324  		if err != nil {
   325  			return
   326  		}
   327  		if err := t.packetIn(from, buf[:nbytes]); err != nil {
   328  			log.Debugf("Bad packet from %v: %v\n", from, err)
   329  		}
   330  	}
   331  }
   332  
   333  func (t *udp) packetIn(from *net.UDPAddr, buf []byte) error {
   334  	if len(buf) < headSize+1 {
   335  		return errPacketTooSmall
   336  	}
   337  	hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:]
   338  	shouldhash := crypto.Sha3(buf[macSize:])
   339  	if !bytes.Equal(hash, shouldhash) {
   340  		return errBadHash
   341  	}
   342  	fromID, err := recoverNodeID(crypto.Sha3(buf[headSize:]), sig)
   343  	if err != nil {
   344  		return err
   345  	}
   346  
   347  	var req interface {
   348  		handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error
   349  	}
   350  	switch ptype := sigdata[0]; ptype {
   351  	case pingPacket:
   352  		req = new(ping)
   353  	case pongPacket:
   354  		req = new(pong)
   355  	case findnodePacket:
   356  		req = new(findnode)
   357  	case neighborsPacket:
   358  		req = new(neighbors)
   359  	default:
   360  		return fmt.Errorf("unknown type: %d", ptype)
   361  	}
   362  	if err := rlp.Decode(bytes.NewReader(sigdata[1:]), req); err != nil {
   363  		return err
   364  	}
   365  	log.DebugDetailf("<<< %v %T %v\n", from, req, req)
   366  	return req.handle(t, from, fromID, hash)
   367  }
   368  
   369  func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
   370  	if expired(req.Expiration) {
   371  		return errExpired
   372  	}
   373  	t.mutex.Lock()
   374  	// Note: we're ignoring the provided IP address right now
   375  	n := t.bumpOrAdd(fromID, from)
   376  	if req.Port != 0 {
   377  		n.TCPPort = int(req.Port)
   378  	}
   379  	t.mutex.Unlock()
   380  
   381  	t.send(n, pongPacket, pong{
   382  		ReplyTok:   mac,
   383  		Expiration: uint64(time.Now().Add(expiration).Unix()),
   384  	})
   385  	return nil
   386  }
   387  
   388  func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
   389  	if expired(req.Expiration) {
   390  		return errExpired
   391  	}
   392  	t.mutex.Lock()
   393  	t.bump(fromID)
   394  	t.mutex.Unlock()
   395  
   396  	t.replies <- reply{fromID, pongPacket, req}
   397  	return nil
   398  }
   399  
   400  func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
   401  	if expired(req.Expiration) {
   402  		return errExpired
   403  	}
   404  	t.mutex.Lock()
   405  	e := t.bumpOrAdd(fromID, from)
   406  	closest := t.closest(req.Target, bucketSize).entries
   407  	t.mutex.Unlock()
   408  
   409  	t.send(e, neighborsPacket, neighbors{
   410  		Nodes:      closest,
   411  		Expiration: uint64(time.Now().Add(expiration).Unix()),
   412  	})
   413  	return nil
   414  }
   415  
   416  func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
   417  	if expired(req.Expiration) {
   418  		return errExpired
   419  	}
   420  	t.mutex.Lock()
   421  	t.bump(fromID)
   422  	t.add(req.Nodes)
   423  	t.mutex.Unlock()
   424  
   425  	t.replies <- reply{fromID, neighborsPacket, req}
   426  	return nil
   427  }
   428  
   429  func expired(ts uint64) bool {
   430  	return time.Unix(int64(ts), 0).Before(time.Now())
   431  }