gopkg.in/dedis/onet.v2@v2.0.0-20181115163211-c8f3724038a7/local.go (about)

     1  package onet
     2  
     3  import (
     4  	"crypto/tls"
     5  	"errors"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"net"
     9  	"os"
    10  	"strconv"
    11  	"strings"
    12  	"sync"
    13  	"testing"
    14  	"time"
    15  
    16  	"gopkg.in/dedis/kyber.v2"
    17  	"gopkg.in/dedis/kyber.v2/util/key"
    18  	"gopkg.in/dedis/onet.v2/log"
    19  	"gopkg.in/dedis/onet.v2/network"
    20  )
    21  
    22  // LeakyTestCheck represents an enum to indicate how deep CloseAll needs to
    23  // check the tests.
    24  type LeakyTestCheck int
    25  
    26  const (
    27  	// CheckNone will make CloseAll not check anything.
    28  	CheckNone LeakyTestCheck = iota + 1
    29  	// CheckGoroutines will only check for leaking goroutines.
    30  	CheckGoroutines
    31  	// CheckAll will also check for leaking Overlay.Processors and
    32  	// ProtocolInstances.
    33  	CheckAll
    34  )
    35  
    36  // TestClose interface allows a service to clean up for the tests. It will only
    37  // be called when a test calls `LocalTest.CloseAll()`.
    38  type TestClose interface {
    39  	// TestClose can clean up things needed in the service.
    40  	TestClose()
    41  }
    42  
    43  // LocalTest represents all that is needed for a local test-run
    44  type LocalTest struct {
    45  	// A map of ServerIdentity.Id to Servers
    46  	Servers map[network.ServerIdentityID]*Server
    47  	// A map of ServerIdentity.Id to Overlays
    48  	Overlays map[network.ServerIdentityID]*Overlay
    49  	// A map of ServerIdentity.Id to Services
    50  	Services map[network.ServerIdentityID]map[ServiceID]Service
    51  	// A map of Tree.Id to Trees
    52  	Trees map[TreeID]*Tree
    53  	// All single nodes
    54  	Nodes []*TreeNodeInstance
    55  	// How carefully to check for leaking resources at the end of the test.
    56  	Check LeakyTestCheck
    57  	// are we running tcp or local layer
    58  	mode string
    59  	// TLS certificate if we want TLS for websocket
    60  	webSocketTLSCertificate []byte
    61  	// TLS certificate key if we want TLS for websocket
    62  	webSocketTLSCertificateKey []byte
    63  	// the context for the local connections
    64  	// it enables to have multiple local test running simultaneously
    65  	ctx   *network.LocalManager
    66  	Suite network.Suite
    67  	path  string
    68  	// Once closed is set, do not allow further operations on it,
    69  	// since now the temp directory is gone.
    70  	closed bool
    71  	T      *testing.T
    72  
    73  	// keep the latestPort used so that we can add nodes later
    74  	latestPort int
    75  }
    76  
    77  const (
    78  	// TCP represents the TCP mode of networking for this local test
    79  	TCP = "tcp"
    80  	// Local represents the Local mode of networking for this local test
    81  	Local = "local"
    82  )
    83  
    84  // NewLocalTest creates a new Local handler that can be used to test protocols
    85  // locally
    86  func NewLocalTest(s network.Suite) *LocalTest {
    87  	dir, err := ioutil.TempDir("", "onet")
    88  	if err != nil {
    89  		log.Fatal("could not create temp directory: ", err)
    90  	}
    91  
    92  	return &LocalTest{
    93  		Servers:    make(map[network.ServerIdentityID]*Server),
    94  		Overlays:   make(map[network.ServerIdentityID]*Overlay),
    95  		Services:   make(map[network.ServerIdentityID]map[ServiceID]Service),
    96  		Trees:      make(map[TreeID]*Tree),
    97  		Nodes:      make([]*TreeNodeInstance, 0, 1),
    98  		Check:      CheckGoroutines,
    99  		mode:       Local,
   100  		ctx:        network.NewLocalManager(),
   101  		Suite:      s,
   102  		path:       dir,
   103  		latestPort: 2000,
   104  	}
   105  }
   106  
   107  // NewLocalTestT is like NewLocalTest but also stores the testing.T variable.
   108  func NewLocalTestT(s network.Suite, t *testing.T) *LocalTest {
   109  	l := NewLocalTest(s)
   110  	l.T = t
   111  	return l
   112  }
   113  
   114  // NewTCPTest returns a LocalTest but using a TCPRouter as the underlying
   115  // communication layer.
   116  func NewTCPTest(s network.Suite) *LocalTest {
   117  	t := NewLocalTest(s)
   118  	t.mode = TCP
   119  	return t
   120  }
   121  
   122  // NewTCPTestWithTLS returns a LocalTest but using a TCPRouter as the
   123  // underlying communication layer and containing information for TLS setup.
   124  func NewTCPTestWithTLS(s network.Suite, wsTLSCertificate []byte,
   125  	wsTLSCertificateKey []byte) *LocalTest {
   126  	t := NewLocalTest(s)
   127  	t.mode = TCP
   128  	t.webSocketTLSCertificate = wsTLSCertificate
   129  	t.webSocketTLSCertificateKey = wsTLSCertificateKey
   130  	return t
   131  }
   132  
   133  // StartProtocol takes a name and a tree and will create a
   134  // new Node with the protocol 'name' running from the tree-root
   135  func (l *LocalTest) StartProtocol(name string, t *Tree) (ProtocolInstance, error) {
   136  	l.panicClosed()
   137  	rootServerIdentityID := t.Root.ServerIdentity.ID
   138  	for _, h := range l.Servers {
   139  		if h.ServerIdentity.ID.Equal(rootServerIdentityID) {
   140  			// XXX do we really need multiples overlays ? Can't we just use the
   141  			// Node, since it is already dispatched as like a TreeNode ?
   142  			return l.Overlays[h.ServerIdentity.ID].StartProtocol(name, t, NilServiceID)
   143  		}
   144  	}
   145  	return nil, errors.New("Didn't find server for tree-root")
   146  }
   147  
   148  // CreateProtocol takes a name and a tree and will create a
   149  // new Node with the protocol 'name' without running it
   150  func (l *LocalTest) CreateProtocol(name string, t *Tree) (ProtocolInstance, error) {
   151  	l.panicClosed()
   152  	rootServerIdentityID := t.Root.ServerIdentity.ID
   153  	for _, h := range l.Servers {
   154  		if h.ServerIdentity.ID.Equal(rootServerIdentityID) {
   155  			// XXX do we really need multiples overlays ? Can't we just use the
   156  			// Node, since it is already dispatched as like a TreeNode ?
   157  			return l.Overlays[h.ServerIdentity.ID].CreateProtocol(name, t, NilServiceID)
   158  		}
   159  	}
   160  	return nil, errors.New("Didn't find server for tree-root")
   161  }
   162  
   163  // GenServers returns n Servers with a localRouter
   164  func (l *LocalTest) GenServers(n int) []*Server {
   165  	l.panicClosed()
   166  	servers := l.genLocalHosts(n)
   167  	for _, server := range servers {
   168  		server.ServerIdentity.SetPrivate(server.private)
   169  		l.Servers[server.ServerIdentity.ID] = server
   170  		l.Overlays[server.ServerIdentity.ID] = server.overlay
   171  		l.Services[server.ServerIdentity.ID] = server.serviceManager.services
   172  	}
   173  	return servers
   174  
   175  }
   176  
   177  // GenTree will create a tree of n servers with a localRouter, and returns the
   178  // list of servers and the associated roster / tree.
   179  func (l *LocalTest) GenTree(n int, register bool) ([]*Server, *Roster, *Tree) {
   180  	l.panicClosed()
   181  	servers := l.GenServers(n)
   182  
   183  	list := l.GenRosterFromHost(servers...)
   184  	tree := list.GenerateBinaryTree()
   185  	l.Trees[tree.ID] = tree
   186  	if register {
   187  		servers[0].overlay.RegisterRoster(list)
   188  		servers[0].overlay.RegisterTree(tree)
   189  	}
   190  	return servers, list, tree
   191  
   192  }
   193  
   194  // GenBigTree will create a tree of n servers.
   195  // If register is true, the Roster and Tree will be registered with the overlay.
   196  // 'nbrServers' is how many servers are created
   197  // 'nbrTreeNodes' is how many TreeNodes are created
   198  // nbrServers can be smaller than nbrTreeNodes, in which case a given server will
   199  // be used more than once in the tree.
   200  func (l *LocalTest) GenBigTree(nbrTreeNodes, nbrServers, bf int, register bool) ([]*Server, *Roster, *Tree) {
   201  	l.panicClosed()
   202  	servers := l.GenServers(nbrServers)
   203  
   204  	list := l.GenRosterFromHost(servers...)
   205  	tree := list.GenerateBigNaryTree(bf, nbrTreeNodes)
   206  	l.Trees[tree.ID] = tree
   207  	if register {
   208  		servers[0].overlay.RegisterRoster(list)
   209  		servers[0].overlay.RegisterTree(tree)
   210  	}
   211  	return servers, list, tree
   212  }
   213  
   214  // GenRosterFromHost takes a number of servers as arguments and creates
   215  // an Roster.
   216  func (l *LocalTest) GenRosterFromHost(servers ...*Server) *Roster {
   217  	l.panicClosed()
   218  	var entities []*network.ServerIdentity
   219  	for i := range servers {
   220  		entities = append(entities, servers[i].ServerIdentity)
   221  	}
   222  	return NewRoster(entities)
   223  }
   224  
   225  func (l *LocalTest) panicClosed() {
   226  	if l.closed {
   227  		panic("attempt to use LocalTest after CloseAll")
   228  	}
   229  }
   230  
   231  // WaitDone loops until all protocolInstances are done or
   232  // the timeout is reached. If all protocolInstances are closed
   233  // within the timeout, nil is returned.
   234  func (l *LocalTest) WaitDone(t time.Duration) error {
   235  	var lingering []string
   236  	for i := 0; i < 10; i++ {
   237  		lingering = []string{}
   238  		for _, o := range l.Overlays {
   239  			o.instancesLock.Lock()
   240  			for si, pi := range o.protocolInstances {
   241  				lingering = append(lingering, fmt.Sprintf("ProtocolInstance type %T on %s with id %s",
   242  					pi, o.ServerIdentity(), si))
   243  			}
   244  			o.instancesLock.Unlock()
   245  		}
   246  		for _, s := range l.Servers {
   247  			disp, ok := s.serviceManager.Dispatcher.(*network.RoutineDispatcher)
   248  			if ok && disp.GetRoutines() > 0 {
   249  				lingering = append(lingering, fmt.Sprintf("RoutineDispatcher has %v routines running on %s", disp.GetRoutines(), s.ServerIdentity))
   250  			}
   251  		}
   252  		if len(lingering) == 0 {
   253  			return nil
   254  		}
   255  		time.Sleep(t / 10)
   256  	}
   257  	return errors.New("still have things lingering: " + strings.Join(lingering, "\n"))
   258  }
   259  
   260  // CloseAll closes all the servers.
   261  func (l *LocalTest) CloseAll() {
   262  	log.Lvl3("Stopping all")
   263  	if r := recover(); r != nil {
   264  		// Make sure that a panic is correctly caught, as CloseAll is most often
   265  		// called in a `defer` statement, and we don't want to show leaking
   266  		// go-routines or hanging protocolInstances if a panic occurs.
   267  		panic(r)
   268  	}
   269  	if l.T != nil && l.T.Failed() {
   270  		return
   271  	}
   272  
   273  	InformAllServersStopped()
   274  
   275  	// If the debug-level is 0, we copy all errors to a buffer that
   276  	// will be discarded at the end.
   277  	if log.DebugVisible() == 0 {
   278  		log.OutputToBuf()
   279  	}
   280  
   281  	if err := l.WaitDone(5 * time.Second); err != nil {
   282  		switch l.Check {
   283  		case CheckNone:
   284  			// Ignore waitDone
   285  		case CheckGoroutines:
   286  			// Only print a warning
   287  			if l.T != nil {
   288  				l.T.Log("Warning:", err)
   289  			} else {
   290  				log.Warn("Warning:", err)
   291  			}
   292  		case CheckAll:
   293  			// Fail if there are leaking processes or protocolInstances
   294  			if l.T != nil {
   295  				l.T.Fatal(err.Error())
   296  			} else {
   297  				log.Fatal(err.Error())
   298  			}
   299  		}
   300  	}
   301  
   302  	for _, node := range l.Nodes {
   303  		log.Lvl3("Closing node", node)
   304  		node.closeDispatch()
   305  	}
   306  	l.Nodes = make([]*TreeNodeInstance, 0)
   307  
   308  	sd := sync.WaitGroup{}
   309  	for _, srv := range l.Servers {
   310  		sd.Add(1)
   311  		go func(server *Server) {
   312  			log.Lvl3("Closing server", server.ServerIdentity.Address)
   313  			err := server.Close()
   314  			if err != nil {
   315  				log.Error("Closing server", server.ServerIdentity.Address,
   316  					"gives error", err)
   317  			}
   318  
   319  			for server.Listening() {
   320  				log.Lvl1("Sleeping while waiting to close...")
   321  				time.Sleep(10 * time.Millisecond)
   322  			}
   323  			sd.Done()
   324  		}(srv)
   325  	}
   326  	sd.Wait()
   327  	l.Servers = map[network.ServerIdentityID]*Server{}
   328  	l.ctx.Stop()
   329  
   330  	os.RemoveAll(l.path)
   331  	l.closed = true
   332  
   333  	if log.DebugVisible() == 0 {
   334  		log.OutputToOs()
   335  	}
   336  	if l.Check != CheckNone {
   337  		log.AfterTest(nil)
   338  	}
   339  }
   340  
   341  // getTree returns the tree of the given TreeNode
   342  func (l *LocalTest) getTree(tn *TreeNode) *Tree {
   343  	l.panicClosed()
   344  	var tree *Tree
   345  	for _, t := range l.Trees {
   346  		if tn.IsInTree(t) {
   347  			tree = t
   348  			break
   349  		}
   350  	}
   351  	return tree
   352  }
   353  
   354  // NewTreeNodeInstance creates a new node on a TreeNode
   355  func (l *LocalTest) NewTreeNodeInstance(tn *TreeNode, protName string) (*TreeNodeInstance, error) {
   356  	l.panicClosed()
   357  	o := l.Overlays[tn.ServerIdentity.ID]
   358  	if o == nil {
   359  		return nil, errors.New("Didn't find corresponding overlay")
   360  	}
   361  	tree := l.getTree(tn)
   362  	if tree == nil {
   363  		return nil, errors.New("Didn't find tree corresponding to TreeNode")
   364  	}
   365  	protID := ProtocolNameToID(protName)
   366  	if !l.Servers[tn.ServerIdentity.ID].protocols.ProtocolExists(protID) {
   367  		return nil, errors.New("Didn't find protocol: " + protName)
   368  	}
   369  	tok := &Token{
   370  		TreeID:     tree.ID,
   371  		TreeNodeID: tn.ID,
   372  	}
   373  	io := o.protoIO.getByName(protName)
   374  	node := newTreeNodeInstance(o, tok, tn, io)
   375  	l.Nodes = append(l.Nodes, node)
   376  	return node, nil
   377  }
   378  
   379  // GetTreeNodeInstances returns all TreeNodeInstances that belong to a server
   380  func (l *LocalTest) GetTreeNodeInstances(id network.ServerIdentityID) []*TreeNodeInstance {
   381  	l.panicClosed()
   382  	var nodes []*TreeNodeInstance
   383  	for _, n := range l.Overlays[id].instances {
   384  		nodes = append(nodes, n)
   385  	}
   386  	return nodes
   387  }
   388  
   389  // sendTreeNode injects a message directly in the Overlay-layer, bypassing
   390  // Host and Network
   391  func (l *LocalTest) sendTreeNode(proto string, from, to *TreeNodeInstance, msg network.Message) error {
   392  	l.panicClosed()
   393  	if !from.Tree().ID.Equal(to.Tree().ID) {
   394  		return errors.New("Can't send from one tree to another")
   395  	}
   396  	onetMsg := &ProtocolMsg{
   397  		Msg:     msg,
   398  		MsgType: network.MessageType(msg),
   399  		From:    from.token,
   400  		To:      to.token,
   401  	}
   402  	io := l.Overlays[to.ServerIdentity().ID].protoIO.getByName(proto)
   403  	return to.overlay.TransmitMsg(onetMsg, io)
   404  }
   405  
   406  // addPendingTreeMarshal takes a treeMarshal and adds it to the list of the
   407  // known trees, also triggering dispatching of onet-messages waiting for that
   408  // tree
   409  func (l *LocalTest) addPendingTreeMarshal(c *Server, tm *TreeMarshal) {
   410  	l.panicClosed()
   411  	c.overlay.addPendingTreeMarshal(tm)
   412  }
   413  
   414  // checkPendingTreeMarshal looks whether there are any treeMarshals to be
   415  // called
   416  func (l *LocalTest) checkPendingTreeMarshal(c *Server, el *Roster) {
   417  	l.panicClosed()
   418  	c.overlay.checkPendingTreeMarshal(el)
   419  }
   420  
   421  // GetPrivate returns the private key of a server
   422  func (l *LocalTest) GetPrivate(c *Server) kyber.Scalar {
   423  	return c.private
   424  }
   425  
   426  // GetServices returns a slice of all services asked for.
   427  // The sid is the id of the service that will be collected.
   428  func (l *LocalTest) GetServices(servers []*Server, sid ServiceID) []Service {
   429  	services := make([]Service, len(servers))
   430  	for i, h := range servers {
   431  		services[i] = l.Services[h.ServerIdentity.ID][sid]
   432  	}
   433  	return services
   434  }
   435  
   436  // MakeSRS creates and returns nbr Servers, the associated Roster and the
   437  // Service object of the first server in the list having sid as a ServiceID.
   438  func (l *LocalTest) MakeSRS(s network.Suite, nbr int, sid ServiceID) ([]*Server, *Roster, Service) {
   439  	l.panicClosed()
   440  	servers := l.GenServers(nbr)
   441  	el := l.GenRosterFromHost(servers...)
   442  	return servers, el, l.Services[servers[0].ServerIdentity.ID][sid]
   443  }
   444  
   445  // NewPrivIdentity returns a secret + ServerIdentity. The SI will have
   446  // "localserver:+port as first address.
   447  func NewPrivIdentity(suite network.Suite, port int) (kyber.Scalar, *network.ServerIdentity) {
   448  	address := network.NewLocalAddress("127.0.0.1:" + strconv.Itoa(port))
   449  	kp := key.NewKeyPair(suite)
   450  	id := network.NewServerIdentity(kp.Public, address)
   451  	return kp.Private, id
   452  }
   453  
   454  // NewTCPServer creates a new server with a tcpRouter with "localhost:"+port as an
   455  // address.
   456  func newTCPServer(s network.Suite, port int, path string) *Server {
   457  	priv, id := NewPrivIdentity(s, port)
   458  	addr := network.NewTCPAddress(id.Address.NetworkAddress())
   459  	id2 := network.NewServerIdentity(id.Public, addr)
   460  	var tcpHost *network.TCPHost
   461  	// For the websocket we need a port at the address one higher than the
   462  	// TCPHost. Let TCPHost chose a port, then check if the port+1 is also
   463  	// available. Else redo the search.
   464  	for {
   465  		var err error
   466  		tcpHost, err = network.NewTCPHost(id2, s)
   467  		if err != nil {
   468  			panic(err)
   469  		}
   470  		id.Address = tcpHost.Address()
   471  		if port != 0 {
   472  			break
   473  		}
   474  		port, err := strconv.Atoi(id.Address.Port())
   475  		if err != nil {
   476  			panic(err)
   477  		}
   478  		addr := net.JoinHostPort(id.Address.Host(), strconv.Itoa(port+1))
   479  		if l, err := net.Listen("tcp", addr); err == nil {
   480  			l.Close()
   481  			break
   482  		}
   483  		log.Lvl2("Found closed port:", addr)
   484  	}
   485  	id.Address = network.NewAddress(id.Address.ConnType(), "127.0.0.1:"+id.Address.Port())
   486  	router := network.NewRouter(id, tcpHost)
   487  	router.UnauthOk = true
   488  	return newServer(s, path, router, priv)
   489  }
   490  
   491  // NewLocalServer returns a new server using a LocalRouter (channels) to communicate.
   492  // At the return of this function, the router is already Run()ing in a go
   493  // routine.
   494  func NewLocalServer(s network.Suite, port int) *Server {
   495  	dir, err := ioutil.TempDir("", "example")
   496  	if err != nil {
   497  		log.Fatal(err)
   498  	}
   499  
   500  	priv, id := NewPrivIdentity(s, port)
   501  	localRouter, err := network.NewLocalRouter(id, s)
   502  	if err != nil {
   503  		panic(err)
   504  	}
   505  	h := newServer(s, dir, localRouter, priv)
   506  	h.StartInBackground()
   507  	return h
   508  }
   509  
   510  // NewClient returns *Client for which the types depend on the mode of the
   511  // LocalContext.
   512  func (l *LocalTest) NewClient(serviceName string) *Client {
   513  	switch l.mode {
   514  	case TCP:
   515  		return NewClient(l.Suite, serviceName)
   516  	default:
   517  		log.Fatal("Can't make local client")
   518  		return nil
   519  	}
   520  }
   521  
   522  // NewClientKeep returns *Client for which the types depend on the mode of the
   523  // LocalContext, the connection is not closed after sending requests.
   524  func (l *LocalTest) NewClientKeep(serviceName string) *Client {
   525  	switch l.mode {
   526  	case TCP:
   527  		return NewClientKeep(l.Suite, serviceName)
   528  	default:
   529  		log.Fatal("Can't make local client")
   530  		return nil
   531  	}
   532  }
   533  
   534  // genLocalHosts returns n servers created with a localRouter
   535  func (l *LocalTest) genLocalHosts(n int) []*Server {
   536  	l.panicClosed()
   537  	servers := make([]*Server, n)
   538  	for i := 0; i < n; i++ {
   539  		port := l.latestPort
   540  		l.latestPort += 10
   541  		servers[i] = l.NewServer(l.Suite, port)
   542  	}
   543  	return servers
   544  }
   545  
   546  // NewServer returns a new server which type is determined by the local mode:
   547  // TCP or Local. If it's TCP, then an available port is used, otherwise, the
   548  // port given in argument is used.
   549  func (l *LocalTest) NewServer(s network.Suite, port int) *Server {
   550  	l.panicClosed()
   551  	var server *Server
   552  	switch l.mode {
   553  	case TCP:
   554  		server = l.newTCPServer(s)
   555  		// Set TLS certificate if any configuration available
   556  		if len(l.webSocketTLSCertificate) > 0 && len(l.webSocketTLSCertificateKey) > 0 {
   557  			cert, err := tls.X509KeyPair(l.webSocketTLSCertificate, l.webSocketTLSCertificateKey)
   558  			if err != nil {
   559  				panic(err)
   560  			}
   561  			server.WebSocket.Lock()
   562  			server.WebSocket.TLSConfig = &tls.Config{Certificates: []tls.Certificate{cert}}
   563  			server.WebSocket.Unlock()
   564  		}
   565  		server.StartInBackground()
   566  	default:
   567  		server = l.NewLocalServer(s, port)
   568  	}
   569  	return server
   570  }
   571  
   572  // NewTCPServer returns a new TCP Server attached to this LocalTest, configured
   573  // for TLS if possible (if anything in LocalTest.webSocketTLSCertificate/Key).
   574  func (l *LocalTest) newTCPServer(s network.Suite) *Server {
   575  	l.panicClosed()
   576  	server := newTCPServer(s, 0, l.path)
   577  	l.Servers[server.ServerIdentity.ID] = server
   578  	l.Overlays[server.ServerIdentity.ID] = server.overlay
   579  	l.Services[server.ServerIdentity.ID] = server.serviceManager.services
   580  
   581  	return server
   582  }
   583  
   584  // NewLocalServer returns a fresh Host using local connections within the context
   585  // of this LocalTest
   586  func (l *LocalTest) NewLocalServer(s network.Suite, port int) *Server {
   587  	l.panicClosed()
   588  	priv, id := NewPrivIdentity(s, port)
   589  	localRouter, err := network.NewLocalRouterWithManager(l.ctx, id, s)
   590  	if err != nil {
   591  		panic(err)
   592  	}
   593  	server := newServer(s, l.path, localRouter, priv)
   594  	server.StartInBackground()
   595  	l.Servers[server.ServerIdentity.ID] = server
   596  	l.Overlays[server.ServerIdentity.ID] = server.overlay
   597  	l.Services[server.ServerIdentity.ID] = server.serviceManager.services
   598  
   599  	return server
   600  
   601  }