github.com/koko1123/flow-go-1@v0.29.6/module/dkg/controller_test.go (about)

     1  package dkg
     2  
     3  import (
     4  	"fmt"
     5  	"os"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/rs/zerolog"
    10  	"github.com/stretchr/testify/assert"
    11  	"github.com/stretchr/testify/require"
    12  
    13  	"github.com/koko1123/flow-go-1/model/flow"
    14  	msg "github.com/koko1123/flow-go-1/model/messages"
    15  	"github.com/koko1123/flow-go-1/module/signature"
    16  	"github.com/koko1123/flow-go-1/utils/unittest"
    17  	"github.com/onflow/flow-go/crypto"
    18  )
    19  
    20  // node is a test object that simulates a running instance of the DKG protocol
    21  // where transitions from one phase to another are dictated by a timer.
    22  type node struct {
    23  	id             int
    24  	controller     *Controller
    25  	phase1Duration time.Duration
    26  	phase2Duration time.Duration
    27  	phase3Duration time.Duration
    28  }
    29  
    30  func newNode(id int, controller *Controller,
    31  	phase1Duration time.Duration,
    32  	phase2Duration time.Duration,
    33  	phase3Duration time.Duration) *node {
    34  
    35  	return &node{
    36  		id:             id,
    37  		controller:     controller,
    38  		phase1Duration: phase1Duration,
    39  		phase2Duration: phase2Duration,
    40  		phase3Duration: phase3Duration,
    41  	}
    42  }
    43  
    44  func (n *node) run() error {
    45  
    46  	// runErrCh is used to receive potential errors from the async DKG run
    47  	// routine
    48  	runErrCh := make(chan error)
    49  
    50  	// start the DKG controller
    51  	go func() {
    52  		runErrCh <- n.controller.Run()
    53  	}()
    54  
    55  	// timers to control phase transitions
    56  	var phase1Timer <-chan time.Time
    57  	var phase2Timer <-chan time.Time
    58  	var phase3Timer <-chan time.Time
    59  
    60  	phase1Timer = time.After(n.phase1Duration)
    61  
    62  	for {
    63  		select {
    64  		case err := <-runErrCh:
    65  			// received an error from the async run routine
    66  			return fmt.Errorf("Async Run error: %w", err)
    67  		case <-phase1Timer:
    68  			err := n.controller.EndPhase1()
    69  			if err != nil {
    70  				return fmt.Errorf("Error transitioning to Phase 2: %w", err)
    71  			}
    72  			phase2Timer = time.After(n.phase2Duration)
    73  		case <-phase2Timer:
    74  			err := n.controller.EndPhase2()
    75  			if err != nil {
    76  				return fmt.Errorf("Error transitioning to Phase 3: %w", err)
    77  			}
    78  			phase3Timer = time.After(n.phase3Duration)
    79  		case <-phase3Timer:
    80  			err := n.controller.End()
    81  			if err != nil {
    82  				return fmt.Errorf("Error ending DKG: %w", err)
    83  			}
    84  			return nil
    85  		}
    86  	}
    87  }
    88  
    89  // broker is a test implementation of DKGBroker that enables nodes to exchange
    90  // private and public messages through a shared set of channels.
    91  type broker struct {
    92  	id                int
    93  	privateChannels   []chan msg.PrivDKGMessageIn
    94  	broadcastChannels []chan msg.BroadcastDKGMessage
    95  	logger            zerolog.Logger
    96  	dkgInstanceID     string
    97  }
    98  
    99  // PrivateSend implements the crypto.DKGProcessor interface.
   100  func (b *broker) PrivateSend(dest int, data []byte) {
   101  	b.privateChannels[dest] <- msg.PrivDKGMessageIn{
   102  		DKGMessage:           msg.NewDKGMessage(data, b.dkgInstanceID),
   103  		CommitteeMemberIndex: uint64(b.id),
   104  	}
   105  }
   106  
   107  // Broadcast implements the crypto.DKGProcessor interface.
   108  //
   109  // ATTENTION: Normally the processor requires Broadcast to provide guaranteed
   110  // delivery (either all nodes receive the message or none of them receive it).
   111  // Here we are just assuming that with a long enough duration for phases 2 and
   112  // 3, all nodes are guaranteed to see everyone's messages. So it is important
   113  // to set timeouts carefully in the tests.
   114  func (b *broker) Broadcast(data []byte) {
   115  	for i := 0; i < len(b.broadcastChannels); i++ {
   116  		if i == b.id {
   117  			continue
   118  		}
   119  		// epoch and phase are not relevant at the controller level
   120  		b.broadcastChannels[i] <- msg.BroadcastDKGMessage{
   121  			DKGMessage:           msg.NewDKGMessage(data, b.dkgInstanceID),
   122  			CommitteeMemberIndex: uint64(b.id),
   123  		}
   124  	}
   125  }
   126  
   127  // Disqualify implements the crypto.DKGProcessor interface.
   128  func (b *broker) Disqualify(node int, log string) {
   129  	b.logger.Error().Msgf("node %d disqualified node %d: %s", b.id, node, log)
   130  }
   131  
   132  // FlagMisbehavior implements the crypto.DKGProcessor interface.
   133  func (b *broker) FlagMisbehavior(node int, logData string) {
   134  	b.logger.Error().Msgf("node %d flagged node %d: %s", b.id, node, logData)
   135  }
   136  
   137  // GetIndex implements the DKGBroker interface.
   138  func (b *broker) GetIndex() int {
   139  	return int(b.id)
   140  }
   141  
   142  // GetPrivateMsgCh implements the DKGBroker interface.
   143  func (b *broker) GetPrivateMsgCh() <-chan msg.PrivDKGMessageIn {
   144  	return b.privateChannels[b.id]
   145  }
   146  
   147  // GetBroadcastMsgCh implements the DKGBroker interface.
   148  func (b *broker) GetBroadcastMsgCh() <-chan msg.BroadcastDKGMessage {
   149  	return b.broadcastChannels[b.id]
   150  }
   151  
   152  // Poll implements the DKGBroker interface.
   153  func (b *broker) Poll(referenceBlock flow.Identifier) error { return nil }
   154  
   155  // SubmitResult implements the DKGBroker interface.
   156  func (b *broker) SubmitResult(crypto.PublicKey, []crypto.PublicKey) error { return nil }
   157  
   158  // Shutdown implements the DKGBroker interface.
   159  func (b *broker) Shutdown() {}
   160  
   161  type testCase struct {
   162  	totalNodes     int
   163  	phase1Duration time.Duration
   164  	phase2Duration time.Duration
   165  	phase3Duration time.Duration
   166  }
   167  
   168  // TestDKGHappyPath tests the controller in optimal conditions, when all nodes
   169  // are working correctly.
   170  func TestDKGHappyPath(t *testing.T) {
   171  	// Define different test cases with varying number of nodes, and phase
   172  	// durations. Since these are all happy path cases, there are no messages
   173  	// sent during phases 2 and 3; all messaging is done in phase 1. So we can
   174  	// can set shorter durations for phases 2 and 3.
   175  	testCases := []testCase{
   176  		{totalNodes: 5, phase1Duration: 1 * time.Second, phase2Duration: 10 * time.Millisecond, phase3Duration: 10 * time.Millisecond},
   177  		{totalNodes: 10, phase1Duration: 2 * time.Second, phase2Duration: 50 * time.Millisecond, phase3Duration: 50 * time.Millisecond},
   178  		{totalNodes: 15, phase1Duration: 5 * time.Second, phase2Duration: 100 * time.Millisecond, phase3Duration: 100 * time.Millisecond},
   179  	}
   180  
   181  	for _, tc := range testCases {
   182  		t.Run(fmt.Sprintf("%d nodes", tc.totalNodes), func(t *testing.T) {
   183  			testDKG(t, tc.totalNodes, tc.totalNodes, tc.phase1Duration, tc.phase2Duration, tc.phase3Duration)
   184  		})
   185  	}
   186  }
   187  
   188  // TestDKGThreshold tests that the controller results in a successful DKG as
   189  // long as the minimum threshold for non-byzantine nodes is satisfied.
   190  func TestDKGThreshold(t *testing.T) {
   191  	// define different test cases with varying number of nodes, and phase
   192  	// durations
   193  	testCases := []testCase{
   194  		{totalNodes: 5, phase1Duration: 1 * time.Second, phase2Duration: 100 * time.Millisecond, phase3Duration: 100 * time.Millisecond},
   195  		{totalNodes: 10, phase1Duration: 2 * time.Second, phase2Duration: 500 * time.Millisecond, phase3Duration: 500 * time.Millisecond},
   196  		{totalNodes: 15, phase1Duration: 5 * time.Second, phase2Duration: time.Second, phase3Duration: time.Second},
   197  	}
   198  
   199  	for _, tc := range testCases {
   200  		// gn is the minimum number of good nodes required for the DKG protocol
   201  		// to go well
   202  		gn := tc.totalNodes - signature.RandomBeaconThreshold(tc.totalNodes)
   203  		t.Run(fmt.Sprintf("%d/%d nodes", gn, tc.totalNodes), func(t *testing.T) {
   204  			testDKG(t, tc.totalNodes, gn, tc.phase1Duration, tc.phase2Duration, tc.phase3Duration)
   205  		})
   206  	}
   207  }
   208  
   209  func testDKG(t *testing.T, totalNodes int, goodNodes int, phase1Duration, phase2Duration, phase3Duration time.Duration) {
   210  	nodes := initNodes(t, totalNodes, phase1Duration, phase2Duration, phase3Duration)
   211  	gnodes := nodes[:goodNodes]
   212  
   213  	// Start all the good nodes in parallel
   214  	for _, n := range gnodes {
   215  		go func(node *node) {
   216  			err := node.run()
   217  			require.NoError(t, err)
   218  		}(n)
   219  	}
   220  
   221  	// Wait until they are all shutdown
   222  	wait(t, gnodes, 5*phase1Duration)
   223  
   224  	// Check that all nodes have agreed on the same set of public keys
   225  	checkArtifacts(t, gnodes, totalNodes)
   226  }
   227  
   228  // Initialise nodes and communication channels.
   229  func initNodes(t *testing.T, n int, phase1Duration, phase2Duration, phase3Duration time.Duration) []*node {
   230  	// Create the channels through which the nodes will communicate
   231  	privateChannels := make([]chan msg.PrivDKGMessageIn, 0, n)
   232  	broadcastChannels := make([]chan msg.BroadcastDKGMessage, 0, n)
   233  	for i := 0; i < n; i++ {
   234  		privateChannels = append(privateChannels, make(chan msg.PrivDKGMessageIn, 5*n*n))
   235  		broadcastChannels = append(broadcastChannels, make(chan msg.BroadcastDKGMessage, 5*n*n))
   236  	}
   237  
   238  	nodes := make([]*node, 0, n)
   239  
   240  	// Setup
   241  	for i := 0; i < n; i++ {
   242  		logger := zerolog.New(os.Stderr).With().Int("id", i).Logger()
   243  
   244  		broker := &broker{
   245  			id:                i,
   246  			privateChannels:   privateChannels,
   247  			broadcastChannels: broadcastChannels,
   248  			logger:            logger,
   249  		}
   250  
   251  		seed := unittest.SeedFixture(20)
   252  
   253  		dkg, err := crypto.NewJointFeldman(n, signature.RandomBeaconThreshold(n), i, broker)
   254  		require.NoError(t, err)
   255  
   256  		// create a config with no delays for tests
   257  		config := ControllerConfig{
   258  			BaseStartDelay:                 0,
   259  			BaseHandleFirstBroadcastDelay:  0,
   260  			HandleSubsequentBroadcastDelay: 0,
   261  		}
   262  
   263  		controller := NewController(
   264  			logger,
   265  			"dkg_test",
   266  			dkg,
   267  			seed,
   268  			broker,
   269  			config,
   270  		)
   271  		require.NoError(t, err)
   272  
   273  		node := newNode(i, controller, phase1Duration, phase2Duration, phase3Duration)
   274  		nodes = append(nodes, node)
   275  	}
   276  
   277  	return nodes
   278  }
   279  
   280  // Wait for all the nodes to reach the SHUTDOWN state, or timeout.
   281  func wait(t *testing.T, nodes []*node, timeout time.Duration) {
   282  	timer := time.After(timeout)
   283  	for {
   284  		select {
   285  		case <-timer:
   286  			t.Fatal("TIMEOUT")
   287  		default:
   288  			done := true
   289  			for _, node := range nodes {
   290  				if node.controller.GetState() != Shutdown {
   291  					done = false
   292  					break
   293  				}
   294  			}
   295  			if done {
   296  				return
   297  			}
   298  			time.Sleep(50 * time.Millisecond)
   299  		}
   300  	}
   301  }
   302  
   303  // Check that all nodes have produced the same set of public keys
   304  func checkArtifacts(t *testing.T, nodes []*node, totalNodes int) {
   305  	_, refGroupPublicKey, refPublicKeys := nodes[0].controller.GetArtifacts()
   306  
   307  	for i := 1; i < len(nodes); i++ {
   308  		privateShare, groupPublicKey, publicKeys := nodes[i].controller.GetArtifacts()
   309  
   310  		require.NotEmpty(t, privateShare)
   311  		require.NotEmpty(t, groupPublicKey)
   312  
   313  		require.True(t, refGroupPublicKey.Equals(groupPublicKey),
   314  			"node %d has a different groupPubKey than node 0: %s %s",
   315  			i,
   316  			groupPublicKey,
   317  			refGroupPublicKey)
   318  
   319  		require.Len(t, publicKeys, totalNodes)
   320  
   321  		for j := 0; j < totalNodes; j++ {
   322  			if !refPublicKeys[j].Equals(publicKeys[j]) {
   323  				t.Fatalf("node %d has a different pubs[%d] than node 0: %s, %s",
   324  					i,
   325  					j,
   326  					refPublicKeys[j],
   327  					publicKeys[j])
   328  			}
   329  		}
   330  	}
   331  }
   332  
   333  func TestDelay(t *testing.T) {
   334  
   335  	t.Run("should return 0 delay for <=0 inputs", func(t *testing.T) {
   336  		delay := computePreprocessingDelay(0, 100)
   337  		assert.Equal(t, delay, time.Duration(0))
   338  		delay = computePreprocessingDelay(time.Hour, 0)
   339  		assert.Equal(t, delay, time.Duration(0))
   340  		delay = computePreprocessingDelay(time.Millisecond, -1)
   341  		assert.Equal(t, delay, time.Duration(0))
   342  		delay = computePreprocessingDelay(-time.Millisecond, 100)
   343  		assert.Equal(t, delay, time.Duration(0))
   344  	})
   345  
   346  	// NOTE: this is a probabilistic test. It will (extremely infrequently) fail.
   347  	t.Run("should return different values for same inputs", func(t *testing.T) {
   348  		d1 := computePreprocessingDelay(time.Hour, 100)
   349  		d2 := computePreprocessingDelay(time.Hour, 100)
   350  		assert.NotEqual(t, d1, d2)
   351  	})
   352  
   353  	t.Run("should return values in expected range", func(t *testing.T) {
   354  		baseDelay := time.Second
   355  		dkgSize := 100
   356  		minDelay := time.Duration(0)
   357  		// m=b*n^2
   358  		expectedMaxDelay := time.Duration(int64(baseDelay) * int64(dkgSize) * int64(dkgSize))
   359  
   360  		maxDelay := computePreprocessingDelayMax(baseDelay, dkgSize)
   361  		assert.Equal(t, expectedMaxDelay, maxDelay)
   362  
   363  		delay := computePreprocessingDelay(baseDelay, dkgSize)
   364  		assert.LessOrEqual(t, minDelay, delay)
   365  		assert.GreaterOrEqual(t, expectedMaxDelay, delay)
   366  	})
   367  
   368  	t.Run("should return values in expected range for defaults", func(t *testing.T) {
   369  		baseDelay := DefaultBaseHandleFirstBroadcastDelay
   370  		dkgSize := 150
   371  		minDelay := time.Duration(0)
   372  		// m=b*n^2
   373  		expectedMaxDelay := time.Duration(int64(baseDelay) * int64(dkgSize) * int64(dkgSize))
   374  
   375  		maxDelay := computePreprocessingDelayMax(baseDelay, dkgSize)
   376  		assert.Equal(t, expectedMaxDelay, maxDelay)
   377  
   378  		delay := computePreprocessingDelay(baseDelay, dkgSize)
   379  		assert.LessOrEqual(t, minDelay, delay)
   380  		assert.GreaterOrEqual(t, expectedMaxDelay, delay)
   381  	})
   382  }