github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/module/dkg/controller_test.go (about) 1 package dkg 2 3 import ( 4 "fmt" 5 "os" 6 "testing" 7 "time" 8 9 "github.com/onflow/crypto" 10 "github.com/rs/zerolog" 11 "github.com/stretchr/testify/require" 12 13 "github.com/onflow/flow-go/model/flow" 14 msg "github.com/onflow/flow-go/model/messages" 15 "github.com/onflow/flow-go/module/signature" 16 "github.com/onflow/flow-go/utils/unittest" 17 ) 18 19 // node is a test object that simulates a running instance of the DKG protocol 20 // where transitions from one phase to another are dictated by a timer. 21 type node struct { 22 id int 23 controller *Controller 24 phase1Duration time.Duration 25 phase2Duration time.Duration 26 phase3Duration time.Duration 27 } 28 29 func newNode(id int, controller *Controller, 30 phase1Duration time.Duration, 31 phase2Duration time.Duration, 32 phase3Duration time.Duration) *node { 33 34 return &node{ 35 id: id, 36 controller: controller, 37 phase1Duration: phase1Duration, 38 phase2Duration: phase2Duration, 39 phase3Duration: phase3Duration, 40 } 41 } 42 43 func (n *node) run() error { 44 45 // runErrCh is used to receive potential errors from the async DKG run 46 // routine 47 runErrCh := make(chan error) 48 49 // start the DKG controller 50 go func() { 51 runErrCh <- n.controller.Run() 52 }() 53 54 // timers to control phase transitions 55 var phase1Timer <-chan time.Time 56 var phase2Timer <-chan time.Time 57 var phase3Timer <-chan time.Time 58 59 phase1Timer = time.After(n.phase1Duration) 60 61 for { 62 select { 63 case err := <-runErrCh: 64 // received an error from the async run routine 65 return fmt.Errorf("Async Run error: %w", err) 66 case <-phase1Timer: 67 err := n.controller.EndPhase1() 68 if err != nil { 69 return fmt.Errorf("Error transitioning to Phase 2: %w", err) 70 } 71 phase2Timer = time.After(n.phase2Duration) 72 case <-phase2Timer: 73 err := n.controller.EndPhase2() 74 if err != nil { 75 return fmt.Errorf("Error transitioning to Phase 3: %w", err) 76 } 77 phase3Timer = time.After(n.phase3Duration) 78 case <-phase3Timer: 79 err := n.controller.End() 80 if err != nil { 81 return fmt.Errorf("Error ending DKG: %w", err) 82 } 83 return nil 84 } 85 } 86 } 87 88 // broker is a test implementation of DKGBroker that enables nodes to exchange 89 // private and public messages through a shared set of channels. 90 type broker struct { 91 id int 92 privateChannels []chan msg.PrivDKGMessageIn 93 broadcastChannels []chan msg.BroadcastDKGMessage 94 logger zerolog.Logger 95 dkgInstanceID string 96 } 97 98 // PrivateSend implements the crypto.DKGProcessor interface. 99 func (b *broker) PrivateSend(dest int, data []byte) { 100 b.privateChannels[dest] <- msg.PrivDKGMessageIn{ 101 DKGMessage: msg.NewDKGMessage(data, b.dkgInstanceID), 102 CommitteeMemberIndex: uint64(b.id), 103 } 104 } 105 106 // Broadcast implements the crypto.DKGProcessor interface. 107 // 108 // ATTENTION: Normally the processor requires Broadcast to provide guaranteed 109 // delivery (either all nodes receive the message or none of them receive it). 110 // Here we are just assuming that with a long enough duration for phases 2 and 111 // 3, all nodes are guaranteed to see everyone's messages. So it is important 112 // to set timeouts carefully in the tests. 113 func (b *broker) Broadcast(data []byte) { 114 for i := 0; i < len(b.broadcastChannels); i++ { 115 if i == b.id { 116 continue 117 } 118 // epoch and phase are not relevant at the controller level 119 b.broadcastChannels[i] <- msg.BroadcastDKGMessage{ 120 DKGMessage: msg.NewDKGMessage(data, b.dkgInstanceID), 121 CommitteeMemberIndex: uint64(b.id), 122 } 123 } 124 } 125 126 // Disqualify implements the crypto.DKGProcessor interface. 127 func (b *broker) Disqualify(node int, log string) { 128 b.logger.Error().Msgf("node %d disqualified node %d: %s", b.id, node, log) 129 } 130 131 // FlagMisbehavior implements the crypto.DKGProcessor interface. 132 func (b *broker) FlagMisbehavior(node int, logData string) { 133 b.logger.Error().Msgf("node %d flagged node %d: %s", b.id, node, logData) 134 } 135 136 // GetIndex implements the DKGBroker interface. 137 func (b *broker) GetIndex() int { 138 return int(b.id) 139 } 140 141 // GetPrivateMsgCh implements the DKGBroker interface. 142 func (b *broker) GetPrivateMsgCh() <-chan msg.PrivDKGMessageIn { 143 return b.privateChannels[b.id] 144 } 145 146 // GetBroadcastMsgCh implements the DKGBroker interface. 147 func (b *broker) GetBroadcastMsgCh() <-chan msg.BroadcastDKGMessage { 148 return b.broadcastChannels[b.id] 149 } 150 151 // Poll implements the DKGBroker interface. 152 func (b *broker) Poll(referenceBlock flow.Identifier) error { return nil } 153 154 // SubmitResult implements the DKGBroker interface. 155 func (b *broker) SubmitResult(crypto.PublicKey, []crypto.PublicKey) error { return nil } 156 157 // Shutdown implements the DKGBroker interface. 158 func (b *broker) Shutdown() {} 159 160 type testCase struct { 161 totalNodes int 162 phase1Duration time.Duration 163 phase2Duration time.Duration 164 phase3Duration time.Duration 165 } 166 167 // TestDKGHappyPath tests the controller in optimal conditions, when all nodes 168 // are working correctly. 169 func TestDKGHappyPath(t *testing.T) { 170 // Define different test cases with varying number of nodes, and phase 171 // durations. Since these are all happy path cases, there are no messages 172 // sent during phases 2 and 3; all messaging is done in phase 1. So we can 173 // can set shorter durations for phases 2 and 3. 174 testCases := []testCase{ 175 {totalNodes: 5, phase1Duration: 1 * time.Second, phase2Duration: 10 * time.Millisecond, phase3Duration: 10 * time.Millisecond}, 176 {totalNodes: 10, phase1Duration: 2 * time.Second, phase2Duration: 50 * time.Millisecond, phase3Duration: 50 * time.Millisecond}, 177 {totalNodes: 15, phase1Duration: 5 * time.Second, phase2Duration: 100 * time.Millisecond, phase3Duration: 100 * time.Millisecond}, 178 } 179 180 for _, tc := range testCases { 181 t.Run(fmt.Sprintf("%d nodes", tc.totalNodes), func(t *testing.T) { 182 testDKG(t, tc.totalNodes, tc.totalNodes, tc.phase1Duration, tc.phase2Duration, tc.phase3Duration) 183 }) 184 } 185 } 186 187 // TestDKGThreshold tests that the controller results in a successful DKG as 188 // long as the minimum threshold for non-byzantine nodes is satisfied. 189 func TestDKGThreshold(t *testing.T) { 190 // define different test cases with varying number of nodes, and phase 191 // durations 192 testCases := []testCase{ 193 {totalNodes: 5, phase1Duration: 1 * time.Second, phase2Duration: 100 * time.Millisecond, phase3Duration: 100 * time.Millisecond}, 194 {totalNodes: 10, phase1Duration: 2 * time.Second, phase2Duration: 500 * time.Millisecond, phase3Duration: 500 * time.Millisecond}, 195 {totalNodes: 15, phase1Duration: 5 * time.Second, phase2Duration: time.Second, phase3Duration: time.Second}, 196 } 197 198 for _, tc := range testCases { 199 // gn is the minimum number of good nodes required for the DKG protocol 200 // to go well 201 gn := tc.totalNodes - signature.RandomBeaconThreshold(tc.totalNodes) 202 t.Run(fmt.Sprintf("%d/%d nodes", gn, tc.totalNodes), func(t *testing.T) { 203 testDKG(t, tc.totalNodes, gn, tc.phase1Duration, tc.phase2Duration, tc.phase3Duration) 204 }) 205 } 206 } 207 208 func testDKG(t *testing.T, totalNodes int, goodNodes int, phase1Duration, phase2Duration, phase3Duration time.Duration) { 209 nodes := initNodes(t, totalNodes, phase1Duration, phase2Duration, phase3Duration) 210 gnodes := nodes[:goodNodes] 211 212 // Start all the good nodes in parallel 213 for _, n := range gnodes { 214 go func(node *node) { 215 err := node.run() 216 require.NoError(t, err) 217 }(n) 218 } 219 220 // Wait until they are all shutdown 221 wait(t, gnodes, 5*phase1Duration) 222 223 // Check that all nodes have agreed on the same set of public keys 224 checkArtifacts(t, gnodes, totalNodes) 225 } 226 227 // Initialise nodes and communication channels. 228 func initNodes(t *testing.T, n int, phase1Duration, phase2Duration, phase3Duration time.Duration) []*node { 229 // Create the channels through which the nodes will communicate 230 privateChannels := make([]chan msg.PrivDKGMessageIn, 0, n) 231 broadcastChannels := make([]chan msg.BroadcastDKGMessage, 0, n) 232 for i := 0; i < n; i++ { 233 privateChannels = append(privateChannels, make(chan msg.PrivDKGMessageIn, 5*n*n)) 234 broadcastChannels = append(broadcastChannels, make(chan msg.BroadcastDKGMessage, 5*n*n)) 235 } 236 237 nodes := make([]*node, 0, n) 238 239 // Setup 240 for i := 0; i < n; i++ { 241 logger := zerolog.New(os.Stderr).With().Int("id", i).Logger() 242 243 broker := &broker{ 244 id: i, 245 privateChannels: privateChannels, 246 broadcastChannels: broadcastChannels, 247 logger: logger, 248 } 249 250 seed := unittest.SeedFixture(crypto.KeyGenSeedMinLen) 251 252 dkg, err := crypto.NewJointFeldman(n, signature.RandomBeaconThreshold(n), i, broker) 253 require.NoError(t, err) 254 255 controller := NewController( 256 logger, 257 "dkg_test", 258 dkg, 259 seed, 260 broker, 261 ) 262 require.NoError(t, err) 263 264 node := newNode(i, controller, phase1Duration, phase2Duration, phase3Duration) 265 nodes = append(nodes, node) 266 } 267 268 return nodes 269 } 270 271 // Wait for all the nodes to reach the SHUTDOWN state, or timeout. 272 func wait(t *testing.T, nodes []*node, timeout time.Duration) { 273 timer := time.After(timeout) 274 for { 275 select { 276 case <-timer: 277 t.Fatal("TIMEOUT") 278 default: 279 done := true 280 for _, node := range nodes { 281 if node.controller.GetState() != Shutdown { 282 done = false 283 break 284 } 285 } 286 if done { 287 return 288 } 289 time.Sleep(50 * time.Millisecond) 290 } 291 } 292 } 293 294 // Check that all nodes have produced the same set of public keys 295 func checkArtifacts(t *testing.T, nodes []*node, totalNodes int) { 296 _, refGroupPublicKey, refPublicKeys := nodes[0].controller.GetArtifacts() 297 298 for i := 1; i < len(nodes); i++ { 299 privateShare, groupPublicKey, publicKeys := nodes[i].controller.GetArtifacts() 300 301 require.NotEmpty(t, privateShare) 302 require.NotEmpty(t, groupPublicKey) 303 304 require.True(t, refGroupPublicKey.Equals(groupPublicKey), 305 "node %d has a different groupPubKey than node 0: %s %s", 306 i, 307 groupPublicKey, 308 refGroupPublicKey) 309 310 require.Len(t, publicKeys, totalNodes) 311 312 for j := 0; j < totalNodes; j++ { 313 if !refPublicKeys[j].Equals(publicKeys[j]) { 314 t.Fatalf("node %d has a different pubs[%d] than node 0: %s, %s", 315 i, 316 j, 317 refPublicKeys[j], 318 publicKeys[j]) 319 } 320 } 321 } 322 }