github.com/anacrolix/torrent@v1.61.0/peer_protocol/handshake.go (about)

     1  package peer_protocol
     2  
     3  import (
     4  	"context"
     5  	"encoding/hex"
     6  	"fmt"
     7  	"io"
     8  	"math/bits"
     9  	"strings"
    10  	"unsafe"
    11  
    12  	"github.com/anacrolix/missinggo/v2/panicif"
    13  
    14  	"github.com/anacrolix/torrent/internal/ctxrw"
    15  	"github.com/anacrolix/torrent/metainfo"
    16  )
    17  
    18  type ExtensionBit uint
    19  
    20  // https://www.bittorrent.org/beps/bep_0004.html
    21  // https://wiki.theory.org/BitTorrentSpecification.html#Reserved_Bytes
    22  const (
    23  	ExtensionBitDht  = 0 // http://www.bittorrent.org/beps/bep_0005.html
    24  	ExtensionBitFast = 2 // http://www.bittorrent.org/beps/bep_0006.html
    25  	// A peer connection initiator can set this when sending a v1 infohash during handshake if they
    26  	// allow the receiving end to upgrade to v2 by responding with the corresponding v2 infohash.
    27  	// BEP 52, and BEP 4. TODO: Set by default and then clear it when it's not appropriate to send.
    28  	ExtensionBitV2Upgrade                    = 4
    29  	ExtensionBitAzureusExtensionNegotiation1 = 16
    30  	ExtensionBitAzureusExtensionNegotiation2 = 17
    31  	// LibTorrent Extension Protocol, http://www.bittorrent.org/beps/bep_0010.html
    32  	ExtensionBitLtep = 20
    33  	// https://wiki.theory.org/BitTorrent_Location-aware_Protocol_1
    34  	ExtensionBitLocationAwareProtocol    = 43
    35  	ExtensionBitAzureusMessagingProtocol = 63 // https://www.bittorrent.org/beps/bep_0004.html
    36  
    37  )
    38  
    39  func handshakeWriter(w io.Writer, bb <-chan []byte, done chan<- error) {
    40  	var err error
    41  	for b := range bb {
    42  		_, err = w.Write(b)
    43  		if err != nil {
    44  			break
    45  		}
    46  	}
    47  	done <- err
    48  }
    49  
    50  type (
    51  	PeerExtensionBits [8]byte
    52  )
    53  
    54  var bitTags = []struct {
    55  	bit ExtensionBit
    56  	tag string
    57  }{
    58  	// Ordered by their bit position left to right.
    59  	{ExtensionBitAzureusMessagingProtocol, "amp"},
    60  	{ExtensionBitLocationAwareProtocol, "loc"},
    61  	{ExtensionBitLtep, "ltep"},
    62  	{ExtensionBitAzureusExtensionNegotiation2, "azen2"},
    63  	{ExtensionBitAzureusExtensionNegotiation1, "azen1"},
    64  	{ExtensionBitV2Upgrade, "v2"},
    65  	{ExtensionBitFast, "fast"},
    66  	{ExtensionBitDht, "dht"},
    67  }
    68  
    69  func (pex PeerExtensionBits) String() string {
    70  	pexHex := hex.EncodeToString(pex[:])
    71  	tags := make([]string, 0, len(bitTags)+1)
    72  	for _, bitTag := range bitTags {
    73  		if pex.GetBit(bitTag.bit) {
    74  			tags = append(tags, bitTag.tag)
    75  			pex.SetBit(bitTag.bit, false)
    76  		}
    77  	}
    78  	unknownCount := bits.OnesCount64(*(*uint64)((unsafe.Pointer(&pex[0]))))
    79  	if unknownCount != 0 {
    80  		tags = append(tags, fmt.Sprintf("%v unknown", unknownCount))
    81  	}
    82  	return fmt.Sprintf("%v (%s)", pexHex, strings.Join(tags, ", "))
    83  
    84  }
    85  
    86  func NewPeerExtensionBytes(bits ...ExtensionBit) (ret PeerExtensionBits) {
    87  	for _, b := range bits {
    88  		ret.SetBit(b, true)
    89  	}
    90  	return
    91  }
    92  
    93  func (pex PeerExtensionBits) SupportsExtended() bool {
    94  	return pex.GetBit(ExtensionBitLtep)
    95  }
    96  
    97  func (pex PeerExtensionBits) SupportsDHT() bool {
    98  	return pex.GetBit(ExtensionBitDht)
    99  }
   100  
   101  func (pex PeerExtensionBits) SupportsFast() bool {
   102  	return pex.GetBit(ExtensionBitFast)
   103  }
   104  
   105  func (pex *PeerExtensionBits) SetBit(bit ExtensionBit, on bool) {
   106  	if on {
   107  		pex[7-bit/8] |= 1 << (bit % 8)
   108  	} else {
   109  		pex[7-bit/8] &^= 1 << (bit % 8)
   110  	}
   111  }
   112  
   113  func (pex PeerExtensionBits) GetBit(bit ExtensionBit) bool {
   114  	return pex[7-bit/8]&(1<<(bit%8)) != 0
   115  }
   116  
   117  type HandshakeResult struct {
   118  	PeerExtensionBits
   119  	PeerID [20]byte
   120  	metainfo.Hash
   121  }
   122  
   123  // ih is nil if we expect the peer to declare the InfoHash, such as when the peer initiated the
   124  // connection. Returns ok if the Handshake was successful, and err if there was an unexpected
   125  // condition other than the peer simply abandoning the Handshake.
   126  func Handshake(
   127  	ctx context.Context,
   128  	sock io.ReadWriter,
   129  	ih *metainfo.Hash,
   130  	peerID [20]byte,
   131  	extensions PeerExtensionBits,
   132  ) (
   133  	res HandshakeResult, err error,
   134  ) {
   135  	sock = ctxrw.WrapReadWriter(ctx, sock)
   136  	// Bytes to be sent to the peer. Should never block the sender.
   137  	postCh := make(chan []byte, 4)
   138  	// A single error value sent when the writer completes.
   139  	writeDone := make(chan error, 1)
   140  	// Performs writes to the socket and ensures posts don't block.
   141  	go handshakeWriter(sock, postCh, writeDone)
   142  
   143  	defer func() {
   144  		close(postCh) // Done writing.
   145  		if err != nil {
   146  			return
   147  		}
   148  		// Wait until writes complete before returning from handshake.
   149  		err = <-writeDone
   150  		if err != nil {
   151  			err = fmt.Errorf("error writing: %w", err)
   152  		}
   153  	}()
   154  
   155  	post := func(bb []byte) {
   156  		panicif.SendBlocks(postCh, bb)
   157  	}
   158  
   159  	post(protocolBytes())
   160  	post(extensions[:])
   161  	if ih != nil { // We already know what we want.
   162  		post(ih[:])
   163  		post(peerID[:])
   164  	}
   165  
   166  	// Putting an array on the heap still escapes.
   167  	b := make([]byte, 68)
   168  	// Read in one hit to avoid potential overhead in underlying reader.
   169  	_, err = io.ReadFull(sock, b[:])
   170  	if err != nil {
   171  		return res, fmt.Errorf("while reading: %w", err)
   172  	}
   173  
   174  	p := b[:len(Protocol)]
   175  	// This gets optimized to runtime.memequal
   176  	if string(p) != Protocol {
   177  		return res, fmt.Errorf("unexpected protocol string %q", string(p))
   178  	}
   179  	b = b[len(p):]
   180  	read := func(dst []byte) {
   181  		n := copy(dst, b)
   182  		panicif.NotEq(n, len(dst))
   183  		b = b[n:]
   184  	}
   185  	read(res.PeerExtensionBits[:])
   186  	read(res.Hash[:])
   187  	read(res.PeerID[:])
   188  	panicif.NotEq(len(b), 0)
   189  	// peerExtensions.Add(res.PeerExtensionBits.String(), 1)
   190  
   191  	// TODO: Maybe we can just drop peers here if we're not interested. This
   192  	// could prevent them trying to reconnect, falsely believing there was
   193  	// just a problem.
   194  	if ih == nil { // We were waiting for the peer to tell us what they wanted.
   195  		post(res.Hash[:])
   196  		post(peerID[:])
   197  	}
   198  
   199  	return
   200  }