github.com/prysmaticlabs/prysm@v1.4.4/beacon-chain/state/v1/references_test.go (about) 1 package v1 2 3 import ( 4 "reflect" 5 "runtime" 6 "runtime/debug" 7 "testing" 8 9 "github.com/prysmaticlabs/prysm/shared/copyutil" 10 11 "github.com/prysmaticlabs/go-bitfield" 12 iface "github.com/prysmaticlabs/prysm/beacon-chain/state/interface" 13 p2ppb "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1" 14 ethpb "github.com/prysmaticlabs/prysm/proto/eth/v1alpha1" 15 "github.com/prysmaticlabs/prysm/shared/bytesutil" 16 "github.com/prysmaticlabs/prysm/shared/testutil/assert" 17 "github.com/prysmaticlabs/prysm/shared/testutil/require" 18 ) 19 20 func TestStateReferenceSharing_Finalizer(t *testing.T) { 21 // This test showcases the logic on a the RandaoMixes field with the GC finalizer. 22 23 a, err := InitializeFromProtoUnsafe(&p2ppb.BeaconState{RandaoMixes: [][]byte{[]byte("foo")}}) 24 require.NoError(t, err) 25 assert.Equal(t, uint(1), a.sharedFieldReferences[randaoMixes].Refs(), "Expected a single reference for RANDAO mixes") 26 27 func() { 28 // Create object in a different scope for GC 29 b := a.Copy() 30 assert.Equal(t, uint(2), a.sharedFieldReferences[randaoMixes].Refs(), "Expected 2 references to RANDAO mixes") 31 _ = b 32 }() 33 34 runtime.GC() // Should run finalizer on object b 35 assert.Equal(t, uint(1), a.sharedFieldReferences[randaoMixes].Refs(), "Expected 1 shared reference to RANDAO mixes!") 36 37 copied := a.Copy() 38 b, ok := copied.(*BeaconState) 39 require.Equal(t, true, ok) 40 assert.Equal(t, uint(2), b.sharedFieldReferences[randaoMixes].Refs(), "Expected 2 shared references to RANDAO mixes") 41 require.NoError(t, b.UpdateRandaoMixesAtIndex(0, []byte("bar"))) 42 if b.sharedFieldReferences[randaoMixes].Refs() != 1 || a.sharedFieldReferences[randaoMixes].Refs() != 1 { 43 t.Error("Expected 1 shared reference to RANDAO mix for both a and b") 44 } 45 } 46 47 func TestStateReferenceCopy_NoUnexpectedRootsMutation(t *testing.T) { 48 root1, root2 := bytesutil.ToBytes32([]byte("foo")), bytesutil.ToBytes32([]byte("bar")) 49 a, err := InitializeFromProtoUnsafe(&p2ppb.BeaconState{ 50 BlockRoots: [][]byte{ 51 root1[:], 52 }, 53 StateRoots: [][]byte{ 54 root1[:], 55 }, 56 }) 57 require.NoError(t, err) 58 assertRefCount(t, a, blockRoots, 1) 59 assertRefCount(t, a, stateRoots, 1) 60 61 // Copy, increases reference count. 62 copied := a.Copy() 63 b, ok := copied.(*BeaconState) 64 require.Equal(t, true, ok) 65 assertRefCount(t, a, blockRoots, 2) 66 assertRefCount(t, a, stateRoots, 2) 67 assertRefCount(t, b, blockRoots, 2) 68 assertRefCount(t, b, stateRoots, 2) 69 assert.Equal(t, 1, len(b.state.GetBlockRoots()), "No block roots found") 70 assert.Equal(t, 1, len(b.state.GetStateRoots()), "No state roots found") 71 72 // Assert shared state. 73 blockRootsA := a.state.GetBlockRoots() 74 stateRootsA := a.state.GetStateRoots() 75 blockRootsB := b.state.GetBlockRoots() 76 stateRootsB := b.state.GetStateRoots() 77 if len(blockRootsA) != len(blockRootsB) || len(blockRootsA) < 1 { 78 t.Errorf("Unexpected number of block roots, want: %v", 1) 79 } 80 if len(stateRootsA) != len(stateRootsB) || len(stateRootsA) < 1 { 81 t.Errorf("Unexpected number of state roots, want: %v", 1) 82 } 83 assertValFound(t, blockRootsA, root1[:]) 84 assertValFound(t, blockRootsB, root1[:]) 85 assertValFound(t, stateRootsA, root1[:]) 86 assertValFound(t, stateRootsB, root1[:]) 87 88 // Mutator should only affect calling state: a. 89 require.NoError(t, a.UpdateBlockRootAtIndex(0, root2)) 90 require.NoError(t, a.UpdateStateRootAtIndex(0, root2)) 91 92 // Assert no shared state mutation occurred only on state a (copy on write). 93 assertValNotFound(t, a.state.GetBlockRoots(), root1[:]) 94 assertValNotFound(t, a.state.GetStateRoots(), root1[:]) 95 assertValFound(t, a.state.GetBlockRoots(), root2[:]) 96 assertValFound(t, a.state.GetStateRoots(), root2[:]) 97 assertValFound(t, b.state.GetBlockRoots(), root1[:]) 98 assertValFound(t, b.state.GetStateRoots(), root1[:]) 99 if len(blockRootsA) != len(blockRootsB) || len(blockRootsA) < 1 { 100 t.Errorf("Unexpected number of block roots, want: %v", 1) 101 } 102 if len(stateRootsA) != len(stateRootsB) || len(stateRootsA) < 1 { 103 t.Errorf("Unexpected number of state roots, want: %v", 1) 104 } 105 assert.DeepEqual(t, root2[:], a.state.GetBlockRoots()[0], "Expected mutation not found") 106 assert.DeepEqual(t, root2[:], a.state.GetStateRoots()[0], "Expected mutation not found") 107 assert.DeepEqual(t, root1[:], blockRootsB[0], "Unexpected mutation found") 108 assert.DeepEqual(t, root1[:], stateRootsB[0], "Unexpected mutation found") 109 110 // Copy on write happened, reference counters are reset. 111 assertRefCount(t, a, blockRoots, 1) 112 assertRefCount(t, a, stateRoots, 1) 113 assertRefCount(t, b, blockRoots, 1) 114 assertRefCount(t, b, stateRoots, 1) 115 } 116 117 func TestStateReferenceCopy_NoUnexpectedRandaoMutation(t *testing.T) { 118 119 val1, val2 := []byte("foo"), []byte("bar") 120 a, err := InitializeFromProtoUnsafe(&p2ppb.BeaconState{ 121 RandaoMixes: [][]byte{ 122 val1, 123 }, 124 }) 125 require.NoError(t, err) 126 assertRefCount(t, a, randaoMixes, 1) 127 128 // Copy, increases reference count. 129 copied := a.Copy() 130 b, ok := copied.(*BeaconState) 131 require.Equal(t, true, ok) 132 assertRefCount(t, a, randaoMixes, 2) 133 assertRefCount(t, b, randaoMixes, 2) 134 assert.Equal(t, 1, len(b.state.GetRandaoMixes()), "No randao mixes found") 135 136 // Assert shared state. 137 mixesA := a.state.GetRandaoMixes() 138 mixesB := b.state.GetRandaoMixes() 139 if len(mixesA) != len(mixesB) || len(mixesA) < 1 { 140 t.Errorf("Unexpected number of mix values, want: %v", 1) 141 } 142 assertValFound(t, mixesA, val1) 143 assertValFound(t, mixesB, val1) 144 145 // Mutator should only affect calling state: a. 146 require.NoError(t, a.UpdateRandaoMixesAtIndex(0, val2)) 147 148 // Assert no shared state mutation occurred only on state a (copy on write). 149 if len(mixesA) != len(mixesB) || len(mixesA) < 1 { 150 t.Errorf("Unexpected number of mix values, want: %v", 1) 151 } 152 assertValFound(t, a.state.GetRandaoMixes(), val2) 153 assertValNotFound(t, a.state.GetRandaoMixes(), val1) 154 assertValFound(t, b.state.GetRandaoMixes(), val1) 155 assertValNotFound(t, b.state.GetRandaoMixes(), val2) 156 assertValFound(t, mixesB, val1) 157 assertValNotFound(t, mixesB, val2) 158 assert.DeepEqual(t, val2, a.state.GetRandaoMixes()[0], "Expected mutation not found") 159 assert.DeepEqual(t, val1, mixesB[0], "Unexpected mutation found") 160 161 // Copy on write happened, reference counters are reset. 162 assertRefCount(t, a, randaoMixes, 1) 163 assertRefCount(t, b, randaoMixes, 1) 164 } 165 166 func TestStateReferenceCopy_NoUnexpectedAttestationsMutation(t *testing.T) { 167 assertAttFound := func(vals []*p2ppb.PendingAttestation, val uint64) { 168 for i := range vals { 169 if reflect.DeepEqual(vals[i].AggregationBits, bitfield.NewBitlist(val)) { 170 return 171 } 172 } 173 t.Log(string(debug.Stack())) 174 t.Fatalf("Expected attestation not found (%v), want: %v", vals, val) 175 } 176 assertAttNotFound := func(vals []*p2ppb.PendingAttestation, val uint64) { 177 for i := range vals { 178 if reflect.DeepEqual(vals[i].AggregationBits, bitfield.NewBitlist(val)) { 179 t.Log(string(debug.Stack())) 180 t.Fatalf("Unexpected attestation found (%v): %v", vals, val) 181 return 182 } 183 } 184 } 185 186 a, err := InitializeFromProtoUnsafe(&p2ppb.BeaconState{}) 187 require.NoError(t, err) 188 assertRefCount(t, a, previousEpochAttestations, 1) 189 assertRefCount(t, a, currentEpochAttestations, 1) 190 191 // Update initial state. 192 atts := []*p2ppb.PendingAttestation{ 193 {AggregationBits: bitfield.NewBitlist(1)}, 194 {AggregationBits: bitfield.NewBitlist(2)}, 195 } 196 a.setPreviousEpochAttestations(atts[:1]) 197 a.setCurrentEpochAttestations(atts[:1]) 198 curAtt, err := a.CurrentEpochAttestations() 199 require.NoError(t, err) 200 assert.Equal(t, 1, len(curAtt), "Unexpected number of attestations") 201 preAtt, err := a.PreviousEpochAttestations() 202 require.NoError(t, err) 203 assert.Equal(t, 1, len(preAtt), "Unexpected number of attestations") 204 205 // Copy, increases reference count. 206 copied := a.Copy() 207 b, ok := copied.(*BeaconState) 208 require.Equal(t, true, ok) 209 assertRefCount(t, a, previousEpochAttestations, 2) 210 assertRefCount(t, a, currentEpochAttestations, 2) 211 assertRefCount(t, b, previousEpochAttestations, 2) 212 assertRefCount(t, b, currentEpochAttestations, 2) 213 assert.Equal(t, 1, len(b.state.GetPreviousEpochAttestations()), "Unexpected number of attestations") 214 assert.Equal(t, 1, len(b.state.GetCurrentEpochAttestations()), "Unexpected number of attestations") 215 216 // Assert shared state. 217 curAttsA := a.state.GetCurrentEpochAttestations() 218 prevAttsA := a.state.GetPreviousEpochAttestations() 219 curAttsB := b.state.GetCurrentEpochAttestations() 220 prevAttsB := b.state.GetPreviousEpochAttestations() 221 if len(curAttsA) != len(curAttsB) || len(curAttsA) < 1 { 222 t.Errorf("Unexpected number of attestations, want: %v", 1) 223 } 224 if len(prevAttsA) != len(prevAttsB) || len(prevAttsA) < 1 { 225 t.Errorf("Unexpected number of attestations, want: %v", 1) 226 } 227 assertAttFound(curAttsA, 1) 228 assertAttFound(prevAttsA, 1) 229 assertAttFound(curAttsB, 1) 230 assertAttFound(prevAttsB, 1) 231 232 // Extends state a attestations. 233 require.NoError(t, a.AppendCurrentEpochAttestations(atts[1])) 234 require.NoError(t, a.AppendPreviousEpochAttestations(atts[1])) 235 curAtt, err = a.CurrentEpochAttestations() 236 require.NoError(t, err) 237 assert.Equal(t, 2, len(curAtt), "Unexpected number of attestations") 238 preAtt, err = a.PreviousEpochAttestations() 239 require.NoError(t, err) 240 assert.Equal(t, 2, len(preAtt), "Unexpected number of attestations") 241 assertAttFound(a.state.GetCurrentEpochAttestations(), 1) 242 assertAttFound(a.state.GetPreviousEpochAttestations(), 1) 243 assertAttFound(a.state.GetCurrentEpochAttestations(), 2) 244 assertAttFound(a.state.GetPreviousEpochAttestations(), 2) 245 assertAttFound(b.state.GetCurrentEpochAttestations(), 1) 246 assertAttFound(b.state.GetPreviousEpochAttestations(), 1) 247 assertAttNotFound(b.state.GetCurrentEpochAttestations(), 2) 248 assertAttNotFound(b.state.GetPreviousEpochAttestations(), 2) 249 250 // Mutator should only affect calling state: a. 251 applyToEveryAttestation := func(state *p2ppb.BeaconState) { 252 // One MUST copy on write. 253 atts = make([]*p2ppb.PendingAttestation, len(state.CurrentEpochAttestations)) 254 copy(atts, state.CurrentEpochAttestations) 255 state.CurrentEpochAttestations = atts 256 for i := range state.GetCurrentEpochAttestations() { 257 att := copyutil.CopyPendingAttestation(state.CurrentEpochAttestations[i]) 258 att.AggregationBits = bitfield.NewBitlist(3) 259 state.CurrentEpochAttestations[i] = att 260 } 261 262 atts = make([]*p2ppb.PendingAttestation, len(state.PreviousEpochAttestations)) 263 copy(atts, state.PreviousEpochAttestations) 264 state.PreviousEpochAttestations = atts 265 for i := range state.GetPreviousEpochAttestations() { 266 att := copyutil.CopyPendingAttestation(state.PreviousEpochAttestations[i]) 267 att.AggregationBits = bitfield.NewBitlist(3) 268 state.PreviousEpochAttestations[i] = att 269 } 270 } 271 applyToEveryAttestation(a.state) 272 273 // Assert no shared state mutation occurred only on state a (copy on write). 274 assertAttFound(a.state.GetCurrentEpochAttestations(), 3) 275 assertAttFound(a.state.GetPreviousEpochAttestations(), 3) 276 assertAttNotFound(a.state.GetCurrentEpochAttestations(), 1) 277 assertAttNotFound(a.state.GetPreviousEpochAttestations(), 1) 278 assertAttNotFound(a.state.GetCurrentEpochAttestations(), 2) 279 assertAttNotFound(a.state.GetPreviousEpochAttestations(), 2) 280 // State b must be unaffected. 281 assertAttNotFound(b.state.GetCurrentEpochAttestations(), 3) 282 assertAttNotFound(b.state.GetPreviousEpochAttestations(), 3) 283 assertAttFound(b.state.GetCurrentEpochAttestations(), 1) 284 assertAttFound(b.state.GetPreviousEpochAttestations(), 1) 285 assertAttNotFound(b.state.GetCurrentEpochAttestations(), 2) 286 assertAttNotFound(b.state.GetPreviousEpochAttestations(), 2) 287 288 // Copy on write happened, reference counters are reset. 289 assertRefCount(t, a, currentEpochAttestations, 1) 290 assertRefCount(t, b, currentEpochAttestations, 1) 291 assertRefCount(t, a, previousEpochAttestations, 1) 292 assertRefCount(t, b, previousEpochAttestations, 1) 293 } 294 295 func TestValidatorReferences_RemainsConsistent(t *testing.T) { 296 a, err := InitializeFromProtoUnsafe(&p2ppb.BeaconState{ 297 Validators: []*ethpb.Validator{ 298 {PublicKey: []byte{'A'}}, 299 {PublicKey: []byte{'B'}}, 300 {PublicKey: []byte{'C'}}, 301 {PublicKey: []byte{'D'}}, 302 {PublicKey: []byte{'E'}}, 303 }, 304 }) 305 require.NoError(t, err) 306 307 // Create a second state. 308 copied := a.Copy() 309 b, ok := copied.(*BeaconState) 310 require.Equal(t, true, ok) 311 312 // Update First Validator. 313 assert.NoError(t, a.UpdateValidatorAtIndex(0, ðpb.Validator{PublicKey: []byte{'Z'}})) 314 315 assert.DeepNotEqual(t, a.state.Validators[0], b.state.Validators[0], "validators are equal when they are supposed to be different") 316 // Modify all validators from copied state. 317 assert.NoError(t, b.ApplyToEveryValidator(func(idx int, val *ethpb.Validator) (bool, *ethpb.Validator, error) { 318 return true, ðpb.Validator{PublicKey: []byte{'V'}}, nil 319 })) 320 321 // Ensure reference is properly accounted for. 322 assert.NoError(t, a.ReadFromEveryValidator(func(idx int, val iface.ReadOnlyValidator) error { 323 assert.NotEqual(t, bytesutil.ToBytes48([]byte{'V'}), val.PublicKey()) 324 return nil 325 })) 326 } 327 328 // assertRefCount checks whether reference count for a given state 329 // at a given index is equal to expected amount. 330 func assertRefCount(t *testing.T, b *BeaconState, idx fieldIndex, want uint) { 331 if cnt := b.sharedFieldReferences[idx].Refs(); cnt != want { 332 t.Errorf("Unexpected count of references for index %d, want: %v, got: %v", idx, want, cnt) 333 } 334 } 335 336 // assertValFound checks whether item with a given value exists in list. 337 func assertValFound(t *testing.T, vals [][]byte, val []byte) { 338 for i := range vals { 339 if reflect.DeepEqual(vals[i], val) { 340 return 341 } 342 } 343 t.Log(string(debug.Stack())) 344 t.Fatalf("Expected value not found (%v), want: %v", vals, val) 345 } 346 347 // assertValNotFound checks whether item with a given value doesn't exist in list. 348 func assertValNotFound(t *testing.T, vals [][]byte, val []byte) { 349 for i := range vals { 350 if reflect.DeepEqual(vals[i], val) { 351 t.Log(string(debug.Stack())) 352 t.Errorf("Unexpected value found (%v),: %v", vals, val) 353 return 354 } 355 } 356 }