github.com/koko1123/flow-go-1@v0.29.6/state/fork/traversal_test.go (about)

     1  package fork
     2  
     3  import (
     4  	"errors"
     5  	"testing"
     6  
     7  	"github.com/stretchr/testify/mock"
     8  	"github.com/stretchr/testify/suite"
     9  
    10  	"github.com/koko1123/flow-go-1/model/flow"
    11  	"github.com/koko1123/flow-go-1/storage"
    12  	mockstorage "github.com/koko1123/flow-go-1/storage/mock"
    13  	"github.com/koko1123/flow-go-1/utils/unittest"
    14  )
    15  
    16  func TestTraverse(t *testing.T) {
    17  	suite.Run(t, new(TraverseSuite))
    18  }
    19  
    20  type TraverseSuite struct {
    21  	suite.Suite
    22  
    23  	byID     map[flow.Identifier]*flow.Header
    24  	byHeight map[uint64]*flow.Header
    25  	headers  *mockstorage.Headers
    26  	genesis  *flow.Header
    27  }
    28  
    29  func (s *TraverseSuite) SetupTest() {
    30  	// create a storage.Headers mock with a backing map
    31  	s.byID = make(map[flow.Identifier]*flow.Header)
    32  	s.byHeight = make(map[uint64]*flow.Header)
    33  	s.headers = new(mockstorage.Headers)
    34  	s.headers.On("ByBlockID", mock.Anything).Return(
    35  		func(id flow.Identifier) *flow.Header {
    36  			return s.byID[id]
    37  		},
    38  		func(id flow.Identifier) error {
    39  			_, ok := s.byID[id]
    40  			if !ok {
    41  				return storage.ErrNotFound
    42  			}
    43  			return nil
    44  		})
    45  
    46  	// populate the mocked header storage with genesis and 10 child blocks
    47  	genesis := unittest.BlockHeaderFixture()
    48  	genesis.Height = 0
    49  	s.byID[genesis.ID()] = genesis
    50  	s.byHeight[genesis.Height] = genesis
    51  	s.genesis = genesis
    52  
    53  	parent := genesis
    54  	for i := 0; i < 10; i++ {
    55  		child := unittest.BlockHeaderWithParentFixture(parent)
    56  		s.byID[child.ID()] = child
    57  		s.byHeight[child.Height] = child
    58  		parent = child
    59  	}
    60  }
    61  
    62  // TestTraverse_MissingForkHead tests the behaviour of block traversing for the
    63  // case where the fork head is an unknown block. We expect:
    64  // * traversal errors
    65  // * traversal does _not_ invoke the visitor callback
    66  func (s *TraverseSuite) TestTraverse_MissingForkHead() {
    67  	unknownForkHead := unittest.IdentifierFixture()
    68  
    69  	visitor := func(_ *flow.Header) error {
    70  		s.Require().Fail("visitor should not be called")
    71  		return nil
    72  	}
    73  
    74  	s.Run("TraverseBackward from non-existent start block", func() {
    75  		err := TraverseBackward(s.headers, unknownForkHead, visitor, IncludingBlock(s.genesis.ID()))
    76  		s.Require().Error(err)
    77  	})
    78  
    79  	// should return error and not call callback when start block doesn't exist
    80  	s.Run("non-existent start block", func() {
    81  		err := TraverseForward(s.headers, unknownForkHead, visitor, IncludingBlock(s.genesis.ID()))
    82  		s.Require().Error(err)
    83  	})
    84  }
    85  
    86  // TestTraverse_VisitorError tests the behaviour of block traversing for the
    87  // case where the visitor callback errors. We expect
    88  // * the visitor error is propagated by the block traversal
    89  func (s *TraverseSuite) TestTraverse_VisitorError() {
    90  	forkHead := s.byHeight[8].ID()
    91  
    92  	visitorError := errors.New("some visitor error")
    93  	visitor := func(_ *flow.Header) error { return visitorError }
    94  
    95  	s.Run("TraverseBackward with visitor error", func() {
    96  		err := TraverseBackward(s.headers, forkHead, visitor, IncludingHeight(1))
    97  		s.Require().ErrorIs(err, visitorError)
    98  	})
    99  
   100  	s.Run("TraverseForward with visitor error", func() {
   101  		err := TraverseForward(s.headers, forkHead, visitor, IncludingHeight(1))
   102  		s.Require().ErrorIs(err, visitorError)
   103  	})
   104  }
   105  
   106  // TestTraverse_UnknownTerminalBlock tests the behaviour of block traversing
   107  // for the case where the terminal block is unknown
   108  func (s *TraverseSuite) TestTraverse_UnknownTerminalBlock() {
   109  	forkHead := s.byHeight[8].ID()
   110  	unknownTerminal := unittest.IdentifierFixture()
   111  	visitor := func(_ *flow.Header) error {
   112  		s.Require().Fail("visitor should not be called")
   113  		return nil
   114  	}
   115  
   116  	s.Run("backwards traversal with non-existent terminal block (inclusive)", func() {
   117  		err := TraverseBackward(s.headers, forkHead, visitor, IncludingBlock(unknownTerminal))
   118  		s.Require().Error(err)
   119  	})
   120  
   121  	s.Run("backwards traversal with non-existent terminal block (exclusive)", func() {
   122  		err := TraverseBackward(s.headers, forkHead, visitor, ExcludingBlock(unknownTerminal))
   123  		s.Require().Error(err)
   124  	})
   125  
   126  	s.Run("forward traversal with non-existent terminal block (inclusive)", func() {
   127  		err := TraverseForward(s.headers, forkHead, visitor, IncludingBlock(unknownTerminal))
   128  		s.Require().Error(err)
   129  	})
   130  
   131  	s.Run("forward traversal with non-existent terminal block (exclusive)", func() {
   132  		err := TraverseForward(s.headers, forkHead, visitor, ExcludingBlock(unknownTerminal))
   133  		s.Require().Error(err)
   134  	})
   135  }
   136  
   137  // TestTraverseBackward_DownToBlock tests different happy-path scenarios for reverse
   138  // block traversing where the terminal block (lowest block) is specified by its ID
   139  func (s *TraverseSuite) TestTraverseBackward_DownToBlock() {
   140  
   141  	// edge case where start == end and the end block is _excluded_
   142  	s.Run("zero blocks to traverse", func() {
   143  		start := s.byHeight[5].ID()
   144  		end := s.byHeight[5].ID()
   145  
   146  		err := TraverseBackward(s.headers, start, func(header *flow.Header) error {
   147  			s.Require().Fail("visitor should not be called")
   148  			return nil
   149  		}, ExcludingBlock(end))
   150  		s.Require().NoError(err)
   151  	})
   152  
   153  	// edge case where start == end and the end block is _included_
   154  	s.Run("single block to traverse", func() {
   155  		start := s.byHeight[5].ID()
   156  		end := s.byHeight[5].ID()
   157  
   158  		called := 0
   159  		err := TraverseBackward(s.headers, start, func(header *flow.Header) error {
   160  			// should call callback for single block in traversal path
   161  			s.Require().Equal(start, header.ID())
   162  			// track calls - should only be called once
   163  			called++
   164  			return nil
   165  		}, IncludingBlock(end))
   166  		s.Require().NoError(err)
   167  		s.Require().Equal(1, called)
   168  	})
   169  
   170  	// should call the callback exactly once for each block in traversal path
   171  	// and not return an error
   172  	s.Run("multi-block traversal including terminal block", func() {
   173  		startHeight := uint64(8)
   174  		endHeight := uint64(4)
   175  
   176  		start := s.byHeight[startHeight].ID()
   177  		end := s.byHeight[endHeight].ID()
   178  
   179  		// assert that we are receiving the correct block at each height
   180  		height := startHeight
   181  		err := TraverseBackward(s.headers, start, func(header *flow.Header) error {
   182  			expectedID := s.byHeight[height].ID()
   183  			s.Require().Equal(expectedID, header.ID())
   184  			height--
   185  			return nil
   186  		}, IncludingBlock(end))
   187  		s.Require().NoError(err)
   188  		s.Require().Equal(endHeight, height+1)
   189  	})
   190  
   191  	// should call the callback exactly once for each block in traversal path
   192  	// and not return an error
   193  	s.Run("multi-block traversal excluding terminal block", func() {
   194  		startHeight := uint64(8)
   195  		endHeight := uint64(4)
   196  
   197  		start := s.byHeight[startHeight].ID()
   198  		end := s.byHeight[endHeight].ID()
   199  
   200  		// assert that we are receiving the correct block at each height
   201  		height := startHeight
   202  		err := TraverseBackward(s.headers, start, func(header *flow.Header) error {
   203  			expectedID := s.byHeight[height].ID()
   204  			s.Require().Equal(expectedID, header.ID())
   205  			height--
   206  			return nil
   207  		}, ExcludingBlock(end))
   208  		s.Require().NoError(err)
   209  		s.Require().Equal(endHeight, height)
   210  	})
   211  
   212  	// edge case where we traverse only the genesis block
   213  	s.Run("traversing only genesis block", func() {
   214  		genesisID := s.genesis.ID()
   215  
   216  		called := 0
   217  		err := TraverseBackward(s.headers, genesisID, func(header *flow.Header) error {
   218  			// should call callback for single block in traversal path
   219  			s.Require().Equal(genesisID, header.ID())
   220  			// track calls - should only be called once
   221  			called++
   222  			return nil
   223  		}, IncludingBlock(genesisID))
   224  		s.Require().NoError(err)
   225  		s.Require().Equal(1, called)
   226  	})
   227  }
   228  
   229  // TestTraverseBackward_DownToHeight tests different happy-path scenarios for reverse
   230  // block traversing where the terminal block (lowest block) is specified by height
   231  func (s *TraverseSuite) TestTraverseBackward_DownToHeight() {
   232  
   233  	// edge case where start == end and the end block is _excluded_
   234  	s.Run("zero blocks to traverse", func() {
   235  		startHeight := uint64(5)
   236  		start := s.byHeight[startHeight].ID()
   237  
   238  		err := TraverseBackward(s.headers, start, func(header *flow.Header) error {
   239  			s.Require().Fail("visitor should not be called")
   240  			return nil
   241  		}, ExcludingHeight(startHeight))
   242  		s.Require().NoError(err)
   243  	})
   244  
   245  	// edge case where start == end and the end block is _included_
   246  	s.Run("single block to traverse", func() {
   247  		startHeight := uint64(5)
   248  		start := s.byHeight[startHeight].ID()
   249  
   250  		called := 0
   251  		err := TraverseBackward(s.headers, start, func(header *flow.Header) error {
   252  			// should call callback for single block in traversal path
   253  			s.Require().Equal(start, header.ID())
   254  			// track calls - should only be called once
   255  			called++
   256  			return nil
   257  		}, IncludingHeight(startHeight))
   258  		s.Require().NoError(err)
   259  		s.Require().Equal(1, called)
   260  	})
   261  
   262  	// should call the callback exactly once for each block in traversal path
   263  	// and not return an error
   264  	s.Run("multi-block traversal including terminal block", func() {
   265  		startHeight := uint64(8)
   266  		endHeight := uint64(4)
   267  		start := s.byHeight[startHeight].ID()
   268  
   269  		// assert that we are receiving the correct block at each height
   270  		height := startHeight
   271  		err := TraverseBackward(s.headers, start, func(header *flow.Header) error {
   272  			expectedID := s.byHeight[height].ID()
   273  			s.Require().Equal(expectedID, header.ID())
   274  			height--
   275  			return nil
   276  		}, IncludingHeight(endHeight))
   277  		s.Require().NoError(err)
   278  		s.Require().Equal(endHeight, height+1)
   279  	})
   280  
   281  	// should call the callback exactly once for each block in traversal path
   282  	// and not return an error
   283  	s.Run("multi-block traversal excluding terminal block", func() {
   284  		startHeight := uint64(8)
   285  		endHeight := uint64(4)
   286  		start := s.byHeight[startHeight].ID()
   287  
   288  		// assert that we are receiving the correct block at each height
   289  		height := startHeight
   290  		err := TraverseBackward(s.headers, start, func(header *flow.Header) error {
   291  			expectedID := s.byHeight[height].ID()
   292  			s.Require().Equal(expectedID, header.ID())
   293  			height--
   294  			return nil
   295  		}, ExcludingHeight(endHeight))
   296  		s.Require().NoError(err)
   297  		s.Require().Equal(endHeight, height)
   298  	})
   299  
   300  	// edge case where we traverse only the genesis block
   301  	s.Run("traversing only genesis block", func() {
   302  		genesisID := s.genesis.ID()
   303  
   304  		called := 0
   305  		err := TraverseBackward(s.headers, genesisID, func(header *flow.Header) error {
   306  			// should call callback for single block in traversal path
   307  			s.Require().Equal(genesisID, header.ID())
   308  			// track calls - should only be called once
   309  			called++
   310  			return nil
   311  		}, IncludingHeight(s.genesis.Height))
   312  		s.Require().NoError(err)
   313  		s.Require().Equal(1, called)
   314  	})
   315  }
   316  
   317  // TestTraverseForward_UpFromBlock tests different happy-path scenarios for parent-first
   318  // block traversing where the terminal block (lowest block) is specified by its ID
   319  func (s *TraverseSuite) TestTraverseForward_UpFromBlock() {
   320  
   321  	// edge case where start == end and the terminal block is _excluded_
   322  	s.Run("zero blocks to traverse", func() {
   323  		upperBlock := s.byHeight[5].ID()
   324  		lowerBlock := s.byHeight[5].ID()
   325  
   326  		err := TraverseForward(s.headers, upperBlock, func(header *flow.Header) error {
   327  			s.Require().Fail("visitor should not be called")
   328  			return nil
   329  		}, ExcludingBlock(lowerBlock))
   330  		s.Require().NoError(err)
   331  	})
   332  
   333  	// should call the callback exactly once and not return an error when start == end
   334  	s.Run("single-block traversal", func() {
   335  		upperBlock := s.byHeight[5].ID()
   336  		lowerBlock := s.byHeight[5].ID()
   337  
   338  		called := 0
   339  		err := TraverseForward(s.headers, upperBlock, func(header *flow.Header) error {
   340  			// should call callback for single block in traversal path
   341  			s.Require().Equal(upperBlock, header.ID())
   342  			// track calls - should only be called once
   343  			called++
   344  			return nil
   345  		}, IncludingBlock(lowerBlock))
   346  		s.Require().NoError(err)
   347  		s.Require().Equal(1, called)
   348  	})
   349  
   350  	// should call the callback exactly once for each block in traversal path
   351  	// and not return an error
   352  	s.Run("multi-block traversal including terminal block", func() {
   353  		upperHeight := uint64(8)
   354  		lowerHeight := uint64(4)
   355  
   356  		upperBlock := s.byHeight[upperHeight].ID()
   357  		lowerBlock := s.byHeight[lowerHeight].ID()
   358  
   359  		// assert that we are receiving the correct block at each height
   360  		height := lowerHeight
   361  		err := TraverseForward(s.headers, upperBlock, func(header *flow.Header) error {
   362  			expectedID := s.byHeight[height].ID()
   363  			s.Require().Equal(height, header.Height)
   364  			s.Require().Equal(expectedID, header.ID())
   365  			height++
   366  			return nil
   367  		}, IncludingBlock(lowerBlock))
   368  		s.Require().NoError(err)
   369  		s.Require().Equal(height, upperHeight+1)
   370  	})
   371  
   372  	// should call the callback exactly once for each block in traversal path
   373  	// and not return an error
   374  	s.Run("multi-block traversal excluding terminal block", func() {
   375  		upperHeight := uint64(8)
   376  		lowerHeight := uint64(4)
   377  
   378  		upperBlock := s.byHeight[upperHeight].ID()
   379  		lowerBlock := s.byHeight[lowerHeight].ID()
   380  
   381  		// assert that we are receiving the correct block at each height
   382  		height := lowerHeight + 1
   383  		err := TraverseForward(s.headers, upperBlock, func(header *flow.Header) error {
   384  			expectedID := s.byHeight[height].ID()
   385  			s.Require().Equal(height, header.Height)
   386  			s.Require().Equal(expectedID, header.ID())
   387  			height++
   388  			return nil
   389  		}, ExcludingBlock(lowerBlock))
   390  		s.Require().NoError(err)
   391  		s.Require().Equal(height, upperHeight+1)
   392  	})
   393  
   394  	// edge case where we traverse only the genesis block
   395  	s.Run("traversing only genesis block", func() {
   396  		genesisID := s.genesis.ID()
   397  
   398  		called := 0
   399  		err := TraverseForward(s.headers, genesisID, func(header *flow.Header) error {
   400  			// should call callback for single block in traversal path
   401  			s.Require().Equal(genesisID, header.ID())
   402  			// track calls - should only be called once
   403  			called++
   404  			return nil
   405  		}, IncludingBlock(genesisID))
   406  		s.Require().NoError(err)
   407  		s.Require().Equal(1, called)
   408  	})
   409  }
   410  
   411  // TestTraverseForward_UpFromHeight tests different happy-path scenarios for parent-first
   412  // block traversing where the terminal block (lowest block) is specified by height
   413  func (s *TraverseSuite) TestTraverseForward_UpFromHeight() {
   414  
   415  	// edge case where start == end and the terminal block is _excluded_
   416  	s.Run("zero blocks to traverse", func() {
   417  		upperHeight := uint64(5)
   418  		upperBlock := s.byHeight[upperHeight].ID()
   419  
   420  		err := TraverseForward(s.headers, upperBlock, func(header *flow.Header) error {
   421  			s.Require().Fail("visitor should not be called")
   422  			return nil
   423  		}, ExcludingHeight(upperHeight))
   424  		s.Require().NoError(err)
   425  	})
   426  
   427  	// should call the callback exactly once and not return an error when start == end
   428  	s.Run("single-block traversal", func() {
   429  		upperHeight := uint64(5)
   430  		upperBlock := s.byHeight[upperHeight].ID()
   431  
   432  		called := 0
   433  		err := TraverseForward(s.headers, upperBlock, func(header *flow.Header) error {
   434  			// should call callback for single block in traversal path
   435  			s.Require().Equal(upperBlock, header.ID())
   436  			// track calls - should only be called once
   437  			called++
   438  			return nil
   439  		}, IncludingHeight(upperHeight))
   440  		s.Require().NoError(err)
   441  		s.Require().Equal(1, called)
   442  	})
   443  
   444  	// should call the callback exactly once for each block in traversal path
   445  	// and not return an error
   446  	s.Run("multi-block traversal including terminal block", func() {
   447  		upperHeight := uint64(8)
   448  		lowerHeight := uint64(4)
   449  		upperBlock := s.byHeight[upperHeight].ID()
   450  
   451  		// assert that we are receiving the correct block at each height
   452  		height := lowerHeight
   453  		err := TraverseForward(s.headers, upperBlock, func(header *flow.Header) error {
   454  			expectedID := s.byHeight[height].ID()
   455  			s.Require().Equal(height, header.Height)
   456  			s.Require().Equal(expectedID, header.ID())
   457  			height++
   458  			return nil
   459  		}, IncludingHeight(lowerHeight))
   460  		s.Require().NoError(err)
   461  		s.Require().Equal(height, upperHeight+1)
   462  	})
   463  
   464  	// should call the callback exactly once for each block in traversal path
   465  	// and not return an error
   466  	s.Run("multi-block traversal excluding terminal block", func() {
   467  		upperHeight := uint64(8)
   468  		lowerHeight := uint64(4)
   469  		upperBlock := s.byHeight[upperHeight].ID()
   470  
   471  		// assert that we are receiving the correct block at each height
   472  		height := lowerHeight + 1
   473  		err := TraverseForward(s.headers, upperBlock, func(header *flow.Header) error {
   474  			expectedID := s.byHeight[height].ID()
   475  			s.Require().Equal(height, header.Height)
   476  			s.Require().Equal(expectedID, header.ID())
   477  			height++
   478  			return nil
   479  		}, ExcludingHeight(lowerHeight))
   480  		s.Require().NoError(err)
   481  		s.Require().Equal(height, upperHeight+1)
   482  	})
   483  
   484  	// edge case where we traverse only the genesis block
   485  	s.Run("traversing only genesis block", func() {
   486  		genesisID := s.genesis.ID()
   487  
   488  		called := 0
   489  		err := TraverseForward(s.headers, genesisID, func(header *flow.Header) error {
   490  			// should call callback for single block in traversal path
   491  			s.Require().Equal(genesisID, header.ID())
   492  			// track calls - should only be called once
   493  			called++
   494  			return nil
   495  		}, IncludingHeight(s.genesis.Height))
   496  		s.Require().NoError(err)
   497  		s.Require().Equal(1, called)
   498  	})
   499  }
   500  
   501  // TestTraverse_OnDifferentForkThanTerminalBlock tests that block traversing
   502  // errors if the end block is on a different Fork. This is only applicable
   503  // when terminal block (lowest block) is specified by its ID.
   504  func (s *TraverseSuite) TestTraverse_OnDifferentForkThanTerminalBlock() {
   505  	forkHead := s.byHeight[8].ID()
   506  	noopVisitor := func(header *flow.Header) error { return nil }
   507  
   508  	// make other fork
   509  	otherForkHead := s.genesis
   510  	otherForkByHeight := make(map[uint64]*flow.Header)
   511  	for i := 0; i < 10; i++ {
   512  		child := unittest.BlockHeaderWithParentFixture(otherForkHead)
   513  		s.byID[child.ID()] = child
   514  		otherForkByHeight[child.Height] = child
   515  		otherForkHead = child
   516  	}
   517  	terminalBlockID := otherForkByHeight[2].ID()
   518  
   519  	s.Run("forwards traversal with terminal block (on different fork) included ", func() {
   520  		// assert that we are receiving the correct block at each height
   521  		err := TraverseForward(s.headers, forkHead, noopVisitor, ExcludingBlock(terminalBlockID))
   522  		s.Require().Error(err)
   523  	})
   524  
   525  	s.Run("forwards traversal with terminal block (on different fork) excluded ", func() {
   526  		// assert that we are receiving the correct block at each height
   527  		err := TraverseForward(s.headers, forkHead, noopVisitor, IncludingBlock(terminalBlockID))
   528  		s.Require().Error(err)
   529  	})
   530  
   531  	s.Run("backwards traversal with terminal block (on different fork) included ", func() {
   532  		// assert that we are receiving the correct block at each height
   533  		err := TraverseBackward(s.headers, forkHead, noopVisitor, ExcludingBlock(terminalBlockID))
   534  		s.Require().Error(err)
   535  	})
   536  
   537  	s.Run("backwards traversal with terminal block (on different fork) excluded ", func() {
   538  		// assert that we are receiving the correct block at each height
   539  		err := TraverseBackward(s.headers, forkHead, noopVisitor, IncludingBlock(terminalBlockID))
   540  		s.Require().Error(err)
   541  	})
   542  
   543  }