github.com/onflow/flow-go@v0.35.7-crescendo-preview.23-atree-inlining/utils/unittest/mocks/epoch_query.go (about) 1 package mocks 2 3 import ( 4 "sync" 5 "testing" 6 7 "github.com/stretchr/testify/require" 8 9 "github.com/onflow/flow-go/model/flow" 10 "github.com/onflow/flow-go/state/protocol" 11 "github.com/onflow/flow-go/state/protocol/invalid" 12 ) 13 14 // EpochQuery implements protocol.EpochQuery for testing purposes. 15 // Safe for concurrent use by multiple goroutines. 16 type EpochQuery struct { 17 t *testing.T 18 mu sync.RWMutex 19 counter uint64 // represents the current epoch 20 byCounter map[uint64]protocol.Epoch // all epochs 21 } 22 23 func NewEpochQuery(t *testing.T, counter uint64, epochs ...protocol.Epoch) *EpochQuery { 24 mock := &EpochQuery{ 25 t: t, 26 counter: counter, 27 byCounter: make(map[uint64]protocol.Epoch), 28 } 29 30 for _, epoch := range epochs { 31 mock.Add(epoch) 32 } 33 34 return mock 35 } 36 37 func (mock *EpochQuery) Current() protocol.Epoch { 38 mock.mu.RLock() 39 defer mock.mu.RUnlock() 40 return mock.byCounter[mock.counter] 41 } 42 43 func (mock *EpochQuery) Next() protocol.Epoch { 44 mock.mu.RLock() 45 defer mock.mu.RUnlock() 46 epoch, exists := mock.byCounter[mock.counter+1] 47 if !exists { 48 return invalid.NewEpoch(protocol.ErrNextEpochNotSetup) 49 } 50 return epoch 51 } 52 53 func (mock *EpochQuery) Previous() protocol.Epoch { 54 mock.mu.RLock() 55 defer mock.mu.RUnlock() 56 epoch, exists := mock.byCounter[mock.counter-1] 57 if !exists { 58 return invalid.NewEpoch(protocol.ErrNoPreviousEpoch) 59 } 60 return epoch 61 } 62 63 // Phase returns a phase consistent with the current epoch state. 64 func (mock *EpochQuery) Phase() flow.EpochPhase { 65 mock.mu.RLock() 66 defer mock.mu.RUnlock() 67 _, exists := mock.byCounter[mock.counter+1] 68 if exists { 69 return flow.EpochPhaseSetup 70 } 71 return flow.EpochPhaseStaking 72 } 73 74 func (mock *EpochQuery) ByCounter(counter uint64) protocol.Epoch { 75 mock.mu.RLock() 76 defer mock.mu.RUnlock() 77 return mock.byCounter[counter] 78 } 79 80 func (mock *EpochQuery) Transition() { 81 mock.mu.Lock() 82 defer mock.mu.Unlock() 83 mock.counter++ 84 } 85 86 func (mock *EpochQuery) Add(epoch protocol.Epoch) { 87 mock.mu.Lock() 88 defer mock.mu.Unlock() 89 counter, err := epoch.Counter() 90 require.NoError(mock.t, err, "cannot add epoch with invalid counter") 91 mock.byCounter[counter] = epoch 92 }