github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/utils/unittest/mocks/protocol_state.go (about)

     1  package mocks
     2  
     3  import (
     4  	"fmt"
     5  	"sync"
     6  
     7  	"github.com/stretchr/testify/mock"
     8  
     9  	"github.com/onflow/flow-go/model/flow"
    10  	"github.com/onflow/flow-go/state/protocol"
    11  	protocolmock "github.com/onflow/flow-go/state/protocol/mock"
    12  	"github.com/onflow/flow-go/storage"
    13  )
    14  
    15  // ProtocolState is a mocked version of protocol state, which
    16  // has very close behavior to the real implementation
    17  // but for testing purpose.
    18  // If you are testing a module that depends on protocol state's
    19  // behavior, but you don't want to mock up the methods and its return
    20  // value, then just use this module
    21  type ProtocolState struct {
    22  	sync.Mutex
    23  	protocol.ParticipantState
    24  	blocks    map[flow.Identifier]*flow.Block
    25  	children  map[flow.Identifier][]flow.Identifier
    26  	heights   map[uint64]*flow.Block
    27  	finalized uint64
    28  	root      *flow.Block
    29  	result    *flow.ExecutionResult
    30  	seal      *flow.Seal
    31  }
    32  
    33  func NewProtocolState() *ProtocolState {
    34  	return &ProtocolState{
    35  		blocks:   make(map[flow.Identifier]*flow.Block),
    36  		children: make(map[flow.Identifier][]flow.Identifier),
    37  		heights:  make(map[uint64]*flow.Block),
    38  	}
    39  }
    40  
    41  type Params struct {
    42  	state *ProtocolState
    43  }
    44  
    45  func (p *Params) ChainID() flow.ChainID {
    46  	return p.state.root.Header.ChainID
    47  }
    48  
    49  func (p *Params) SporkID() flow.Identifier {
    50  	return flow.ZeroID
    51  }
    52  
    53  func (p *Params) SporkRootBlockHeight() uint64 {
    54  	return 0
    55  }
    56  
    57  func (p *Params) ProtocolVersion() uint {
    58  	return 0
    59  }
    60  
    61  func (p *Params) EpochCommitSafetyThreshold() uint64 {
    62  	return 0
    63  }
    64  
    65  func (p *Params) EpochFallbackTriggered() (bool, error) {
    66  	return false, fmt.Errorf("not implemented")
    67  }
    68  
    69  func (p *Params) FinalizedRoot() *flow.Header {
    70  	return p.state.root.Header
    71  }
    72  
    73  func (p *Params) SealedRoot() *flow.Header {
    74  	return p.FinalizedRoot()
    75  }
    76  
    77  func (p *Params) Seal() *flow.Seal {
    78  	return nil
    79  }
    80  
    81  func (ps *ProtocolState) Params() protocol.Params {
    82  	return &Params{
    83  		state: ps,
    84  	}
    85  }
    86  
    87  func (ps *ProtocolState) AtBlockID(blockID flow.Identifier) protocol.Snapshot {
    88  	ps.Lock()
    89  	defer ps.Unlock()
    90  
    91  	snapshot := new(protocolmock.Snapshot)
    92  	block, ok := ps.blocks[blockID]
    93  	if ok {
    94  		snapshot.On("Head").Return(block.Header, nil)
    95  	} else {
    96  		snapshot.On("Head").Return(nil, storage.ErrNotFound)
    97  	}
    98  	return snapshot
    99  }
   100  
   101  func (ps *ProtocolState) AtHeight(height uint64) protocol.Snapshot {
   102  	ps.Lock()
   103  	defer ps.Unlock()
   104  
   105  	snapshot := new(protocolmock.Snapshot)
   106  	block, ok := ps.heights[height]
   107  	if ok {
   108  		snapshot.On("Head").Return(block.Header, nil)
   109  		mocked := snapshot.On("Descendants")
   110  		mocked.RunFn = func(args mock.Arguments) {
   111  			pendings := pending(ps, block.Header.ID())
   112  			mocked.ReturnArguments = mock.Arguments{pendings, nil}
   113  		}
   114  
   115  	} else {
   116  		snapshot.On("Head").Return(nil, storage.ErrNotFound)
   117  	}
   118  	return snapshot
   119  }
   120  
   121  func (ps *ProtocolState) Final() protocol.Snapshot {
   122  	ps.Lock()
   123  	defer ps.Unlock()
   124  
   125  	final, ok := ps.heights[ps.finalized]
   126  	if !ok {
   127  		return nil
   128  	}
   129  
   130  	snapshot := new(protocolmock.Snapshot)
   131  	snapshot.On("Head").Return(final.Header, nil)
   132  	finalID := final.ID()
   133  	mocked := snapshot.On("Descendants")
   134  	mocked.RunFn = func(args mock.Arguments) {
   135  		// not concurrent safe
   136  		pendings := pending(ps, finalID)
   137  		mocked.ReturnArguments = mock.Arguments{pendings, nil}
   138  	}
   139  
   140  	return snapshot
   141  }
   142  
   143  func pending(ps *ProtocolState, blockID flow.Identifier) []flow.Identifier {
   144  	var pendingIDs []flow.Identifier
   145  	pendingIDs, ok := ps.children[blockID]
   146  
   147  	if !ok {
   148  		return pendingIDs
   149  	}
   150  
   151  	for _, pendingID := range pendingIDs {
   152  		additionalIDs := pending(ps, pendingID)
   153  		pendingIDs = append(pendingIDs, additionalIDs...)
   154  	}
   155  
   156  	return pendingIDs
   157  }
   158  
   159  func (m *ProtocolState) Bootstrap(root *flow.Block, result *flow.ExecutionResult, seal *flow.Seal) error {
   160  	m.Lock()
   161  	defer m.Unlock()
   162  
   163  	if _, ok := m.blocks[root.ID()]; ok {
   164  		return storage.ErrAlreadyExists
   165  	}
   166  
   167  	m.blocks[root.ID()] = root
   168  	m.root = root
   169  	m.result = result
   170  	m.seal = seal
   171  	m.heights[root.Header.Height] = root
   172  	m.finalized = root.Header.Height
   173  	return nil
   174  }
   175  
   176  func (m *ProtocolState) Extend(block *flow.Block) error {
   177  	m.Lock()
   178  	defer m.Unlock()
   179  
   180  	id := block.ID()
   181  	if _, ok := m.blocks[id]; ok {
   182  		return storage.ErrAlreadyExists
   183  	}
   184  
   185  	if _, ok := m.blocks[block.Header.ParentID]; !ok {
   186  		return fmt.Errorf("could not retrieve parent")
   187  	}
   188  
   189  	m.blocks[id] = block
   190  
   191  	// index children
   192  	children, ok := m.children[block.Header.ParentID]
   193  	if !ok {
   194  		children = make([]flow.Identifier, 0)
   195  	}
   196  
   197  	children = append(children, id)
   198  	m.children[block.Header.ParentID] = children
   199  
   200  	return nil
   201  }
   202  
   203  func (m *ProtocolState) Finalize(blockID flow.Identifier) error {
   204  	m.Lock()
   205  	defer m.Unlock()
   206  
   207  	block, ok := m.blocks[blockID]
   208  	if !ok {
   209  		return fmt.Errorf("could not retrieve final header")
   210  	}
   211  
   212  	if block.Header.Height <= m.finalized {
   213  		return fmt.Errorf("could not finalize old blocks")
   214  	}
   215  
   216  	// update heights
   217  	cur := block
   218  	for height := cur.Header.Height; height > m.finalized; height-- {
   219  		parent, ok := m.blocks[cur.Header.ParentID]
   220  		if !ok {
   221  			return fmt.Errorf("parent does not exist for block at height: %v, parentID: %v", cur.Header.Height, cur.Header.ParentID)
   222  		}
   223  		m.heights[height] = cur
   224  		cur = parent
   225  	}
   226  
   227  	m.finalized = block.Header.Height
   228  
   229  	return nil
   230  }