github.com/prysmaticlabs/prysm@v1.4.4/beacon-chain/state/v1/references_test.go (about)

     1  package v1
     2  
     3  import (
     4  	"reflect"
     5  	"runtime"
     6  	"runtime/debug"
     7  	"testing"
     8  
     9  	"github.com/prysmaticlabs/prysm/shared/copyutil"
    10  
    11  	"github.com/prysmaticlabs/go-bitfield"
    12  	iface "github.com/prysmaticlabs/prysm/beacon-chain/state/interface"
    13  	p2ppb "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1"
    14  	ethpb "github.com/prysmaticlabs/prysm/proto/eth/v1alpha1"
    15  	"github.com/prysmaticlabs/prysm/shared/bytesutil"
    16  	"github.com/prysmaticlabs/prysm/shared/testutil/assert"
    17  	"github.com/prysmaticlabs/prysm/shared/testutil/require"
    18  )
    19  
    20  func TestStateReferenceSharing_Finalizer(t *testing.T) {
    21  	// This test showcases the logic on a the RandaoMixes field with the GC finalizer.
    22  
    23  	a, err := InitializeFromProtoUnsafe(&p2ppb.BeaconState{RandaoMixes: [][]byte{[]byte("foo")}})
    24  	require.NoError(t, err)
    25  	assert.Equal(t, uint(1), a.sharedFieldReferences[randaoMixes].Refs(), "Expected a single reference for RANDAO mixes")
    26  
    27  	func() {
    28  		// Create object in a different scope for GC
    29  		b := a.Copy()
    30  		assert.Equal(t, uint(2), a.sharedFieldReferences[randaoMixes].Refs(), "Expected 2 references to RANDAO mixes")
    31  		_ = b
    32  	}()
    33  
    34  	runtime.GC() // Should run finalizer on object b
    35  	assert.Equal(t, uint(1), a.sharedFieldReferences[randaoMixes].Refs(), "Expected 1 shared reference to RANDAO mixes!")
    36  
    37  	copied := a.Copy()
    38  	b, ok := copied.(*BeaconState)
    39  	require.Equal(t, true, ok)
    40  	assert.Equal(t, uint(2), b.sharedFieldReferences[randaoMixes].Refs(), "Expected 2 shared references to RANDAO mixes")
    41  	require.NoError(t, b.UpdateRandaoMixesAtIndex(0, []byte("bar")))
    42  	if b.sharedFieldReferences[randaoMixes].Refs() != 1 || a.sharedFieldReferences[randaoMixes].Refs() != 1 {
    43  		t.Error("Expected 1 shared reference to RANDAO mix for both a and b")
    44  	}
    45  }
    46  
    47  func TestStateReferenceCopy_NoUnexpectedRootsMutation(t *testing.T) {
    48  	root1, root2 := bytesutil.ToBytes32([]byte("foo")), bytesutil.ToBytes32([]byte("bar"))
    49  	a, err := InitializeFromProtoUnsafe(&p2ppb.BeaconState{
    50  		BlockRoots: [][]byte{
    51  			root1[:],
    52  		},
    53  		StateRoots: [][]byte{
    54  			root1[:],
    55  		},
    56  	})
    57  	require.NoError(t, err)
    58  	assertRefCount(t, a, blockRoots, 1)
    59  	assertRefCount(t, a, stateRoots, 1)
    60  
    61  	// Copy, increases reference count.
    62  	copied := a.Copy()
    63  	b, ok := copied.(*BeaconState)
    64  	require.Equal(t, true, ok)
    65  	assertRefCount(t, a, blockRoots, 2)
    66  	assertRefCount(t, a, stateRoots, 2)
    67  	assertRefCount(t, b, blockRoots, 2)
    68  	assertRefCount(t, b, stateRoots, 2)
    69  	assert.Equal(t, 1, len(b.state.GetBlockRoots()), "No block roots found")
    70  	assert.Equal(t, 1, len(b.state.GetStateRoots()), "No state roots found")
    71  
    72  	// Assert shared state.
    73  	blockRootsA := a.state.GetBlockRoots()
    74  	stateRootsA := a.state.GetStateRoots()
    75  	blockRootsB := b.state.GetBlockRoots()
    76  	stateRootsB := b.state.GetStateRoots()
    77  	if len(blockRootsA) != len(blockRootsB) || len(blockRootsA) < 1 {
    78  		t.Errorf("Unexpected number of block roots, want: %v", 1)
    79  	}
    80  	if len(stateRootsA) != len(stateRootsB) || len(stateRootsA) < 1 {
    81  		t.Errorf("Unexpected number of state roots, want: %v", 1)
    82  	}
    83  	assertValFound(t, blockRootsA, root1[:])
    84  	assertValFound(t, blockRootsB, root1[:])
    85  	assertValFound(t, stateRootsA, root1[:])
    86  	assertValFound(t, stateRootsB, root1[:])
    87  
    88  	// Mutator should only affect calling state: a.
    89  	require.NoError(t, a.UpdateBlockRootAtIndex(0, root2))
    90  	require.NoError(t, a.UpdateStateRootAtIndex(0, root2))
    91  
    92  	// Assert no shared state mutation occurred only on state a (copy on write).
    93  	assertValNotFound(t, a.state.GetBlockRoots(), root1[:])
    94  	assertValNotFound(t, a.state.GetStateRoots(), root1[:])
    95  	assertValFound(t, a.state.GetBlockRoots(), root2[:])
    96  	assertValFound(t, a.state.GetStateRoots(), root2[:])
    97  	assertValFound(t, b.state.GetBlockRoots(), root1[:])
    98  	assertValFound(t, b.state.GetStateRoots(), root1[:])
    99  	if len(blockRootsA) != len(blockRootsB) || len(blockRootsA) < 1 {
   100  		t.Errorf("Unexpected number of block roots, want: %v", 1)
   101  	}
   102  	if len(stateRootsA) != len(stateRootsB) || len(stateRootsA) < 1 {
   103  		t.Errorf("Unexpected number of state roots, want: %v", 1)
   104  	}
   105  	assert.DeepEqual(t, root2[:], a.state.GetBlockRoots()[0], "Expected mutation not found")
   106  	assert.DeepEqual(t, root2[:], a.state.GetStateRoots()[0], "Expected mutation not found")
   107  	assert.DeepEqual(t, root1[:], blockRootsB[0], "Unexpected mutation found")
   108  	assert.DeepEqual(t, root1[:], stateRootsB[0], "Unexpected mutation found")
   109  
   110  	// Copy on write happened, reference counters are reset.
   111  	assertRefCount(t, a, blockRoots, 1)
   112  	assertRefCount(t, a, stateRoots, 1)
   113  	assertRefCount(t, b, blockRoots, 1)
   114  	assertRefCount(t, b, stateRoots, 1)
   115  }
   116  
   117  func TestStateReferenceCopy_NoUnexpectedRandaoMutation(t *testing.T) {
   118  
   119  	val1, val2 := []byte("foo"), []byte("bar")
   120  	a, err := InitializeFromProtoUnsafe(&p2ppb.BeaconState{
   121  		RandaoMixes: [][]byte{
   122  			val1,
   123  		},
   124  	})
   125  	require.NoError(t, err)
   126  	assertRefCount(t, a, randaoMixes, 1)
   127  
   128  	// Copy, increases reference count.
   129  	copied := a.Copy()
   130  	b, ok := copied.(*BeaconState)
   131  	require.Equal(t, true, ok)
   132  	assertRefCount(t, a, randaoMixes, 2)
   133  	assertRefCount(t, b, randaoMixes, 2)
   134  	assert.Equal(t, 1, len(b.state.GetRandaoMixes()), "No randao mixes found")
   135  
   136  	// Assert shared state.
   137  	mixesA := a.state.GetRandaoMixes()
   138  	mixesB := b.state.GetRandaoMixes()
   139  	if len(mixesA) != len(mixesB) || len(mixesA) < 1 {
   140  		t.Errorf("Unexpected number of mix values, want: %v", 1)
   141  	}
   142  	assertValFound(t, mixesA, val1)
   143  	assertValFound(t, mixesB, val1)
   144  
   145  	// Mutator should only affect calling state: a.
   146  	require.NoError(t, a.UpdateRandaoMixesAtIndex(0, val2))
   147  
   148  	// Assert no shared state mutation occurred only on state a (copy on write).
   149  	if len(mixesA) != len(mixesB) || len(mixesA) < 1 {
   150  		t.Errorf("Unexpected number of mix values, want: %v", 1)
   151  	}
   152  	assertValFound(t, a.state.GetRandaoMixes(), val2)
   153  	assertValNotFound(t, a.state.GetRandaoMixes(), val1)
   154  	assertValFound(t, b.state.GetRandaoMixes(), val1)
   155  	assertValNotFound(t, b.state.GetRandaoMixes(), val2)
   156  	assertValFound(t, mixesB, val1)
   157  	assertValNotFound(t, mixesB, val2)
   158  	assert.DeepEqual(t, val2, a.state.GetRandaoMixes()[0], "Expected mutation not found")
   159  	assert.DeepEqual(t, val1, mixesB[0], "Unexpected mutation found")
   160  
   161  	// Copy on write happened, reference counters are reset.
   162  	assertRefCount(t, a, randaoMixes, 1)
   163  	assertRefCount(t, b, randaoMixes, 1)
   164  }
   165  
   166  func TestStateReferenceCopy_NoUnexpectedAttestationsMutation(t *testing.T) {
   167  	assertAttFound := func(vals []*p2ppb.PendingAttestation, val uint64) {
   168  		for i := range vals {
   169  			if reflect.DeepEqual(vals[i].AggregationBits, bitfield.NewBitlist(val)) {
   170  				return
   171  			}
   172  		}
   173  		t.Log(string(debug.Stack()))
   174  		t.Fatalf("Expected attestation not found (%v), want: %v", vals, val)
   175  	}
   176  	assertAttNotFound := func(vals []*p2ppb.PendingAttestation, val uint64) {
   177  		for i := range vals {
   178  			if reflect.DeepEqual(vals[i].AggregationBits, bitfield.NewBitlist(val)) {
   179  				t.Log(string(debug.Stack()))
   180  				t.Fatalf("Unexpected attestation found (%v): %v", vals, val)
   181  				return
   182  			}
   183  		}
   184  	}
   185  
   186  	a, err := InitializeFromProtoUnsafe(&p2ppb.BeaconState{})
   187  	require.NoError(t, err)
   188  	assertRefCount(t, a, previousEpochAttestations, 1)
   189  	assertRefCount(t, a, currentEpochAttestations, 1)
   190  
   191  	// Update initial state.
   192  	atts := []*p2ppb.PendingAttestation{
   193  		{AggregationBits: bitfield.NewBitlist(1)},
   194  		{AggregationBits: bitfield.NewBitlist(2)},
   195  	}
   196  	a.setPreviousEpochAttestations(atts[:1])
   197  	a.setCurrentEpochAttestations(atts[:1])
   198  	curAtt, err := a.CurrentEpochAttestations()
   199  	require.NoError(t, err)
   200  	assert.Equal(t, 1, len(curAtt), "Unexpected number of attestations")
   201  	preAtt, err := a.PreviousEpochAttestations()
   202  	require.NoError(t, err)
   203  	assert.Equal(t, 1, len(preAtt), "Unexpected number of attestations")
   204  
   205  	// Copy, increases reference count.
   206  	copied := a.Copy()
   207  	b, ok := copied.(*BeaconState)
   208  	require.Equal(t, true, ok)
   209  	assertRefCount(t, a, previousEpochAttestations, 2)
   210  	assertRefCount(t, a, currentEpochAttestations, 2)
   211  	assertRefCount(t, b, previousEpochAttestations, 2)
   212  	assertRefCount(t, b, currentEpochAttestations, 2)
   213  	assert.Equal(t, 1, len(b.state.GetPreviousEpochAttestations()), "Unexpected number of attestations")
   214  	assert.Equal(t, 1, len(b.state.GetCurrentEpochAttestations()), "Unexpected number of attestations")
   215  
   216  	// Assert shared state.
   217  	curAttsA := a.state.GetCurrentEpochAttestations()
   218  	prevAttsA := a.state.GetPreviousEpochAttestations()
   219  	curAttsB := b.state.GetCurrentEpochAttestations()
   220  	prevAttsB := b.state.GetPreviousEpochAttestations()
   221  	if len(curAttsA) != len(curAttsB) || len(curAttsA) < 1 {
   222  		t.Errorf("Unexpected number of attestations, want: %v", 1)
   223  	}
   224  	if len(prevAttsA) != len(prevAttsB) || len(prevAttsA) < 1 {
   225  		t.Errorf("Unexpected number of attestations, want: %v", 1)
   226  	}
   227  	assertAttFound(curAttsA, 1)
   228  	assertAttFound(prevAttsA, 1)
   229  	assertAttFound(curAttsB, 1)
   230  	assertAttFound(prevAttsB, 1)
   231  
   232  	// Extends state a attestations.
   233  	require.NoError(t, a.AppendCurrentEpochAttestations(atts[1]))
   234  	require.NoError(t, a.AppendPreviousEpochAttestations(atts[1]))
   235  	curAtt, err = a.CurrentEpochAttestations()
   236  	require.NoError(t, err)
   237  	assert.Equal(t, 2, len(curAtt), "Unexpected number of attestations")
   238  	preAtt, err = a.PreviousEpochAttestations()
   239  	require.NoError(t, err)
   240  	assert.Equal(t, 2, len(preAtt), "Unexpected number of attestations")
   241  	assertAttFound(a.state.GetCurrentEpochAttestations(), 1)
   242  	assertAttFound(a.state.GetPreviousEpochAttestations(), 1)
   243  	assertAttFound(a.state.GetCurrentEpochAttestations(), 2)
   244  	assertAttFound(a.state.GetPreviousEpochAttestations(), 2)
   245  	assertAttFound(b.state.GetCurrentEpochAttestations(), 1)
   246  	assertAttFound(b.state.GetPreviousEpochAttestations(), 1)
   247  	assertAttNotFound(b.state.GetCurrentEpochAttestations(), 2)
   248  	assertAttNotFound(b.state.GetPreviousEpochAttestations(), 2)
   249  
   250  	// Mutator should only affect calling state: a.
   251  	applyToEveryAttestation := func(state *p2ppb.BeaconState) {
   252  		// One MUST copy on write.
   253  		atts = make([]*p2ppb.PendingAttestation, len(state.CurrentEpochAttestations))
   254  		copy(atts, state.CurrentEpochAttestations)
   255  		state.CurrentEpochAttestations = atts
   256  		for i := range state.GetCurrentEpochAttestations() {
   257  			att := copyutil.CopyPendingAttestation(state.CurrentEpochAttestations[i])
   258  			att.AggregationBits = bitfield.NewBitlist(3)
   259  			state.CurrentEpochAttestations[i] = att
   260  		}
   261  
   262  		atts = make([]*p2ppb.PendingAttestation, len(state.PreviousEpochAttestations))
   263  		copy(atts, state.PreviousEpochAttestations)
   264  		state.PreviousEpochAttestations = atts
   265  		for i := range state.GetPreviousEpochAttestations() {
   266  			att := copyutil.CopyPendingAttestation(state.PreviousEpochAttestations[i])
   267  			att.AggregationBits = bitfield.NewBitlist(3)
   268  			state.PreviousEpochAttestations[i] = att
   269  		}
   270  	}
   271  	applyToEveryAttestation(a.state)
   272  
   273  	// Assert no shared state mutation occurred only on state a (copy on write).
   274  	assertAttFound(a.state.GetCurrentEpochAttestations(), 3)
   275  	assertAttFound(a.state.GetPreviousEpochAttestations(), 3)
   276  	assertAttNotFound(a.state.GetCurrentEpochAttestations(), 1)
   277  	assertAttNotFound(a.state.GetPreviousEpochAttestations(), 1)
   278  	assertAttNotFound(a.state.GetCurrentEpochAttestations(), 2)
   279  	assertAttNotFound(a.state.GetPreviousEpochAttestations(), 2)
   280  	// State b must be unaffected.
   281  	assertAttNotFound(b.state.GetCurrentEpochAttestations(), 3)
   282  	assertAttNotFound(b.state.GetPreviousEpochAttestations(), 3)
   283  	assertAttFound(b.state.GetCurrentEpochAttestations(), 1)
   284  	assertAttFound(b.state.GetPreviousEpochAttestations(), 1)
   285  	assertAttNotFound(b.state.GetCurrentEpochAttestations(), 2)
   286  	assertAttNotFound(b.state.GetPreviousEpochAttestations(), 2)
   287  
   288  	// Copy on write happened, reference counters are reset.
   289  	assertRefCount(t, a, currentEpochAttestations, 1)
   290  	assertRefCount(t, b, currentEpochAttestations, 1)
   291  	assertRefCount(t, a, previousEpochAttestations, 1)
   292  	assertRefCount(t, b, previousEpochAttestations, 1)
   293  }
   294  
   295  func TestValidatorReferences_RemainsConsistent(t *testing.T) {
   296  	a, err := InitializeFromProtoUnsafe(&p2ppb.BeaconState{
   297  		Validators: []*ethpb.Validator{
   298  			{PublicKey: []byte{'A'}},
   299  			{PublicKey: []byte{'B'}},
   300  			{PublicKey: []byte{'C'}},
   301  			{PublicKey: []byte{'D'}},
   302  			{PublicKey: []byte{'E'}},
   303  		},
   304  	})
   305  	require.NoError(t, err)
   306  
   307  	// Create a second state.
   308  	copied := a.Copy()
   309  	b, ok := copied.(*BeaconState)
   310  	require.Equal(t, true, ok)
   311  
   312  	// Update First Validator.
   313  	assert.NoError(t, a.UpdateValidatorAtIndex(0, &ethpb.Validator{PublicKey: []byte{'Z'}}))
   314  
   315  	assert.DeepNotEqual(t, a.state.Validators[0], b.state.Validators[0], "validators are equal when they are supposed to be different")
   316  	// Modify all validators from copied state.
   317  	assert.NoError(t, b.ApplyToEveryValidator(func(idx int, val *ethpb.Validator) (bool, *ethpb.Validator, error) {
   318  		return true, &ethpb.Validator{PublicKey: []byte{'V'}}, nil
   319  	}))
   320  
   321  	// Ensure reference is properly accounted for.
   322  	assert.NoError(t, a.ReadFromEveryValidator(func(idx int, val iface.ReadOnlyValidator) error {
   323  		assert.NotEqual(t, bytesutil.ToBytes48([]byte{'V'}), val.PublicKey())
   324  		return nil
   325  	}))
   326  }
   327  
   328  // assertRefCount checks whether reference count for a given state
   329  // at a given index is equal to expected amount.
   330  func assertRefCount(t *testing.T, b *BeaconState, idx fieldIndex, want uint) {
   331  	if cnt := b.sharedFieldReferences[idx].Refs(); cnt != want {
   332  		t.Errorf("Unexpected count of references for index %d, want: %v, got: %v", idx, want, cnt)
   333  	}
   334  }
   335  
   336  // assertValFound checks whether item with a given value exists in list.
   337  func assertValFound(t *testing.T, vals [][]byte, val []byte) {
   338  	for i := range vals {
   339  		if reflect.DeepEqual(vals[i], val) {
   340  			return
   341  		}
   342  	}
   343  	t.Log(string(debug.Stack()))
   344  	t.Fatalf("Expected value not found (%v), want: %v", vals, val)
   345  }
   346  
   347  // assertValNotFound checks whether item with a given value doesn't exist in list.
   348  func assertValNotFound(t *testing.T, vals [][]byte, val []byte) {
   349  	for i := range vals {
   350  		if reflect.DeepEqual(vals[i], val) {
   351  			t.Log(string(debug.Stack()))
   352  			t.Errorf("Unexpected value found (%v),: %v", vals, val)
   353  			return
   354  		}
   355  	}
   356  }