github.com/koko1123/flow-go-1@v0.29.6/utils/unittest/mocks/protocol_state.go (about) 1 package mocks 2 3 import ( 4 "fmt" 5 "sync" 6 7 "github.com/stretchr/testify/mock" 8 9 "github.com/koko1123/flow-go-1/model/flow" 10 "github.com/koko1123/flow-go-1/state/protocol" 11 protocolmock "github.com/koko1123/flow-go-1/state/protocol/mock" 12 "github.com/koko1123/flow-go-1/storage" 13 ) 14 15 // ProtocolState is a mocked version of protocol state, which 16 // has very close behavior to the real implementation 17 // but for testing purpose. 18 // If you are testing a module that depends on protocol state's 19 // behavior, but you don't want to mock up the methods and its return 20 // value, then just use this module 21 type ProtocolState struct { 22 sync.Mutex 23 protocol.MutableState 24 blocks map[flow.Identifier]*flow.Block 25 children map[flow.Identifier][]flow.Identifier 26 heights map[uint64]*flow.Block 27 finalized uint64 28 root *flow.Block 29 result *flow.ExecutionResult 30 seal *flow.Seal 31 } 32 33 func NewProtocolState() *ProtocolState { 34 return &ProtocolState{ 35 blocks: make(map[flow.Identifier]*flow.Block), 36 children: make(map[flow.Identifier][]flow.Identifier), 37 heights: make(map[uint64]*flow.Block), 38 } 39 } 40 41 type Params struct { 42 state *ProtocolState 43 } 44 45 func (p *Params) ChainID() (flow.ChainID, error) { 46 return p.state.root.Header.ChainID, nil 47 } 48 49 func (p *Params) SporkID() (flow.Identifier, error) { 50 return flow.ZeroID, fmt.Errorf("not implemented") 51 } 52 53 func (p *Params) ProtocolVersion() (uint, error) { 54 return 0, fmt.Errorf("not implemented") 55 } 56 57 func (p *Params) Root() (*flow.Header, error) { 58 return p.state.root.Header, nil 59 } 60 61 func (p *Params) Seal() (*flow.Seal, error) { 62 return nil, fmt.Errorf("not implemented") 63 } 64 65 func (ps *ProtocolState) Params() protocol.Params { 66 return &Params{ 67 state: ps, 68 } 69 } 70 71 func (ps *ProtocolState) AtBlockID(blockID flow.Identifier) protocol.Snapshot { 72 ps.Lock() 73 defer ps.Unlock() 74 75 snapshot := new(protocolmock.Snapshot) 76 block, ok := ps.blocks[blockID] 77 if ok { 78 snapshot.On("Head").Return(block.Header, nil) 79 } else { 80 snapshot.On("Head").Return(nil, storage.ErrNotFound) 81 } 82 return snapshot 83 } 84 85 func (ps *ProtocolState) AtHeight(height uint64) protocol.Snapshot { 86 ps.Lock() 87 defer ps.Unlock() 88 89 snapshot := new(protocolmock.Snapshot) 90 block, ok := ps.heights[height] 91 if ok { 92 snapshot.On("Head").Return(block.Header, nil) 93 } else { 94 snapshot.On("Head").Return(nil, storage.ErrNotFound) 95 } 96 return snapshot 97 } 98 99 func (ps *ProtocolState) Final() protocol.Snapshot { 100 ps.Lock() 101 defer ps.Unlock() 102 103 final, ok := ps.heights[ps.finalized] 104 if !ok { 105 return nil 106 } 107 108 snapshot := new(protocolmock.Snapshot) 109 snapshot.On("Head").Return(final.Header, nil) 110 finalID := final.ID() 111 mocked := snapshot.On("Descendants") 112 mocked.RunFn = func(args mock.Arguments) { 113 // not concurrent safe 114 pendings := pending(ps, finalID) 115 mocked.ReturnArguments = mock.Arguments{pendings, nil} 116 } 117 118 mocked = snapshot.On("ValidDescendants") 119 mocked.RunFn = func(args mock.Arguments) { 120 // not concurrent safe 121 pendings := pending(ps, finalID) 122 mocked.ReturnArguments = mock.Arguments{pendings, nil} 123 } 124 return snapshot 125 } 126 127 func pending(ps *ProtocolState, blockID flow.Identifier) []flow.Identifier { 128 var pendingIDs []flow.Identifier 129 pendingIDs, ok := ps.children[blockID] 130 131 if !ok { 132 return pendingIDs 133 } 134 135 for _, pendingID := range pendingIDs { 136 additionalIDs := pending(ps, pendingID) 137 pendingIDs = append(pendingIDs, additionalIDs...) 138 } 139 140 return pendingIDs 141 } 142 143 func (m *ProtocolState) Bootstrap(root *flow.Block, result *flow.ExecutionResult, seal *flow.Seal) error { 144 m.Lock() 145 defer m.Unlock() 146 147 if _, ok := m.blocks[root.ID()]; ok { 148 return storage.ErrAlreadyExists 149 } 150 151 m.blocks[root.ID()] = root 152 m.root = root 153 m.result = result 154 m.seal = seal 155 m.heights[root.Header.Height] = root 156 m.finalized = root.Header.Height 157 return nil 158 } 159 160 func (m *ProtocolState) Extend(block *flow.Block) error { 161 m.Lock() 162 defer m.Unlock() 163 164 id := block.ID() 165 if _, ok := m.blocks[id]; ok { 166 return storage.ErrAlreadyExists 167 } 168 169 if _, ok := m.blocks[block.Header.ParentID]; !ok { 170 return fmt.Errorf("could not retrieve parent") 171 } 172 173 m.blocks[id] = block 174 175 // index children 176 children, ok := m.children[block.Header.ParentID] 177 if !ok { 178 children = make([]flow.Identifier, 0) 179 } 180 181 children = append(children, id) 182 m.children[block.Header.ParentID] = children 183 184 return nil 185 } 186 187 func (m *ProtocolState) Finalize(blockID flow.Identifier) error { 188 m.Lock() 189 defer m.Unlock() 190 191 block, ok := m.blocks[blockID] 192 if !ok { 193 return fmt.Errorf("could not retrieve final header") 194 } 195 196 if block.Header.Height <= m.finalized { 197 return fmt.Errorf("could not finalize old blocks") 198 } 199 200 // update heights 201 cur := block 202 for height := cur.Header.Height; height > m.finalized; height-- { 203 parent, ok := m.blocks[cur.Header.ParentID] 204 if !ok { 205 return fmt.Errorf("parent does not exist for block at height: %v, parentID: %v", cur.Header.Height, cur.Header.ParentID) 206 } 207 m.heights[height] = cur 208 cur = parent 209 } 210 211 m.finalized = block.Header.Height 212 213 return nil 214 }