github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/module/dkg/controller_test.go (about)

     1  package dkg
     2  
     3  import (
     4  	"fmt"
     5  	"os"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/onflow/crypto"
    10  	"github.com/rs/zerolog"
    11  	"github.com/stretchr/testify/require"
    12  
    13  	"github.com/onflow/flow-go/model/flow"
    14  	msg "github.com/onflow/flow-go/model/messages"
    15  	"github.com/onflow/flow-go/module/signature"
    16  	"github.com/onflow/flow-go/utils/unittest"
    17  )
    18  
    19  // node is a test object that simulates a running instance of the DKG protocol
    20  // where transitions from one phase to another are dictated by a timer.
    21  type node struct {
    22  	id             int
    23  	controller     *Controller
    24  	phase1Duration time.Duration
    25  	phase2Duration time.Duration
    26  	phase3Duration time.Duration
    27  }
    28  
    29  func newNode(id int, controller *Controller,
    30  	phase1Duration time.Duration,
    31  	phase2Duration time.Duration,
    32  	phase3Duration time.Duration) *node {
    33  
    34  	return &node{
    35  		id:             id,
    36  		controller:     controller,
    37  		phase1Duration: phase1Duration,
    38  		phase2Duration: phase2Duration,
    39  		phase3Duration: phase3Duration,
    40  	}
    41  }
    42  
    43  func (n *node) run() error {
    44  
    45  	// runErrCh is used to receive potential errors from the async DKG run
    46  	// routine
    47  	runErrCh := make(chan error)
    48  
    49  	// start the DKG controller
    50  	go func() {
    51  		runErrCh <- n.controller.Run()
    52  	}()
    53  
    54  	// timers to control phase transitions
    55  	var phase1Timer <-chan time.Time
    56  	var phase2Timer <-chan time.Time
    57  	var phase3Timer <-chan time.Time
    58  
    59  	phase1Timer = time.After(n.phase1Duration)
    60  
    61  	for {
    62  		select {
    63  		case err := <-runErrCh:
    64  			// received an error from the async run routine
    65  			return fmt.Errorf("Async Run error: %w", err)
    66  		case <-phase1Timer:
    67  			err := n.controller.EndPhase1()
    68  			if err != nil {
    69  				return fmt.Errorf("Error transitioning to Phase 2: %w", err)
    70  			}
    71  			phase2Timer = time.After(n.phase2Duration)
    72  		case <-phase2Timer:
    73  			err := n.controller.EndPhase2()
    74  			if err != nil {
    75  				return fmt.Errorf("Error transitioning to Phase 3: %w", err)
    76  			}
    77  			phase3Timer = time.After(n.phase3Duration)
    78  		case <-phase3Timer:
    79  			err := n.controller.End()
    80  			if err != nil {
    81  				return fmt.Errorf("Error ending DKG: %w", err)
    82  			}
    83  			return nil
    84  		}
    85  	}
    86  }
    87  
    88  // broker is a test implementation of DKGBroker that enables nodes to exchange
    89  // private and public messages through a shared set of channels.
    90  type broker struct {
    91  	id                int
    92  	privateChannels   []chan msg.PrivDKGMessageIn
    93  	broadcastChannels []chan msg.BroadcastDKGMessage
    94  	logger            zerolog.Logger
    95  	dkgInstanceID     string
    96  }
    97  
    98  // PrivateSend implements the crypto.DKGProcessor interface.
    99  func (b *broker) PrivateSend(dest int, data []byte) {
   100  	b.privateChannels[dest] <- msg.PrivDKGMessageIn{
   101  		DKGMessage:           msg.NewDKGMessage(data, b.dkgInstanceID),
   102  		CommitteeMemberIndex: uint64(b.id),
   103  	}
   104  }
   105  
   106  // Broadcast implements the crypto.DKGProcessor interface.
   107  //
   108  // ATTENTION: Normally the processor requires Broadcast to provide guaranteed
   109  // delivery (either all nodes receive the message or none of them receive it).
   110  // Here we are just assuming that with a long enough duration for phases 2 and
   111  // 3, all nodes are guaranteed to see everyone's messages. So it is important
   112  // to set timeouts carefully in the tests.
   113  func (b *broker) Broadcast(data []byte) {
   114  	for i := 0; i < len(b.broadcastChannels); i++ {
   115  		if i == b.id {
   116  			continue
   117  		}
   118  		// epoch and phase are not relevant at the controller level
   119  		b.broadcastChannels[i] <- msg.BroadcastDKGMessage{
   120  			DKGMessage:           msg.NewDKGMessage(data, b.dkgInstanceID),
   121  			CommitteeMemberIndex: uint64(b.id),
   122  		}
   123  	}
   124  }
   125  
   126  // Disqualify implements the crypto.DKGProcessor interface.
   127  func (b *broker) Disqualify(node int, log string) {
   128  	b.logger.Error().Msgf("node %d disqualified node %d: %s", b.id, node, log)
   129  }
   130  
   131  // FlagMisbehavior implements the crypto.DKGProcessor interface.
   132  func (b *broker) FlagMisbehavior(node int, logData string) {
   133  	b.logger.Error().Msgf("node %d flagged node %d: %s", b.id, node, logData)
   134  }
   135  
   136  // GetIndex implements the DKGBroker interface.
   137  func (b *broker) GetIndex() int {
   138  	return int(b.id)
   139  }
   140  
   141  // GetPrivateMsgCh implements the DKGBroker interface.
   142  func (b *broker) GetPrivateMsgCh() <-chan msg.PrivDKGMessageIn {
   143  	return b.privateChannels[b.id]
   144  }
   145  
   146  // GetBroadcastMsgCh implements the DKGBroker interface.
   147  func (b *broker) GetBroadcastMsgCh() <-chan msg.BroadcastDKGMessage {
   148  	return b.broadcastChannels[b.id]
   149  }
   150  
   151  // Poll implements the DKGBroker interface.
   152  func (b *broker) Poll(referenceBlock flow.Identifier) error { return nil }
   153  
   154  // SubmitResult implements the DKGBroker interface.
   155  func (b *broker) SubmitResult(crypto.PublicKey, []crypto.PublicKey) error { return nil }
   156  
   157  // Shutdown implements the DKGBroker interface.
   158  func (b *broker) Shutdown() {}
   159  
   160  type testCase struct {
   161  	totalNodes     int
   162  	phase1Duration time.Duration
   163  	phase2Duration time.Duration
   164  	phase3Duration time.Duration
   165  }
   166  
   167  // TestDKGHappyPath tests the controller in optimal conditions, when all nodes
   168  // are working correctly.
   169  func TestDKGHappyPath(t *testing.T) {
   170  	// Define different test cases with varying number of nodes, and phase
   171  	// durations. Since these are all happy path cases, there are no messages
   172  	// sent during phases 2 and 3; all messaging is done in phase 1. So we can
   173  	// can set shorter durations for phases 2 and 3.
   174  	testCases := []testCase{
   175  		{totalNodes: 5, phase1Duration: 1 * time.Second, phase2Duration: 10 * time.Millisecond, phase3Duration: 10 * time.Millisecond},
   176  		{totalNodes: 10, phase1Duration: 2 * time.Second, phase2Duration: 50 * time.Millisecond, phase3Duration: 50 * time.Millisecond},
   177  		{totalNodes: 15, phase1Duration: 5 * time.Second, phase2Duration: 100 * time.Millisecond, phase3Duration: 100 * time.Millisecond},
   178  	}
   179  
   180  	for _, tc := range testCases {
   181  		t.Run(fmt.Sprintf("%d nodes", tc.totalNodes), func(t *testing.T) {
   182  			testDKG(t, tc.totalNodes, tc.totalNodes, tc.phase1Duration, tc.phase2Duration, tc.phase3Duration)
   183  		})
   184  	}
   185  }
   186  
   187  // TestDKGThreshold tests that the controller results in a successful DKG as
   188  // long as the minimum threshold for non-byzantine nodes is satisfied.
   189  func TestDKGThreshold(t *testing.T) {
   190  	// define different test cases with varying number of nodes, and phase
   191  	// durations
   192  	testCases := []testCase{
   193  		{totalNodes: 5, phase1Duration: 1 * time.Second, phase2Duration: 100 * time.Millisecond, phase3Duration: 100 * time.Millisecond},
   194  		{totalNodes: 10, phase1Duration: 2 * time.Second, phase2Duration: 500 * time.Millisecond, phase3Duration: 500 * time.Millisecond},
   195  		{totalNodes: 15, phase1Duration: 5 * time.Second, phase2Duration: time.Second, phase3Duration: time.Second},
   196  	}
   197  
   198  	for _, tc := range testCases {
   199  		// gn is the minimum number of good nodes required for the DKG protocol
   200  		// to go well
   201  		gn := tc.totalNodes - signature.RandomBeaconThreshold(tc.totalNodes)
   202  		t.Run(fmt.Sprintf("%d/%d nodes", gn, tc.totalNodes), func(t *testing.T) {
   203  			testDKG(t, tc.totalNodes, gn, tc.phase1Duration, tc.phase2Duration, tc.phase3Duration)
   204  		})
   205  	}
   206  }
   207  
   208  func testDKG(t *testing.T, totalNodes int, goodNodes int, phase1Duration, phase2Duration, phase3Duration time.Duration) {
   209  	nodes := initNodes(t, totalNodes, phase1Duration, phase2Duration, phase3Duration)
   210  	gnodes := nodes[:goodNodes]
   211  
   212  	// Start all the good nodes in parallel
   213  	for _, n := range gnodes {
   214  		go func(node *node) {
   215  			err := node.run()
   216  			require.NoError(t, err)
   217  		}(n)
   218  	}
   219  
   220  	// Wait until they are all shutdown
   221  	wait(t, gnodes, 5*phase1Duration)
   222  
   223  	// Check that all nodes have agreed on the same set of public keys
   224  	checkArtifacts(t, gnodes, totalNodes)
   225  }
   226  
   227  // Initialise nodes and communication channels.
   228  func initNodes(t *testing.T, n int, phase1Duration, phase2Duration, phase3Duration time.Duration) []*node {
   229  	// Create the channels through which the nodes will communicate
   230  	privateChannels := make([]chan msg.PrivDKGMessageIn, 0, n)
   231  	broadcastChannels := make([]chan msg.BroadcastDKGMessage, 0, n)
   232  	for i := 0; i < n; i++ {
   233  		privateChannels = append(privateChannels, make(chan msg.PrivDKGMessageIn, 5*n*n))
   234  		broadcastChannels = append(broadcastChannels, make(chan msg.BroadcastDKGMessage, 5*n*n))
   235  	}
   236  
   237  	nodes := make([]*node, 0, n)
   238  
   239  	// Setup
   240  	for i := 0; i < n; i++ {
   241  		logger := zerolog.New(os.Stderr).With().Int("id", i).Logger()
   242  
   243  		broker := &broker{
   244  			id:                i,
   245  			privateChannels:   privateChannels,
   246  			broadcastChannels: broadcastChannels,
   247  			logger:            logger,
   248  		}
   249  
   250  		seed := unittest.SeedFixture(crypto.KeyGenSeedMinLen)
   251  
   252  		dkg, err := crypto.NewJointFeldman(n, signature.RandomBeaconThreshold(n), i, broker)
   253  		require.NoError(t, err)
   254  
   255  		controller := NewController(
   256  			logger,
   257  			"dkg_test",
   258  			dkg,
   259  			seed,
   260  			broker,
   261  		)
   262  		require.NoError(t, err)
   263  
   264  		node := newNode(i, controller, phase1Duration, phase2Duration, phase3Duration)
   265  		nodes = append(nodes, node)
   266  	}
   267  
   268  	return nodes
   269  }
   270  
   271  // Wait for all the nodes to reach the SHUTDOWN state, or timeout.
   272  func wait(t *testing.T, nodes []*node, timeout time.Duration) {
   273  	timer := time.After(timeout)
   274  	for {
   275  		select {
   276  		case <-timer:
   277  			t.Fatal("TIMEOUT")
   278  		default:
   279  			done := true
   280  			for _, node := range nodes {
   281  				if node.controller.GetState() != Shutdown {
   282  					done = false
   283  					break
   284  				}
   285  			}
   286  			if done {
   287  				return
   288  			}
   289  			time.Sleep(50 * time.Millisecond)
   290  		}
   291  	}
   292  }
   293  
   294  // Check that all nodes have produced the same set of public keys
   295  func checkArtifacts(t *testing.T, nodes []*node, totalNodes int) {
   296  	_, refGroupPublicKey, refPublicKeys := nodes[0].controller.GetArtifacts()
   297  
   298  	for i := 1; i < len(nodes); i++ {
   299  		privateShare, groupPublicKey, publicKeys := nodes[i].controller.GetArtifacts()
   300  
   301  		require.NotEmpty(t, privateShare)
   302  		require.NotEmpty(t, groupPublicKey)
   303  
   304  		require.True(t, refGroupPublicKey.Equals(groupPublicKey),
   305  			"node %d has a different groupPubKey than node 0: %s %s",
   306  			i,
   307  			groupPublicKey,
   308  			refGroupPublicKey)
   309  
   310  		require.Len(t, publicKeys, totalNodes)
   311  
   312  		for j := 0; j < totalNodes; j++ {
   313  			if !refPublicKeys[j].Equals(publicKeys[j]) {
   314  				t.Fatalf("node %d has a different pubs[%d] than node 0: %s, %s",
   315  					i,
   316  					j,
   317  					refPublicKeys[j],
   318  					publicKeys[j])
   319  			}
   320  		}
   321  	}
   322  }