github.com/aquanetwork/aquachain@v1.7.8/core/state/statedb_test.go (about) 1 // Copyright 2016 The aquachain Authors 2 // This file is part of the aquachain library. 3 // 4 // The aquachain library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The aquachain library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the aquachain library. If not, see <http://www.gnu.org/licenses/>. 16 17 package state 18 19 import ( 20 "bytes" 21 "encoding/binary" 22 "fmt" 23 "math" 24 "math/big" 25 "math/rand" 26 "reflect" 27 "strings" 28 "testing" 29 "testing/quick" 30 31 check "gopkg.in/check.v1" 32 33 "gitlab.com/aquachain/aquachain/aquadb" 34 "gitlab.com/aquachain/aquachain/common" 35 "gitlab.com/aquachain/aquachain/core/types" 36 ) 37 38 // Tests that updating a state trie does not leak any database writes prior to 39 // actually committing the state. 40 func TestUpdateLeaks(t *testing.T) { 41 // Create an empty state database 42 db := aquadb.NewMemDatabase() 43 state, _ := New(common.Hash{}, NewDatabase(db)) 44 45 // Update it with some accounts 46 for i := byte(0); i < 255; i++ { 47 addr := common.BytesToAddress([]byte{i}) 48 state.AddBalance(addr, big.NewInt(int64(11*i))) 49 state.SetNonce(addr, uint64(42*i)) 50 if i%2 == 0 { 51 state.SetState(addr, common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i})) 52 } 53 if i%3 == 0 { 54 state.SetCode(addr, []byte{i, i, i, i, i}) 55 } 56 state.IntermediateRoot(false) 57 } 58 // Ensure that no data was leaked into the database 59 for _, key := range db.Keys() { 60 value, _ := db.Get(key) 61 t.Errorf("State leaked into database: %x -> %x", key, value) 62 } 63 } 64 65 // Tests that no intermediate state of an object is stored into the database, 66 // only the one right before the commit. 67 func TestIntermediateLeaks(t *testing.T) { 68 // Create two state databases, one transitioning to the final state, the other final from the beginning 69 transDb := aquadb.NewMemDatabase() 70 finalDb := aquadb.NewMemDatabase() 71 transState, _ := New(common.Hash{}, NewDatabase(transDb)) 72 finalState, _ := New(common.Hash{}, NewDatabase(finalDb)) 73 74 modify := func(state *StateDB, addr common.Address, i, tweak byte) { 75 state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak))) 76 state.SetNonce(addr, uint64(42*i+tweak)) 77 if i%2 == 0 { 78 state.SetState(addr, common.Hash{i, i, i, 0}, common.Hash{}) 79 state.SetState(addr, common.Hash{i, i, i, tweak}, common.Hash{i, i, i, i, tweak}) 80 } 81 if i%3 == 0 { 82 state.SetCode(addr, []byte{i, i, i, i, i, tweak}) 83 } 84 } 85 86 // Modify the transient state. 87 for i := byte(0); i < 255; i++ { 88 modify(transState, common.Address{byte(i)}, i, 0) 89 } 90 // Write modifications to trie. 91 transState.IntermediateRoot(false) 92 93 // Overwrite all the data with new values in the transient database. 94 for i := byte(0); i < 255; i++ { 95 modify(transState, common.Address{byte(i)}, i, 99) 96 modify(finalState, common.Address{byte(i)}, i, 99) 97 } 98 99 // Commit and cross check the databases. 100 if _, err := transState.Commit(false); err != nil { 101 t.Fatalf("failed to commit transition state: %v", err) 102 } 103 if _, err := finalState.Commit(false); err != nil { 104 t.Fatalf("failed to commit final state: %v", err) 105 } 106 for _, key := range finalDb.Keys() { 107 if _, err := transDb.Get(key); err != nil { 108 val, _ := finalDb.Get(key) 109 t.Errorf("entry missing from the transition database: %x -> %x", key, val) 110 } 111 } 112 for _, key := range transDb.Keys() { 113 if _, err := finalDb.Get(key); err != nil { 114 val, _ := transDb.Get(key) 115 t.Errorf("extra entry in the transition database: %x -> %x", key, val) 116 } 117 } 118 } 119 120 // TestCopy tests that copying a statedb object indeed makes the original and 121 // the copy independent of each other. This test is a regression test against 122 // https://gitlab.com/aquachain/aquachain/pull/15549. 123 func TestCopy(t *testing.T) { 124 // Create a random state test to copy and modify "independently" 125 db := aquadb.NewMemDatabase() 126 orig, _ := New(common.Hash{}, NewDatabase(db)) 127 128 for i := byte(0); i < 255; i++ { 129 obj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i})) 130 obj.AddBalance(big.NewInt(int64(i))) 131 orig.updateStateObject(obj) 132 } 133 orig.Finalise(false) 134 135 // Copy the state, modify both in-memory 136 copy := orig.Copy() 137 138 for i := byte(0); i < 255; i++ { 139 origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i})) 140 copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i})) 141 142 origObj.AddBalance(big.NewInt(2 * int64(i))) 143 copyObj.AddBalance(big.NewInt(3 * int64(i))) 144 145 orig.updateStateObject(origObj) 146 copy.updateStateObject(copyObj) 147 } 148 // Finalise the changes on both concurrently 149 done := make(chan struct{}) 150 go func() { 151 orig.Finalise(true) 152 close(done) 153 }() 154 copy.Finalise(true) 155 <-done 156 157 // Verify that the two states have been updated independently 158 for i := byte(0); i < 255; i++ { 159 origObj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i})) 160 copyObj := copy.GetOrNewStateObject(common.BytesToAddress([]byte{i})) 161 162 if want := big.NewInt(3 * int64(i)); origObj.Balance().Cmp(want) != 0 { 163 t.Errorf("orig obj %d: balance mismatch: have %v, want %v", i, origObj.Balance(), want) 164 } 165 if want := big.NewInt(4 * int64(i)); copyObj.Balance().Cmp(want) != 0 { 166 t.Errorf("copy obj %d: balance mismatch: have %v, want %v", i, copyObj.Balance(), want) 167 } 168 } 169 } 170 171 func TestSnapshotRandom(t *testing.T) { 172 config := &quick.Config{MaxCount: 1000} 173 err := quick.Check((*snapshotTest).run, config) 174 if cerr, ok := err.(*quick.CheckError); ok { 175 test := cerr.In[0].(*snapshotTest) 176 t.Errorf("%v:\n%s", test.err, test) 177 } else if err != nil { 178 t.Error(err) 179 } 180 } 181 182 // A snapshotTest checks that reverting StateDB snapshots properly undoes all changes 183 // captured by the snapshot. Instances of this test with pseudorandom content are created 184 // by Generate. 185 // 186 // The test works as follows: 187 // 188 // A new state is created and all actions are applied to it. Several snapshots are taken 189 // in between actions. The test then reverts each snapshot. For each snapshot the actions 190 // leading up to it are replayed on a fresh, empty state. The behaviour of all public 191 // accessor methods on the reverted state must match the return value of the equivalent 192 // methods on the replayed state. 193 type snapshotTest struct { 194 addrs []common.Address // all account addresses 195 actions []testAction // modifications to the state 196 snapshots []int // actions indexes at which snapshot is taken 197 err error // failure details are reported through this field 198 } 199 200 type testAction struct { 201 name string 202 fn func(testAction, *StateDB) 203 args []int64 204 noAddr bool 205 } 206 207 // newTestAction creates a random action that changes state. 208 func newTestAction(addr common.Address, r *rand.Rand) testAction { 209 actions := []testAction{ 210 { 211 name: "SetBalance", 212 fn: func(a testAction, s *StateDB) { 213 s.SetBalance(addr, big.NewInt(a.args[0])) 214 }, 215 args: make([]int64, 1), 216 }, 217 { 218 name: "AddBalance", 219 fn: func(a testAction, s *StateDB) { 220 s.AddBalance(addr, big.NewInt(a.args[0])) 221 }, 222 args: make([]int64, 1), 223 }, 224 { 225 name: "SetNonce", 226 fn: func(a testAction, s *StateDB) { 227 s.SetNonce(addr, uint64(a.args[0])) 228 }, 229 args: make([]int64, 1), 230 }, 231 { 232 name: "SetState", 233 fn: func(a testAction, s *StateDB) { 234 var key, val common.Hash 235 binary.BigEndian.PutUint16(key[:], uint16(a.args[0])) 236 binary.BigEndian.PutUint16(val[:], uint16(a.args[1])) 237 s.SetState(addr, key, val) 238 }, 239 args: make([]int64, 2), 240 }, 241 { 242 name: "SetCode", 243 fn: func(a testAction, s *StateDB) { 244 code := make([]byte, 16) 245 binary.BigEndian.PutUint64(code, uint64(a.args[0])) 246 binary.BigEndian.PutUint64(code[8:], uint64(a.args[1])) 247 s.SetCode(addr, code) 248 }, 249 args: make([]int64, 2), 250 }, 251 { 252 name: "CreateAccount", 253 fn: func(a testAction, s *StateDB) { 254 s.CreateAccount(addr) 255 }, 256 }, 257 { 258 name: "Suicide", 259 fn: func(a testAction, s *StateDB) { 260 s.Suicide(addr) 261 }, 262 }, 263 { 264 name: "AddRefund", 265 fn: func(a testAction, s *StateDB) { 266 s.AddRefund(uint64(a.args[0])) 267 }, 268 args: make([]int64, 1), 269 noAddr: true, 270 }, 271 { 272 name: "AddLog", 273 fn: func(a testAction, s *StateDB) { 274 data := make([]byte, 2) 275 binary.BigEndian.PutUint16(data, uint16(a.args[0])) 276 s.AddLog(&types.Log{Address: addr, Data: data}) 277 }, 278 args: make([]int64, 1), 279 }, 280 } 281 action := actions[r.Intn(len(actions))] 282 var nameargs []string 283 if !action.noAddr { 284 nameargs = append(nameargs, addr.Hex()) 285 } 286 for _, i := range action.args { 287 action.args[i] = rand.Int63n(100) 288 nameargs = append(nameargs, fmt.Sprint(action.args[i])) 289 } 290 action.name += strings.Join(nameargs, ", ") 291 return action 292 } 293 294 // Generate returns a new snapshot test of the given size. All randomness is 295 // derived from r. 296 func (*snapshotTest) Generate(r *rand.Rand, size int) reflect.Value { 297 // Generate random actions. 298 addrs := make([]common.Address, 50) 299 for i := range addrs { 300 addrs[i][0] = byte(i) 301 } 302 actions := make([]testAction, size) 303 for i := range actions { 304 addr := addrs[r.Intn(len(addrs))] 305 actions[i] = newTestAction(addr, r) 306 } 307 // Generate snapshot indexes. 308 nsnapshots := int(math.Sqrt(float64(size))) 309 if size > 0 && nsnapshots == 0 { 310 nsnapshots = 1 311 } 312 snapshots := make([]int, nsnapshots) 313 snaplen := len(actions) / nsnapshots 314 for i := range snapshots { 315 // Try to place the snapshots some number of actions apart from each other. 316 snapshots[i] = (i * snaplen) + r.Intn(snaplen) 317 } 318 return reflect.ValueOf(&snapshotTest{addrs, actions, snapshots, nil}) 319 } 320 321 func (test *snapshotTest) String() string { 322 out := new(bytes.Buffer) 323 sindex := 0 324 for i, action := range test.actions { 325 if len(test.snapshots) > sindex && i == test.snapshots[sindex] { 326 fmt.Fprintf(out, "---- snapshot %d ----\n", sindex) 327 sindex++ 328 } 329 fmt.Fprintf(out, "%4d: %s\n", i, action.name) 330 } 331 return out.String() 332 } 333 334 func (test *snapshotTest) run() bool { 335 // Run all actions and create snapshots. 336 var ( 337 db = aquadb.NewMemDatabase() 338 state, _ = New(common.Hash{}, NewDatabase(db)) 339 snapshotRevs = make([]int, len(test.snapshots)) 340 sindex = 0 341 ) 342 for i, action := range test.actions { 343 if len(test.snapshots) > sindex && i == test.snapshots[sindex] { 344 snapshotRevs[sindex] = state.Snapshot() 345 sindex++ 346 } 347 action.fn(action, state) 348 } 349 // Revert all snapshots in reverse order. Each revert must yield a state 350 // that is equivalent to fresh state with all actions up the snapshot applied. 351 for sindex--; sindex >= 0; sindex-- { 352 checkstate, _ := New(common.Hash{}, state.Database()) 353 for _, action := range test.actions[:test.snapshots[sindex]] { 354 action.fn(action, checkstate) 355 } 356 state.RevertToSnapshot(snapshotRevs[sindex]) 357 if err := test.checkEqual(state, checkstate); err != nil { 358 test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err) 359 return false 360 } 361 } 362 return true 363 } 364 365 // checkEqual checks that methods of state and checkstate return the same values. 366 func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { 367 for _, addr := range test.addrs { 368 var err error 369 checkeq := func(op string, a, b interface{}) bool { 370 if err == nil && !reflect.DeepEqual(a, b) { 371 err = fmt.Errorf("got %s(%s) == %v, want %v", op, addr.Hex(), a, b) 372 return false 373 } 374 return true 375 } 376 // Check basic accessor methods. 377 checkeq("Exist", state.Exist(addr), checkstate.Exist(addr)) 378 checkeq("HasSuicided", state.HasSuicided(addr), checkstate.HasSuicided(addr)) 379 checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr)) 380 checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr)) 381 checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr)) 382 checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr)) 383 checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr)) 384 // Check storage. 385 if obj := state.getStateObject(addr); obj != nil { 386 state.ForEachStorage(addr, func(key, val common.Hash) bool { 387 return checkeq("GetState("+key.Hex()+")", val, checkstate.GetState(addr, key)) 388 }) 389 checkstate.ForEachStorage(addr, func(key, checkval common.Hash) bool { 390 return checkeq("GetState("+key.Hex()+")", state.GetState(addr, key), checkval) 391 }) 392 } 393 if err != nil { 394 return err 395 } 396 } 397 398 if state.GetRefund() != checkstate.GetRefund() { 399 return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d", 400 state.GetRefund(), checkstate.GetRefund()) 401 } 402 if !reflect.DeepEqual(state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{})) { 403 return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v", 404 state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{})) 405 } 406 return nil 407 } 408 409 func (s *StateSuite) TestTouchDelete(c *check.C) { 410 s.state.GetOrNewStateObject(common.Address{}) 411 root, _ := s.state.Commit(false) 412 s.state.Reset(root) 413 414 snapshot := s.state.Snapshot() 415 s.state.AddBalance(common.Address{}, new(big.Int)) 416 if len(s.state.stateObjectsDirty) != 1 { 417 c.Fatal("expected one dirty state object") 418 } 419 s.state.RevertToSnapshot(snapshot) 420 if len(s.state.stateObjectsDirty) != 0 { 421 c.Fatal("expected no dirty state object") 422 } 423 }