github.com/anacrolix/torrent@v1.61.0/ltep_test.go (about)

     1  package torrent_test
     2  
     3  import (
     4  	"math/rand"
     5  	"strconv"
     6  	"testing"
     7  
     8  	"github.com/anacrolix/sync"
     9  	qt "github.com/go-quicktest/qt"
    10  
    11  	. "github.com/anacrolix/torrent"
    12  	"github.com/anacrolix/torrent/internal/testutil"
    13  	pp "github.com/anacrolix/torrent/peer_protocol"
    14  )
    15  
    16  const (
    17  	testRepliesToOddsExtensionName  = "pm_me_odds"
    18  	testRepliesToEvensExtensionName = "pm_me_evens"
    19  )
    20  
    21  func countHandler(
    22  	t *testing.T,
    23  	wg *sync.WaitGroup,
    24  	// Name of the endpoint that this handler is for, for logging.
    25  	handlerName string,
    26  	// Whether we expect evens or odds
    27  	expectedMod2 uint,
    28  	// Extension name of messages we expect to handle.
    29  	answerToName pp.ExtensionName,
    30  	// Extension name of messages we expect to send.
    31  	replyToName pp.ExtensionName,
    32  	// Signal done when this value is seen.
    33  	doneValue uint,
    34  ) func(event PeerConnReadExtensionMessageEvent) {
    35  	return func(event PeerConnReadExtensionMessageEvent) {
    36  		// Read handshake, don't look it up.
    37  		if event.ExtensionNumber == 0 {
    38  			return
    39  		}
    40  		name, builtin, err := event.PeerConn.LocalLtepProtocolMap.LookupId(event.ExtensionNumber)
    41  		qt.Assert(t, qt.IsNil(err))
    42  		// Not a user protocol.
    43  		if builtin {
    44  			return
    45  		}
    46  		switch name {
    47  		case answerToName:
    48  			u64, err := strconv.ParseUint(string(event.Payload), 10, 0)
    49  			qt.Assert(t, qt.IsNil(err))
    50  			i := uint(u64)
    51  			t.Logf("%v got %d", handlerName, i)
    52  			if i == doneValue {
    53  				wg.Done()
    54  				return
    55  			}
    56  			qt.Assert(t, qt.Equals(i%2, expectedMod2))
    57  			go func() {
    58  				qt.Assert(t, qt.IsNil(
    59  					event.PeerConn.WriteExtendedMessage(
    60  						replyToName,
    61  						[]byte(strconv.FormatUint(uint64(i+1), 10)))))
    62  			}()
    63  		default:
    64  			t.Fatalf("got unexpected extension name %q", name)
    65  		}
    66  	}
    67  }
    68  
    69  func TestUserLtep(t *testing.T) {
    70  	var wg sync.WaitGroup
    71  
    72  	makeCfg := func() *ClientConfig {
    73  		cfg := TestingConfig(t)
    74  		// Only want a single connection to between the clients.
    75  		cfg.DisableUTP = true
    76  		cfg.DisableIPv6 = true
    77  		return cfg
    78  	}
    79  
    80  	evensCfg := makeCfg()
    81  	evensCfg.Callbacks.ReadExtendedHandshake = func(pc *PeerConn, msg *pp.ExtendedHandshakeMessage) {
    82  		// The client lock is held while handling this event, so we have to do synchronous work in a
    83  		// separate goroutine.
    84  		go func() {
    85  			// Check sending an extended message for a protocol the peer doesn't support is an error.
    86  			qt.Check(t, qt.IsNotNil(pc.WriteExtendedMessage("pm_me_floats", []byte("3.142"))))
    87  			// Kick things off by sending a 1.
    88  			qt.Check(t, qt.IsNil(pc.WriteExtendedMessage(testRepliesToOddsExtensionName, []byte("1"))))
    89  		}()
    90  	}
    91  	evensCfg.Callbacks.PeerConnReadExtensionMessage = append(
    92  		evensCfg.Callbacks.PeerConnReadExtensionMessage,
    93  		countHandler(t, &wg, "evens", 0, testRepliesToEvensExtensionName, testRepliesToOddsExtensionName, 100))
    94  	evensCfg.Callbacks.PeerConnAdded = append(evensCfg.Callbacks.PeerConnAdded, func(conn *PeerConn) {
    95  		conn.LocalLtepProtocolMap.AddUserProtocol(testRepliesToEvensExtensionName)
    96  		qt.Assert(t, qt.HasLen(conn.LocalLtepProtocolMap.Index[conn.LocalLtepProtocolMap.NumBuiltin:], 1))
    97  	})
    98  
    99  	oddsCfg := makeCfg()
   100  	oddsCfg.Callbacks.PeerConnAdded = append(oddsCfg.Callbacks.PeerConnAdded, func(conn *PeerConn) {
   101  		conn.LocalLtepProtocolMap.AddUserProtocol(testRepliesToOddsExtensionName)
   102  		qt.Assert(t, qt.HasLen(conn.LocalLtepProtocolMap.Index[conn.LocalLtepProtocolMap.NumBuiltin:], 1))
   103  	})
   104  	oddsCfg.Callbacks.PeerConnReadExtensionMessage = append(
   105  		oddsCfg.Callbacks.PeerConnReadExtensionMessage,
   106  		countHandler(t, &wg, "odds", 1, testRepliesToOddsExtensionName, testRepliesToEvensExtensionName, 100))
   107  
   108  	cl1, err := NewClient(oddsCfg)
   109  	qt.Assert(t, qt.IsNil(err))
   110  	defer cl1.Close()
   111  	cl2, err := NewClient(evensCfg)
   112  	qt.Assert(t, qt.IsNil(err))
   113  	defer cl2.Close()
   114  	addOpts := AddTorrentOpts{}
   115  	rand.Read(addOpts.InfoHash[:])
   116  	t1, _ := cl1.AddTorrentOpt(addOpts)
   117  	t2, _ := cl2.AddTorrentOpt(addOpts)
   118  	defer testutil.ExportStatusWriter(cl1, "cl1", t)()
   119  	defer testutil.ExportStatusWriter(cl2, "cl2", t)()
   120  	// Expect one PeerConn to see the value.
   121  	wg.Add(1)
   122  	added := t1.AddClientPeer(cl2)
   123  	// Ensure some addresses for the other client were added.
   124  	qt.Assert(t, qt.Not(qt.Equals(added, 0)))
   125  	wg.Wait()
   126  	_ = t2
   127  }