github.com/onflow/flow-go@v0.33.17/utils/unittest/mocks/protocol_state.go (about)

     1  package mocks
     2  
     3  import (
     4  	"fmt"
     5  	"sync"
     6  
     7  	"github.com/stretchr/testify/mock"
     8  
     9  	"github.com/onflow/flow-go/model/flow"
    10  	"github.com/onflow/flow-go/state/protocol"
    11  	protocolmock "github.com/onflow/flow-go/state/protocol/mock"
    12  	"github.com/onflow/flow-go/storage"
    13  )
    14  
    15  // ProtocolState is a mocked version of protocol state, which
    16  // has very close behavior to the real implementation
    17  // but for testing purpose.
    18  // If you are testing a module that depends on protocol state's
    19  // behavior, but you don't want to mock up the methods and its return
    20  // value, then just use this module
    21  type ProtocolState struct {
    22  	sync.Mutex
    23  	protocol.ParticipantState
    24  	blocks    map[flow.Identifier]*flow.Block
    25  	children  map[flow.Identifier][]flow.Identifier
    26  	heights   map[uint64]*flow.Block
    27  	finalized uint64
    28  	root      *flow.Block
    29  	result    *flow.ExecutionResult
    30  	seal      *flow.Seal
    31  }
    32  
    33  func NewProtocolState() *ProtocolState {
    34  	return &ProtocolState{
    35  		blocks:   make(map[flow.Identifier]*flow.Block),
    36  		children: make(map[flow.Identifier][]flow.Identifier),
    37  		heights:  make(map[uint64]*flow.Block),
    38  	}
    39  }
    40  
    41  type Params struct {
    42  	state *ProtocolState
    43  }
    44  
    45  func (p *Params) ChainID() (flow.ChainID, error) {
    46  	return p.state.root.Header.ChainID, nil
    47  }
    48  
    49  func (p *Params) SporkID() (flow.Identifier, error) {
    50  	return flow.ZeroID, fmt.Errorf("not implemented")
    51  }
    52  
    53  func (p *Params) SporkRootBlockHeight() (uint64, error) {
    54  	return 0, fmt.Errorf("not implemented")
    55  }
    56  
    57  func (p *Params) ProtocolVersion() (uint, error) {
    58  	return 0, fmt.Errorf("not implemented")
    59  }
    60  
    61  func (p *Params) EpochCommitSafetyThreshold() (uint64, error) {
    62  	return 0, fmt.Errorf("not implemented")
    63  }
    64  
    65  func (p *Params) EpochFallbackTriggered() (bool, error) {
    66  	return false, fmt.Errorf("not implemented")
    67  }
    68  
    69  func (p *Params) FinalizedRoot() (*flow.Header, error) {
    70  	return p.state.root.Header, nil
    71  }
    72  
    73  func (p *Params) SealedRoot() (*flow.Header, error) {
    74  	return p.FinalizedRoot()
    75  }
    76  
    77  func (p *Params) Seal() (*flow.Seal, error) {
    78  	return nil, fmt.Errorf("not implemented")
    79  }
    80  
    81  func (ps *ProtocolState) Params() protocol.Params {
    82  	return &Params{
    83  		state: ps,
    84  	}
    85  }
    86  
    87  func (ps *ProtocolState) AtBlockID(blockID flow.Identifier) protocol.Snapshot {
    88  	ps.Lock()
    89  	defer ps.Unlock()
    90  
    91  	snapshot := new(protocolmock.Snapshot)
    92  	block, ok := ps.blocks[blockID]
    93  	if ok {
    94  		snapshot.On("Head").Return(block.Header, nil)
    95  	} else {
    96  		snapshot.On("Head").Return(nil, storage.ErrNotFound)
    97  	}
    98  	return snapshot
    99  }
   100  
   101  func (ps *ProtocolState) AtHeight(height uint64) protocol.Snapshot {
   102  	ps.Lock()
   103  	defer ps.Unlock()
   104  
   105  	snapshot := new(protocolmock.Snapshot)
   106  	block, ok := ps.heights[height]
   107  	if ok {
   108  		snapshot.On("Head").Return(block.Header, nil)
   109  	} else {
   110  		snapshot.On("Head").Return(nil, storage.ErrNotFound)
   111  	}
   112  	return snapshot
   113  }
   114  
   115  func (ps *ProtocolState) Final() protocol.Snapshot {
   116  	ps.Lock()
   117  	defer ps.Unlock()
   118  
   119  	final, ok := ps.heights[ps.finalized]
   120  	if !ok {
   121  		return nil
   122  	}
   123  
   124  	snapshot := new(protocolmock.Snapshot)
   125  	snapshot.On("Head").Return(final.Header, nil)
   126  	finalID := final.ID()
   127  	mocked := snapshot.On("Descendants")
   128  	mocked.RunFn = func(args mock.Arguments) {
   129  		// not concurrent safe
   130  		pendings := pending(ps, finalID)
   131  		mocked.ReturnArguments = mock.Arguments{pendings, nil}
   132  	}
   133  
   134  	return snapshot
   135  }
   136  
   137  func pending(ps *ProtocolState, blockID flow.Identifier) []flow.Identifier {
   138  	var pendingIDs []flow.Identifier
   139  	pendingIDs, ok := ps.children[blockID]
   140  
   141  	if !ok {
   142  		return pendingIDs
   143  	}
   144  
   145  	for _, pendingID := range pendingIDs {
   146  		additionalIDs := pending(ps, pendingID)
   147  		pendingIDs = append(pendingIDs, additionalIDs...)
   148  	}
   149  
   150  	return pendingIDs
   151  }
   152  
   153  func (m *ProtocolState) Bootstrap(root *flow.Block, result *flow.ExecutionResult, seal *flow.Seal) error {
   154  	m.Lock()
   155  	defer m.Unlock()
   156  
   157  	if _, ok := m.blocks[root.ID()]; ok {
   158  		return storage.ErrAlreadyExists
   159  	}
   160  
   161  	m.blocks[root.ID()] = root
   162  	m.root = root
   163  	m.result = result
   164  	m.seal = seal
   165  	m.heights[root.Header.Height] = root
   166  	m.finalized = root.Header.Height
   167  	return nil
   168  }
   169  
   170  func (m *ProtocolState) Extend(block *flow.Block) error {
   171  	m.Lock()
   172  	defer m.Unlock()
   173  
   174  	id := block.ID()
   175  	if _, ok := m.blocks[id]; ok {
   176  		return storage.ErrAlreadyExists
   177  	}
   178  
   179  	if _, ok := m.blocks[block.Header.ParentID]; !ok {
   180  		return fmt.Errorf("could not retrieve parent")
   181  	}
   182  
   183  	m.blocks[id] = block
   184  
   185  	// index children
   186  	children, ok := m.children[block.Header.ParentID]
   187  	if !ok {
   188  		children = make([]flow.Identifier, 0)
   189  	}
   190  
   191  	children = append(children, id)
   192  	m.children[block.Header.ParentID] = children
   193  
   194  	return nil
   195  }
   196  
   197  func (m *ProtocolState) Finalize(blockID flow.Identifier) error {
   198  	m.Lock()
   199  	defer m.Unlock()
   200  
   201  	block, ok := m.blocks[blockID]
   202  	if !ok {
   203  		return fmt.Errorf("could not retrieve final header")
   204  	}
   205  
   206  	if block.Header.Height <= m.finalized {
   207  		return fmt.Errorf("could not finalize old blocks")
   208  	}
   209  
   210  	// update heights
   211  	cur := block
   212  	for height := cur.Header.Height; height > m.finalized; height-- {
   213  		parent, ok := m.blocks[cur.Header.ParentID]
   214  		if !ok {
   215  			return fmt.Errorf("parent does not exist for block at height: %v, parentID: %v", cur.Header.Height, cur.Header.ParentID)
   216  		}
   217  		m.heights[height] = cur
   218  		cur = parent
   219  	}
   220  
   221  	m.finalized = block.Header.Height
   222  
   223  	return nil
   224  }