gitlab.com/SiaPrime/SiaPrime@v1.4.1/modules/gateway/rpc_test.go (about)

     1  package gateway
     2  
     3  import (
     4  	"errors"
     5  	"io"
     6  	"sync"
     7  	"sync/atomic"
     8  	"testing"
     9  	"time"
    10  
    11  	"gitlab.com/SiaPrime/SiaPrime/encoding"
    12  	"gitlab.com/SiaPrime/SiaPrime/modules"
    13  )
    14  
    15  func TestRPCID(t *testing.T) {
    16  	cases := map[rpcID]string{
    17  		{}:                                       "        ",
    18  		{'f', 'o', 'o'}:                          "foo     ",
    19  		{'f', 'o', 'o', 'b', 'a', 'r', 'b', 'a'}: "foobarba",
    20  	}
    21  	for id, s := range cases {
    22  		if id.String() != s {
    23  			t.Errorf("rpcID.String mismatch: expected %v, got %v", s, id.String())
    24  		}
    25  	}
    26  }
    27  
    28  func TestHandlerName(t *testing.T) {
    29  	cases := map[string]rpcID{
    30  		"":          {},
    31  		"foo":       {'f', 'o', 'o'},
    32  		"foobarbaz": {'f', 'o', 'o', 'b', 'a', 'r', 'b', 'a'},
    33  	}
    34  	for s, id := range cases {
    35  		if hid := handlerName(s); hid != id {
    36  			t.Errorf("handlerName mismatch: expected %v, got %v", id, hid)
    37  		}
    38  	}
    39  }
    40  
    41  // TestRegisterRPC tests that registering the same RPC twice causes a panic.
    42  func TestRegisterRPC(t *testing.T) {
    43  	if testing.Short() {
    44  		t.SkipNow()
    45  	}
    46  	t.Parallel()
    47  	g := newTestingGateway(t)
    48  	defer g.Close()
    49  
    50  	g.RegisterRPC("Foo", func(conn modules.PeerConn) error { return nil })
    51  	defer func() {
    52  		if r := recover(); r == nil {
    53  			t.Error("Registering the same RPC twice did not cause a panic")
    54  		}
    55  	}()
    56  	g.RegisterRPC("Foo", func(conn modules.PeerConn) error { return nil })
    57  }
    58  
    59  // TestUnregisterRPC tests that unregistering an RPC causes calls to it to
    60  // fail, and checks that unregistering a non-registered RPC causes a panic.
    61  func TestUnregisterRPC(t *testing.T) {
    62  	if testing.Short() {
    63  		t.SkipNow()
    64  	}
    65  	t.Parallel()
    66  	g1 := newNamedTestingGateway(t, "1")
    67  	defer g1.Close()
    68  	g2 := newNamedTestingGateway(t, "2")
    69  	defer g2.Close()
    70  
    71  	err := g2.Connect(g1.Address())
    72  	if err != nil {
    73  		t.Fatal(err)
    74  	}
    75  
    76  	dummyFunc := func(conn modules.PeerConn) error {
    77  		var str string
    78  		return encoding.ReadObject(conn, &str, 11)
    79  	}
    80  
    81  	// Register RPC and check that calling it succeeds.
    82  	g1.RegisterRPC("Foo", func(conn modules.PeerConn) error {
    83  		return encoding.WriteObject(conn, "foo")
    84  	})
    85  	err = g2.RPC(g1.Address(), "Foo", dummyFunc)
    86  	if err != nil {
    87  		t.Errorf("calling registered RPC on g1 returned %q", err)
    88  	}
    89  	// Unregister RPC and check that calling it fails.
    90  	g1.UnregisterRPC("Foo")
    91  	err = g2.RPC(g1.Address(), "Foo", dummyFunc)
    92  	if err.Error() != io.EOF.Error() {
    93  		t.Errorf("calling unregistered RPC on g1 returned %v instead of %v", err, io.EOF)
    94  	}
    95  
    96  	// Unregister again and check that it panics.
    97  	defer func() {
    98  		if r := recover(); r == nil {
    99  			t.Error("Unregistering an unregistered RPC did not cause a panic")
   100  		}
   101  	}()
   102  	g1.UnregisterRPC("Foo")
   103  }
   104  
   105  // TestRegisterConnectCall tests that registering the same on-connect call
   106  // twice causes a panic.
   107  func TestRegisterConnectCall(t *testing.T) {
   108  	if testing.Short() {
   109  		t.SkipNow()
   110  	}
   111  	t.Parallel()
   112  	g := newTestingGateway(t)
   113  	defer g.Close()
   114  
   115  	// Register an on-connect call.
   116  	g.RegisterConnectCall("Foo", func(conn modules.PeerConn) error { return nil })
   117  	defer func() {
   118  		if r := recover(); r == nil {
   119  			t.Error("Registering the same on-connect call twice did not cause a panic")
   120  		}
   121  	}()
   122  	g.RegisterConnectCall("Foo", func(conn modules.PeerConn) error { return nil })
   123  }
   124  
   125  // TestUnregisterConnectCallPanics tests that unregistering the same on-connect
   126  // call twice causes a panic.
   127  func TestUnregisterConnectCallPanics(t *testing.T) {
   128  	if testing.Short() {
   129  		t.SkipNow()
   130  	}
   131  	t.Parallel()
   132  	g1 := newNamedTestingGateway(t, "1")
   133  	defer g1.Close()
   134  	g2 := newNamedTestingGateway(t, "2")
   135  	defer g2.Close()
   136  
   137  	rpcChan := make(chan struct{})
   138  
   139  	// Register on-connect call and test that RPC is called on connect.
   140  	g1.RegisterConnectCall("Foo", func(conn modules.PeerConn) error {
   141  		rpcChan <- struct{}{}
   142  		return nil
   143  	})
   144  	err := g1.Connect(g2.Address())
   145  	if err != nil {
   146  		t.Fatal(err)
   147  	}
   148  	select {
   149  	case <-rpcChan:
   150  	case <-time.After(200 * time.Millisecond):
   151  		t.Fatal("ConnectCall not called on Connect after it was registered")
   152  	}
   153  	// Disconnect, unregister on-connect call, and test that RPC is not called on connect.
   154  	err = g1.Disconnect(g2.Address())
   155  	if err != nil {
   156  		t.Fatal(err)
   157  	}
   158  	g1.UnregisterConnectCall("Foo")
   159  	err = g1.Connect(g2.Address())
   160  	if err != nil {
   161  		t.Fatal(err)
   162  	}
   163  	select {
   164  	case <-rpcChan:
   165  		t.Fatal("ConnectCall called on Connect after it was unregistered")
   166  	case <-time.After(200 * time.Millisecond):
   167  	}
   168  	// Unregister again and check that it panics.
   169  	defer func() {
   170  		if r := recover(); r == nil {
   171  			t.Error("Unregistering an unregistered on-connect call did not cause a panic")
   172  		}
   173  	}()
   174  	g1.UnregisterConnectCall("Foo")
   175  }
   176  
   177  func TestRPC(t *testing.T) {
   178  	if testing.Short() {
   179  		t.SkipNow()
   180  	}
   181  	t.Parallel()
   182  	g1 := newNamedTestingGateway(t, "1")
   183  	defer g1.Close()
   184  
   185  	if err := g1.RPC("foo.com:123", "", nil); err == nil {
   186  		t.Fatal("RPC on unconnected peer succeeded")
   187  	}
   188  
   189  	g2 := newNamedTestingGateway(t, "2")
   190  	defer g2.Close()
   191  
   192  	err := g1.Connect(g2.Address())
   193  	if err != nil {
   194  		t.Fatal("failed to connect:", err)
   195  	}
   196  
   197  	g2.RegisterRPC("Foo", func(conn modules.PeerConn) error {
   198  		var i uint64
   199  		err := encoding.ReadObject(conn, &i, 8)
   200  		if err != nil {
   201  			return err
   202  		} else if i == 0xdeadbeef {
   203  			return encoding.WriteObject(conn, "foo")
   204  		} else {
   205  			return encoding.WriteObject(conn, "bar")
   206  		}
   207  	})
   208  
   209  	var foo string
   210  	err = g1.RPC(g2.Address(), "Foo", func(conn modules.PeerConn) error {
   211  		err := encoding.WriteObject(conn, 0xdeadbeef)
   212  		if err != nil {
   213  			return err
   214  		}
   215  		return encoding.ReadObject(conn, &foo, 11)
   216  	})
   217  	if err != nil {
   218  		t.Fatal(err)
   219  	}
   220  	if foo != "foo" {
   221  		t.Fatal("Foo gave wrong response:", foo)
   222  	}
   223  
   224  	// wrong number should produce an error
   225  	err = g1.RPC(g2.Address(), "Foo", func(conn modules.PeerConn) error {
   226  		err := encoding.WriteObject(conn, 0xbadbeef)
   227  		if err != nil {
   228  			return err
   229  		}
   230  		return encoding.ReadObject(conn, &foo, 11)
   231  	})
   232  	if err != nil {
   233  		t.Fatal(err)
   234  	}
   235  	if foo != "bar" {
   236  		t.Fatal("Foo gave wrong response:", foo)
   237  	}
   238  
   239  	// don't read or write anything
   240  	err = g1.RPC(g2.Address(), "Foo", func(modules.PeerConn) error {
   241  		return errNoPeers // any non-nil error will do
   242  	})
   243  	if err == nil {
   244  		t.Fatal("bad RPC did not produce an error")
   245  	}
   246  
   247  	g1.peers[g2.Address()].sess.Close()
   248  	if err := g1.RPC(g2.Address(), "Foo", nil); err == nil {
   249  		t.Fatal("RPC on closed peer connection succeeded")
   250  	}
   251  }
   252  
   253  func TestThreadedHandleConn(t *testing.T) {
   254  	if testing.Short() {
   255  		t.SkipNow()
   256  	}
   257  	t.Parallel()
   258  	g1 := newNamedTestingGateway(t, "1")
   259  	defer g1.Close()
   260  	g2 := newNamedTestingGateway(t, "2")
   261  	defer g2.Close()
   262  
   263  	err := g1.Connect(g2.Address())
   264  	if err != nil {
   265  		t.Fatal("failed to connect:", err)
   266  	}
   267  
   268  	g2.RegisterRPC("Foo", func(conn modules.PeerConn) error {
   269  		var i uint64
   270  		err := encoding.ReadObject(conn, &i, 8)
   271  		if err != nil {
   272  			return err
   273  		} else if i == 0xdeadbeef {
   274  			return encoding.WriteObject(conn, "foo")
   275  		} else {
   276  			return encoding.WriteObject(conn, "bar")
   277  		}
   278  	})
   279  
   280  	// custom rpc fn (doesn't automatically write rpcID)
   281  	rpcFn := func(fn func(modules.PeerConn) error) error {
   282  		conn, err := g1.peers[g2.Address()].open()
   283  		if err != nil {
   284  			return err
   285  		}
   286  		defer conn.Close()
   287  		return fn(conn)
   288  	}
   289  
   290  	// bad rpcID
   291  	err = rpcFn(func(conn modules.PeerConn) error {
   292  		return encoding.WriteObject(conn, [3]byte{1, 2, 3})
   293  	})
   294  	if err != nil {
   295  		t.Fatal("rpcFn failed:", err)
   296  	}
   297  
   298  	// unknown rpcID
   299  	err = rpcFn(func(conn modules.PeerConn) error {
   300  		return encoding.WriteObject(conn, handlerName("bar"))
   301  	})
   302  	if err != nil {
   303  		t.Fatal("rpcFn failed:", err)
   304  	}
   305  
   306  	// valid rpcID
   307  	err = rpcFn(func(conn modules.PeerConn) error {
   308  		return encoding.WriteObject(conn, handlerName("Foo"))
   309  	})
   310  	if err != nil {
   311  		t.Fatal("rpcFn failed:", err)
   312  	}
   313  }
   314  
   315  // TestBroadcast tests that calling broadcast with a slice of peers only
   316  // broadcasts to those peers.
   317  func TestBroadcast(t *testing.T) {
   318  	if testing.Short() {
   319  		t.SkipNow()
   320  	}
   321  	t.Parallel()
   322  	g1 := newNamedTestingGateway(t, "1")
   323  	defer g1.Close()
   324  	g2 := newNamedTestingGateway(t, "2")
   325  	defer g2.Close()
   326  	g3 := newNamedTestingGateway(t, "3")
   327  	defer g3.Close()
   328  
   329  	err := g1.Connect(g2.Address())
   330  	if err != nil {
   331  		t.Fatal("failed to connect:", err)
   332  	}
   333  	err = g1.Connect(g3.Address())
   334  	if err != nil {
   335  		t.Fatal("failed to connect:", err)
   336  	}
   337  
   338  	var g2Payload, g3Payload string
   339  	g2DoneChan := make(chan struct{})
   340  	g3DoneChan := make(chan struct{})
   341  	bothDoneChan := make(chan struct{})
   342  
   343  	g2.RegisterRPC("Recv", func(conn modules.PeerConn) error {
   344  		encoding.ReadObject(conn, &g2Payload, 100)
   345  		g2DoneChan <- struct{}{}
   346  		return nil
   347  	})
   348  	g3.RegisterRPC("Recv", func(conn modules.PeerConn) error {
   349  		encoding.ReadObject(conn, &g3Payload, 100)
   350  		g3DoneChan <- struct{}{}
   351  		return nil
   352  	})
   353  
   354  	// Test that broadcasting to all peers in g1.Peers() broadcasts to all peers.
   355  	peers := g1.Peers()
   356  	g1.Broadcast("Recv", "bar", peers)
   357  	go func() {
   358  		<-g2DoneChan
   359  		<-g3DoneChan
   360  		bothDoneChan <- struct{}{}
   361  	}()
   362  	select {
   363  	case <-bothDoneChan:
   364  		// Both g2 and g3 should receive the broadcast.
   365  	case <-time.After(5 * time.Second):
   366  		t.Fatal("broadcasting to gateway.Peers() should broadcast to all peers")
   367  	}
   368  	if g2Payload != "bar" || g3Payload != "bar" {
   369  		t.Fatal("broadcast failed:", g2Payload, g3Payload)
   370  	}
   371  
   372  	// Test that broadcasting to only g2 does not broadcast to g3.
   373  	peers = make([]modules.Peer, 0)
   374  	for _, p := range g1.Peers() {
   375  		if p.NetAddress == g2.Address() {
   376  			peers = append(peers, p)
   377  			break
   378  		}
   379  	}
   380  	g1.Broadcast("Recv", "baz", peers)
   381  	select {
   382  	case <-g2DoneChan:
   383  		// Only g2 should receive a broadcast.
   384  	case <-g3DoneChan:
   385  		t.Error("broadcast broadcasted to peers not in the peers arg")
   386  	case <-time.After(200 * time.Millisecond):
   387  		t.Fatal("called broadcast with g2 in peers list, but g2 didn't receive it.")
   388  	}
   389  	if g2Payload != "baz" {
   390  		t.Fatal("broadcast failed:", g2Payload)
   391  	}
   392  
   393  	// Test that broadcasting to only g3 does not broadcast to g2.
   394  	peers = make([]modules.Peer, 0)
   395  	for _, p := range g1.Peers() {
   396  		if p.NetAddress == g3.Address() {
   397  			peers = append(peers, p)
   398  			break
   399  		}
   400  	}
   401  	g1.Broadcast("Recv", "qux", peers)
   402  	select {
   403  	case <-g2DoneChan:
   404  		t.Error("broadcast broadcasted to peers not in the peers arg")
   405  	case <-g3DoneChan:
   406  		// Only g3 should receive a broadcast.
   407  	case <-time.After(200 * time.Millisecond):
   408  		t.Fatal("called broadcast with g3 in peers list, but g3 didn't receive it.")
   409  	}
   410  	if g3Payload != "qux" {
   411  		t.Fatal("broadcast failed:", g3Payload)
   412  	}
   413  
   414  	// Test that broadcasting to an empty slice (but not nil!) does not broadcast
   415  	// to g2 or g3.
   416  	peers = make([]modules.Peer, 0)
   417  	g1.Broadcast("Recv", "quux", peers)
   418  	select {
   419  	case <-g2DoneChan:
   420  		t.Error("broadcast broadcasted to peers not in the peers arg")
   421  	case <-g3DoneChan:
   422  		t.Error("broadcast broadcasted to peers not in the peers arg")
   423  	case <-time.After(200 * time.Millisecond):
   424  		// Neither peer should receive a broadcast.
   425  	}
   426  
   427  	// Test that calling broadcast with nil peers does not broadcast to g2 or g3.
   428  	g1.Broadcast("Recv", "foo", nil)
   429  	select {
   430  	case <-g2DoneChan:
   431  		t.Error("broadcast broadcasted to peers not in the peers arg")
   432  	case <-g3DoneChan:
   433  		t.Error("broadcast broadcasted to peers not in the peers arg")
   434  	case <-time.After(200 * time.Millisecond):
   435  		// Neither peer should receive a broadcast.
   436  	}
   437  }
   438  
   439  // TestOutboundAndInboundRPCs tests that both inbound and outbound connections
   440  // can successfully make RPC calls.
   441  func TestOutboundAndInboundRPCs(t *testing.T) {
   442  	if testing.Short() {
   443  		t.SkipNow()
   444  	}
   445  	t.Parallel()
   446  	g1 := newNamedTestingGateway(t, "1")
   447  	defer g1.Close()
   448  	g2 := newNamedTestingGateway(t, "2")
   449  	defer g2.Close()
   450  
   451  	rpcChanG1 := make(chan struct{})
   452  	rpcChanG2 := make(chan struct{})
   453  
   454  	g1.RegisterRPC("recv", func(conn modules.PeerConn) error {
   455  		rpcChanG1 <- struct{}{}
   456  		return nil
   457  	})
   458  	g2.RegisterRPC("recv", func(conn modules.PeerConn) error {
   459  		rpcChanG2 <- struct{}{}
   460  		return nil
   461  	})
   462  
   463  	err := g1.Connect(g2.Address())
   464  	if err != nil {
   465  		t.Fatal(err)
   466  	}
   467  	time.Sleep(10 * time.Millisecond)
   468  
   469  	err = g1.RPC(g2.Address(), "recv", func(conn modules.PeerConn) error { return nil })
   470  	if err != nil {
   471  		t.Fatal(err)
   472  	}
   473  	<-rpcChanG2
   474  
   475  	// Call the "recv" RPC on g1. We don't know g1's address as g2 sees it, so we
   476  	// get it from the first address in g2's peer list.
   477  	var addr modules.NetAddress
   478  	for pAddr := range g2.peers {
   479  		addr = pAddr
   480  		break
   481  	}
   482  	err = g2.RPC(addr, "recv", func(conn modules.PeerConn) error { return nil })
   483  	if err != nil {
   484  		t.Fatal(err)
   485  	}
   486  	<-rpcChanG1
   487  }
   488  
   489  // TestCallingRPCFromRPC tests that calling an RPC from an RPC works.
   490  func TestCallingRPCFromRPC(t *testing.T) {
   491  	if testing.Short() {
   492  		t.SkipNow()
   493  	}
   494  	t.Parallel()
   495  	g1 := newNamedTestingGateway(t, "1")
   496  	defer g1.Close()
   497  	g2 := newNamedTestingGateway(t, "2")
   498  	defer g2.Close()
   499  
   500  	errChan := make(chan error)
   501  	g1.RegisterRPC("FOO", func(conn modules.PeerConn) error {
   502  		err := g1.RPC(conn.RPCAddr(), "BAR", func(conn modules.PeerConn) error { return nil })
   503  		errChan <- err
   504  		return err
   505  	})
   506  
   507  	barChan := make(chan struct{})
   508  	g2.RegisterRPC("BAR", func(conn modules.PeerConn) error {
   509  		barChan <- struct{}{}
   510  		return nil
   511  	})
   512  
   513  	err := g1.Connect(g2.Address())
   514  	if err != nil {
   515  		t.Fatal(err)
   516  	}
   517  
   518  	// Wait for g2 to accept the connection
   519  	for {
   520  		if len(g2.Peers()) > 0 {
   521  			break
   522  		}
   523  	}
   524  
   525  	err = g2.RPC(g1.Address(), "FOO", func(conn modules.PeerConn) error {
   526  		return nil
   527  	})
   528  	if err != nil {
   529  		t.Fatal(err)
   530  	}
   531  
   532  	select {
   533  	case err = <-errChan:
   534  		if err != nil {
   535  			t.Fatal(err)
   536  		}
   537  	case <-time.After(500 * time.Millisecond):
   538  		t.Fatal("expected FOO RPC to be called")
   539  	}
   540  
   541  	select {
   542  	case <-barChan:
   543  	case <-time.After(1 * time.Second):
   544  		t.Fatal("expected BAR RPC to be called")
   545  	}
   546  }
   547  
   548  // TestRPCRatelimit checks that a peer calling an RPC repeatedly does not result
   549  // in a crash.
   550  func TestRPCRatelimit(t *testing.T) {
   551  	if testing.Short() {
   552  		t.SkipNow()
   553  	}
   554  	t.Parallel()
   555  	g1 := newNamedTestingGateway(t, "1")
   556  	defer g1.Close()
   557  	g2 := newNamedTestingGateway(t, "2")
   558  	defer g2.Close()
   559  
   560  	var atomicCalls, atomicErrs uint64
   561  	g2.RegisterRPC("recv", func(conn modules.PeerConn) error {
   562  		_, err := conn.Write([]byte("hi"))
   563  		if err != nil {
   564  			atomic.AddUint64(&atomicErrs, 1)
   565  			return err
   566  		}
   567  		atomic.AddUint64(&atomicCalls, 1)
   568  		return nil
   569  	})
   570  
   571  	err := g1.Connect(g2.Address())
   572  	if err != nil {
   573  		t.Fatal(err)
   574  	}
   575  	// Block until the connection is confirmed.
   576  	for i := 0; i < 50; i++ {
   577  		time.Sleep(10 * time.Millisecond)
   578  		g1.mu.Lock()
   579  		g1Peers := len(g1.peers)
   580  		g1.mu.Unlock()
   581  		g2.mu.Lock()
   582  		g2Peers := len(g2.peers)
   583  		g2.mu.Unlock()
   584  		if g1Peers > 0 || g2Peers > 0 {
   585  			break
   586  		}
   587  	}
   588  	g1.mu.Lock()
   589  	g1Peers := len(g1.peers)
   590  	g1.mu.Unlock()
   591  	g2.mu.Lock()
   592  	g2Peers := len(g2.peers)
   593  	g2.mu.Unlock()
   594  	if g1Peers == 0 || g2Peers == 0 {
   595  		t.Fatal("Peers did not connect to eachother")
   596  	}
   597  
   598  	// Call "recv" in a tight loop. Check that the number of successful calls
   599  	// does not exceed the ratelimit.
   600  	start := time.Now()
   601  	var wg sync.WaitGroup
   602  	targetDuration := rpcStdDeadline * 4 / 3
   603  	maxCallsForDuration := targetDuration / peerRPCDelay
   604  	callVolume := int(maxCallsForDuration * 3 / 5)
   605  	for i := 0; i < callVolume; i++ {
   606  		wg.Add(1)
   607  		go func() {
   608  			defer wg.Done()
   609  			// Call an RPC on our peer. Error is ignored, as many are expected
   610  			// and indicate that the test is working.
   611  			_ = g1.RPC(g2.Address(), "recv", func(conn modules.PeerConn) error {
   612  				buf := make([]byte, 2)
   613  				_, err := conn.Read(buf)
   614  				if err != nil {
   615  					return err
   616  				}
   617  				if string(buf) != "hi" {
   618  					return errors.New("caller rpc failed")
   619  				}
   620  				return nil
   621  			})
   622  		}()
   623  		// Sleep for a little bit so that the connections are coming all in a
   624  		// row instead of all at once. But sleep for little enough time that the
   625  		// number of connectings is still far surpassing the allowed ratelimit.
   626  		time.Sleep(peerRPCDelay / 10)
   627  	}
   628  	wg.Wait()
   629  
   630  	stop := time.Now()
   631  	elapsed := stop.Sub(start)
   632  	expected := peerRPCDelay * (time.Duration(atomic.LoadUint64(&atomicCalls)) + 1)
   633  	if elapsed*10/9 < expected {
   634  		t.Error("ratelimit does not seem to be effective", expected, elapsed)
   635  	}
   636  }