github.com/daeglee/go-ethereum@v0.0.0-20190504220456-cad3e8d18e9b/swarm/pss/prox_test.go (about) 1 package pss 2 3 import ( 4 "context" 5 "crypto/ecdsa" 6 "encoding/binary" 7 "errors" 8 "fmt" 9 "strconv" 10 "strings" 11 "sync" 12 "testing" 13 "time" 14 15 "github.com/ethereum/go-ethereum/common" 16 "github.com/ethereum/go-ethereum/common/hexutil" 17 "github.com/ethereum/go-ethereum/log" 18 "github.com/ethereum/go-ethereum/node" 19 "github.com/ethereum/go-ethereum/p2p" 20 "github.com/ethereum/go-ethereum/p2p/enode" 21 "github.com/ethereum/go-ethereum/p2p/simulations/adapters" 22 "github.com/ethereum/go-ethereum/rpc" 23 "github.com/ethereum/go-ethereum/swarm/network" 24 "github.com/ethereum/go-ethereum/swarm/network/simulation" 25 "github.com/ethereum/go-ethereum/swarm/pot" 26 "github.com/ethereum/go-ethereum/swarm/state" 27 ) 28 29 // needed to make the enode id of the receiving node available to the handler for triggers 30 type handlerContextFunc func(*testData, *adapters.NodeConfig) *handler 31 32 // struct to notify reception of messages to simulation driver 33 // TODO To make code cleaner: 34 // - consider a separate pss unwrap to message event in sim framework (this will make eventual message propagation analysis with pss easier/possible in the future) 35 // - consider also test api calls to inspect handling results of messages 36 type handlerNotification struct { 37 id enode.ID 38 serial uint64 39 } 40 41 type testData struct { 42 mu sync.Mutex 43 sim *simulation.Simulation 44 handlerDone bool // set to true on termination of the simulation run 45 requiredMessages int 46 allowedMessages int 47 messageCount int 48 kademlias map[enode.ID]*network.Kademlia 49 nodeAddrs map[enode.ID][]byte // make predictable overlay addresses from the generated random enode ids 50 recipients map[int][]enode.ID // for logging output only 51 allowed map[int][]enode.ID // allowed recipients 52 expectedMsgs map[enode.ID][]uint64 // message serials we expect respective nodes to receive 53 allowedMsgs map[enode.ID][]uint64 // message serials we expect respective nodes to receive 54 senders map[int]enode.ID // originating nodes of the messages (intention is to choose as far as possible from the receiving neighborhood) 55 handlerC chan handlerNotification // passes message from pss message handler to simulation driver 56 doneC chan struct{} // terminates the handler channel listener 57 errC chan error // error to pass to main sim thread 58 msgC chan handlerNotification // message receipt notification to main sim thread 59 msgs [][]byte // recipient addresses of messages 60 } 61 62 var ( 63 pof = pot.DefaultPof(256) // generate messages and index them 64 topic = BytesToTopic([]byte{0xf3, 0x9e, 0x06, 0x82}) 65 ) 66 67 func (d *testData) getMsgCount() int { 68 d.mu.Lock() 69 defer d.mu.Unlock() 70 return d.messageCount 71 } 72 73 func (d *testData) incrementMsgCount() int { 74 d.mu.Lock() 75 defer d.mu.Unlock() 76 d.messageCount++ 77 return d.messageCount 78 } 79 80 func (d *testData) isDone() bool { 81 d.mu.Lock() 82 defer d.mu.Unlock() 83 return d.handlerDone 84 } 85 86 func (d *testData) setDone() { 87 d.mu.Lock() 88 defer d.mu.Unlock() 89 d.handlerDone = true 90 } 91 92 func getCmdParams(t *testing.T) (int, int, time.Duration) { 93 args := strings.Split(t.Name(), "/") 94 msgCount, err := strconv.ParseInt(args[2], 10, 16) 95 if err != nil { 96 t.Fatal(err) 97 } 98 nodeCount, err := strconv.ParseInt(args[1], 10, 16) 99 if err != nil { 100 t.Fatal(err) 101 } 102 timeoutStr := fmt.Sprintf("%ss", args[3]) 103 timeoutDur, err := time.ParseDuration(timeoutStr) 104 if err != nil { 105 t.Fatal(err) 106 } 107 return int(msgCount), int(nodeCount), timeoutDur 108 } 109 110 func newTestData() *testData { 111 return &testData{ 112 kademlias: make(map[enode.ID]*network.Kademlia), 113 nodeAddrs: make(map[enode.ID][]byte), 114 recipients: make(map[int][]enode.ID), 115 allowed: make(map[int][]enode.ID), 116 expectedMsgs: make(map[enode.ID][]uint64), 117 allowedMsgs: make(map[enode.ID][]uint64), 118 senders: make(map[int]enode.ID), 119 handlerC: make(chan handlerNotification), 120 doneC: make(chan struct{}), 121 errC: make(chan error), 122 msgC: make(chan handlerNotification), 123 } 124 } 125 126 func (d *testData) getKademlia(nodeId *enode.ID) (*network.Kademlia, error) { 127 kadif, ok := d.sim.NodeItem(*nodeId, simulation.BucketKeyKademlia) 128 if !ok { 129 return nil, fmt.Errorf("no kademlia entry for %v", nodeId) 130 } 131 kad, ok := kadif.(*network.Kademlia) 132 if !ok { 133 return nil, fmt.Errorf("invalid kademlia entry for %v", nodeId) 134 } 135 return kad, nil 136 } 137 138 func (d *testData) init(msgCount int) error { 139 log.Debug("TestProxNetwork start") 140 141 for _, nodeId := range d.sim.NodeIDs() { 142 kad, err := d.getKademlia(&nodeId) 143 if err != nil { 144 return err 145 } 146 d.nodeAddrs[nodeId] = kad.BaseAddr() 147 } 148 149 for i := 0; i < int(msgCount); i++ { 150 msgAddr := pot.RandomAddress() // we choose message addresses randomly 151 d.msgs = append(d.msgs, msgAddr.Bytes()) 152 smallestPo := 256 153 var targets []enode.ID 154 var closestPO int 155 156 // loop through all nodes and find the required and allowed recipients of each message 157 // (for more information, please see the comment to the main test function) 158 for _, nod := range d.sim.Net.GetNodes() { 159 po, _ := pof(d.msgs[i], d.nodeAddrs[nod.ID()], 0) 160 depth := d.kademlias[nod.ID()].NeighbourhoodDepth() 161 162 // only nodes with closest IDs (wrt the msg address) will be required recipients 163 if po > closestPO { 164 closestPO = po 165 targets = nil 166 targets = append(targets, nod.ID()) 167 } else if po == closestPO { 168 targets = append(targets, nod.ID()) 169 } 170 171 if po >= depth { 172 d.allowedMessages++ 173 d.allowed[i] = append(d.allowed[i], nod.ID()) 174 d.allowedMsgs[nod.ID()] = append(d.allowedMsgs[nod.ID()], uint64(i)) 175 } 176 177 // a node with the smallest PO (wrt msg) will be the sender, 178 // in order to increase the distance the msg must travel 179 if po < smallestPo { 180 smallestPo = po 181 d.senders[i] = nod.ID() 182 } 183 } 184 185 d.requiredMessages += len(targets) 186 for _, id := range targets { 187 d.recipients[i] = append(d.recipients[i], id) 188 d.expectedMsgs[id] = append(d.expectedMsgs[id], uint64(i)) 189 } 190 191 log.Debug("nn for msg", "targets", len(d.recipients[i]), "msgidx", i, "msg", common.Bytes2Hex(msgAddr[:8]), "sender", d.senders[i], "senderpo", smallestPo) 192 } 193 log.Debug("msgs to receive", "count", d.requiredMessages) 194 return nil 195 } 196 197 // Here we test specific functionality of the pss, setting the prox property of 198 // the handler. The tests generate a number of messages with random addresses. 199 // Then, for each message it calculates which nodes have the msg address 200 // within its nearest neighborhood depth, and stores those nodes as possible 201 // recipients. Those nodes that are the closest to the message address (nodes 202 // belonging to the deepest PO wrt the msg address) are stored as required 203 // recipients. The difference between allowed and required recipients results 204 // from the fact that the nearest neighbours are not necessarily reciprocal. 205 // Upon sending the messages, the test verifies that the respective message is 206 // passed to the message handlers of these required recipients. The test fails 207 // if a message is handled by recipient which is not listed among the allowed 208 // recipients of this particular message. It also fails after timeout, if not 209 // all the required recipients have received their respective messages. 210 // 211 // For example, if proximity order of certain msg address is 4, and node X 212 // has PO=5 wrt the message address, and nodes Y and Z have PO=6, then: 213 // nodes Y and Z will be considered required recipients of the msg, 214 // whereas nodes X, Y and Z will be allowed recipients. 215 func TestProxNetwork(t *testing.T) { 216 t.Run("16/16/15", testProxNetwork) 217 } 218 219 // params in run name: nodes/msgs 220 func TestProxNetworkLong(t *testing.T) { 221 if !*longrunning { 222 t.Skip("run with --longrunning flag to run extensive network tests") 223 } 224 t.Run("8/100/30", testProxNetwork) 225 t.Run("16/100/30", testProxNetwork) 226 t.Run("32/100/60", testProxNetwork) 227 t.Run("64/100/60", testProxNetwork) 228 t.Run("128/100/120", testProxNetwork) 229 } 230 231 func testProxNetwork(t *testing.T) { 232 tstdata := newTestData() 233 msgCount, nodeCount, timeout := getCmdParams(t) 234 handlerContextFuncs := make(map[Topic]handlerContextFunc) 235 handlerContextFuncs[topic] = nodeMsgHandler 236 services := newProxServices(tstdata, true, handlerContextFuncs, tstdata.kademlias) 237 tstdata.sim = simulation.New(services) 238 defer tstdata.sim.Close() 239 ctx, cancel := context.WithTimeout(context.Background(), timeout) 240 defer cancel() 241 filename := fmt.Sprintf("testdata/snapshot_%d.json", nodeCount) 242 err := tstdata.sim.UploadSnapshot(ctx, filename) 243 if err != nil { 244 t.Fatal(err) 245 } 246 err = tstdata.init(msgCount) // initialize the test data 247 if err != nil { 248 t.Fatal(err) 249 } 250 wrapper := func(c context.Context, _ *simulation.Simulation) error { 251 return testRoutine(tstdata, c) 252 } 253 result := tstdata.sim.Run(ctx, wrapper) // call the main test function 254 if result.Error != nil { 255 // context deadline exceeded 256 // however, it might just mean that not all possible messages are received 257 // now we must check if all required messages are received 258 cnt := tstdata.getMsgCount() 259 log.Debug("TestProxNetwork finished", "rcv", cnt) 260 if cnt < tstdata.requiredMessages { 261 t.Fatal(result.Error) 262 } 263 } 264 t.Logf("completed %d", result.Duration) 265 } 266 267 func (tstdata *testData) sendAllMsgs() { 268 for i, msg := range tstdata.msgs { 269 log.Debug("sending msg", "idx", i, "from", tstdata.senders[i]) 270 nodeClient, err := tstdata.sim.Net.GetNode(tstdata.senders[i]).Client() 271 if err != nil { 272 tstdata.errC <- err 273 } 274 var uvarByte [8]byte 275 binary.PutUvarint(uvarByte[:], uint64(i)) 276 nodeClient.Call(nil, "pss_sendRaw", hexutil.Encode(msg), hexutil.Encode(topic[:]), hexutil.Encode(uvarByte[:])) 277 } 278 log.Debug("all messages sent") 279 } 280 281 // testRoutine is the main test function, called by Simulation.Run() 282 func testRoutine(tstdata *testData, ctx context.Context) error { 283 go handlerChannelListener(tstdata, ctx) 284 go tstdata.sendAllMsgs() 285 received := 0 286 287 // collect incoming messages and terminate with corresponding status when message handler listener ends 288 for { 289 select { 290 case err := <-tstdata.errC: 291 return err 292 case hn := <-tstdata.msgC: 293 received++ 294 log.Debug("msg received", "msgs_received", received, "total_expected", tstdata.requiredMessages, "id", hn.id, "serial", hn.serial) 295 if received == tstdata.allowedMessages { 296 close(tstdata.doneC) 297 return nil 298 } 299 } 300 } 301 return nil 302 } 303 304 func handlerChannelListener(tstdata *testData, ctx context.Context) { 305 for { 306 select { 307 case <-tstdata.doneC: // graceful exit 308 tstdata.setDone() 309 tstdata.errC <- nil 310 return 311 312 case <-ctx.Done(): // timeout or cancel 313 tstdata.setDone() 314 tstdata.errC <- ctx.Err() 315 return 316 317 // incoming message from pss message handler 318 case handlerNotification := <-tstdata.handlerC: 319 // check if recipient has already received all its messages and notify to fail the test if so 320 aMsgs := tstdata.allowedMsgs[handlerNotification.id] 321 if len(aMsgs) == 0 { 322 tstdata.setDone() 323 tstdata.errC <- fmt.Errorf("too many messages received by recipient %x", handlerNotification.id) 324 return 325 } 326 327 // check if message serial is in expected messages for this recipient and notify to fail the test if not 328 idx := -1 329 for i, msg := range aMsgs { 330 if handlerNotification.serial == msg { 331 idx = i 332 break 333 } 334 } 335 if idx == -1 { 336 tstdata.setDone() 337 tstdata.errC <- fmt.Errorf("message %d received by wrong recipient %v", handlerNotification.serial, handlerNotification.id) 338 return 339 } 340 341 // message is ok, so remove that message serial from the recipient expectation array and notify the main sim thread 342 aMsgs[idx] = aMsgs[len(aMsgs)-1] 343 aMsgs = aMsgs[:len(aMsgs)-1] 344 tstdata.msgC <- handlerNotification 345 } 346 } 347 } 348 349 func nodeMsgHandler(tstdata *testData, config *adapters.NodeConfig) *handler { 350 return &handler{ 351 f: func(msg []byte, p *p2p.Peer, asymmetric bool, keyid string) error { 352 cnt := tstdata.incrementMsgCount() 353 log.Debug("nodeMsgHandler rcv", "cnt", cnt) 354 355 // using simple serial in message body, makes it easy to keep track of who's getting what 356 serial, c := binary.Uvarint(msg) 357 if c <= 0 { 358 log.Crit(fmt.Sprintf("corrupt message received by %x (uvarint parse returned %d)", config.ID, c)) 359 } 360 361 if tstdata.isDone() { 362 return errors.New("handlers aborted") // terminate if simulation is over 363 } 364 365 // pass message context to the listener in the simulation 366 tstdata.handlerC <- handlerNotification{ 367 id: config.ID, 368 serial: serial, 369 } 370 return nil 371 }, 372 caps: &handlerCaps{ 373 raw: true, // we use raw messages for simplicity 374 prox: true, 375 }, 376 } 377 } 378 379 // an adaptation of the same services setup as in pss_test.go 380 // replaces pss_test.go when those tests are rewritten to the new swarm/network/simulation package 381 func newProxServices(tstdata *testData, allowRaw bool, handlerContextFuncs map[Topic]handlerContextFunc, kademlias map[enode.ID]*network.Kademlia) map[string]simulation.ServiceFunc { 382 stateStore := state.NewInmemoryStore() 383 kademlia := func(id enode.ID, bzzkey []byte) *network.Kademlia { 384 if k, ok := kademlias[id]; ok { 385 return k 386 } 387 params := network.NewKadParams() 388 params.MaxBinSize = 3 389 params.MinBinSize = 1 390 params.MaxRetries = 1000 391 params.RetryExponent = 2 392 params.RetryInterval = 1000000 393 kademlias[id] = network.NewKademlia(bzzkey, params) 394 return kademlias[id] 395 } 396 return map[string]simulation.ServiceFunc{ 397 "bzz": func(ctx *adapters.ServiceContext, b *sync.Map) (node.Service, func(), error) { 398 var err error 399 var bzzPrivateKey *ecdsa.PrivateKey 400 // normally translation of enode id to swarm address is concealed by the network package 401 // however, we need to keep track of it in the test driver as well. 402 // if the translation in the network package changes, that can cause these tests to unpredictably fail 403 // therefore we keep a local copy of the translation here 404 addr := network.NewAddr(ctx.Config.Node()) 405 bzzPrivateKey, err = simulation.BzzPrivateKeyFromConfig(ctx.Config) 406 if err != nil { 407 return nil, nil, err 408 } 409 addr.OAddr = network.PrivateKeyToBzzKey(bzzPrivateKey) 410 b.Store(simulation.BucketKeyBzzPrivateKey, bzzPrivateKey) 411 hp := network.NewHiveParams() 412 hp.Discovery = false 413 config := &network.BzzConfig{ 414 OverlayAddr: addr.Over(), 415 UnderlayAddr: addr.Under(), 416 HiveParams: hp, 417 } 418 return network.NewBzz(config, kademlia(ctx.Config.ID, addr.OAddr), stateStore, nil, nil), nil, nil 419 }, 420 "pss": func(ctx *adapters.ServiceContext, b *sync.Map) (node.Service, func(), error) { 421 // execadapter does not exec init() 422 initTest() 423 424 // create keys in whisper and set up the pss object 425 ctxlocal, cancel := context.WithTimeout(context.Background(), time.Second*3) 426 defer cancel() 427 keys, err := wapi.NewKeyPair(ctxlocal) 428 privkey, err := w.GetPrivateKey(keys) 429 pssp := NewPssParams().WithPrivateKey(privkey) 430 pssp.AllowRaw = allowRaw 431 bzzPrivateKey, err := simulation.BzzPrivateKeyFromConfig(ctx.Config) 432 if err != nil { 433 return nil, nil, err 434 } 435 bzzKey := network.PrivateKeyToBzzKey(bzzPrivateKey) 436 pskad := kademlia(ctx.Config.ID, bzzKey) 437 ps, err := NewPss(pskad, pssp) 438 if err != nil { 439 return nil, nil, err 440 } 441 442 // register the handlers we've been passed 443 var deregisters []func() 444 for tpc, hndlrFunc := range handlerContextFuncs { 445 deregisters = append(deregisters, ps.Register(&tpc, hndlrFunc(tstdata, ctx.Config))) 446 } 447 448 // if handshake mode is set, add the controller 449 // TODO: This should be hooked to the handshake test file 450 if useHandshake { 451 SetHandshakeController(ps, NewHandshakeParams()) 452 } 453 454 // we expose some api calls for cheating 455 ps.addAPI(rpc.API{ 456 Namespace: "psstest", 457 Version: "0.3", 458 Service: NewAPITest(ps), 459 Public: false, 460 }) 461 462 b.Store(simulation.BucketKeyKademlia, pskad) 463 464 // return Pss and cleanups 465 return ps, func() { 466 // run the handler deregister functions in reverse order 467 for i := len(deregisters); i > 0; i-- { 468 deregisters[i-1]() 469 } 470 }, nil 471 }, 472 } 473 }