github.com/avahowell/sia@v0.5.1-beta.0.20160524050156-83dcc3d37c94/modules/gateway/rpc_test.go (about)

     1  package gateway
     2  
     3  import (
     4  	"testing"
     5  	"time"
     6  
     7  	"github.com/NebulousLabs/Sia/encoding"
     8  	"github.com/NebulousLabs/Sia/modules"
     9  )
    10  
    11  func TestRPCID(t *testing.T) {
    12  	cases := map[rpcID]string{
    13  		rpcID{}:                                       "        ",
    14  		rpcID{'f', 'o', 'o'}:                          "foo     ",
    15  		rpcID{'f', 'o', 'o', 'b', 'a', 'r', 'b', 'a'}: "foobarba",
    16  	}
    17  	for id, s := range cases {
    18  		if id.String() != s {
    19  			t.Errorf("rpcID.String mismatch: expected %v, got %v", s, id.String())
    20  		}
    21  	}
    22  }
    23  
    24  func TestHandlerName(t *testing.T) {
    25  	cases := map[string]rpcID{
    26  		"":          {},
    27  		"foo":       {'f', 'o', 'o'},
    28  		"foobarbaz": {'f', 'o', 'o', 'b', 'a', 'r', 'b', 'a'},
    29  	}
    30  	for s, id := range cases {
    31  		if hid := handlerName(s); hid != id {
    32  			t.Errorf("handlerName mismatch: expected %v, got %v", id, hid)
    33  		}
    34  	}
    35  }
    36  
    37  func TestRPC(t *testing.T) {
    38  	g1 := newTestingGateway("TestRPC1", t)
    39  	defer g1.Close()
    40  
    41  	if err := g1.RPC("foo.com:123", "", nil); err == nil {
    42  		t.Fatal("RPC on unconnected peer succeeded")
    43  	}
    44  
    45  	g2 := newTestingGateway("TestRPC2", t)
    46  	defer g2.Close()
    47  
    48  	err := g1.Connect(g2.Address())
    49  	if err != nil {
    50  		t.Fatal("failed to connect:", err)
    51  	}
    52  
    53  	g2.RegisterRPC("Foo", func(conn modules.PeerConn) error {
    54  		var i uint64
    55  		err := encoding.ReadObject(conn, &i, 8)
    56  		if err != nil {
    57  			return err
    58  		} else if i == 0xdeadbeef {
    59  			return encoding.WriteObject(conn, "foo")
    60  		} else {
    61  			return encoding.WriteObject(conn, "bar")
    62  		}
    63  	})
    64  
    65  	var foo string
    66  	err = g1.RPC(g2.Address(), "Foo", func(conn modules.PeerConn) error {
    67  		err := encoding.WriteObject(conn, 0xdeadbeef)
    68  		if err != nil {
    69  			return err
    70  		}
    71  		return encoding.ReadObject(conn, &foo, 11)
    72  	})
    73  	if err != nil {
    74  		t.Fatal(err)
    75  	}
    76  	if foo != "foo" {
    77  		t.Fatal("Foo gave wrong response:", foo)
    78  	}
    79  
    80  	// wrong number should produce an error
    81  	err = g1.RPC(g2.Address(), "Foo", func(conn modules.PeerConn) error {
    82  		err := encoding.WriteObject(conn, 0xbadbeef)
    83  		if err != nil {
    84  			return err
    85  		}
    86  		return encoding.ReadObject(conn, &foo, 11)
    87  	})
    88  	if err != nil {
    89  		t.Fatal(err)
    90  	}
    91  	if foo != "bar" {
    92  		t.Fatal("Foo gave wrong response:", foo)
    93  	}
    94  
    95  	// don't read or write anything
    96  	err = g1.RPC(g2.Address(), "Foo", func(modules.PeerConn) error {
    97  		return errNoPeers // any non-nil error will do
    98  	})
    99  	if err == nil {
   100  		t.Fatal("bad RPC did not produce an error")
   101  	}
   102  
   103  	g1.peers[g2.Address()].sess.Close()
   104  	if err := g1.RPC(g2.Address(), "Foo", nil); err == nil {
   105  		t.Fatal("RPC on closed peer connection succeeded")
   106  	}
   107  }
   108  
   109  func TestThreadedHandleConn(t *testing.T) {
   110  	g1 := newTestingGateway("TestThreadedHandleConn1", t)
   111  	defer g1.Close()
   112  	g2 := newTestingGateway("TestThreadedHandleConn2", t)
   113  	defer g2.Close()
   114  
   115  	err := g1.Connect(g2.Address())
   116  	if err != nil {
   117  		t.Fatal("failed to connect:", err)
   118  	}
   119  
   120  	g2.RegisterRPC("Foo", func(conn modules.PeerConn) error {
   121  		var i uint64
   122  		err := encoding.ReadObject(conn, &i, 8)
   123  		if err != nil {
   124  			return err
   125  		} else if i == 0xdeadbeef {
   126  			return encoding.WriteObject(conn, "foo")
   127  		} else {
   128  			return encoding.WriteObject(conn, "bar")
   129  		}
   130  	})
   131  
   132  	// custom rpc fn (doesn't automatically write rpcID)
   133  	rpcFn := func(fn func(modules.PeerConn) error) error {
   134  		conn, err := g1.peers[g2.Address()].open()
   135  		if err != nil {
   136  			return err
   137  		}
   138  		defer conn.Close()
   139  		return fn(conn)
   140  	}
   141  
   142  	// bad rpcID
   143  	err = rpcFn(func(conn modules.PeerConn) error {
   144  		return encoding.WriteObject(conn, [3]byte{1, 2, 3})
   145  	})
   146  	if err != nil {
   147  		t.Fatal("rpcFn failed:", err)
   148  	}
   149  
   150  	// unknown rpcID
   151  	err = rpcFn(func(conn modules.PeerConn) error {
   152  		return encoding.WriteObject(conn, handlerName("bar"))
   153  	})
   154  	if err != nil {
   155  		t.Fatal("rpcFn failed:", err)
   156  	}
   157  
   158  	// valid rpcID
   159  	err = rpcFn(func(conn modules.PeerConn) error {
   160  		return encoding.WriteObject(conn, handlerName("Foo"))
   161  	})
   162  	if err != nil {
   163  		t.Fatal("rpcFn failed:", err)
   164  	}
   165  }
   166  
   167  // TestBroadcast tests that calling broadcast with a slice of peers only
   168  // broadcasts to those peers.
   169  func TestBroadcast(t *testing.T) {
   170  	g1 := newTestingGateway("TestBroadcast1", t)
   171  	defer g1.Close()
   172  	g2 := newTestingGateway("TestBroadcast2", t)
   173  	defer g2.Close()
   174  	g3 := newTestingGateway("TestBroadcast3", t)
   175  	defer g3.Close()
   176  
   177  	err := g1.Connect(g2.Address())
   178  	if err != nil {
   179  		t.Fatal("failed to connect:", err)
   180  	}
   181  	err = g1.Connect(g3.Address())
   182  	if err != nil {
   183  		t.Fatal("failed to connect:", err)
   184  	}
   185  
   186  	var g2Payload, g3Payload string
   187  	g2DoneChan := make(chan struct{})
   188  	g3DoneChan := make(chan struct{})
   189  	bothDoneChan := make(chan struct{})
   190  
   191  	g2.RegisterRPC("Recv", func(conn modules.PeerConn) error {
   192  		encoding.ReadObject(conn, &g2Payload, 100)
   193  		g2DoneChan <- struct{}{}
   194  		return nil
   195  	})
   196  	g3.RegisterRPC("Recv", func(conn modules.PeerConn) error {
   197  		encoding.ReadObject(conn, &g3Payload, 100)
   198  		g3DoneChan <- struct{}{}
   199  		return nil
   200  	})
   201  
   202  	// Test that broadcasting to all peers in g1.Peers() broadcasts to all peers.
   203  	peers := g1.Peers()
   204  	g1.Broadcast("Recv", "bar", peers)
   205  	go func() {
   206  		<-g2DoneChan
   207  		<-g3DoneChan
   208  		bothDoneChan <- struct{}{}
   209  	}()
   210  	select {
   211  	case <-bothDoneChan:
   212  		// Both g2 and g3 should receive the broadcast.
   213  	case <-time.After(200 * time.Millisecond):
   214  		t.Fatal("broadcasting to gateway.Peers() should broadcast to all peers")
   215  	}
   216  	if g2Payload != "bar" || g3Payload != "bar" {
   217  		t.Fatal("broadcast failed:", g2Payload, g3Payload)
   218  	}
   219  
   220  	// Test that broadcasting to only g2 does not broadcast to g3.
   221  	peers = make([]modules.Peer, 0)
   222  	for _, p := range g1.Peers() {
   223  		if p.NetAddress == g2.Address() {
   224  			peers = append(peers, p)
   225  			break
   226  		}
   227  	}
   228  	g1.Broadcast("Recv", "baz", peers)
   229  	select {
   230  	case <-g2DoneChan:
   231  		// Only g2 should receive a broadcast.
   232  	case <-g3DoneChan:
   233  		t.Error("broadcast broadcasted to peers not in the peers arg")
   234  	case <-time.After(200 * time.Millisecond):
   235  		t.Fatal("called broadcast with g2 in peers list, but g2 didn't receive it.")
   236  	}
   237  	if g2Payload != "baz" {
   238  		t.Fatal("broadcast failed:", g2Payload)
   239  	}
   240  
   241  	// Test that broadcasting to only g3 does not broadcast to g2.
   242  	peers = make([]modules.Peer, 0)
   243  	for _, p := range g1.Peers() {
   244  		if p.NetAddress == g3.Address() {
   245  			peers = append(peers, p)
   246  			break
   247  		}
   248  	}
   249  	g1.Broadcast("Recv", "qux", peers)
   250  	select {
   251  	case <-g2DoneChan:
   252  		t.Error("broadcast broadcasted to peers not in the peers arg")
   253  	case <-g3DoneChan:
   254  		// Only g3 should receive a broadcast.
   255  	case <-time.After(200 * time.Millisecond):
   256  		t.Fatal("called broadcast with g3 in peers list, but g3 didn't receive it.")
   257  	}
   258  	if g3Payload != "qux" {
   259  		t.Fatal("broadcast failed:", g3Payload)
   260  	}
   261  
   262  	// Test that broadcasting to an empty slice (but not nil!) does not broadcast
   263  	// to g2 or g3.
   264  	peers = make([]modules.Peer, 0)
   265  	g1.Broadcast("Recv", "quux", peers)
   266  	select {
   267  	case <-g2DoneChan:
   268  		t.Error("broadcast broadcasted to peers not in the peers arg")
   269  	case <-g3DoneChan:
   270  		t.Error("broadcast broadcasted to peers not in the peers arg")
   271  	case <-time.After(200 * time.Millisecond):
   272  		// Neither peer should receive a broadcast.
   273  	}
   274  
   275  	// Test that calling broadcast with nil peers does not broadcast to g2 or g3.
   276  	g1.Broadcast("Recv", "foo", nil)
   277  	select {
   278  	case <-g2DoneChan:
   279  		t.Error("broadcast broadcasted to peers not in the peers arg")
   280  	case <-g3DoneChan:
   281  		t.Error("broadcast broadcasted to peers not in the peers arg")
   282  	case <-time.After(200 * time.Millisecond):
   283  		// Neither peer should receive a broadcast.
   284  	}
   285  }
   286  
   287  // TestOutboundAndInboundRPCs tests that both inbound and outbound connections
   288  // can successfully make RPC calls.
   289  func TestOutboundAndInboundRPCs(t *testing.T) {
   290  	g1 := newTestingGateway("TestRPC1", t)
   291  	defer g1.Close()
   292  	g2 := newTestingGateway("TestRPC2", t)
   293  	defer g2.Close()
   294  
   295  	rpcChanG1 := make(chan struct{})
   296  	rpcChanG2 := make(chan struct{})
   297  
   298  	g1.RegisterRPC("recv", func(conn modules.PeerConn) error {
   299  		rpcChanG1 <- struct{}{}
   300  		return nil
   301  	})
   302  	g2.RegisterRPC("recv", func(conn modules.PeerConn) error {
   303  		rpcChanG2 <- struct{}{}
   304  		return nil
   305  	})
   306  
   307  	err := g1.Connect(g2.Address())
   308  	if err != nil {
   309  		t.Fatal(err)
   310  	}
   311  	time.Sleep(10 * time.Millisecond)
   312  
   313  	err = g1.RPC(g2.Address(), "recv", func(conn modules.PeerConn) error { return nil })
   314  	if err != nil {
   315  		t.Fatal(err)
   316  	}
   317  	<-rpcChanG2
   318  
   319  	// Call the "recv" RPC on g1. We don't know g1's address as g2 sees it, so we
   320  	// get it from the first address in g2's peer list.
   321  	var addr modules.NetAddress
   322  	for p_addr := range g2.peers {
   323  		addr = p_addr
   324  		break
   325  	}
   326  	err = g2.RPC(addr, "recv", func(conn modules.PeerConn) error { return nil })
   327  	if err != nil {
   328  		t.Fatal(err)
   329  	}
   330  	<-rpcChanG1
   331  }
   332  
   333  // TestCallingRPCFromRPC tests that calling an RPC from an RPC works.
   334  func TestCallingRPCFromRPC(t *testing.T) {
   335  	g1 := newTestingGateway("TestCallingRPCFromRPC1", t)
   336  	defer g1.Close()
   337  	g2 := newTestingGateway("TestCallingRPCFromRPC2", t)
   338  	defer g2.Close()
   339  
   340  	errChan := make(chan error)
   341  	g1.RegisterRPC("FOO", func(conn modules.PeerConn) error {
   342  		err := g1.RPC(modules.NetAddress(conn.RemoteAddr().String()), "BAR", func(conn modules.PeerConn) error { return nil })
   343  		errChan <- err
   344  		return err
   345  	})
   346  
   347  	barChan := make(chan struct{})
   348  	g2.RegisterRPC("BAR", func(conn modules.PeerConn) error {
   349  		barChan <- struct{}{}
   350  		return nil
   351  	})
   352  
   353  	err := g1.Connect(g2.Address())
   354  	if err != nil {
   355  		t.Fatal(err)
   356  	}
   357  
   358  	// Call the "FOO" RPC on g1. We don't know g1's address as g2 sees it, so we
   359  	// get it from the first address in g2's peer list.
   360  	var addr modules.NetAddress
   361  	for _, p := range g2.Peers() {
   362  		addr = p.NetAddress
   363  		break
   364  	}
   365  	err = g2.RPC(addr, "FOO", func(conn modules.PeerConn) error { return nil })
   366  
   367  	select {
   368  	case err = <-errChan:
   369  		if err != nil {
   370  			t.Fatal(err)
   371  		}
   372  	case <-time.After(200 * time.Millisecond):
   373  		t.Fatal("expected FOO RPC to be called")
   374  	}
   375  
   376  	select {
   377  	case <-barChan:
   378  	case <-time.After(200 * time.Millisecond):
   379  		t.Fatal("expected BAR RPC to be called")
   380  	}
   381  }