github.com/koko1123/flow-go-1@v0.29.6/module/dkg/controller_test.go (about) 1 package dkg 2 3 import ( 4 "fmt" 5 "os" 6 "testing" 7 "time" 8 9 "github.com/rs/zerolog" 10 "github.com/stretchr/testify/assert" 11 "github.com/stretchr/testify/require" 12 13 "github.com/koko1123/flow-go-1/model/flow" 14 msg "github.com/koko1123/flow-go-1/model/messages" 15 "github.com/koko1123/flow-go-1/module/signature" 16 "github.com/koko1123/flow-go-1/utils/unittest" 17 "github.com/onflow/flow-go/crypto" 18 ) 19 20 // node is a test object that simulates a running instance of the DKG protocol 21 // where transitions from one phase to another are dictated by a timer. 22 type node struct { 23 id int 24 controller *Controller 25 phase1Duration time.Duration 26 phase2Duration time.Duration 27 phase3Duration time.Duration 28 } 29 30 func newNode(id int, controller *Controller, 31 phase1Duration time.Duration, 32 phase2Duration time.Duration, 33 phase3Duration time.Duration) *node { 34 35 return &node{ 36 id: id, 37 controller: controller, 38 phase1Duration: phase1Duration, 39 phase2Duration: phase2Duration, 40 phase3Duration: phase3Duration, 41 } 42 } 43 44 func (n *node) run() error { 45 46 // runErrCh is used to receive potential errors from the async DKG run 47 // routine 48 runErrCh := make(chan error) 49 50 // start the DKG controller 51 go func() { 52 runErrCh <- n.controller.Run() 53 }() 54 55 // timers to control phase transitions 56 var phase1Timer <-chan time.Time 57 var phase2Timer <-chan time.Time 58 var phase3Timer <-chan time.Time 59 60 phase1Timer = time.After(n.phase1Duration) 61 62 for { 63 select { 64 case err := <-runErrCh: 65 // received an error from the async run routine 66 return fmt.Errorf("Async Run error: %w", err) 67 case <-phase1Timer: 68 err := n.controller.EndPhase1() 69 if err != nil { 70 return fmt.Errorf("Error transitioning to Phase 2: %w", err) 71 } 72 phase2Timer = time.After(n.phase2Duration) 73 case <-phase2Timer: 74 err := n.controller.EndPhase2() 75 if err != nil { 76 return fmt.Errorf("Error transitioning to Phase 3: %w", err) 77 } 78 phase3Timer = time.After(n.phase3Duration) 79 case <-phase3Timer: 80 err := n.controller.End() 81 if err != nil { 82 return fmt.Errorf("Error ending DKG: %w", err) 83 } 84 return nil 85 } 86 } 87 } 88 89 // broker is a test implementation of DKGBroker that enables nodes to exchange 90 // private and public messages through a shared set of channels. 91 type broker struct { 92 id int 93 privateChannels []chan msg.PrivDKGMessageIn 94 broadcastChannels []chan msg.BroadcastDKGMessage 95 logger zerolog.Logger 96 dkgInstanceID string 97 } 98 99 // PrivateSend implements the crypto.DKGProcessor interface. 100 func (b *broker) PrivateSend(dest int, data []byte) { 101 b.privateChannels[dest] <- msg.PrivDKGMessageIn{ 102 DKGMessage: msg.NewDKGMessage(data, b.dkgInstanceID), 103 CommitteeMemberIndex: uint64(b.id), 104 } 105 } 106 107 // Broadcast implements the crypto.DKGProcessor interface. 108 // 109 // ATTENTION: Normally the processor requires Broadcast to provide guaranteed 110 // delivery (either all nodes receive the message or none of them receive it). 111 // Here we are just assuming that with a long enough duration for phases 2 and 112 // 3, all nodes are guaranteed to see everyone's messages. So it is important 113 // to set timeouts carefully in the tests. 114 func (b *broker) Broadcast(data []byte) { 115 for i := 0; i < len(b.broadcastChannels); i++ { 116 if i == b.id { 117 continue 118 } 119 // epoch and phase are not relevant at the controller level 120 b.broadcastChannels[i] <- msg.BroadcastDKGMessage{ 121 DKGMessage: msg.NewDKGMessage(data, b.dkgInstanceID), 122 CommitteeMemberIndex: uint64(b.id), 123 } 124 } 125 } 126 127 // Disqualify implements the crypto.DKGProcessor interface. 128 func (b *broker) Disqualify(node int, log string) { 129 b.logger.Error().Msgf("node %d disqualified node %d: %s", b.id, node, log) 130 } 131 132 // FlagMisbehavior implements the crypto.DKGProcessor interface. 133 func (b *broker) FlagMisbehavior(node int, logData string) { 134 b.logger.Error().Msgf("node %d flagged node %d: %s", b.id, node, logData) 135 } 136 137 // GetIndex implements the DKGBroker interface. 138 func (b *broker) GetIndex() int { 139 return int(b.id) 140 } 141 142 // GetPrivateMsgCh implements the DKGBroker interface. 143 func (b *broker) GetPrivateMsgCh() <-chan msg.PrivDKGMessageIn { 144 return b.privateChannels[b.id] 145 } 146 147 // GetBroadcastMsgCh implements the DKGBroker interface. 148 func (b *broker) GetBroadcastMsgCh() <-chan msg.BroadcastDKGMessage { 149 return b.broadcastChannels[b.id] 150 } 151 152 // Poll implements the DKGBroker interface. 153 func (b *broker) Poll(referenceBlock flow.Identifier) error { return nil } 154 155 // SubmitResult implements the DKGBroker interface. 156 func (b *broker) SubmitResult(crypto.PublicKey, []crypto.PublicKey) error { return nil } 157 158 // Shutdown implements the DKGBroker interface. 159 func (b *broker) Shutdown() {} 160 161 type testCase struct { 162 totalNodes int 163 phase1Duration time.Duration 164 phase2Duration time.Duration 165 phase3Duration time.Duration 166 } 167 168 // TestDKGHappyPath tests the controller in optimal conditions, when all nodes 169 // are working correctly. 170 func TestDKGHappyPath(t *testing.T) { 171 // Define different test cases with varying number of nodes, and phase 172 // durations. Since these are all happy path cases, there are no messages 173 // sent during phases 2 and 3; all messaging is done in phase 1. So we can 174 // can set shorter durations for phases 2 and 3. 175 testCases := []testCase{ 176 {totalNodes: 5, phase1Duration: 1 * time.Second, phase2Duration: 10 * time.Millisecond, phase3Duration: 10 * time.Millisecond}, 177 {totalNodes: 10, phase1Duration: 2 * time.Second, phase2Duration: 50 * time.Millisecond, phase3Duration: 50 * time.Millisecond}, 178 {totalNodes: 15, phase1Duration: 5 * time.Second, phase2Duration: 100 * time.Millisecond, phase3Duration: 100 * time.Millisecond}, 179 } 180 181 for _, tc := range testCases { 182 t.Run(fmt.Sprintf("%d nodes", tc.totalNodes), func(t *testing.T) { 183 testDKG(t, tc.totalNodes, tc.totalNodes, tc.phase1Duration, tc.phase2Duration, tc.phase3Duration) 184 }) 185 } 186 } 187 188 // TestDKGThreshold tests that the controller results in a successful DKG as 189 // long as the minimum threshold for non-byzantine nodes is satisfied. 190 func TestDKGThreshold(t *testing.T) { 191 // define different test cases with varying number of nodes, and phase 192 // durations 193 testCases := []testCase{ 194 {totalNodes: 5, phase1Duration: 1 * time.Second, phase2Duration: 100 * time.Millisecond, phase3Duration: 100 * time.Millisecond}, 195 {totalNodes: 10, phase1Duration: 2 * time.Second, phase2Duration: 500 * time.Millisecond, phase3Duration: 500 * time.Millisecond}, 196 {totalNodes: 15, phase1Duration: 5 * time.Second, phase2Duration: time.Second, phase3Duration: time.Second}, 197 } 198 199 for _, tc := range testCases { 200 // gn is the minimum number of good nodes required for the DKG protocol 201 // to go well 202 gn := tc.totalNodes - signature.RandomBeaconThreshold(tc.totalNodes) 203 t.Run(fmt.Sprintf("%d/%d nodes", gn, tc.totalNodes), func(t *testing.T) { 204 testDKG(t, tc.totalNodes, gn, tc.phase1Duration, tc.phase2Duration, tc.phase3Duration) 205 }) 206 } 207 } 208 209 func testDKG(t *testing.T, totalNodes int, goodNodes int, phase1Duration, phase2Duration, phase3Duration time.Duration) { 210 nodes := initNodes(t, totalNodes, phase1Duration, phase2Duration, phase3Duration) 211 gnodes := nodes[:goodNodes] 212 213 // Start all the good nodes in parallel 214 for _, n := range gnodes { 215 go func(node *node) { 216 err := node.run() 217 require.NoError(t, err) 218 }(n) 219 } 220 221 // Wait until they are all shutdown 222 wait(t, gnodes, 5*phase1Duration) 223 224 // Check that all nodes have agreed on the same set of public keys 225 checkArtifacts(t, gnodes, totalNodes) 226 } 227 228 // Initialise nodes and communication channels. 229 func initNodes(t *testing.T, n int, phase1Duration, phase2Duration, phase3Duration time.Duration) []*node { 230 // Create the channels through which the nodes will communicate 231 privateChannels := make([]chan msg.PrivDKGMessageIn, 0, n) 232 broadcastChannels := make([]chan msg.BroadcastDKGMessage, 0, n) 233 for i := 0; i < n; i++ { 234 privateChannels = append(privateChannels, make(chan msg.PrivDKGMessageIn, 5*n*n)) 235 broadcastChannels = append(broadcastChannels, make(chan msg.BroadcastDKGMessage, 5*n*n)) 236 } 237 238 nodes := make([]*node, 0, n) 239 240 // Setup 241 for i := 0; i < n; i++ { 242 logger := zerolog.New(os.Stderr).With().Int("id", i).Logger() 243 244 broker := &broker{ 245 id: i, 246 privateChannels: privateChannels, 247 broadcastChannels: broadcastChannels, 248 logger: logger, 249 } 250 251 seed := unittest.SeedFixture(20) 252 253 dkg, err := crypto.NewJointFeldman(n, signature.RandomBeaconThreshold(n), i, broker) 254 require.NoError(t, err) 255 256 // create a config with no delays for tests 257 config := ControllerConfig{ 258 BaseStartDelay: 0, 259 BaseHandleFirstBroadcastDelay: 0, 260 HandleSubsequentBroadcastDelay: 0, 261 } 262 263 controller := NewController( 264 logger, 265 "dkg_test", 266 dkg, 267 seed, 268 broker, 269 config, 270 ) 271 require.NoError(t, err) 272 273 node := newNode(i, controller, phase1Duration, phase2Duration, phase3Duration) 274 nodes = append(nodes, node) 275 } 276 277 return nodes 278 } 279 280 // Wait for all the nodes to reach the SHUTDOWN state, or timeout. 281 func wait(t *testing.T, nodes []*node, timeout time.Duration) { 282 timer := time.After(timeout) 283 for { 284 select { 285 case <-timer: 286 t.Fatal("TIMEOUT") 287 default: 288 done := true 289 for _, node := range nodes { 290 if node.controller.GetState() != Shutdown { 291 done = false 292 break 293 } 294 } 295 if done { 296 return 297 } 298 time.Sleep(50 * time.Millisecond) 299 } 300 } 301 } 302 303 // Check that all nodes have produced the same set of public keys 304 func checkArtifacts(t *testing.T, nodes []*node, totalNodes int) { 305 _, refGroupPublicKey, refPublicKeys := nodes[0].controller.GetArtifacts() 306 307 for i := 1; i < len(nodes); i++ { 308 privateShare, groupPublicKey, publicKeys := nodes[i].controller.GetArtifacts() 309 310 require.NotEmpty(t, privateShare) 311 require.NotEmpty(t, groupPublicKey) 312 313 require.True(t, refGroupPublicKey.Equals(groupPublicKey), 314 "node %d has a different groupPubKey than node 0: %s %s", 315 i, 316 groupPublicKey, 317 refGroupPublicKey) 318 319 require.Len(t, publicKeys, totalNodes) 320 321 for j := 0; j < totalNodes; j++ { 322 if !refPublicKeys[j].Equals(publicKeys[j]) { 323 t.Fatalf("node %d has a different pubs[%d] than node 0: %s, %s", 324 i, 325 j, 326 refPublicKeys[j], 327 publicKeys[j]) 328 } 329 } 330 } 331 } 332 333 func TestDelay(t *testing.T) { 334 335 t.Run("should return 0 delay for <=0 inputs", func(t *testing.T) { 336 delay := computePreprocessingDelay(0, 100) 337 assert.Equal(t, delay, time.Duration(0)) 338 delay = computePreprocessingDelay(time.Hour, 0) 339 assert.Equal(t, delay, time.Duration(0)) 340 delay = computePreprocessingDelay(time.Millisecond, -1) 341 assert.Equal(t, delay, time.Duration(0)) 342 delay = computePreprocessingDelay(-time.Millisecond, 100) 343 assert.Equal(t, delay, time.Duration(0)) 344 }) 345 346 // NOTE: this is a probabilistic test. It will (extremely infrequently) fail. 347 t.Run("should return different values for same inputs", func(t *testing.T) { 348 d1 := computePreprocessingDelay(time.Hour, 100) 349 d2 := computePreprocessingDelay(time.Hour, 100) 350 assert.NotEqual(t, d1, d2) 351 }) 352 353 t.Run("should return values in expected range", func(t *testing.T) { 354 baseDelay := time.Second 355 dkgSize := 100 356 minDelay := time.Duration(0) 357 // m=b*n^2 358 expectedMaxDelay := time.Duration(int64(baseDelay) * int64(dkgSize) * int64(dkgSize)) 359 360 maxDelay := computePreprocessingDelayMax(baseDelay, dkgSize) 361 assert.Equal(t, expectedMaxDelay, maxDelay) 362 363 delay := computePreprocessingDelay(baseDelay, dkgSize) 364 assert.LessOrEqual(t, minDelay, delay) 365 assert.GreaterOrEqual(t, expectedMaxDelay, delay) 366 }) 367 368 t.Run("should return values in expected range for defaults", func(t *testing.T) { 369 baseDelay := DefaultBaseHandleFirstBroadcastDelay 370 dkgSize := 150 371 minDelay := time.Duration(0) 372 // m=b*n^2 373 expectedMaxDelay := time.Duration(int64(baseDelay) * int64(dkgSize) * int64(dkgSize)) 374 375 maxDelay := computePreprocessingDelayMax(baseDelay, dkgSize) 376 assert.Equal(t, expectedMaxDelay, maxDelay) 377 378 delay := computePreprocessingDelay(baseDelay, dkgSize) 379 assert.LessOrEqual(t, minDelay, delay) 380 assert.GreaterOrEqual(t, expectedMaxDelay, delay) 381 }) 382 }