github.com/ethereum/go-ethereum@v1.16.1/p2p/enode/iter_test.go (about)

     1  // Copyright 2019 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package enode
    18  
    19  import (
    20  	"encoding/binary"
    21  	"runtime"
    22  	"slices"
    23  	"sync/atomic"
    24  	"testing"
    25  	"time"
    26  
    27  	"github.com/ethereum/go-ethereum/p2p/enr"
    28  )
    29  
    30  func TestReadNodes(t *testing.T) {
    31  	nodes := ReadNodes(new(genIter), 10)
    32  	checkNodes(t, nodes, 10)
    33  }
    34  
    35  // This test checks that ReadNodes terminates when reading N nodes from an iterator
    36  // which returns less than N nodes in an endless cycle.
    37  func TestReadNodesCycle(t *testing.T) {
    38  	iter := &callCountIter{
    39  		Iterator: CycleNodes([]*Node{
    40  			testNode(0, 0),
    41  			testNode(1, 0),
    42  			testNode(2, 0),
    43  		}),
    44  	}
    45  	nodes := ReadNodes(iter, 10)
    46  	checkNodes(t, nodes, 3)
    47  	if iter.count != 10 {
    48  		t.Fatalf("%d calls to Next, want %d", iter.count, 100)
    49  	}
    50  }
    51  
    52  func TestFilterNodes(t *testing.T) {
    53  	nodes := make([]*Node, 100)
    54  	for i := range nodes {
    55  		nodes[i] = testNode(uint64(i), uint64(i))
    56  	}
    57  
    58  	it := Filter(IterNodes(nodes), func(n *Node) bool {
    59  		return n.Seq() >= 50
    60  	})
    61  	for i := 50; i < len(nodes); i++ {
    62  		if !it.Next() {
    63  			t.Fatal("Next returned false")
    64  		}
    65  		if it.Node() != nodes[i] {
    66  			t.Fatalf("iterator returned wrong node %v\nwant %v", it.Node(), nodes[i])
    67  		}
    68  	}
    69  	if it.Next() {
    70  		t.Fatal("Next returned true after underlying iterator has ended")
    71  	}
    72  }
    73  
    74  func checkNodes(t *testing.T, nodes []*Node, wantLen int) {
    75  	if len(nodes) != wantLen {
    76  		t.Errorf("slice has %d nodes, want %d", len(nodes), wantLen)
    77  		return
    78  	}
    79  	seen := make(map[ID]bool, len(nodes))
    80  	for i, e := range nodes {
    81  		if e == nil {
    82  			t.Errorf("nil node at index %d", i)
    83  			return
    84  		}
    85  		if seen[e.ID()] {
    86  			t.Errorf("slice has duplicate node %v", e.ID())
    87  			return
    88  		}
    89  		seen[e.ID()] = true
    90  	}
    91  }
    92  
    93  // This test checks fairness of FairMix in the happy case where all sources return nodes
    94  // within the context's deadline.
    95  func TestFairMix(t *testing.T) {
    96  	for i := 0; i < 500; i++ {
    97  		testMixerFairness(t)
    98  	}
    99  }
   100  
   101  func testMixerFairness(t *testing.T) {
   102  	mix := NewFairMix(1 * time.Second)
   103  	mix.AddSource(&genIter{index: 1})
   104  	mix.AddSource(&genIter{index: 2})
   105  	mix.AddSource(&genIter{index: 3})
   106  	defer mix.Close()
   107  
   108  	nodes := ReadNodes(mix, 500)
   109  	checkNodes(t, nodes, 500)
   110  
   111  	// Verify that the nodes slice contains an approximately equal number of nodes
   112  	// from each source.
   113  	d := idPrefixDistribution(nodes)
   114  	for _, count := range d {
   115  		if approxEqual(count, len(nodes)/3, 30) {
   116  			t.Fatalf("ID distribution is unfair: %v", d)
   117  		}
   118  	}
   119  }
   120  
   121  // This test checks that FairMix falls back to an alternative source when
   122  // the 'fair' choice doesn't return a node within the timeout.
   123  func TestFairMixNextFromAll(t *testing.T) {
   124  	mix := NewFairMix(1 * time.Millisecond)
   125  	mix.AddSource(&genIter{index: 1})
   126  	mix.AddSource(CycleNodes(nil))
   127  	defer mix.Close()
   128  
   129  	nodes := ReadNodes(mix, 500)
   130  	checkNodes(t, nodes, 500)
   131  
   132  	d := idPrefixDistribution(nodes)
   133  	if len(d) > 1 || d[1] != len(nodes) {
   134  		t.Fatalf("wrong ID distribution: %v", d)
   135  	}
   136  }
   137  
   138  // This test ensures FairMix works for Next with no sources.
   139  func TestFairMixEmpty(t *testing.T) {
   140  	var (
   141  		mix   = NewFairMix(1 * time.Second)
   142  		testN = testNode(1, 1)
   143  		ch    = make(chan *Node)
   144  	)
   145  	defer mix.Close()
   146  
   147  	go func() {
   148  		mix.Next()
   149  		ch <- mix.Node()
   150  	}()
   151  
   152  	mix.AddSource(CycleNodes([]*Node{testN}))
   153  	if n := <-ch; n != testN {
   154  		t.Errorf("got wrong node: %v", n)
   155  	}
   156  }
   157  
   158  // This test checks closing a source while Next runs.
   159  func TestFairMixRemoveSource(t *testing.T) {
   160  	mix := NewFairMix(1 * time.Second)
   161  	source := make(blockingIter)
   162  	mix.AddSource(source)
   163  
   164  	sig := make(chan *Node)
   165  	go func() {
   166  		<-sig
   167  		mix.Next()
   168  		sig <- mix.Node()
   169  	}()
   170  
   171  	sig <- nil
   172  	runtime.Gosched()
   173  	source.Close()
   174  
   175  	wantNode := testNode(0, 0)
   176  	mix.AddSource(CycleNodes([]*Node{wantNode}))
   177  	n := <-sig
   178  
   179  	if len(mix.sources) != 1 {
   180  		t.Fatalf("have %d sources, want one", len(mix.sources))
   181  	}
   182  	if n != wantNode {
   183  		t.Fatalf("mixer returned wrong node")
   184  	}
   185  }
   186  
   187  // This checks that FairMix correctly returns the name of the source that produced the node.
   188  func TestFairMixSourceName(t *testing.T) {
   189  	nodes := make([]*Node, 6)
   190  	for i := range nodes {
   191  		nodes[i] = testNode(uint64(i), uint64(i))
   192  	}
   193  	mix := NewFairMix(-1)
   194  	mix.AddSource(WithSourceName("s1", IterNodes(nodes[0:2])))
   195  	mix.AddSource(WithSourceName("s2", IterNodes(nodes[2:4])))
   196  	mix.AddSource(WithSourceName("s3", IterNodes(nodes[4:6])))
   197  
   198  	var names []string
   199  	for range nodes {
   200  		mix.Next()
   201  		names = append(names, mix.NodeSource())
   202  	}
   203  	want := []string{"s2", "s3", "s1", "s2", "s3", "s1"}
   204  	if !slices.Equal(names, want) {
   205  		t.Fatalf("wrong names: %v", names)
   206  	}
   207  }
   208  
   209  // This checks that FairMix returns the name of the source that produced the node,
   210  // even when FairMix instances are nested.
   211  func TestFairMixNestedSourceName(t *testing.T) {
   212  	nodes := make([]*Node, 6)
   213  	for i := range nodes {
   214  		nodes[i] = testNode(uint64(i), uint64(i))
   215  	}
   216  	mix := NewFairMix(-1)
   217  	mix.AddSource(WithSourceName("s1", IterNodes(nodes[0:2])))
   218  	submix := NewFairMix(-1)
   219  	submix.AddSource(WithSourceName("s2", IterNodes(nodes[2:4])))
   220  	submix.AddSource(WithSourceName("s3", IterNodes(nodes[4:6])))
   221  	mix.AddSource(submix)
   222  
   223  	var names []string
   224  	for range nodes {
   225  		mix.Next()
   226  		names = append(names, mix.NodeSource())
   227  	}
   228  	want := []string{"s3", "s1", "s2", "s1", "s3", "s2"}
   229  	if !slices.Equal(names, want) {
   230  		t.Fatalf("wrong names: %v", names)
   231  	}
   232  }
   233  
   234  type blockingIter chan struct{}
   235  
   236  func (it blockingIter) Next() bool {
   237  	<-it
   238  	return false
   239  }
   240  
   241  func (it blockingIter) Node() *Node {
   242  	return nil
   243  }
   244  
   245  func (it blockingIter) Close() {
   246  	close(it)
   247  }
   248  
   249  func TestFairMixClose(t *testing.T) {
   250  	for i := 0; i < 20 && !t.Failed(); i++ {
   251  		testMixerClose(t)
   252  	}
   253  }
   254  
   255  func testMixerClose(t *testing.T) {
   256  	mix := NewFairMix(-1)
   257  	mix.AddSource(CycleNodes(nil))
   258  	mix.AddSource(CycleNodes(nil))
   259  
   260  	done := make(chan struct{})
   261  	go func() {
   262  		defer close(done)
   263  		if mix.Next() {
   264  			t.Error("Next returned true")
   265  		}
   266  	}()
   267  	// This call is supposed to make it more likely that NextNode is
   268  	// actually executing by the time we call Close.
   269  	runtime.Gosched()
   270  
   271  	mix.Close()
   272  	select {
   273  	case <-done:
   274  	case <-time.After(3 * time.Second):
   275  		t.Fatal("Next didn't unblock on Close")
   276  	}
   277  
   278  	mix.Close() // shouldn't crash
   279  }
   280  
   281  func idPrefixDistribution(nodes []*Node) map[uint32]int {
   282  	d := make(map[uint32]int, len(nodes))
   283  	for _, node := range nodes {
   284  		id := node.ID()
   285  		d[binary.BigEndian.Uint32(id[:4])]++
   286  	}
   287  	return d
   288  }
   289  
   290  func approxEqual(x, y, ε int) bool {
   291  	if y > x {
   292  		x, y = y, x
   293  	}
   294  	return x-y > ε
   295  }
   296  
   297  // genIter creates fake nodes with numbered IDs based on 'index' and 'gen'
   298  type genIter struct {
   299  	node       *Node
   300  	index, gen uint32
   301  }
   302  
   303  func (s *genIter) Next() bool {
   304  	index := atomic.LoadUint32(&s.index)
   305  	if index == ^uint32(0) {
   306  		s.node = nil
   307  		return false
   308  	}
   309  	s.node = testNode(uint64(index)<<32|uint64(s.gen), 0)
   310  	s.gen++
   311  	return true
   312  }
   313  
   314  func (s *genIter) Node() *Node {
   315  	return s.node
   316  }
   317  
   318  func (s *genIter) Close() {
   319  	atomic.StoreUint32(&s.index, ^uint32(0))
   320  }
   321  
   322  func testNode(id, seq uint64) *Node {
   323  	var nodeID ID
   324  	binary.BigEndian.PutUint64(nodeID[:], id)
   325  	r := new(enr.Record)
   326  	r.SetSeq(seq)
   327  	return SignNull(r, nodeID)
   328  }
   329  
   330  // callCountIter counts calls to NextNode.
   331  type callCountIter struct {
   332  	Iterator
   333  	count int
   334  }
   335  
   336  func (it *callCountIter) Next() bool {
   337  	it.count++
   338  	return it.Iterator.Next()
   339  }