github.com/koko1123/flow-go-1@v0.29.6/utils/unittest/mocks/protocol_state.go (about)

     1  package mocks
     2  
     3  import (
     4  	"fmt"
     5  	"sync"
     6  
     7  	"github.com/stretchr/testify/mock"
     8  
     9  	"github.com/koko1123/flow-go-1/model/flow"
    10  	"github.com/koko1123/flow-go-1/state/protocol"
    11  	protocolmock "github.com/koko1123/flow-go-1/state/protocol/mock"
    12  	"github.com/koko1123/flow-go-1/storage"
    13  )
    14  
    15  // ProtocolState is a mocked version of protocol state, which
    16  // has very close behavior to the real implementation
    17  // but for testing purpose.
    18  // If you are testing a module that depends on protocol state's
    19  // behavior, but you don't want to mock up the methods and its return
    20  // value, then just use this module
    21  type ProtocolState struct {
    22  	sync.Mutex
    23  	protocol.MutableState
    24  	blocks    map[flow.Identifier]*flow.Block
    25  	children  map[flow.Identifier][]flow.Identifier
    26  	heights   map[uint64]*flow.Block
    27  	finalized uint64
    28  	root      *flow.Block
    29  	result    *flow.ExecutionResult
    30  	seal      *flow.Seal
    31  }
    32  
    33  func NewProtocolState() *ProtocolState {
    34  	return &ProtocolState{
    35  		blocks:   make(map[flow.Identifier]*flow.Block),
    36  		children: make(map[flow.Identifier][]flow.Identifier),
    37  		heights:  make(map[uint64]*flow.Block),
    38  	}
    39  }
    40  
    41  type Params struct {
    42  	state *ProtocolState
    43  }
    44  
    45  func (p *Params) ChainID() (flow.ChainID, error) {
    46  	return p.state.root.Header.ChainID, nil
    47  }
    48  
    49  func (p *Params) SporkID() (flow.Identifier, error) {
    50  	return flow.ZeroID, fmt.Errorf("not implemented")
    51  }
    52  
    53  func (p *Params) ProtocolVersion() (uint, error) {
    54  	return 0, fmt.Errorf("not implemented")
    55  }
    56  
    57  func (p *Params) Root() (*flow.Header, error) {
    58  	return p.state.root.Header, nil
    59  }
    60  
    61  func (p *Params) Seal() (*flow.Seal, error) {
    62  	return nil, fmt.Errorf("not implemented")
    63  }
    64  
    65  func (ps *ProtocolState) Params() protocol.Params {
    66  	return &Params{
    67  		state: ps,
    68  	}
    69  }
    70  
    71  func (ps *ProtocolState) AtBlockID(blockID flow.Identifier) protocol.Snapshot {
    72  	ps.Lock()
    73  	defer ps.Unlock()
    74  
    75  	snapshot := new(protocolmock.Snapshot)
    76  	block, ok := ps.blocks[blockID]
    77  	if ok {
    78  		snapshot.On("Head").Return(block.Header, nil)
    79  	} else {
    80  		snapshot.On("Head").Return(nil, storage.ErrNotFound)
    81  	}
    82  	return snapshot
    83  }
    84  
    85  func (ps *ProtocolState) AtHeight(height uint64) protocol.Snapshot {
    86  	ps.Lock()
    87  	defer ps.Unlock()
    88  
    89  	snapshot := new(protocolmock.Snapshot)
    90  	block, ok := ps.heights[height]
    91  	if ok {
    92  		snapshot.On("Head").Return(block.Header, nil)
    93  	} else {
    94  		snapshot.On("Head").Return(nil, storage.ErrNotFound)
    95  	}
    96  	return snapshot
    97  }
    98  
    99  func (ps *ProtocolState) Final() protocol.Snapshot {
   100  	ps.Lock()
   101  	defer ps.Unlock()
   102  
   103  	final, ok := ps.heights[ps.finalized]
   104  	if !ok {
   105  		return nil
   106  	}
   107  
   108  	snapshot := new(protocolmock.Snapshot)
   109  	snapshot.On("Head").Return(final.Header, nil)
   110  	finalID := final.ID()
   111  	mocked := snapshot.On("Descendants")
   112  	mocked.RunFn = func(args mock.Arguments) {
   113  		// not concurrent safe
   114  		pendings := pending(ps, finalID)
   115  		mocked.ReturnArguments = mock.Arguments{pendings, nil}
   116  	}
   117  
   118  	mocked = snapshot.On("ValidDescendants")
   119  	mocked.RunFn = func(args mock.Arguments) {
   120  		// not concurrent safe
   121  		pendings := pending(ps, finalID)
   122  		mocked.ReturnArguments = mock.Arguments{pendings, nil}
   123  	}
   124  	return snapshot
   125  }
   126  
   127  func pending(ps *ProtocolState, blockID flow.Identifier) []flow.Identifier {
   128  	var pendingIDs []flow.Identifier
   129  	pendingIDs, ok := ps.children[blockID]
   130  
   131  	if !ok {
   132  		return pendingIDs
   133  	}
   134  
   135  	for _, pendingID := range pendingIDs {
   136  		additionalIDs := pending(ps, pendingID)
   137  		pendingIDs = append(pendingIDs, additionalIDs...)
   138  	}
   139  
   140  	return pendingIDs
   141  }
   142  
   143  func (m *ProtocolState) Bootstrap(root *flow.Block, result *flow.ExecutionResult, seal *flow.Seal) error {
   144  	m.Lock()
   145  	defer m.Unlock()
   146  
   147  	if _, ok := m.blocks[root.ID()]; ok {
   148  		return storage.ErrAlreadyExists
   149  	}
   150  
   151  	m.blocks[root.ID()] = root
   152  	m.root = root
   153  	m.result = result
   154  	m.seal = seal
   155  	m.heights[root.Header.Height] = root
   156  	m.finalized = root.Header.Height
   157  	return nil
   158  }
   159  
   160  func (m *ProtocolState) Extend(block *flow.Block) error {
   161  	m.Lock()
   162  	defer m.Unlock()
   163  
   164  	id := block.ID()
   165  	if _, ok := m.blocks[id]; ok {
   166  		return storage.ErrAlreadyExists
   167  	}
   168  
   169  	if _, ok := m.blocks[block.Header.ParentID]; !ok {
   170  		return fmt.Errorf("could not retrieve parent")
   171  	}
   172  
   173  	m.blocks[id] = block
   174  
   175  	// index children
   176  	children, ok := m.children[block.Header.ParentID]
   177  	if !ok {
   178  		children = make([]flow.Identifier, 0)
   179  	}
   180  
   181  	children = append(children, id)
   182  	m.children[block.Header.ParentID] = children
   183  
   184  	return nil
   185  }
   186  
   187  func (m *ProtocolState) Finalize(blockID flow.Identifier) error {
   188  	m.Lock()
   189  	defer m.Unlock()
   190  
   191  	block, ok := m.blocks[blockID]
   192  	if !ok {
   193  		return fmt.Errorf("could not retrieve final header")
   194  	}
   195  
   196  	if block.Header.Height <= m.finalized {
   197  		return fmt.Errorf("could not finalize old blocks")
   198  	}
   199  
   200  	// update heights
   201  	cur := block
   202  	for height := cur.Header.Height; height > m.finalized; height-- {
   203  		parent, ok := m.blocks[cur.Header.ParentID]
   204  		if !ok {
   205  			return fmt.Errorf("parent does not exist for block at height: %v, parentID: %v", cur.Header.Height, cur.Header.ParentID)
   206  		}
   207  		m.heights[height] = cur
   208  		cur = parent
   209  	}
   210  
   211  	m.finalized = block.Header.Height
   212  
   213  	return nil
   214  }