github.com/MetalBlockchain/metalgo@v1.11.9/snow/engine/snowman/getter/getter_test.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package getter
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/prometheus/client_golang/prometheus"
    13  	"github.com/stretchr/testify/require"
    14  	"go.uber.org/mock/gomock"
    15  
    16  	"github.com/MetalBlockchain/metalgo/ids"
    17  	"github.com/MetalBlockchain/metalgo/snow/consensus/snowman"
    18  	"github.com/MetalBlockchain/metalgo/snow/consensus/snowman/snowmantest"
    19  	"github.com/MetalBlockchain/metalgo/snow/engine/common"
    20  	"github.com/MetalBlockchain/metalgo/snow/engine/snowman/block"
    21  	"github.com/MetalBlockchain/metalgo/utils/logging"
    22  	"github.com/MetalBlockchain/metalgo/utils/set"
    23  )
    24  
    25  var errUnknownBlock = errors.New("unknown block")
    26  
    27  type StateSyncEnabledMock struct {
    28  	*block.TestVM
    29  	*block.MockStateSyncableVM
    30  }
    31  
    32  func newTest(t *testing.T) (common.AllGetsServer, StateSyncEnabledMock, *common.SenderTest) {
    33  	ctrl := gomock.NewController(t)
    34  
    35  	vm := StateSyncEnabledMock{
    36  		TestVM:              &block.TestVM{},
    37  		MockStateSyncableVM: block.NewMockStateSyncableVM(ctrl),
    38  	}
    39  
    40  	sender := &common.SenderTest{
    41  		T: t,
    42  	}
    43  	sender.Default(true)
    44  
    45  	bs, err := New(
    46  		vm,
    47  		sender,
    48  		logging.NoLog{},
    49  		time.Second,
    50  		2000,
    51  		prometheus.NewRegistry(),
    52  	)
    53  	require.NoError(t, err)
    54  
    55  	return bs, vm, sender
    56  }
    57  
    58  func TestAcceptedFrontier(t *testing.T) {
    59  	require := require.New(t)
    60  	bs, vm, sender := newTest(t)
    61  
    62  	blkID := ids.GenerateTestID()
    63  	vm.LastAcceptedF = func(context.Context) (ids.ID, error) {
    64  		return blkID, nil
    65  	}
    66  
    67  	var accepted ids.ID
    68  	sender.SendAcceptedFrontierF = func(_ context.Context, _ ids.NodeID, _ uint32, containerID ids.ID) {
    69  		accepted = containerID
    70  	}
    71  
    72  	require.NoError(bs.GetAcceptedFrontier(context.Background(), ids.EmptyNodeID, 0))
    73  	require.Equal(blkID, accepted)
    74  }
    75  
    76  func TestFilterAccepted(t *testing.T) {
    77  	require := require.New(t)
    78  	bs, vm, sender := newTest(t)
    79  
    80  	acceptedBlk := snowmantest.BuildChild(snowmantest.Genesis)
    81  	require.NoError(acceptedBlk.Accept(context.Background()))
    82  
    83  	unknownBlkID := ids.GenerateTestID()
    84  
    85  	vm.GetBlockF = func(_ context.Context, blkID ids.ID) (snowman.Block, error) {
    86  		switch blkID {
    87  		case snowmantest.GenesisID:
    88  			return snowmantest.Genesis, nil
    89  		case acceptedBlk.ID():
    90  			return acceptedBlk, nil
    91  		case unknownBlkID:
    92  			return nil, errUnknownBlock
    93  		default:
    94  			require.FailNow(errUnknownBlock.Error())
    95  			return nil, errUnknownBlock
    96  		}
    97  	}
    98  
    99  	var accepted []ids.ID
   100  	sender.SendAcceptedF = func(_ context.Context, _ ids.NodeID, _ uint32, frontier []ids.ID) {
   101  		accepted = frontier
   102  	}
   103  
   104  	blkIDs := set.Of(snowmantest.GenesisID, acceptedBlk.ID(), unknownBlkID)
   105  	require.NoError(bs.GetAccepted(context.Background(), ids.EmptyNodeID, 0, blkIDs))
   106  
   107  	require.Len(accepted, 2)
   108  	require.Contains(accepted, snowmantest.GenesisID)
   109  	require.Contains(accepted, acceptedBlk.ID())
   110  	require.NotContains(accepted, unknownBlkID)
   111  }