github.com/ylsGit/go-ethereum@v1.6.5/core/state/statedb_test.go (about) 1 // Copyright 2016 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package state 18 19 import ( 20 "bytes" 21 "encoding/binary" 22 "fmt" 23 "math" 24 "math/big" 25 "math/rand" 26 "reflect" 27 "strings" 28 "testing" 29 "testing/quick" 30 31 "github.com/ethereum/go-ethereum/common" 32 "github.com/ethereum/go-ethereum/core/types" 33 "github.com/ethereum/go-ethereum/ethdb" 34 ) 35 36 // Tests that updating a state trie does not leak any database writes prior to 37 // actually committing the state. 38 func TestUpdateLeaks(t *testing.T) { 39 // Create an empty state database 40 db, _ := ethdb.NewMemDatabase() 41 state, _ := New(common.Hash{}, db) 42 43 // Update it with some accounts 44 for i := byte(0); i < 255; i++ { 45 addr := common.BytesToAddress([]byte{i}) 46 state.AddBalance(addr, big.NewInt(int64(11*i))) 47 state.SetNonce(addr, uint64(42*i)) 48 if i%2 == 0 { 49 state.SetState(addr, common.BytesToHash([]byte{i, i, i}), common.BytesToHash([]byte{i, i, i, i})) 50 } 51 if i%3 == 0 { 52 state.SetCode(addr, []byte{i, i, i, i, i}) 53 } 54 state.IntermediateRoot(false) 55 } 56 // Ensure that no data was leaked into the database 57 for _, key := range db.Keys() { 58 value, _ := db.Get(key) 59 t.Errorf("State leaked into database: %x -> %x", key, value) 60 } 61 } 62 63 // Tests that no intermediate state of an object is stored into the database, 64 // only the one right before the commit. 65 func TestIntermediateLeaks(t *testing.T) { 66 // Create two state databases, one transitioning to the final state, the other final from the beginning 67 transDb, _ := ethdb.NewMemDatabase() 68 finalDb, _ := ethdb.NewMemDatabase() 69 transState, _ := New(common.Hash{}, transDb) 70 finalState, _ := New(common.Hash{}, finalDb) 71 72 modify := func(state *StateDB, addr common.Address, i, tweak byte) { 73 state.SetBalance(addr, big.NewInt(int64(11*i)+int64(tweak))) 74 state.SetNonce(addr, uint64(42*i+tweak)) 75 if i%2 == 0 { 76 state.SetState(addr, common.Hash{i, i, i, 0}, common.Hash{}) 77 state.SetState(addr, common.Hash{i, i, i, tweak}, common.Hash{i, i, i, i, tweak}) 78 } 79 if i%3 == 0 { 80 state.SetCode(addr, []byte{i, i, i, i, i, tweak}) 81 } 82 } 83 84 // Modify the transient state. 85 for i := byte(0); i < 255; i++ { 86 modify(transState, common.Address{byte(i)}, i, 0) 87 } 88 // Write modifications to trie. 89 transState.IntermediateRoot(false) 90 91 // Overwrite all the data with new values in the transient database. 92 for i := byte(0); i < 255; i++ { 93 modify(transState, common.Address{byte(i)}, i, 99) 94 modify(finalState, common.Address{byte(i)}, i, 99) 95 } 96 97 // Commit and cross check the databases. 98 if _, err := transState.Commit(false); err != nil { 99 t.Fatalf("failed to commit transition state: %v", err) 100 } 101 if _, err := finalState.Commit(false); err != nil { 102 t.Fatalf("failed to commit final state: %v", err) 103 } 104 for _, key := range finalDb.Keys() { 105 if _, err := transDb.Get(key); err != nil { 106 val, _ := finalDb.Get(key) 107 t.Errorf("entry missing from the transition database: %x -> %x", key, val) 108 } 109 } 110 for _, key := range transDb.Keys() { 111 if _, err := finalDb.Get(key); err != nil { 112 val, _ := transDb.Get(key) 113 t.Errorf("extra entry in the transition database: %x -> %x", key, val) 114 } 115 } 116 } 117 118 func TestSnapshotRandom(t *testing.T) { 119 config := &quick.Config{MaxCount: 1000} 120 err := quick.Check((*snapshotTest).run, config) 121 if cerr, ok := err.(*quick.CheckError); ok { 122 test := cerr.In[0].(*snapshotTest) 123 t.Errorf("%v:\n%s", test.err, test) 124 } else if err != nil { 125 t.Error(err) 126 } 127 } 128 129 // A snapshotTest checks that reverting StateDB snapshots properly undoes all changes 130 // captured by the snapshot. Instances of this test with pseudorandom content are created 131 // by Generate. 132 // 133 // The test works as follows: 134 // 135 // A new state is created and all actions are applied to it. Several snapshots are taken 136 // in between actions. The test then reverts each snapshot. For each snapshot the actions 137 // leading up to it are replayed on a fresh, empty state. The behaviour of all public 138 // accessor methods on the reverted state must match the return value of the equivalent 139 // methods on the replayed state. 140 type snapshotTest struct { 141 addrs []common.Address // all account addresses 142 actions []testAction // modifications to the state 143 snapshots []int // actions indexes at which snapshot is taken 144 err error // failure details are reported through this field 145 } 146 147 type testAction struct { 148 name string 149 fn func(testAction, *StateDB) 150 args []int64 151 noAddr bool 152 } 153 154 // newTestAction creates a random action that changes state. 155 func newTestAction(addr common.Address, r *rand.Rand) testAction { 156 actions := []testAction{ 157 { 158 name: "SetBalance", 159 fn: func(a testAction, s *StateDB) { 160 s.SetBalance(addr, big.NewInt(a.args[0])) 161 }, 162 args: make([]int64, 1), 163 }, 164 { 165 name: "AddBalance", 166 fn: func(a testAction, s *StateDB) { 167 s.AddBalance(addr, big.NewInt(a.args[0])) 168 }, 169 args: make([]int64, 1), 170 }, 171 { 172 name: "SetNonce", 173 fn: func(a testAction, s *StateDB) { 174 s.SetNonce(addr, uint64(a.args[0])) 175 }, 176 args: make([]int64, 1), 177 }, 178 { 179 name: "SetState", 180 fn: func(a testAction, s *StateDB) { 181 var key, val common.Hash 182 binary.BigEndian.PutUint16(key[:], uint16(a.args[0])) 183 binary.BigEndian.PutUint16(val[:], uint16(a.args[1])) 184 s.SetState(addr, key, val) 185 }, 186 args: make([]int64, 2), 187 }, 188 { 189 name: "SetCode", 190 fn: func(a testAction, s *StateDB) { 191 code := make([]byte, 16) 192 binary.BigEndian.PutUint64(code, uint64(a.args[0])) 193 binary.BigEndian.PutUint64(code[8:], uint64(a.args[1])) 194 s.SetCode(addr, code) 195 }, 196 args: make([]int64, 2), 197 }, 198 { 199 name: "CreateAccount", 200 fn: func(a testAction, s *StateDB) { 201 s.CreateAccount(addr) 202 }, 203 }, 204 { 205 name: "Suicide", 206 fn: func(a testAction, s *StateDB) { 207 s.Suicide(addr) 208 }, 209 }, 210 { 211 name: "AddRefund", 212 fn: func(a testAction, s *StateDB) { 213 s.AddRefund(big.NewInt(a.args[0])) 214 }, 215 args: make([]int64, 1), 216 noAddr: true, 217 }, 218 { 219 name: "AddLog", 220 fn: func(a testAction, s *StateDB) { 221 data := make([]byte, 2) 222 binary.BigEndian.PutUint16(data, uint16(a.args[0])) 223 s.AddLog(&types.Log{Address: addr, Data: data}) 224 }, 225 args: make([]int64, 1), 226 }, 227 } 228 action := actions[r.Intn(len(actions))] 229 var nameargs []string 230 if !action.noAddr { 231 nameargs = append(nameargs, addr.Hex()) 232 } 233 for _, i := range action.args { 234 action.args[i] = rand.Int63n(100) 235 nameargs = append(nameargs, fmt.Sprint(action.args[i])) 236 } 237 action.name += strings.Join(nameargs, ", ") 238 return action 239 } 240 241 // Generate returns a new snapshot test of the given size. All randomness is 242 // derived from r. 243 func (*snapshotTest) Generate(r *rand.Rand, size int) reflect.Value { 244 // Generate random actions. 245 addrs := make([]common.Address, 50) 246 for i := range addrs { 247 addrs[i][0] = byte(i) 248 } 249 actions := make([]testAction, size) 250 for i := range actions { 251 addr := addrs[r.Intn(len(addrs))] 252 actions[i] = newTestAction(addr, r) 253 } 254 // Generate snapshot indexes. 255 nsnapshots := int(math.Sqrt(float64(size))) 256 if size > 0 && nsnapshots == 0 { 257 nsnapshots = 1 258 } 259 snapshots := make([]int, nsnapshots) 260 snaplen := len(actions) / nsnapshots 261 for i := range snapshots { 262 // Try to place the snapshots some number of actions apart from each other. 263 snapshots[i] = (i * snaplen) + r.Intn(snaplen) 264 } 265 return reflect.ValueOf(&snapshotTest{addrs, actions, snapshots, nil}) 266 } 267 268 func (test *snapshotTest) String() string { 269 out := new(bytes.Buffer) 270 sindex := 0 271 for i, action := range test.actions { 272 if len(test.snapshots) > sindex && i == test.snapshots[sindex] { 273 fmt.Fprintf(out, "---- snapshot %d ----\n", sindex) 274 sindex++ 275 } 276 fmt.Fprintf(out, "%4d: %s\n", i, action.name) 277 } 278 return out.String() 279 } 280 281 func (test *snapshotTest) run() bool { 282 // Run all actions and create snapshots. 283 var ( 284 db, _ = ethdb.NewMemDatabase() 285 state, _ = New(common.Hash{}, db) 286 snapshotRevs = make([]int, len(test.snapshots)) 287 sindex = 0 288 ) 289 for i, action := range test.actions { 290 if len(test.snapshots) > sindex && i == test.snapshots[sindex] { 291 snapshotRevs[sindex] = state.Snapshot() 292 sindex++ 293 } 294 action.fn(action, state) 295 } 296 297 // Revert all snapshots in reverse order. Each revert must yield a state 298 // that is equivalent to fresh state with all actions up the snapshot applied. 299 for sindex--; sindex >= 0; sindex-- { 300 checkstate, _ := New(common.Hash{}, db) 301 for _, action := range test.actions[:test.snapshots[sindex]] { 302 action.fn(action, checkstate) 303 } 304 state.RevertToSnapshot(snapshotRevs[sindex]) 305 if err := test.checkEqual(state, checkstate); err != nil { 306 test.err = fmt.Errorf("state mismatch after revert to snapshot %d\n%v", sindex, err) 307 return false 308 } 309 } 310 return true 311 } 312 313 // checkEqual checks that methods of state and checkstate return the same values. 314 func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error { 315 for _, addr := range test.addrs { 316 var err error 317 checkeq := func(op string, a, b interface{}) bool { 318 if err == nil && !reflect.DeepEqual(a, b) { 319 err = fmt.Errorf("got %s(%s) == %v, want %v", op, addr.Hex(), a, b) 320 return false 321 } 322 return true 323 } 324 // Check basic accessor methods. 325 checkeq("Exist", state.Exist(addr), checkstate.Exist(addr)) 326 checkeq("HasSuicided", state.HasSuicided(addr), checkstate.HasSuicided(addr)) 327 checkeq("GetBalance", state.GetBalance(addr), checkstate.GetBalance(addr)) 328 checkeq("GetNonce", state.GetNonce(addr), checkstate.GetNonce(addr)) 329 checkeq("GetCode", state.GetCode(addr), checkstate.GetCode(addr)) 330 checkeq("GetCodeHash", state.GetCodeHash(addr), checkstate.GetCodeHash(addr)) 331 checkeq("GetCodeSize", state.GetCodeSize(addr), checkstate.GetCodeSize(addr)) 332 // Check storage. 333 if obj := state.getStateObject(addr); obj != nil { 334 state.ForEachStorage(addr, func(key, val common.Hash) bool { 335 return checkeq("GetState("+key.Hex()+")", val, checkstate.GetState(addr, key)) 336 }) 337 checkstate.ForEachStorage(addr, func(key, checkval common.Hash) bool { 338 return checkeq("GetState("+key.Hex()+")", state.GetState(addr, key), checkval) 339 }) 340 } 341 if err != nil { 342 return err 343 } 344 } 345 346 if state.GetRefund().Cmp(checkstate.GetRefund()) != 0 { 347 return fmt.Errorf("got GetRefund() == %d, want GetRefund() == %d", 348 state.GetRefund(), checkstate.GetRefund()) 349 } 350 if !reflect.DeepEqual(state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{})) { 351 return fmt.Errorf("got GetLogs(common.Hash{}) == %v, want GetLogs(common.Hash{}) == %v", 352 state.GetLogs(common.Hash{}), checkstate.GetLogs(common.Hash{})) 353 } 354 return nil 355 } 356 357 func TestTouchDelete(t *testing.T) { 358 db, _ := ethdb.NewMemDatabase() 359 state, _ := New(common.Hash{}, db) 360 state.GetOrNewStateObject(common.Address{}) 361 root, _ := state.Commit(false) 362 state.Reset(root) 363 364 snapshot := state.Snapshot() 365 state.AddBalance(common.Address{}, new(big.Int)) 366 if len(state.stateObjectsDirty) != 1 { 367 t.Fatal("expected one dirty state object") 368 } 369 370 state.RevertToSnapshot(snapshot) 371 if len(state.stateObjectsDirty) != 0 { 372 t.Fatal("expected no dirty state object") 373 } 374 }