github.com/keltia/go-ipfs@v0.3.8-0.20150909044612-210793031c63/p2p/net/swarm/swarm_test.go (about)

     1  package swarm
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"net"
     8  	"sync"
     9  	"testing"
    10  	"time"
    11  
    12  	metrics "github.com/ipfs/go-ipfs/metrics"
    13  	inet "github.com/ipfs/go-ipfs/p2p/net"
    14  	peer "github.com/ipfs/go-ipfs/p2p/peer"
    15  	testutil "github.com/ipfs/go-ipfs/util/testutil"
    16  
    17  	ma "github.com/ipfs/go-ipfs/Godeps/_workspace/src/github.com/jbenet/go-multiaddr"
    18  	context "github.com/ipfs/go-ipfs/Godeps/_workspace/src/golang.org/x/net/context"
    19  )
    20  
    21  func EchoStreamHandler(stream inet.Stream) {
    22  	go func() {
    23  		defer stream.Close()
    24  
    25  		// pull out the ipfs conn
    26  		c := stream.Conn()
    27  		log.Infof("%s ponging to %s", c.LocalPeer(), c.RemotePeer())
    28  
    29  		buf := make([]byte, 4)
    30  
    31  		for {
    32  			if _, err := stream.Read(buf); err != nil {
    33  				if err != io.EOF {
    34  					log.Info("ping receive error:", err)
    35  				}
    36  				return
    37  			}
    38  
    39  			if !bytes.Equal(buf, []byte("ping")) {
    40  				log.Infof("ping receive error: ping != %s %v", buf, buf)
    41  				return
    42  			}
    43  
    44  			if _, err := stream.Write([]byte("pong")); err != nil {
    45  				log.Info("pond send error:", err)
    46  				return
    47  			}
    48  		}
    49  	}()
    50  }
    51  
    52  func makeSwarms(ctx context.Context, t *testing.T, num int) []*Swarm {
    53  	swarms := make([]*Swarm, 0, num)
    54  
    55  	for i := 0; i < num; i++ {
    56  		localnp := testutil.RandPeerNetParamsOrFatal(t)
    57  
    58  		peerstore := peer.NewPeerstore()
    59  		peerstore.AddPubKey(localnp.ID, localnp.PubKey)
    60  		peerstore.AddPrivKey(localnp.ID, localnp.PrivKey)
    61  
    62  		addrs := []ma.Multiaddr{localnp.Addr}
    63  		swarm, err := NewSwarm(ctx, addrs, localnp.ID, peerstore, metrics.NewBandwidthCounter())
    64  		if err != nil {
    65  			t.Fatal(err)
    66  		}
    67  
    68  		swarm.SetStreamHandler(EchoStreamHandler)
    69  		swarms = append(swarms, swarm)
    70  	}
    71  
    72  	return swarms
    73  }
    74  
    75  func connectSwarms(t *testing.T, ctx context.Context, swarms []*Swarm) {
    76  
    77  	var wg sync.WaitGroup
    78  	connect := func(s *Swarm, dst peer.ID, addr ma.Multiaddr) {
    79  		// TODO: make a DialAddr func.
    80  		s.peers.AddAddr(dst, addr, peer.PermanentAddrTTL)
    81  		if _, err := s.Dial(ctx, dst); err != nil {
    82  			t.Fatal("error swarm dialing to peer", err)
    83  		}
    84  		wg.Done()
    85  	}
    86  
    87  	log.Info("Connecting swarms simultaneously.")
    88  	for _, s1 := range swarms {
    89  		for _, s2 := range swarms {
    90  			if s2.local != s1.local { // don't connect to self.
    91  				wg.Add(1)
    92  				connect(s1, s2.LocalPeer(), s2.ListenAddresses()[0]) // try the first.
    93  			}
    94  		}
    95  	}
    96  	wg.Wait()
    97  
    98  	for _, s := range swarms {
    99  		log.Infof("%s swarm routing table: %s", s.local, s.Peers())
   100  	}
   101  }
   102  
   103  func SubtestSwarm(t *testing.T, SwarmNum int, MsgNum int) {
   104  	// t.Skip("skipping for another test")
   105  
   106  	ctx := context.Background()
   107  	swarms := makeSwarms(ctx, t, SwarmNum)
   108  
   109  	// connect everyone
   110  	connectSwarms(t, ctx, swarms)
   111  
   112  	// ping/pong
   113  	for _, s1 := range swarms {
   114  		log.Debugf("-------------------------------------------------------")
   115  		log.Debugf("%s ping pong round", s1.local)
   116  		log.Debugf("-------------------------------------------------------")
   117  
   118  		_, cancel := context.WithCancel(ctx)
   119  		got := map[peer.ID]int{}
   120  		errChan := make(chan error, MsgNum*len(swarms))
   121  		streamChan := make(chan *Stream, MsgNum)
   122  
   123  		// send out "ping" x MsgNum to every peer
   124  		go func() {
   125  			defer close(streamChan)
   126  
   127  			var wg sync.WaitGroup
   128  			send := func(p peer.ID) {
   129  				defer wg.Done()
   130  
   131  				// first, one stream per peer (nice)
   132  				stream, err := s1.NewStreamWithPeer(p)
   133  				if err != nil {
   134  					errChan <- err
   135  					return
   136  				}
   137  
   138  				// send out ping!
   139  				for k := 0; k < MsgNum; k++ { // with k messages
   140  					msg := "ping"
   141  					log.Debugf("%s %s %s (%d)", s1.local, msg, p, k)
   142  					if _, err := stream.Write([]byte(msg)); err != nil {
   143  						errChan <- err
   144  						continue
   145  					}
   146  				}
   147  
   148  				// read it later
   149  				streamChan <- stream
   150  			}
   151  
   152  			for _, s2 := range swarms {
   153  				if s2.local == s1.local {
   154  					continue // dont send to self...
   155  				}
   156  
   157  				wg.Add(1)
   158  				go send(s2.local)
   159  			}
   160  			wg.Wait()
   161  		}()
   162  
   163  		// receive "pong" x MsgNum from every peer
   164  		go func() {
   165  			defer close(errChan)
   166  			count := 0
   167  			countShouldBe := MsgNum * (len(swarms) - 1)
   168  			for stream := range streamChan { // one per peer
   169  				defer stream.Close()
   170  
   171  				// get peer on the other side
   172  				p := stream.Conn().RemotePeer()
   173  
   174  				// receive pings
   175  				msgCount := 0
   176  				msg := make([]byte, 4)
   177  				for k := 0; k < MsgNum; k++ { // with k messages
   178  
   179  					// read from the stream
   180  					if _, err := stream.Read(msg); err != nil {
   181  						errChan <- err
   182  						continue
   183  					}
   184  
   185  					if string(msg) != "pong" {
   186  						errChan <- fmt.Errorf("unexpected message: %s", msg)
   187  						continue
   188  					}
   189  
   190  					log.Debugf("%s %s %s (%d)", s1.local, msg, p, k)
   191  					msgCount++
   192  				}
   193  
   194  				got[p] = msgCount
   195  				count += msgCount
   196  			}
   197  
   198  			if count != countShouldBe {
   199  				errChan <- fmt.Errorf("count mismatch: %d != %d", count, countShouldBe)
   200  			}
   201  		}()
   202  
   203  		// check any errors (blocks till consumer is done)
   204  		for err := range errChan {
   205  			if err != nil {
   206  				t.Error(err.Error())
   207  			}
   208  		}
   209  
   210  		log.Debugf("%s got pongs", s1.local)
   211  		if (len(swarms) - 1) != len(got) {
   212  			t.Errorf("got (%d) less messages than sent (%d).", len(got), len(swarms))
   213  		}
   214  
   215  		for p, n := range got {
   216  			if n != MsgNum {
   217  				t.Error("peer did not get all msgs", p, n, "/", MsgNum)
   218  			}
   219  		}
   220  
   221  		cancel()
   222  		<-time.After(10 * time.Millisecond)
   223  	}
   224  
   225  	for _, s := range swarms {
   226  		s.Close()
   227  	}
   228  }
   229  
   230  func TestSwarm(t *testing.T) {
   231  	// t.Skip("skipping for another test")
   232  	t.Parallel()
   233  
   234  	// msgs := 1000
   235  	msgs := 100
   236  	swarms := 5
   237  	SubtestSwarm(t, swarms, msgs)
   238  }
   239  
   240  func TestConnHandler(t *testing.T) {
   241  	// t.Skip("skipping for another test")
   242  	t.Parallel()
   243  
   244  	ctx := context.Background()
   245  	swarms := makeSwarms(ctx, t, 5)
   246  
   247  	gotconn := make(chan struct{}, 10)
   248  	swarms[0].SetConnHandler(func(conn *Conn) {
   249  		gotconn <- struct{}{}
   250  	})
   251  
   252  	connectSwarms(t, ctx, swarms)
   253  
   254  	<-time.After(time.Millisecond)
   255  	// should've gotten 5 by now.
   256  
   257  	swarms[0].SetConnHandler(nil)
   258  
   259  	expect := 4
   260  	for i := 0; i < expect; i++ {
   261  		select {
   262  		case <-time.After(time.Second):
   263  			t.Fatal("failed to get connections")
   264  		case <-gotconn:
   265  		}
   266  	}
   267  
   268  	select {
   269  	case <-gotconn:
   270  		t.Fatalf("should have connected to %d swarms", expect)
   271  	default:
   272  	}
   273  }
   274  
   275  func TestAddrBlocking(t *testing.T) {
   276  	ctx := context.Background()
   277  	swarms := makeSwarms(ctx, t, 2)
   278  
   279  	swarms[0].SetConnHandler(func(conn *Conn) {
   280  		t.Fatal("no connections should happen!")
   281  	})
   282  
   283  	_, block, err := net.ParseCIDR("127.0.0.1/8")
   284  	if err != nil {
   285  		t.Fatal(err)
   286  	}
   287  
   288  	swarms[1].Filters.AddDialFilter(block)
   289  
   290  	swarms[1].peers.AddAddr(swarms[0].LocalPeer(), swarms[0].ListenAddresses()[0], peer.PermanentAddrTTL)
   291  	_, err = swarms[1].Dial(ctx, swarms[0].LocalPeer())
   292  	if err == nil {
   293  		t.Fatal("dial should have failed")
   294  	}
   295  
   296  	swarms[0].peers.AddAddr(swarms[1].LocalPeer(), swarms[1].ListenAddresses()[0], peer.PermanentAddrTTL)
   297  	_, err = swarms[0].Dial(ctx, swarms[1].LocalPeer())
   298  	if err == nil {
   299  		t.Fatal("dial should have failed")
   300  	}
   301  }
   302  
   303  func TestFilterBounds(t *testing.T) {
   304  	ctx := context.Background()
   305  	swarms := makeSwarms(ctx, t, 2)
   306  
   307  	conns := make(chan struct{}, 8)
   308  	swarms[0].SetConnHandler(func(conn *Conn) {
   309  		conns <- struct{}{}
   310  	})
   311  
   312  	// Address that we wont be dialing from
   313  	_, block, err := net.ParseCIDR("192.0.0.1/8")
   314  	if err != nil {
   315  		t.Fatal(err)
   316  	}
   317  
   318  	// set filter on both sides, shouldnt matter
   319  	swarms[1].Filters.AddDialFilter(block)
   320  	swarms[0].Filters.AddDialFilter(block)
   321  
   322  	connectSwarms(t, ctx, swarms)
   323  
   324  	select {
   325  	case <-time.After(time.Second):
   326  		t.Fatal("should have gotten connection")
   327  	case <-conns:
   328  		t.Log("got connect")
   329  	}
   330  }