github.com/koko1123/flow-go-1@v0.29.6/admin/commands/common/read_protocol_state_blocks_test.go (about)

     1  package common
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"testing"
     9  
    10  	"github.com/stretchr/testify/assert"
    11  	"github.com/stretchr/testify/mock"
    12  	"github.com/stretchr/testify/require"
    13  	"github.com/stretchr/testify/suite"
    14  
    15  	"github.com/koko1123/flow-go-1/admin"
    16  	"github.com/koko1123/flow-go-1/admin/commands"
    17  	"github.com/koko1123/flow-go-1/model/flow"
    18  	"github.com/koko1123/flow-go-1/state/protocol"
    19  	"github.com/koko1123/flow-go-1/state/protocol/invalid"
    20  	protocolmock "github.com/koko1123/flow-go-1/state/protocol/mock"
    21  	storagemock "github.com/koko1123/flow-go-1/storage/mock"
    22  	"github.com/koko1123/flow-go-1/utils/unittest"
    23  )
    24  
    25  type ReadProtocolStateBlocksSuite struct {
    26  	suite.Suite
    27  
    28  	command commands.AdminCommand
    29  	state   *protocolmock.State
    30  	blocks  *storagemock.Blocks
    31  
    32  	final     *flow.Block
    33  	sealed    *flow.Block
    34  	allBlocks []*flow.Block
    35  }
    36  
    37  func TestReadProtocolStateBlocks(t *testing.T) {
    38  	suite.Run(t, new(ReadProtocolStateBlocksSuite))
    39  }
    40  
    41  func createSnapshot(head *flow.Header) protocol.Snapshot {
    42  	snapshot := &protocolmock.Snapshot{}
    43  	snapshot.On("Head").Return(
    44  		func() *flow.Header {
    45  			return head
    46  		},
    47  		nil,
    48  	)
    49  	return snapshot
    50  }
    51  
    52  func (suite *ReadProtocolStateBlocksSuite) SetupTest() {
    53  	suite.state = new(protocolmock.State)
    54  	suite.blocks = new(storagemock.Blocks)
    55  
    56  	var blocks []*flow.Block
    57  
    58  	genesis := unittest.GenesisFixture()
    59  	blocks = append(blocks, genesis)
    60  	sealed := unittest.BlockWithParentFixture(genesis.Header)
    61  	blocks = append(blocks, sealed)
    62  	final := unittest.BlockWithParentFixture(sealed.Header)
    63  	blocks = append(blocks, final)
    64  	final = unittest.BlockWithParentFixture(final.Header)
    65  	blocks = append(blocks, final)
    66  	final = unittest.BlockWithParentFixture(final.Header)
    67  	blocks = append(blocks, final)
    68  
    69  	suite.allBlocks = blocks
    70  	suite.sealed = sealed
    71  	suite.final = final
    72  
    73  	suite.state.On("Final").Return(createSnapshot(final.Header))
    74  	suite.state.On("Sealed").Return(createSnapshot(sealed.Header))
    75  	suite.state.On("AtBlockID", mock.Anything).Return(
    76  		func(blockID flow.Identifier) protocol.Snapshot {
    77  			for _, block := range blocks {
    78  				if block.ID() == blockID {
    79  					return createSnapshot(block.Header)
    80  				}
    81  			}
    82  			return invalid.NewSnapshot(fmt.Errorf("invalid block ID: %v", blockID))
    83  		},
    84  	)
    85  	suite.state.On("AtHeight", mock.Anything).Return(
    86  		func(height uint64) protocol.Snapshot {
    87  			if int(height) < len(blocks) {
    88  				block := blocks[height]
    89  				return createSnapshot(block.Header)
    90  			}
    91  			return invalid.NewSnapshot(fmt.Errorf("invalid height: %v", height))
    92  		},
    93  	)
    94  
    95  	suite.blocks.On("ByID", mock.Anything).Return(
    96  		func(blockID flow.Identifier) *flow.Block {
    97  			for _, block := range blocks {
    98  				if block.ID() == blockID {
    99  					return block
   100  				}
   101  			}
   102  			return nil
   103  		},
   104  		func(blockID flow.Identifier) error {
   105  			for _, block := range blocks {
   106  				if block.ID() == blockID {
   107  					return nil
   108  				}
   109  			}
   110  			return errors.New("block not found")
   111  		},
   112  	)
   113  
   114  	suite.command = NewReadProtocolStateBlocksCommand(suite.state, suite.blocks)
   115  }
   116  
   117  func (suite *ReadProtocolStateBlocksSuite) TestValidateInvalidFormat() {
   118  	assert.Error(suite.T(), suite.command.Validator(&admin.CommandRequest{
   119  		Data: true,
   120  	}))
   121  	assert.Error(suite.T(), suite.command.Validator(&admin.CommandRequest{
   122  		Data: 420,
   123  	}))
   124  	assert.Error(suite.T(), suite.command.Validator(&admin.CommandRequest{
   125  		Data: "foo",
   126  	}))
   127  	assert.Error(suite.T(), suite.command.Validator(&admin.CommandRequest{
   128  		Data: map[string]interface{}{
   129  			"blah": 123,
   130  		},
   131  	}))
   132  }
   133  
   134  func (suite *ReadProtocolStateBlocksSuite) TestValidateInvalidBlock() {
   135  	assert.Error(suite.T(), suite.command.Validator(&admin.CommandRequest{
   136  		Data: map[string]interface{}{
   137  			"block": true,
   138  		},
   139  	}))
   140  	assert.Error(suite.T(), suite.command.Validator(&admin.CommandRequest{
   141  		Data: map[string]interface{}{
   142  			"block": "",
   143  		},
   144  	}))
   145  	assert.Error(suite.T(), suite.command.Validator(&admin.CommandRequest{
   146  		Data: map[string]interface{}{
   147  			"block": "uhznms",
   148  		},
   149  	}))
   150  	assert.Error(suite.T(), suite.command.Validator(&admin.CommandRequest{
   151  		Data: map[string]interface{}{
   152  			"block": "deadbeef",
   153  		},
   154  	}))
   155  }
   156  
   157  func (suite *ReadProtocolStateBlocksSuite) TestValidateInvalidBlockHeight() {
   158  	assert.Error(suite.T(), suite.command.Validator(&admin.CommandRequest{
   159  		Data: map[string]interface{}{
   160  			"block": float64(-1),
   161  		},
   162  	}))
   163  	assert.Error(suite.T(), suite.command.Validator(&admin.CommandRequest{
   164  		Data: map[string]interface{}{
   165  			"block": float64(1.1),
   166  		},
   167  	}))
   168  }
   169  
   170  func (suite *ReadProtocolStateBlocksSuite) TestValidateInvalidN() {
   171  	assert.Error(suite.T(), suite.command.Validator(&admin.CommandRequest{
   172  		Data: map[string]interface{}{
   173  			"block": 1,
   174  			"n":     "foo",
   175  		},
   176  	}))
   177  	assert.Error(suite.T(), suite.command.Validator(&admin.CommandRequest{
   178  		Data: map[string]interface{}{
   179  			"block": 1,
   180  			"n":     float64(1.1),
   181  		},
   182  	}))
   183  	assert.Error(suite.T(), suite.command.Validator(&admin.CommandRequest{
   184  		Data: map[string]interface{}{
   185  			"block": 1,
   186  			"n":     float64(0),
   187  		},
   188  	}))
   189  }
   190  
   191  func (suite *ReadProtocolStateBlocksSuite) getBlocks(reqData map[string]interface{}) []*flow.Block {
   192  	ctx, cancel := context.WithCancel(context.Background())
   193  	defer cancel()
   194  
   195  	req := &admin.CommandRequest{
   196  		Data: reqData,
   197  	}
   198  	require.NoError(suite.T(), suite.command.Validator(req))
   199  	result, err := suite.command.Handler(ctx, req)
   200  	require.NoError(suite.T(), err)
   201  
   202  	var blocks []*flow.Block
   203  	data, err := json.Marshal(result)
   204  	require.NoError(suite.T(), err)
   205  	require.NoError(suite.T(), json.Unmarshal(data, &blocks))
   206  
   207  	return blocks
   208  }
   209  
   210  func (suite *ReadProtocolStateBlocksSuite) TestHandleFinal() {
   211  	blocks := suite.getBlocks(map[string]interface{}{
   212  		"block": "final",
   213  	})
   214  	require.Len(suite.T(), blocks, 1)
   215  	require.EqualValues(suite.T(), blocks[0], suite.final)
   216  }
   217  
   218  func (suite *ReadProtocolStateBlocksSuite) TestHandleSealed() {
   219  	blocks := suite.getBlocks(map[string]interface{}{
   220  		"block": "sealed",
   221  	})
   222  	require.Len(suite.T(), blocks, 1)
   223  	require.EqualValues(suite.T(), blocks[0], suite.sealed)
   224  }
   225  
   226  func (suite *ReadProtocolStateBlocksSuite) TestHandleHeight() {
   227  	for i, block := range suite.allBlocks {
   228  		responseBlocks := suite.getBlocks(map[string]interface{}{
   229  			"block": float64(i),
   230  		})
   231  		require.Len(suite.T(), responseBlocks, 1)
   232  		require.EqualValues(suite.T(), responseBlocks[0], block)
   233  	}
   234  }
   235  
   236  func (suite *ReadProtocolStateBlocksSuite) TestHandleID() {
   237  	for _, block := range suite.allBlocks {
   238  		responseBlocks := suite.getBlocks(map[string]interface{}{
   239  			"block": block.ID().String(),
   240  		})
   241  		require.Len(suite.T(), responseBlocks, 1)
   242  		require.EqualValues(suite.T(), responseBlocks[0], block)
   243  	}
   244  }
   245  
   246  func (suite *ReadProtocolStateBlocksSuite) TestHandleNExceedsRootBlock() {
   247  	responseBlocks := suite.getBlocks(map[string]interface{}{
   248  		"block": "final",
   249  		"n":     float64(len(suite.allBlocks) + 1),
   250  	})
   251  	require.Len(suite.T(), responseBlocks, len(suite.allBlocks))
   252  	require.ElementsMatch(suite.T(), responseBlocks, suite.allBlocks)
   253  }