github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/utils/unittest/mocks/epoch_query.go (about)

     1  package mocks
     2  
     3  import (
     4  	"sync"
     5  	"testing"
     6  
     7  	"github.com/stretchr/testify/require"
     8  
     9  	"github.com/onflow/flow-go/model/flow"
    10  	"github.com/onflow/flow-go/state/protocol"
    11  	"github.com/onflow/flow-go/state/protocol/invalid"
    12  )
    13  
    14  // EpochQuery implements protocol.EpochQuery for testing purposes.
    15  // Safe for concurrent use by multiple goroutines.
    16  type EpochQuery struct {
    17  	t         *testing.T
    18  	mu        sync.RWMutex
    19  	counter   uint64                    // represents the current epoch
    20  	byCounter map[uint64]protocol.Epoch // all epochs
    21  }
    22  
    23  func NewEpochQuery(t *testing.T, counter uint64, epochs ...protocol.Epoch) *EpochQuery {
    24  	mock := &EpochQuery{
    25  		t:         t,
    26  		counter:   counter,
    27  		byCounter: make(map[uint64]protocol.Epoch),
    28  	}
    29  
    30  	for _, epoch := range epochs {
    31  		mock.Add(epoch)
    32  	}
    33  
    34  	return mock
    35  }
    36  
    37  func (mock *EpochQuery) Current() protocol.Epoch {
    38  	mock.mu.RLock()
    39  	defer mock.mu.RUnlock()
    40  	return mock.byCounter[mock.counter]
    41  }
    42  
    43  func (mock *EpochQuery) Next() protocol.Epoch {
    44  	mock.mu.RLock()
    45  	defer mock.mu.RUnlock()
    46  	epoch, exists := mock.byCounter[mock.counter+1]
    47  	if !exists {
    48  		return invalid.NewEpoch(protocol.ErrNextEpochNotSetup)
    49  	}
    50  	return epoch
    51  }
    52  
    53  func (mock *EpochQuery) Previous() protocol.Epoch {
    54  	mock.mu.RLock()
    55  	defer mock.mu.RUnlock()
    56  	epoch, exists := mock.byCounter[mock.counter-1]
    57  	if !exists {
    58  		return invalid.NewEpoch(protocol.ErrNoPreviousEpoch)
    59  	}
    60  	return epoch
    61  }
    62  
    63  // Phase returns a phase consistent with the current epoch state.
    64  func (mock *EpochQuery) Phase() flow.EpochPhase {
    65  	mock.mu.RLock()
    66  	defer mock.mu.RUnlock()
    67  	_, exists := mock.byCounter[mock.counter+1]
    68  	if exists {
    69  		return flow.EpochPhaseSetup
    70  	}
    71  	return flow.EpochPhaseStaking
    72  }
    73  
    74  func (mock *EpochQuery) ByCounter(counter uint64) protocol.Epoch {
    75  	mock.mu.RLock()
    76  	defer mock.mu.RUnlock()
    77  	return mock.byCounter[counter]
    78  }
    79  
    80  func (mock *EpochQuery) Transition() {
    81  	mock.mu.Lock()
    82  	defer mock.mu.Unlock()
    83  	mock.counter++
    84  }
    85  
    86  func (mock *EpochQuery) Add(epoch protocol.Epoch) {
    87  	mock.mu.Lock()
    88  	defer mock.mu.Unlock()
    89  	counter, err := epoch.Counter()
    90  	require.NoError(mock.t, err, "cannot add epoch with invalid counter")
    91  	mock.byCounter[counter] = epoch
    92  }