github.com/avahowell/sia@v0.5.1-beta.0.20160524050156-83dcc3d37c94/modules/gateway/peers_test.go (about)

     1  package gateway
     2  
     3  import (
     4  	"net"
     5  	"testing"
     6  	"time"
     7  
     8  	"github.com/NebulousLabs/Sia/build"
     9  	"github.com/NebulousLabs/Sia/encoding"
    10  	"github.com/NebulousLabs/Sia/modules"
    11  	"github.com/NebulousLabs/muxado"
    12  )
    13  
    14  // dummyConn implements the net.Conn interface, but does not carry any actual
    15  // data. It is passed to muxado, because passing nil results in segfaults.
    16  type dummyConn struct {
    17  	net.Conn
    18  }
    19  
    20  // muxado uses these methods when sending its GoAway signal
    21  func (dc *dummyConn) Write(p []byte) (int, error) { return len(p), nil }
    22  
    23  func (dc *dummyConn) Close() error { return nil }
    24  
    25  func (dc *dummyConn) SetWriteDeadline(time.Time) error { return nil }
    26  
    27  func TestAddPeer(t *testing.T) {
    28  	g := newTestingGateway("TestAddPeer", t)
    29  	defer g.Close()
    30  	id := g.mu.Lock()
    31  	defer g.mu.Unlock(id)
    32  	g.addPeer(&peer{
    33  		Peer: modules.Peer{
    34  			NetAddress: "foo.com:123",
    35  		},
    36  		sess: muxado.Client(new(dummyConn)),
    37  	})
    38  	if len(g.peers) != 1 {
    39  		t.Fatal("gateway did not add peer")
    40  	}
    41  }
    42  
    43  func TestRandomInboundPeer(t *testing.T) {
    44  	g := newTestingGateway("TestRandomInboundPeer", t)
    45  	defer g.Close()
    46  	id := g.mu.Lock()
    47  	defer g.mu.Unlock(id)
    48  	_, err := g.randomInboundPeer()
    49  	if err != errNoPeers {
    50  		t.Fatal("expected errNoPeers, got", err)
    51  	}
    52  
    53  	g.addPeer(&peer{
    54  		Peer: modules.Peer{
    55  			NetAddress: "foo.com:123",
    56  			Inbound:    true,
    57  		},
    58  		sess: muxado.Client(new(dummyConn)),
    59  	})
    60  	if len(g.peers) != 1 {
    61  		t.Fatal("gateway did not add peer")
    62  	}
    63  	addr, err := g.randomInboundPeer()
    64  	if err != nil || addr != "foo.com:123" {
    65  		t.Fatal("gateway did not select random peer")
    66  	}
    67  }
    68  
    69  func TestListen(t *testing.T) {
    70  	if testing.Short() {
    71  		t.SkipNow()
    72  	}
    73  
    74  	g := newTestingGateway("TestListen", t)
    75  	defer g.Close()
    76  
    77  	// compliant connect with old version
    78  	conn, err := net.Dial("tcp", string(g.Address()))
    79  	if err != nil {
    80  		t.Fatal("dial failed:", err)
    81  	}
    82  	addr := modules.NetAddress(conn.LocalAddr().String())
    83  	// send version
    84  	if err := encoding.WriteObject(conn, "0.1"); err != nil {
    85  		t.Fatal("couldn't write version")
    86  	}
    87  	// read ack
    88  	var ack string
    89  	if err := encoding.ReadObject(conn, &ack, maxAddrLength); err != nil {
    90  		t.Fatal(err)
    91  	} else if ack != "reject" {
    92  		t.Fatal("gateway should have rejected old version")
    93  	}
    94  
    95  	// a simple 'conn.Close' would not obey the muxado disconnect protocol
    96  	muxado.Client(conn).Close()
    97  
    98  	// compliant connect
    99  	conn, err = net.Dial("tcp", string(g.Address()))
   100  	if err != nil {
   101  		t.Fatal("dial failed:", err)
   102  	}
   103  	addr = modules.NetAddress(conn.LocalAddr().String())
   104  	// send version
   105  	if err := encoding.WriteObject(conn, build.Version); err != nil {
   106  		t.Fatal("couldn't write version")
   107  	}
   108  	// read ack
   109  	if err := encoding.ReadObject(conn, &ack, maxAddrLength); err != nil {
   110  		t.Fatal(err)
   111  	} else if ack == "reject" {
   112  		t.Fatal("gateway should have given ack")
   113  	}
   114  
   115  	// g should add the peer
   116  	var ok bool
   117  	for !ok {
   118  		id := g.mu.RLock()
   119  		_, ok = g.peers[addr]
   120  		g.mu.RUnlock(id)
   121  	}
   122  
   123  	muxado.Client(conn).Close()
   124  
   125  	// g should remove the peer
   126  	for ok {
   127  		id := g.mu.RLock()
   128  		_, ok = g.peers[addr]
   129  		g.mu.RUnlock(id)
   130  	}
   131  
   132  	// uncompliant connect
   133  	conn, err = net.Dial("tcp", string(g.Address()))
   134  	if err != nil {
   135  		t.Fatal("dial failed:", err)
   136  	}
   137  	if _, err := conn.Write([]byte("missing length prefix")); err != nil {
   138  		t.Fatal("couldn't write malformed header")
   139  	}
   140  	// g should have closed the connection
   141  	if n, err := conn.Write([]byte("closed")); err != nil && n > 0 {
   142  		t.Error("write succeeded after closed connection")
   143  	}
   144  }
   145  
   146  func TestConnect(t *testing.T) {
   147  	if testing.Short() {
   148  		t.SkipNow()
   149  	}
   150  	// create bootstrap peer
   151  	bootstrap := newTestingGateway("TestConnect1", t)
   152  	defer bootstrap.Close()
   153  
   154  	// give it a node
   155  	id := bootstrap.mu.Lock()
   156  	bootstrap.addNode(dummyNode)
   157  	bootstrap.mu.Unlock(id)
   158  
   159  	// create peer who will connect to bootstrap
   160  	g := newTestingGateway("TestConnect2", t)
   161  	defer g.Close()
   162  
   163  	// first simulate a "bad" connect, where bootstrap won't share its nodes
   164  	bootstrap.handlers[handlerName("ShareNodes")] = func(modules.PeerConn) error {
   165  		return nil
   166  	}
   167  	// connect
   168  	err := g.Connect(bootstrap.Address())
   169  	if err != nil {
   170  		t.Fatal(err)
   171  	}
   172  	// g should not have the node
   173  	if g.removeNode(dummyNode) == nil {
   174  		t.Fatal("bootstrapper should not have received dummyNode:", g.nodes)
   175  	}
   176  
   177  	// split 'em up
   178  	g.Disconnect(bootstrap.Address())
   179  	bootstrap.Disconnect(g.Address())
   180  
   181  	// now restore the correct ShareNodes RPC and try again
   182  	bootstrap.handlers[handlerName("ShareNodes")] = bootstrap.shareNodes
   183  	err = g.Connect(bootstrap.Address())
   184  	if err != nil {
   185  		t.Fatal(err)
   186  	}
   187  	// g should have the node
   188  	time.Sleep(100 * time.Millisecond)
   189  	id = g.mu.RLock()
   190  	if _, ok := g.nodes[dummyNode]; !ok {
   191  		t.Fatal("bootstrapper should have received dummyNode:", g.nodes)
   192  	}
   193  	g.mu.RUnlock(id)
   194  }
   195  
   196  // TestConnectRejects tests that Gateway.Connect only accepts peers with
   197  // sufficient and valid versions.
   198  func TestConnectRejects(t *testing.T) {
   199  	if testing.Short() {
   200  		t.SkipNow()
   201  	}
   202  	g := newTestingGateway("TestConnectRejects", t)
   203  	// Setup a listener that mocks Gateway.acceptConn, but sends the
   204  	// version sent over mockVersionChan instead of build.Version.
   205  	listener, err := net.Listen("tcp", "localhost:0")
   206  	if err != nil {
   207  		t.Fatal(err)
   208  	}
   209  	mockVersionChan := make(chan string)
   210  	go func() {
   211  		for {
   212  			mockVersion := <-mockVersionChan
   213  			conn, err := listener.Accept()
   214  			if err != nil {
   215  				t.Fatal(err)
   216  			}
   217  			// Read remote peer version.
   218  			var remoteVersion string
   219  			if err := encoding.ReadObject(conn, &remoteVersion, maxAddrLength); err != nil {
   220  				t.Fatal(err)
   221  			}
   222  			// Write our mock version.
   223  			if err := encoding.WriteObject(conn, mockVersion); err != nil {
   224  				t.Fatal(err)
   225  			}
   226  		}
   227  	}()
   228  
   229  	tests := []struct {
   230  		version             string
   231  		errWant             error
   232  		insufficientVersion bool
   233  		msg                 string
   234  	}{
   235  		// Test that Connect fails when the remote peer's version is "reject".
   236  		{
   237  			version: "reject",
   238  			errWant: errPeerRejectedConn,
   239  			msg:     "Connect should fail when the remote peer rejects the connection",
   240  		},
   241  		// Test that Connect fails when the remote peer's version is ascii gibberish.
   242  		{
   243  			version:             "foobar",
   244  			insufficientVersion: true,
   245  			msg:                 "Connect should fail when the remote peer's version is ascii gibberish",
   246  		},
   247  		// Test that Connect fails when the remote peer's version is utf8 gibberish.
   248  		{
   249  			version:             "世界",
   250  			insufficientVersion: true,
   251  			msg:                 "Connect should fail when the remote peer's version is utf8 gibberish",
   252  		},
   253  		// Test that Connect fails when the remote peer's version is < 0.4.0 (0).
   254  		{
   255  			version:             "0",
   256  			insufficientVersion: true,
   257  			msg:                 "Connect should fail when the remote peer's version is 0",
   258  		},
   259  		{
   260  			version:             "0.0.0",
   261  			insufficientVersion: true,
   262  			msg:                 "Connect should fail when the remote peer's version is 0.0.0",
   263  		},
   264  		{
   265  			version:             "0000.0000.0000",
   266  			insufficientVersion: true,
   267  			msg:                 "Connect should fail when the remote peer's version is 0000.0000.0000",
   268  		},
   269  		{
   270  			version:             "0.3.9",
   271  			insufficientVersion: true,
   272  			msg:                 "Connect should fail when the remote peer's version is 0.3.9",
   273  		},
   274  		{
   275  			version:             "0.3.9999",
   276  			insufficientVersion: true,
   277  			msg:                 "Connect should fail when the remote peer's version is 0.3.9999",
   278  		},
   279  		{
   280  			version:             "0.3.9.9.9",
   281  			insufficientVersion: true,
   282  			msg:                 "Connect should fail when the remote peer's version is 0.3.9.9.9",
   283  		},
   284  		// Test that Connect succeeds when the remote peer's version is 0.4.0.
   285  		{
   286  			version: "0.4.0",
   287  			msg:     "Connect should succeed when the remote peer's version is 0.4.0",
   288  		},
   289  		// Test that Connect succeeds when the remote peer's version is > 0.4.0.
   290  		{
   291  			version: "9",
   292  			msg:     "Connect should succeed when the remote peer's version is 9",
   293  		},
   294  		{
   295  			version: "9.9.9",
   296  			msg:     "Connect should succeed when the remote peer's version is 9.9.9",
   297  		},
   298  		{
   299  			version: "9999.9999.9999",
   300  			msg:     "Connect should succeed when the remote peer's version is 9999.9999.9999",
   301  		},
   302  	}
   303  	for _, tt := range tests {
   304  		mockVersionChan <- tt.version
   305  		err = g.Connect(modules.NetAddress(listener.Addr().String()))
   306  		if tt.insufficientVersion {
   307  			// Check that the error is the expected type.
   308  			if _, ok := err.(insufficientVersionError); !ok {
   309  				t.Fatalf("expected Connect to error with insufficientVersionError: %s", tt.msg)
   310  			}
   311  		} else {
   312  			// Check that the error is the expected error.
   313  			if err != tt.errWant {
   314  				t.Fatalf("expected Connect to error with '%v', but got '%v': %s", tt.errWant, err, tt.msg)
   315  			}
   316  		}
   317  		g.Disconnect(modules.NetAddress(listener.Addr().String()))
   318  	}
   319  	listener.Close()
   320  }
   321  
   322  // mockGatewayWithVersion is a mock implementation of Gateway that sends a mock
   323  // version on Connect instead of build.Version.
   324  type mockGatewayWithVersion struct {
   325  	*Gateway
   326  	version    string
   327  	versionACK chan string
   328  }
   329  
   330  // Connect is a mock implementation of modules.Gateway.Connect that provides a
   331  // mock version to peers it connects to instead of build.Version. The version
   332  // ack written by the remote peer is written to the versionACK channel.
   333  func (g mockGatewayWithVersion) Connect(addr modules.NetAddress) error {
   334  	conn, err := net.DialTimeout("tcp", string(addr), dialTimeout)
   335  	if err != nil {
   336  		return err
   337  	}
   338  	// send mocked version
   339  	if err := encoding.WriteObject(conn, g.version); err != nil {
   340  		return err
   341  	}
   342  	// read version ack
   343  	var remoteVersion string
   344  	if err := encoding.ReadObject(conn, &remoteVersion, maxAddrLength); err != nil {
   345  		return err
   346  	}
   347  	g.versionACK <- remoteVersion
   348  
   349  	return nil
   350  }
   351  
   352  // TestAcceptConnRejects tests that Gateway.acceptConn only accepts peers with
   353  // sufficient and valid versions.
   354  func TestAcceptConnRejects(t *testing.T) {
   355  	if testing.Short() {
   356  		t.SkipNow()
   357  	}
   358  	g := newTestingGateway("TestAcceptConnRejects1", t)
   359  	defer g.Close()
   360  	mg := mockGatewayWithVersion{
   361  		Gateway:    newTestingGateway("TestAcceptConnRejects2", t),
   362  		versionACK: make(chan string),
   363  	}
   364  	defer mg.Close()
   365  
   366  	tests := []struct {
   367  		remoteVersion       string
   368  		versionResponseWant string
   369  		msg                 string
   370  	}{
   371  		// Test that acceptConn fails when the remote peer's version is "reject".
   372  		{
   373  			remoteVersion:       "reject",
   374  			versionResponseWant: "reject",
   375  			msg:                 "acceptConn shouldn't accept a remote peer whose version is \"reject\"",
   376  		},
   377  		// Test that acceptConn fails when the remote peer's version is ascii gibberish.
   378  		{
   379  			remoteVersion:       "foobar",
   380  			versionResponseWant: "reject",
   381  			msg:                 "acceptConn shouldn't accept a remote peer whose version is ascii giberish",
   382  		},
   383  		// Test that acceptConn fails when the remote peer's version is utf8 gibberish.
   384  		{
   385  			remoteVersion:       "世界",
   386  			versionResponseWant: "reject",
   387  			msg:                 "acceptConn shouldn't accept a remote peer whose version is utf8 giberish",
   388  		},
   389  		// Test that acceptConn fails when the remote peer's version is < 0.4.0 (0).
   390  		{
   391  			remoteVersion:       "0",
   392  			versionResponseWant: "reject",
   393  			msg:                 "acceptConn shouldn't accept a remote peer whose version is 0",
   394  		},
   395  		{
   396  			remoteVersion:       "0.0.0",
   397  			versionResponseWant: "reject",
   398  			msg:                 "acceptConn shouldn't accept a remote peer whose version is 0.0.0",
   399  		},
   400  		{
   401  			remoteVersion:       "0000.0000.0000",
   402  			versionResponseWant: "reject",
   403  			msg:                 "acceptConn shouldn't accept a remote peer whose version is 0000.000.000",
   404  		},
   405  		{
   406  			remoteVersion:       "0.3.9",
   407  			versionResponseWant: "reject",
   408  			msg:                 "acceptConn shouldn't accept a remote peer whose version is 0.3.9",
   409  		},
   410  		{
   411  			remoteVersion:       "0.3.9999",
   412  			versionResponseWant: "reject",
   413  			msg:                 "acceptConn shouldn't accept a remote peer whose version is 0.3.9999",
   414  		},
   415  		{
   416  			remoteVersion:       "0.3.9.9.9",
   417  			versionResponseWant: "reject",
   418  			msg:                 "acceptConn shouldn't accept a remote peer whose version is 0.3.9.9.9",
   419  		},
   420  		// Test that acceptConn succeeds when the remote peer's version is 0.4.0.
   421  		{
   422  			remoteVersion:       "0.4.0",
   423  			versionResponseWant: build.Version,
   424  			msg:                 "acceptConn should accept a remote peer whose version is 0.4.0",
   425  		},
   426  		// Test that acceptConn succeeds when the remote peer's version is > 0.4.0.
   427  		{
   428  			remoteVersion:       "9",
   429  			versionResponseWant: build.Version,
   430  			msg:                 "acceptConn should accept a remote peer whose version is 9",
   431  		},
   432  		{
   433  			remoteVersion:       "9.9.9",
   434  			versionResponseWant: build.Version,
   435  			msg:                 "acceptConn should accept a remote peer whose version is 9.9.9",
   436  		},
   437  		{
   438  			remoteVersion:       "9999.9999.9999",
   439  			versionResponseWant: build.Version,
   440  			msg:                 "acceptConn should accept a remote peer whose version is 9999.9999.9999",
   441  		},
   442  	}
   443  	for _, tt := range tests {
   444  		mg.version = tt.remoteVersion
   445  		go func() {
   446  			err := mg.Connect(g.Address())
   447  			if err != nil {
   448  				t.Fatal(err)
   449  			}
   450  		}()
   451  		remoteVersion := <-mg.versionACK
   452  		if remoteVersion != tt.versionResponseWant {
   453  			t.Fatalf(tt.msg)
   454  		}
   455  		g.Disconnect(mg.Address())
   456  		mg.Disconnect(g.Address())
   457  	}
   458  }
   459  
   460  func TestDisconnect(t *testing.T) {
   461  	g := newTestingGateway("TestDisconnect", t)
   462  	defer g.Close()
   463  
   464  	if err := g.Disconnect("bar.com:123"); err == nil {
   465  		t.Fatal("disconnect removed unconnected peer")
   466  	}
   467  
   468  	// dummy listener to accept connection
   469  	l, err := net.Listen("tcp", "localhost:0")
   470  	if err != nil {
   471  		t.Fatal("couldn't start listener:", err)
   472  	}
   473  	go func() {
   474  		_, err := l.Accept()
   475  		if err != nil {
   476  			t.Fatal("accept failed:", err)
   477  		}
   478  		// conn.Close()
   479  	}()
   480  	// skip standard connection protocol
   481  	conn, err := net.Dial("tcp", l.Addr().String())
   482  	if err != nil {
   483  		t.Fatal("dial failed:", err)
   484  	}
   485  	id := g.mu.Lock()
   486  	g.addPeer(&peer{
   487  		Peer: modules.Peer{
   488  			NetAddress: "foo.com:123",
   489  		},
   490  		sess: muxado.Client(conn),
   491  	})
   492  	g.mu.Unlock(id)
   493  	if err := g.Disconnect("foo.com:123"); err != nil {
   494  		t.Fatal("disconnect failed:", err)
   495  	}
   496  }
   497  
   498  func TestPeerManager(t *testing.T) {
   499  	if testing.Short() {
   500  		t.SkipNow()
   501  	}
   502  
   503  	g1 := newTestingGateway("TestPeerManager1", t)
   504  	defer g1.Close()
   505  
   506  	// create a valid node to connect to
   507  	g2 := newTestingGateway("TestPeerManager2", t)
   508  	defer g2.Close()
   509  
   510  	// g1's node list should only contain g2
   511  	id := g1.mu.Lock()
   512  	g1.nodes = map[modules.NetAddress]struct{}{}
   513  	g1.nodes[g2.Address()] = struct{}{}
   514  	g1.mu.Unlock(id)
   515  
   516  	// when peerManager wakes up, it should connect to g2.
   517  	time.Sleep(6 * time.Second)
   518  
   519  	id = g1.mu.RLock()
   520  	defer g1.mu.RUnlock(id)
   521  	if len(g1.peers) != 1 || g1.peers[g2.Address()] == nil {
   522  		t.Fatal("gateway did not connect to g2:", g1.peers)
   523  	}
   524  }