github.com/MetalBlockchain/metalgo@v1.11.9/snow/networking/router/chain_router_test.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package router
     5  
     6  import (
     7  	"context"
     8  	"sync"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/prometheus/client_golang/prometheus"
    13  	"github.com/stretchr/testify/require"
    14  	"go.uber.org/mock/gomock"
    15  
    16  	"github.com/MetalBlockchain/metalgo/ids"
    17  	"github.com/MetalBlockchain/metalgo/message"
    18  	"github.com/MetalBlockchain/metalgo/network/p2p"
    19  	"github.com/MetalBlockchain/metalgo/snow"
    20  	"github.com/MetalBlockchain/metalgo/snow/engine/common"
    21  	"github.com/MetalBlockchain/metalgo/snow/networking/benchlist"
    22  	"github.com/MetalBlockchain/metalgo/snow/networking/handler"
    23  	"github.com/MetalBlockchain/metalgo/snow/networking/timeout"
    24  	"github.com/MetalBlockchain/metalgo/snow/networking/tracker"
    25  	"github.com/MetalBlockchain/metalgo/snow/snowtest"
    26  	"github.com/MetalBlockchain/metalgo/snow/validators"
    27  	"github.com/MetalBlockchain/metalgo/subnets"
    28  	"github.com/MetalBlockchain/metalgo/utils/constants"
    29  	"github.com/MetalBlockchain/metalgo/utils/logging"
    30  	"github.com/MetalBlockchain/metalgo/utils/math/meter"
    31  	"github.com/MetalBlockchain/metalgo/utils/resource"
    32  	"github.com/MetalBlockchain/metalgo/utils/set"
    33  	"github.com/MetalBlockchain/metalgo/utils/timer"
    34  	"github.com/MetalBlockchain/metalgo/version"
    35  
    36  	p2ppb "github.com/MetalBlockchain/metalgo/proto/pb/p2p"
    37  	commontracker "github.com/MetalBlockchain/metalgo/snow/engine/common/tracker"
    38  )
    39  
    40  const (
    41  	engineType         = p2ppb.EngineType_ENGINE_TYPE_AVALANCHE
    42  	testThreadPoolSize = 2
    43  )
    44  
    45  // TODO refactor tests in this file
    46  
    47  func TestShutdown(t *testing.T) {
    48  	require := require.New(t)
    49  
    50  	snowCtx := snowtest.Context(t, snowtest.CChainID)
    51  	chainCtx := snowtest.ConsensusContext(snowCtx)
    52  	vdrs := validators.NewManager()
    53  	require.NoError(vdrs.AddStaker(chainCtx.SubnetID, ids.GenerateTestNodeID(), nil, ids.Empty, 1))
    54  	benchlist := benchlist.NewNoBenchlist()
    55  	tm, err := timeout.NewManager(
    56  		&timer.AdaptiveTimeoutConfig{
    57  			InitialTimeout:     time.Millisecond,
    58  			MinimumTimeout:     time.Millisecond,
    59  			MaximumTimeout:     10 * time.Second,
    60  			TimeoutCoefficient: 1.25,
    61  			TimeoutHalflife:    5 * time.Minute,
    62  		},
    63  		benchlist,
    64  		prometheus.NewRegistry(),
    65  		prometheus.NewRegistry(),
    66  	)
    67  	require.NoError(err)
    68  
    69  	go tm.Dispatch()
    70  	defer tm.Stop()
    71  
    72  	chainRouter := ChainRouter{}
    73  	require.NoError(chainRouter.Initialize(
    74  		ids.EmptyNodeID,
    75  		logging.NoLog{},
    76  		tm,
    77  		time.Second,
    78  		set.Set[ids.ID]{},
    79  		true,
    80  		set.Set[ids.ID]{},
    81  		nil,
    82  		HealthConfig{},
    83  		prometheus.NewRegistry(),
    84  	))
    85  
    86  	shutdownCalled := make(chan struct{}, 1)
    87  
    88  	resourceTracker, err := tracker.NewResourceTracker(
    89  		prometheus.NewRegistry(),
    90  		resource.NoUsage,
    91  		meter.ContinuousFactory{},
    92  		time.Second,
    93  	)
    94  	require.NoError(err)
    95  
    96  	p2pTracker, err := p2p.NewPeerTracker(
    97  		logging.NoLog{},
    98  		"",
    99  		prometheus.NewRegistry(),
   100  		nil,
   101  		version.CurrentApp,
   102  	)
   103  	require.NoError(err)
   104  
   105  	h, err := handler.New(
   106  		chainCtx,
   107  		vdrs,
   108  		nil,
   109  		time.Second,
   110  		testThreadPoolSize,
   111  		resourceTracker,
   112  		validators.UnhandledSubnetConnector,
   113  		subnets.New(chainCtx.NodeID, subnets.Config{}),
   114  		commontracker.NewPeers(),
   115  		p2pTracker,
   116  		prometheus.NewRegistry(),
   117  	)
   118  	require.NoError(err)
   119  
   120  	bootstrapper := &common.BootstrapperTest{
   121  		EngineTest: common.EngineTest{
   122  			T: t,
   123  		},
   124  	}
   125  	bootstrapper.Default(true)
   126  	bootstrapper.CantGossip = false
   127  	bootstrapper.ContextF = func() *snow.ConsensusContext {
   128  		return chainCtx
   129  	}
   130  	bootstrapper.ShutdownF = func(context.Context) error {
   131  		shutdownCalled <- struct{}{}
   132  		return nil
   133  	}
   134  	bootstrapper.ConnectedF = func(context.Context, ids.NodeID, *version.Application) error {
   135  		return nil
   136  	}
   137  	bootstrapper.HaltF = func(context.Context) {}
   138  
   139  	engine := &common.EngineTest{T: t}
   140  	engine.Default(true)
   141  	engine.CantGossip = false
   142  	engine.ContextF = func() *snow.ConsensusContext {
   143  		return chainCtx
   144  	}
   145  	engine.ShutdownF = func(context.Context) error {
   146  		shutdownCalled <- struct{}{}
   147  		return nil
   148  	}
   149  	engine.ConnectedF = func(context.Context, ids.NodeID, *version.Application) error {
   150  		return nil
   151  	}
   152  	engine.HaltF = func(context.Context) {}
   153  	h.SetEngineManager(&handler.EngineManager{
   154  		Avalanche: &handler.Engine{
   155  			StateSyncer:  nil,
   156  			Bootstrapper: bootstrapper,
   157  			Consensus:    engine,
   158  		},
   159  		Snowman: &handler.Engine{
   160  			StateSyncer:  nil,
   161  			Bootstrapper: bootstrapper,
   162  			Consensus:    engine,
   163  		},
   164  	})
   165  	chainCtx.State.Set(snow.EngineState{
   166  		Type:  engineType,
   167  		State: snow.NormalOp, // assumed bootstrapping is done
   168  	})
   169  
   170  	chainRouter.AddChain(context.Background(), h)
   171  
   172  	bootstrapper.StartF = func(context.Context, uint32) error {
   173  		return nil
   174  	}
   175  	h.Start(context.Background(), false)
   176  
   177  	chainRouter.Shutdown(context.Background())
   178  
   179  	ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond)
   180  	defer cancel()
   181  
   182  	select {
   183  	case <-ctx.Done():
   184  		require.FailNow("Handler shutdown was not called or timed out after 250ms during chainRouter shutdown")
   185  	case <-shutdownCalled:
   186  	}
   187  
   188  	shutdownDuration, err := h.AwaitStopped(ctx)
   189  	require.NoError(err)
   190  	require.GreaterOrEqual(shutdownDuration, time.Duration(0))
   191  	require.Less(shutdownDuration, 250*time.Millisecond)
   192  }
   193  
   194  func TestConnectedAfterShutdownErrorLogRegression(t *testing.T) {
   195  	require := require.New(t)
   196  
   197  	snowCtx := snowtest.Context(t, snowtest.PChainID)
   198  	chainCtx := snowtest.ConsensusContext(snowCtx)
   199  
   200  	chainRouter := ChainRouter{}
   201  	require.NoError(chainRouter.Initialize(
   202  		ids.EmptyNodeID,
   203  		logging.NoWarn{}, // If an error log is emitted, the test will fail
   204  		nil,
   205  		time.Second,
   206  		set.Set[ids.ID]{},
   207  		true,
   208  		set.Set[ids.ID]{},
   209  		nil,
   210  		HealthConfig{},
   211  		prometheus.NewRegistry(),
   212  	))
   213  
   214  	resourceTracker, err := tracker.NewResourceTracker(
   215  		prometheus.NewRegistry(),
   216  		resource.NoUsage,
   217  		meter.ContinuousFactory{},
   218  		time.Second,
   219  	)
   220  	require.NoError(err)
   221  
   222  	p2pTracker, err := p2p.NewPeerTracker(
   223  		logging.NoLog{},
   224  		"",
   225  		prometheus.NewRegistry(),
   226  		nil,
   227  		version.CurrentApp,
   228  	)
   229  	require.NoError(err)
   230  
   231  	h, err := handler.New(
   232  		chainCtx,
   233  		nil,
   234  		nil,
   235  		time.Second,
   236  		testThreadPoolSize,
   237  		resourceTracker,
   238  		validators.UnhandledSubnetConnector,
   239  		subnets.New(chainCtx.NodeID, subnets.Config{}),
   240  		commontracker.NewPeers(),
   241  		p2pTracker,
   242  		prometheus.NewRegistry(),
   243  	)
   244  	require.NoError(err)
   245  
   246  	engine := common.EngineTest{
   247  		T: t,
   248  		StartF: func(context.Context, uint32) error {
   249  			return nil
   250  		},
   251  		ContextF: func() *snow.ConsensusContext {
   252  			return chainCtx
   253  		},
   254  		HaltF: func(context.Context) {},
   255  		ShutdownF: func(context.Context) error {
   256  			return nil
   257  		},
   258  		ConnectedF: func(context.Context, ids.NodeID, *version.Application) error {
   259  			return nil
   260  		},
   261  	}
   262  	engine.Default(true)
   263  	engine.CantGossip = false
   264  
   265  	bootstrapper := &common.BootstrapperTest{
   266  		EngineTest: engine,
   267  		CantClear:  true,
   268  	}
   269  
   270  	h.SetEngineManager(&handler.EngineManager{
   271  		Avalanche: &handler.Engine{
   272  			StateSyncer:  nil,
   273  			Bootstrapper: bootstrapper,
   274  			Consensus:    &engine,
   275  		},
   276  		Snowman: &handler.Engine{
   277  			StateSyncer:  nil,
   278  			Bootstrapper: bootstrapper,
   279  			Consensus:    &engine,
   280  		},
   281  	})
   282  	chainCtx.State.Set(snow.EngineState{
   283  		Type:  engineType,
   284  		State: snow.NormalOp, // assumed bootstrapping is done
   285  	})
   286  
   287  	chainRouter.AddChain(context.Background(), h)
   288  
   289  	h.Start(context.Background(), false)
   290  
   291  	chainRouter.Shutdown(context.Background())
   292  
   293  	shutdownDuration, err := h.AwaitStopped(context.Background())
   294  	require.NoError(err)
   295  	require.GreaterOrEqual(shutdownDuration, time.Duration(0))
   296  
   297  	// Calling connected after shutdown should result in an error log.
   298  	chainRouter.Connected(
   299  		ids.GenerateTestNodeID(),
   300  		version.CurrentApp,
   301  		ids.GenerateTestID(),
   302  	)
   303  }
   304  
   305  func TestShutdownTimesOut(t *testing.T) {
   306  	require := require.New(t)
   307  
   308  	snowCtx := snowtest.Context(t, snowtest.CChainID)
   309  	ctx := snowtest.ConsensusContext(snowCtx)
   310  	nodeID := ids.EmptyNodeID
   311  	vdrs := validators.NewManager()
   312  	require.NoError(vdrs.AddStaker(ctx.SubnetID, ids.GenerateTestNodeID(), nil, ids.Empty, 1))
   313  	benchlist := benchlist.NewNoBenchlist()
   314  	// Ensure that the Ancestors request does not timeout
   315  	tm, err := timeout.NewManager(
   316  		&timer.AdaptiveTimeoutConfig{
   317  			InitialTimeout:     time.Second,
   318  			MinimumTimeout:     500 * time.Millisecond,
   319  			MaximumTimeout:     10 * time.Second,
   320  			TimeoutCoefficient: 1.25,
   321  			TimeoutHalflife:    5 * time.Minute,
   322  		},
   323  		benchlist,
   324  		prometheus.NewRegistry(),
   325  		prometheus.NewRegistry(),
   326  	)
   327  	require.NoError(err)
   328  
   329  	go tm.Dispatch()
   330  	defer tm.Stop()
   331  
   332  	chainRouter := ChainRouter{}
   333  
   334  	require.NoError(chainRouter.Initialize(
   335  		ids.EmptyNodeID,
   336  		logging.NoLog{},
   337  		tm,
   338  		time.Millisecond,
   339  		set.Set[ids.ID]{},
   340  		true,
   341  		set.Set[ids.ID]{},
   342  		nil,
   343  		HealthConfig{},
   344  		prometheus.NewRegistry(),
   345  	))
   346  
   347  	resourceTracker, err := tracker.NewResourceTracker(
   348  		prometheus.NewRegistry(),
   349  		resource.NoUsage,
   350  		meter.ContinuousFactory{},
   351  		time.Second,
   352  	)
   353  	require.NoError(err)
   354  
   355  	p2pTracker, err := p2p.NewPeerTracker(
   356  		logging.NoLog{},
   357  		"",
   358  		prometheus.NewRegistry(),
   359  		nil,
   360  		version.CurrentApp,
   361  	)
   362  	require.NoError(err)
   363  
   364  	h, err := handler.New(
   365  		ctx,
   366  		vdrs,
   367  		nil,
   368  		time.Second,
   369  		testThreadPoolSize,
   370  		resourceTracker,
   371  		validators.UnhandledSubnetConnector,
   372  		subnets.New(ctx.NodeID, subnets.Config{}),
   373  		commontracker.NewPeers(),
   374  		p2pTracker,
   375  		prometheus.NewRegistry(),
   376  	)
   377  	require.NoError(err)
   378  
   379  	bootstrapFinished := make(chan struct{}, 1)
   380  	bootstrapper := &common.BootstrapperTest{
   381  		EngineTest: common.EngineTest{
   382  			T: t,
   383  		},
   384  	}
   385  	bootstrapper.Default(true)
   386  	bootstrapper.CantGossip = false
   387  	bootstrapper.ContextF = func() *snow.ConsensusContext {
   388  		return ctx
   389  	}
   390  	bootstrapper.ConnectedF = func(context.Context, ids.NodeID, *version.Application) error {
   391  		return nil
   392  	}
   393  	bootstrapper.HaltF = func(context.Context) {}
   394  	bootstrapper.PullQueryF = func(context.Context, ids.NodeID, uint32, ids.ID, uint64) error {
   395  		// Ancestors blocks for two seconds
   396  		time.Sleep(2 * time.Second)
   397  		bootstrapFinished <- struct{}{}
   398  		return nil
   399  	}
   400  
   401  	engine := &common.EngineTest{T: t}
   402  	engine.Default(false)
   403  	engine.ContextF = func() *snow.ConsensusContext {
   404  		return ctx
   405  	}
   406  	closed := new(int)
   407  	engine.ShutdownF = func(context.Context) error {
   408  		*closed++
   409  		return nil
   410  	}
   411  	h.SetEngineManager(&handler.EngineManager{
   412  		Avalanche: &handler.Engine{
   413  			StateSyncer:  nil,
   414  			Bootstrapper: bootstrapper,
   415  			Consensus:    engine,
   416  		},
   417  		Snowman: &handler.Engine{
   418  			StateSyncer:  nil,
   419  			Bootstrapper: bootstrapper,
   420  			Consensus:    engine,
   421  		},
   422  	})
   423  	ctx.State.Set(snow.EngineState{
   424  		Type:  engineType,
   425  		State: snow.NormalOp, // assumed bootstrapping is done
   426  	})
   427  
   428  	chainRouter.AddChain(context.Background(), h)
   429  
   430  	bootstrapper.StartF = func(context.Context, uint32) error {
   431  		return nil
   432  	}
   433  	h.Start(context.Background(), false)
   434  
   435  	shutdownFinished := make(chan struct{}, 1)
   436  
   437  	go func() {
   438  		chainID := ids.Empty
   439  		msg := handler.Message{
   440  			InboundMessage: message.InboundPullQuery(chainID, 1, time.Hour, ids.GenerateTestID(), 0, nodeID),
   441  			EngineType:     p2ppb.EngineType_ENGINE_TYPE_UNSPECIFIED,
   442  		}
   443  		h.Push(context.Background(), msg)
   444  
   445  		time.Sleep(50 * time.Millisecond) // Pause to ensure message gets processed
   446  
   447  		chainRouter.Shutdown(context.Background())
   448  		shutdownFinished <- struct{}{}
   449  	}()
   450  
   451  	select {
   452  	case <-bootstrapFinished:
   453  		require.FailNow("Shutdown should have finished in one millisecond before timing out instead of waiting for engine to finish shutting down.")
   454  	case <-shutdownFinished:
   455  	}
   456  }
   457  
   458  // Ensure that a timeout fires if we don't get a response to a request
   459  func TestRouterTimeout(t *testing.T) {
   460  	require := require.New(t)
   461  
   462  	// Create a timeout manager
   463  	maxTimeout := 25 * time.Millisecond
   464  	tm, err := timeout.NewManager(
   465  		&timer.AdaptiveTimeoutConfig{
   466  			InitialTimeout:     10 * time.Millisecond,
   467  			MinimumTimeout:     10 * time.Millisecond,
   468  			MaximumTimeout:     maxTimeout,
   469  			TimeoutCoefficient: 1,
   470  			TimeoutHalflife:    5 * time.Minute,
   471  		},
   472  		benchlist.NewNoBenchlist(),
   473  		prometheus.NewRegistry(),
   474  		prometheus.NewRegistry(),
   475  	)
   476  	require.NoError(err)
   477  
   478  	go tm.Dispatch()
   479  	defer tm.Stop()
   480  
   481  	// Create a router
   482  	chainRouter := ChainRouter{}
   483  	require.NoError(chainRouter.Initialize(
   484  		ids.EmptyNodeID,
   485  		logging.NoLog{},
   486  		tm,
   487  		time.Millisecond,
   488  		set.Set[ids.ID]{},
   489  		true,
   490  		set.Set[ids.ID]{},
   491  		nil,
   492  		HealthConfig{},
   493  		prometheus.NewRegistry(),
   494  	))
   495  	defer chainRouter.Shutdown(context.Background())
   496  
   497  	// Create bootstrapper, engine and handler
   498  	var (
   499  		calledGetStateSummaryFrontierFailed,
   500  		calledGetAcceptedStateSummaryFailed,
   501  		calledGetAcceptedFrontierFailed,
   502  		calledGetAcceptedFailed,
   503  		calledGetAncestorsFailed,
   504  		calledGetFailed,
   505  		calledQueryFailed,
   506  		calledAppRequestFailed,
   507  		calledCrossChainAppRequestFailed bool
   508  
   509  		wg = sync.WaitGroup{}
   510  	)
   511  
   512  	snowCtx := snowtest.Context(t, snowtest.CChainID)
   513  	ctx := snowtest.ConsensusContext(snowCtx)
   514  	vdrs := validators.NewManager()
   515  	require.NoError(vdrs.AddStaker(ctx.SubnetID, ids.GenerateTestNodeID(), nil, ids.Empty, 1))
   516  
   517  	resourceTracker, err := tracker.NewResourceTracker(
   518  		prometheus.NewRegistry(),
   519  		resource.NoUsage,
   520  		meter.ContinuousFactory{},
   521  		time.Second,
   522  	)
   523  	require.NoError(err)
   524  
   525  	p2pTracker, err := p2p.NewPeerTracker(
   526  		logging.NoLog{},
   527  		"",
   528  		prometheus.NewRegistry(),
   529  		nil,
   530  		version.CurrentApp,
   531  	)
   532  	require.NoError(err)
   533  
   534  	h, err := handler.New(
   535  		ctx,
   536  		vdrs,
   537  		nil,
   538  		time.Second,
   539  		testThreadPoolSize,
   540  		resourceTracker,
   541  		validators.UnhandledSubnetConnector,
   542  		subnets.New(ctx.NodeID, subnets.Config{}),
   543  		commontracker.NewPeers(),
   544  		p2pTracker,
   545  		prometheus.NewRegistry(),
   546  	)
   547  	require.NoError(err)
   548  
   549  	bootstrapper := &common.BootstrapperTest{
   550  		EngineTest: common.EngineTest{
   551  			T: t,
   552  		},
   553  	}
   554  	bootstrapper.Default(true)
   555  	bootstrapper.CantGossip = false
   556  	bootstrapper.ContextF = func() *snow.ConsensusContext {
   557  		return ctx
   558  	}
   559  	bootstrapper.ConnectedF = func(context.Context, ids.NodeID, *version.Application) error {
   560  		return nil
   561  	}
   562  	bootstrapper.HaltF = func(context.Context) {}
   563  	bootstrapper.ShutdownF = func(context.Context) error { return nil }
   564  
   565  	bootstrapper.GetStateSummaryFrontierFailedF = func(context.Context, ids.NodeID, uint32) error {
   566  		defer wg.Done()
   567  		calledGetStateSummaryFrontierFailed = true
   568  		return nil
   569  	}
   570  	bootstrapper.GetAcceptedStateSummaryFailedF = func(context.Context, ids.NodeID, uint32) error {
   571  		defer wg.Done()
   572  		calledGetAcceptedStateSummaryFailed = true
   573  		return nil
   574  	}
   575  	bootstrapper.GetAcceptedFrontierFailedF = func(context.Context, ids.NodeID, uint32) error {
   576  		defer wg.Done()
   577  		calledGetAcceptedFrontierFailed = true
   578  		return nil
   579  	}
   580  	bootstrapper.GetAncestorsFailedF = func(context.Context, ids.NodeID, uint32) error {
   581  		defer wg.Done()
   582  		calledGetAncestorsFailed = true
   583  		return nil
   584  	}
   585  	bootstrapper.GetAcceptedFailedF = func(context.Context, ids.NodeID, uint32) error {
   586  		defer wg.Done()
   587  		calledGetAcceptedFailed = true
   588  		return nil
   589  	}
   590  	bootstrapper.GetFailedF = func(context.Context, ids.NodeID, uint32) error {
   591  		defer wg.Done()
   592  		calledGetFailed = true
   593  		return nil
   594  	}
   595  	bootstrapper.QueryFailedF = func(context.Context, ids.NodeID, uint32) error {
   596  		defer wg.Done()
   597  		calledQueryFailed = true
   598  		return nil
   599  	}
   600  	bootstrapper.AppRequestFailedF = func(context.Context, ids.NodeID, uint32, *common.AppError) error {
   601  		defer wg.Done()
   602  		calledAppRequestFailed = true
   603  		return nil
   604  	}
   605  	bootstrapper.CrossChainAppRequestFailedF = func(context.Context, ids.ID, uint32, *common.AppError) error {
   606  		defer wg.Done()
   607  		calledCrossChainAppRequestFailed = true
   608  		return nil
   609  	}
   610  	ctx.State.Set(snow.EngineState{
   611  		Type:  p2ppb.EngineType_ENGINE_TYPE_SNOWMAN,
   612  		State: snow.Bootstrapping, // assumed bootstrapping is ongoing
   613  	})
   614  
   615  	chainRouter.AddChain(context.Background(), h)
   616  
   617  	bootstrapper.StartF = func(context.Context, uint32) error {
   618  		return nil
   619  	}
   620  	h.SetEngineManager(&handler.EngineManager{
   621  		Avalanche: &handler.Engine{
   622  			StateSyncer:  nil,
   623  			Bootstrapper: bootstrapper,
   624  			Consensus:    nil,
   625  		},
   626  		Snowman: &handler.Engine{
   627  			StateSyncer:  nil,
   628  			Bootstrapper: bootstrapper,
   629  			Consensus:    nil,
   630  		},
   631  	})
   632  	h.Start(context.Background(), false)
   633  
   634  	nodeID := ids.GenerateTestNodeID()
   635  	requestID := uint32(0)
   636  	{
   637  		wg.Add(1)
   638  		chainRouter.RegisterRequest(
   639  			context.Background(),
   640  			nodeID,
   641  			ctx.ChainID,
   642  			ctx.ChainID,
   643  			requestID,
   644  			message.StateSummaryFrontierOp,
   645  			message.InternalGetStateSummaryFrontierFailed(
   646  				nodeID,
   647  				ctx.ChainID,
   648  				requestID,
   649  			),
   650  			p2ppb.EngineType_ENGINE_TYPE_SNOWMAN,
   651  		)
   652  	}
   653  
   654  	{
   655  		wg.Add(1)
   656  		requestID++
   657  		chainRouter.RegisterRequest(
   658  			context.Background(),
   659  			nodeID,
   660  			ctx.ChainID,
   661  			ctx.ChainID,
   662  			requestID,
   663  			message.AcceptedStateSummaryOp,
   664  			message.InternalGetAcceptedStateSummaryFailed(
   665  				nodeID,
   666  				ctx.ChainID,
   667  				requestID,
   668  			),
   669  			p2ppb.EngineType_ENGINE_TYPE_SNOWMAN,
   670  		)
   671  	}
   672  
   673  	{
   674  		wg.Add(1)
   675  		requestID++
   676  		chainRouter.RegisterRequest(
   677  			context.Background(),
   678  			nodeID,
   679  			ctx.ChainID,
   680  			ctx.ChainID,
   681  			requestID,
   682  			message.AcceptedFrontierOp,
   683  			message.InternalGetAcceptedFrontierFailed(
   684  				nodeID,
   685  				ctx.ChainID,
   686  				requestID,
   687  			),
   688  			p2ppb.EngineType_ENGINE_TYPE_SNOWMAN,
   689  		)
   690  	}
   691  
   692  	{
   693  		wg.Add(1)
   694  		requestID++
   695  		chainRouter.RegisterRequest(
   696  			context.Background(),
   697  			nodeID,
   698  			ctx.ChainID,
   699  			ctx.ChainID,
   700  			requestID,
   701  			message.AcceptedOp,
   702  			message.InternalGetAcceptedFailed(
   703  				nodeID,
   704  				ctx.ChainID,
   705  				requestID,
   706  			),
   707  			p2ppb.EngineType_ENGINE_TYPE_SNOWMAN,
   708  		)
   709  	}
   710  
   711  	{
   712  		wg.Add(1)
   713  		requestID++
   714  		chainRouter.RegisterRequest(
   715  			context.Background(),
   716  			nodeID,
   717  			ctx.ChainID,
   718  			ctx.ChainID,
   719  			requestID,
   720  			message.AncestorsOp,
   721  			message.InternalGetAncestorsFailed(
   722  				nodeID,
   723  				ctx.ChainID,
   724  				requestID,
   725  				p2ppb.EngineType_ENGINE_TYPE_SNOWMAN,
   726  			),
   727  			p2ppb.EngineType_ENGINE_TYPE_SNOWMAN,
   728  		)
   729  	}
   730  
   731  	{
   732  		wg.Add(1)
   733  		requestID++
   734  		chainRouter.RegisterRequest(
   735  			context.Background(),
   736  			nodeID,
   737  			ctx.ChainID,
   738  			ctx.ChainID,
   739  			requestID,
   740  			message.PutOp,
   741  			message.InternalGetFailed(
   742  				nodeID,
   743  				ctx.ChainID,
   744  				requestID,
   745  			),
   746  			p2ppb.EngineType_ENGINE_TYPE_SNOWMAN,
   747  		)
   748  	}
   749  
   750  	{
   751  		wg.Add(1)
   752  		requestID++
   753  		chainRouter.RegisterRequest(
   754  			context.Background(),
   755  			nodeID,
   756  			ctx.ChainID,
   757  			ctx.ChainID,
   758  			requestID,
   759  			message.ChitsOp,
   760  			message.InternalQueryFailed(
   761  				nodeID,
   762  				ctx.ChainID,
   763  				requestID,
   764  			),
   765  			p2ppb.EngineType_ENGINE_TYPE_SNOWMAN,
   766  		)
   767  	}
   768  
   769  	{
   770  		wg.Add(1)
   771  		requestID++
   772  		chainRouter.RegisterRequest(
   773  			context.Background(),
   774  			nodeID,
   775  			ctx.ChainID,
   776  			ctx.ChainID,
   777  			requestID,
   778  			message.AppResponseOp,
   779  			message.InboundAppError(
   780  				nodeID,
   781  				ctx.ChainID,
   782  				requestID,
   783  				common.ErrTimeout.Code,
   784  				common.ErrTimeout.Message,
   785  			),
   786  			p2ppb.EngineType_ENGINE_TYPE_SNOWMAN,
   787  		)
   788  	}
   789  
   790  	{
   791  		wg.Add(1)
   792  		requestID++
   793  		chainRouter.RegisterRequest(
   794  			context.Background(),
   795  			nodeID,
   796  			ctx.ChainID,
   797  			ctx.ChainID,
   798  			requestID,
   799  			message.CrossChainAppResponseOp,
   800  			message.InternalCrossChainAppError(
   801  				nodeID,
   802  				ctx.ChainID,
   803  				ctx.ChainID,
   804  				requestID,
   805  				common.ErrTimeout.Code,
   806  				common.ErrTimeout.Message,
   807  			),
   808  			p2ppb.EngineType_ENGINE_TYPE_SNOWMAN,
   809  		)
   810  	}
   811  
   812  	wg.Wait()
   813  
   814  	chainRouter.lock.Lock()
   815  	defer chainRouter.lock.Unlock()
   816  
   817  	require.True(calledGetStateSummaryFrontierFailed)
   818  	require.True(calledGetAcceptedStateSummaryFailed)
   819  	require.True(calledGetAcceptedFrontierFailed)
   820  	require.True(calledGetAcceptedFailed)
   821  	require.True(calledGetAncestorsFailed)
   822  	require.True(calledGetFailed)
   823  	require.True(calledQueryFailed)
   824  	require.True(calledAppRequestFailed)
   825  	require.True(calledCrossChainAppRequestFailed)
   826  }
   827  
   828  func TestRouterHonorsRequestedEngine(t *testing.T) {
   829  	ctrl := gomock.NewController(t)
   830  	require := require.New(t)
   831  
   832  	// Create a timeout manager
   833  	tm, err := timeout.NewManager(
   834  		&timer.AdaptiveTimeoutConfig{
   835  			InitialTimeout:     3 * time.Second,
   836  			MinimumTimeout:     3 * time.Second,
   837  			MaximumTimeout:     5 * time.Minute,
   838  			TimeoutCoefficient: 1,
   839  			TimeoutHalflife:    5 * time.Minute,
   840  		},
   841  		benchlist.NewNoBenchlist(),
   842  		prometheus.NewRegistry(),
   843  		prometheus.NewRegistry(),
   844  	)
   845  	require.NoError(err)
   846  
   847  	go tm.Dispatch()
   848  	defer tm.Stop()
   849  
   850  	// Create a router
   851  	chainRouter := ChainRouter{}
   852  	require.NoError(chainRouter.Initialize(
   853  		ids.EmptyNodeID,
   854  		logging.NoLog{},
   855  		tm,
   856  		time.Millisecond,
   857  		set.Set[ids.ID]{},
   858  		true,
   859  		set.Set[ids.ID]{},
   860  		nil,
   861  		HealthConfig{},
   862  		prometheus.NewRegistry(),
   863  	))
   864  	defer chainRouter.Shutdown(context.Background())
   865  
   866  	h := handler.NewMockHandler(ctrl)
   867  
   868  	snowCtx := snowtest.Context(t, snowtest.CChainID)
   869  	ctx := snowtest.ConsensusContext(snowCtx)
   870  	h.EXPECT().Context().Return(ctx).AnyTimes()
   871  	h.EXPECT().SetOnStopped(gomock.Any()).AnyTimes()
   872  	h.EXPECT().Stop(gomock.Any()).AnyTimes()
   873  	h.EXPECT().AwaitStopped(gomock.Any()).AnyTimes()
   874  
   875  	h.EXPECT().Push(gomock.Any(), gomock.Any()).Times(1)
   876  	chainRouter.AddChain(context.Background(), h)
   877  
   878  	h.EXPECT().ShouldHandle(gomock.Any()).Return(true).AnyTimes()
   879  
   880  	nodeID := ids.GenerateTestNodeID()
   881  	requestID := uint32(0)
   882  	{
   883  		chainRouter.RegisterRequest(
   884  			context.Background(),
   885  			nodeID,
   886  			ctx.ChainID,
   887  			ctx.ChainID,
   888  			requestID,
   889  			message.StateSummaryFrontierOp,
   890  			message.InternalGetStateSummaryFrontierFailed(
   891  				nodeID,
   892  				ctx.ChainID,
   893  				requestID,
   894  			),
   895  			p2ppb.EngineType_ENGINE_TYPE_UNSPECIFIED,
   896  		)
   897  		msg := message.InboundStateSummaryFrontier(
   898  			ctx.ChainID,
   899  			requestID,
   900  			nil,
   901  			nodeID,
   902  		)
   903  
   904  		h.EXPECT().Push(gomock.Any(), gomock.Any()).Do(func(_ context.Context, msg handler.Message) {
   905  			require.Equal(p2ppb.EngineType_ENGINE_TYPE_UNSPECIFIED, msg.EngineType)
   906  		})
   907  		chainRouter.HandleInbound(context.Background(), msg)
   908  	}
   909  
   910  	{
   911  		requestID++
   912  		chainRouter.RegisterRequest(
   913  			context.Background(),
   914  			nodeID,
   915  			ctx.ChainID,
   916  			ctx.ChainID,
   917  			requestID,
   918  			message.AcceptedStateSummaryOp,
   919  			message.InternalGetAcceptedStateSummaryFailed(
   920  				nodeID,
   921  				ctx.ChainID,
   922  				requestID,
   923  			),
   924  			engineType,
   925  		)
   926  		msg := message.InboundAcceptedStateSummary(
   927  			ctx.ChainID,
   928  			requestID,
   929  			nil,
   930  			nodeID,
   931  		)
   932  
   933  		h.EXPECT().Push(gomock.Any(), gomock.Any()).Do(func(_ context.Context, msg handler.Message) {
   934  			require.Equal(engineType, msg.EngineType)
   935  		})
   936  		chainRouter.HandleInbound(context.Background(), msg)
   937  	}
   938  
   939  	{
   940  		requestID++
   941  		msg := message.InboundPushQuery(
   942  			ctx.ChainID,
   943  			requestID,
   944  			0,
   945  			nil,
   946  			0,
   947  			nodeID,
   948  		)
   949  
   950  		h.EXPECT().Push(gomock.Any(), gomock.Any()).Do(func(_ context.Context, msg handler.Message) {
   951  			require.Equal(p2ppb.EngineType_ENGINE_TYPE_UNSPECIFIED, msg.EngineType)
   952  		})
   953  		chainRouter.HandleInbound(context.Background(), msg)
   954  	}
   955  
   956  	chainRouter.lock.Lock()
   957  	require.Zero(chainRouter.timedRequests.Len())
   958  	chainRouter.lock.Unlock()
   959  }
   960  
   961  func TestRouterClearTimeouts(t *testing.T) {
   962  	requestID := uint32(123)
   963  
   964  	tests := []struct {
   965  		name        string
   966  		responseOp  message.Op
   967  		responseMsg message.InboundMessage
   968  		timeoutMsg  message.InboundMessage
   969  	}{
   970  		{
   971  			name:        "StateSummaryFrontier",
   972  			responseOp:  message.StateSummaryFrontierOp,
   973  			responseMsg: message.InboundStateSummaryFrontier(ids.Empty, requestID, []byte("summary"), ids.EmptyNodeID),
   974  			timeoutMsg:  message.InternalGetStateSummaryFrontierFailed(ids.EmptyNodeID, ids.Empty, requestID),
   975  		},
   976  		{
   977  			name:        "AcceptedStateSummary",
   978  			responseOp:  message.AcceptedStateSummaryOp,
   979  			responseMsg: message.InboundAcceptedStateSummary(ids.Empty, requestID, []ids.ID{ids.GenerateTestID()}, ids.EmptyNodeID),
   980  			timeoutMsg:  message.InternalGetAcceptedStateSummaryFailed(ids.EmptyNodeID, ids.Empty, requestID),
   981  		},
   982  		{
   983  			name:        "AcceptedFrontierOp",
   984  			responseOp:  message.AcceptedFrontierOp,
   985  			responseMsg: message.InboundAcceptedFrontier(ids.Empty, requestID, ids.GenerateTestID(), ids.EmptyNodeID),
   986  			timeoutMsg:  message.InternalGetAcceptedFrontierFailed(ids.EmptyNodeID, ids.Empty, requestID),
   987  		},
   988  		{
   989  			name:        "Accepted",
   990  			responseOp:  message.AcceptedOp,
   991  			responseMsg: message.InboundAccepted(ids.Empty, requestID, []ids.ID{ids.GenerateTestID()}, ids.EmptyNodeID),
   992  			timeoutMsg:  message.InternalGetAcceptedFailed(ids.EmptyNodeID, ids.Empty, requestID),
   993  		},
   994  		{
   995  			name:        "Chits",
   996  			responseOp:  message.ChitsOp,
   997  			responseMsg: message.InboundChits(ids.Empty, requestID, ids.GenerateTestID(), ids.GenerateTestID(), ids.GenerateTestID(), ids.EmptyNodeID),
   998  			timeoutMsg:  message.InternalQueryFailed(ids.EmptyNodeID, ids.Empty, requestID),
   999  		},
  1000  		{
  1001  			name:        "AppResponse",
  1002  			responseOp:  message.AppResponseOp,
  1003  			responseMsg: message.InboundAppResponse(ids.Empty, requestID, []byte("responseMsg"), ids.EmptyNodeID),
  1004  			timeoutMsg:  message.InboundAppError(ids.EmptyNodeID, ids.Empty, requestID, 123, "error"),
  1005  		},
  1006  		{
  1007  			name:        "AppError",
  1008  			responseOp:  message.AppResponseOp,
  1009  			responseMsg: message.InboundAppError(ids.EmptyNodeID, ids.Empty, requestID, 1234, "custom error"),
  1010  			timeoutMsg:  message.InboundAppError(ids.EmptyNodeID, ids.Empty, requestID, 123, "error"),
  1011  		},
  1012  		{
  1013  			name:        "CrossChainAppResponse",
  1014  			responseOp:  message.CrossChainAppResponseOp,
  1015  			responseMsg: message.InternalCrossChainAppResponse(ids.EmptyNodeID, ids.Empty, ids.Empty, requestID, []byte("responseMsg")),
  1016  			timeoutMsg:  message.InternalCrossChainAppError(ids.EmptyNodeID, ids.Empty, ids.Empty, requestID, 123, "error"),
  1017  		},
  1018  		{
  1019  			name:        "CrossChainAppError",
  1020  			responseOp:  message.CrossChainAppResponseOp,
  1021  			responseMsg: message.InternalCrossChainAppError(ids.EmptyNodeID, ids.Empty, ids.Empty, requestID, 1234, "custom error"),
  1022  			timeoutMsg:  message.InternalCrossChainAppError(ids.EmptyNodeID, ids.Empty, ids.Empty, requestID, 123, "error"),
  1023  		},
  1024  	}
  1025  
  1026  	for _, tt := range tests {
  1027  		t.Run(tt.name, func(t *testing.T) {
  1028  			require := require.New(t)
  1029  
  1030  			chainRouter, _ := newChainRouterTest(t)
  1031  
  1032  			chainRouter.RegisterRequest(
  1033  				context.Background(),
  1034  				ids.EmptyNodeID,
  1035  				ids.Empty,
  1036  				ids.Empty,
  1037  				requestID,
  1038  				tt.responseOp,
  1039  				tt.timeoutMsg,
  1040  				engineType,
  1041  			)
  1042  
  1043  			chainRouter.HandleInbound(context.Background(), tt.responseMsg)
  1044  
  1045  			chainRouter.lock.Lock()
  1046  			require.Zero(chainRouter.timedRequests.Len())
  1047  			chainRouter.lock.Unlock()
  1048  		})
  1049  	}
  1050  }
  1051  
  1052  func TestValidatorOnlyMessageDrops(t *testing.T) {
  1053  	require := require.New(t)
  1054  
  1055  	// Create a timeout manager
  1056  	maxTimeout := 25 * time.Millisecond
  1057  	tm, err := timeout.NewManager(
  1058  		&timer.AdaptiveTimeoutConfig{
  1059  			InitialTimeout:     10 * time.Millisecond,
  1060  			MinimumTimeout:     10 * time.Millisecond,
  1061  			MaximumTimeout:     maxTimeout,
  1062  			TimeoutCoefficient: 1,
  1063  			TimeoutHalflife:    5 * time.Minute,
  1064  		},
  1065  		benchlist.NewNoBenchlist(),
  1066  		prometheus.NewRegistry(),
  1067  		prometheus.NewRegistry(),
  1068  	)
  1069  	require.NoError(err)
  1070  
  1071  	go tm.Dispatch()
  1072  	defer tm.Stop()
  1073  
  1074  	// Create a router
  1075  	chainRouter := ChainRouter{}
  1076  	require.NoError(chainRouter.Initialize(
  1077  		ids.EmptyNodeID,
  1078  		logging.NoLog{},
  1079  		tm,
  1080  		time.Millisecond,
  1081  		set.Set[ids.ID]{},
  1082  		true,
  1083  		set.Set[ids.ID]{},
  1084  		nil,
  1085  		HealthConfig{},
  1086  		prometheus.NewRegistry(),
  1087  	))
  1088  	defer chainRouter.Shutdown(context.Background())
  1089  
  1090  	// Create bootstrapper, engine and handler
  1091  	calledF := false
  1092  	wg := sync.WaitGroup{}
  1093  
  1094  	snowCtx := snowtest.Context(t, snowtest.CChainID)
  1095  	ctx := snowtest.ConsensusContext(snowCtx)
  1096  	sb := subnets.New(ctx.NodeID, subnets.Config{ValidatorOnly: true})
  1097  	vdrs := validators.NewManager()
  1098  	vID := ids.GenerateTestNodeID()
  1099  	require.NoError(vdrs.AddStaker(ctx.SubnetID, vID, nil, ids.Empty, 1))
  1100  	resourceTracker, err := tracker.NewResourceTracker(
  1101  		prometheus.NewRegistry(),
  1102  		resource.NoUsage,
  1103  		meter.ContinuousFactory{},
  1104  		time.Second,
  1105  	)
  1106  	require.NoError(err)
  1107  
  1108  	p2pTracker, err := p2p.NewPeerTracker(
  1109  		logging.NoLog{},
  1110  		"",
  1111  		prometheus.NewRegistry(),
  1112  		nil,
  1113  		version.CurrentApp,
  1114  	)
  1115  	require.NoError(err)
  1116  
  1117  	h, err := handler.New(
  1118  		ctx,
  1119  		vdrs,
  1120  		nil,
  1121  		time.Second,
  1122  		testThreadPoolSize,
  1123  		resourceTracker,
  1124  		validators.UnhandledSubnetConnector,
  1125  		sb,
  1126  		commontracker.NewPeers(),
  1127  		p2pTracker,
  1128  		prometheus.NewRegistry(),
  1129  	)
  1130  	require.NoError(err)
  1131  
  1132  	bootstrapper := &common.BootstrapperTest{
  1133  		EngineTest: common.EngineTest{
  1134  			T: t,
  1135  		},
  1136  	}
  1137  	bootstrapper.Default(false)
  1138  	bootstrapper.ContextF = func() *snow.ConsensusContext {
  1139  		return ctx
  1140  	}
  1141  	bootstrapper.PullQueryF = func(context.Context, ids.NodeID, uint32, ids.ID, uint64) error {
  1142  		defer wg.Done()
  1143  		calledF = true
  1144  		return nil
  1145  	}
  1146  	ctx.State.Set(snow.EngineState{
  1147  		Type:  p2ppb.EngineType_ENGINE_TYPE_SNOWMAN,
  1148  		State: snow.Bootstrapping, // assumed bootstrapping is ongoing
  1149  	})
  1150  
  1151  	engine := &common.EngineTest{T: t}
  1152  	engine.ContextF = func() *snow.ConsensusContext {
  1153  		return ctx
  1154  	}
  1155  	engine.Default(false)
  1156  	h.SetEngineManager(&handler.EngineManager{
  1157  		Avalanche: &handler.Engine{
  1158  			StateSyncer:  nil,
  1159  			Bootstrapper: bootstrapper,
  1160  			Consensus:    engine,
  1161  		},
  1162  		Snowman: &handler.Engine{
  1163  			StateSyncer:  nil,
  1164  			Bootstrapper: bootstrapper,
  1165  			Consensus:    engine,
  1166  		},
  1167  	})
  1168  
  1169  	chainRouter.AddChain(context.Background(), h)
  1170  
  1171  	bootstrapper.StartF = func(context.Context, uint32) error {
  1172  		return nil
  1173  	}
  1174  	h.Start(context.Background(), false)
  1175  
  1176  	var inMsg message.InboundMessage
  1177  	dummyContainerID := ids.GenerateTestID()
  1178  	reqID := uint32(0)
  1179  
  1180  	// Non-validator case
  1181  	nID := ids.GenerateTestNodeID()
  1182  
  1183  	calledF = false
  1184  	inMsg = message.InboundPullQuery(
  1185  		ctx.ChainID,
  1186  		reqID,
  1187  		time.Hour,
  1188  		dummyContainerID,
  1189  		0,
  1190  		nID,
  1191  	)
  1192  	chainRouter.HandleInbound(context.Background(), inMsg)
  1193  
  1194  	require.False(calledF) // should not be called
  1195  
  1196  	// Validator case
  1197  	calledF = false
  1198  	reqID++
  1199  	inMsg = message.InboundPullQuery(
  1200  		ctx.ChainID,
  1201  		reqID,
  1202  		time.Hour,
  1203  		dummyContainerID,
  1204  		0,
  1205  		vID,
  1206  	)
  1207  	wg.Add(1)
  1208  	chainRouter.HandleInbound(context.Background(), inMsg)
  1209  
  1210  	wg.Wait()
  1211  	require.True(calledF) // should be called since this is a validator request
  1212  }
  1213  
  1214  func TestConnectedSubnet(t *testing.T) {
  1215  	require := require.New(t)
  1216  	ctrl := gomock.NewController(t)
  1217  
  1218  	tm, err := timeout.NewManager(
  1219  		&timer.AdaptiveTimeoutConfig{
  1220  			InitialTimeout:     3 * time.Second,
  1221  			MinimumTimeout:     3 * time.Second,
  1222  			MaximumTimeout:     5 * time.Minute,
  1223  			TimeoutCoefficient: 1,
  1224  			TimeoutHalflife:    5 * time.Minute,
  1225  		},
  1226  		benchlist.NewNoBenchlist(),
  1227  		prometheus.NewRegistry(),
  1228  		prometheus.NewRegistry(),
  1229  	)
  1230  	require.NoError(err)
  1231  
  1232  	go tm.Dispatch()
  1233  	defer tm.Stop()
  1234  
  1235  	// Create chain router
  1236  	myNodeID := ids.GenerateTestNodeID()
  1237  	peerNodeID := ids.GenerateTestNodeID()
  1238  	subnetID0 := ids.GenerateTestID()
  1239  	subnetID1 := ids.GenerateTestID()
  1240  	trackedSubnets := set.Of(subnetID0, subnetID1)
  1241  	chainRouter := ChainRouter{}
  1242  	require.NoError(chainRouter.Initialize(
  1243  		myNodeID,
  1244  		logging.NoLog{},
  1245  		tm,
  1246  		time.Millisecond,
  1247  		set.Set[ids.ID]{},
  1248  		true,
  1249  		trackedSubnets,
  1250  		nil,
  1251  		HealthConfig{},
  1252  		prometheus.NewRegistry(),
  1253  	))
  1254  
  1255  	// Create bootstrapper, engine and handler
  1256  	snowCtx := snowtest.Context(t, snowtest.PChainID)
  1257  	ctx := snowtest.ConsensusContext(snowCtx)
  1258  	ctx.Executing.Set(false)
  1259  	ctx.State.Set(snow.EngineState{
  1260  		Type:  engineType,
  1261  		State: snow.NormalOp,
  1262  	})
  1263  
  1264  	myConnectedMsg := handler.Message{
  1265  		InboundMessage: message.InternalConnected(myNodeID, version.CurrentApp),
  1266  		EngineType:     p2ppb.EngineType_ENGINE_TYPE_UNSPECIFIED,
  1267  	}
  1268  	mySubnetConnectedMsg0 := handler.Message{
  1269  		InboundMessage: message.InternalConnectedSubnet(myNodeID, subnetID0),
  1270  		EngineType:     p2ppb.EngineType_ENGINE_TYPE_UNSPECIFIED,
  1271  	}
  1272  	mySubnetConnectedMsg1 := handler.Message{
  1273  		InboundMessage: message.InternalConnectedSubnet(myNodeID, subnetID1),
  1274  		EngineType:     p2ppb.EngineType_ENGINE_TYPE_UNSPECIFIED,
  1275  	}
  1276  
  1277  	platformHandler := handler.NewMockHandler(ctrl)
  1278  	platformHandler.EXPECT().Context().Return(ctx).AnyTimes()
  1279  	platformHandler.EXPECT().SetOnStopped(gomock.Any()).AnyTimes()
  1280  	platformHandler.EXPECT().Push(gomock.Any(), myConnectedMsg).Times(1)
  1281  	platformHandler.EXPECT().Push(gomock.Any(), mySubnetConnectedMsg0).Times(1)
  1282  	platformHandler.EXPECT().Push(gomock.Any(), mySubnetConnectedMsg1).Times(1)
  1283  
  1284  	chainRouter.AddChain(context.Background(), platformHandler)
  1285  
  1286  	peerConnectedMsg := handler.Message{
  1287  		InboundMessage: message.InternalConnected(peerNodeID, version.CurrentApp),
  1288  		EngineType:     p2ppb.EngineType_ENGINE_TYPE_UNSPECIFIED,
  1289  	}
  1290  	platformHandler.EXPECT().Push(gomock.Any(), peerConnectedMsg).Times(1)
  1291  	chainRouter.Connected(peerNodeID, version.CurrentApp, constants.PrimaryNetworkID)
  1292  
  1293  	peerSubnetConnectedMsg0 := handler.Message{
  1294  		InboundMessage: message.InternalConnectedSubnet(peerNodeID, subnetID0),
  1295  		EngineType:     p2ppb.EngineType_ENGINE_TYPE_UNSPECIFIED,
  1296  	}
  1297  	platformHandler.EXPECT().Push(gomock.Any(), peerSubnetConnectedMsg0).Times(1)
  1298  	chainRouter.Connected(peerNodeID, version.CurrentApp, subnetID0)
  1299  
  1300  	myDisconnectedMsg := handler.Message{
  1301  		InboundMessage: message.InternalDisconnected(myNodeID),
  1302  		EngineType:     p2ppb.EngineType_ENGINE_TYPE_UNSPECIFIED,
  1303  	}
  1304  	platformHandler.EXPECT().Push(gomock.Any(), myDisconnectedMsg).Times(1)
  1305  	chainRouter.Benched(constants.PlatformChainID, myNodeID)
  1306  
  1307  	peerDisconnectedMsg := handler.Message{
  1308  		InboundMessage: message.InternalDisconnected(peerNodeID),
  1309  		EngineType:     p2ppb.EngineType_ENGINE_TYPE_UNSPECIFIED,
  1310  	}
  1311  	platformHandler.EXPECT().Push(gomock.Any(), peerDisconnectedMsg).Times(1)
  1312  	chainRouter.Benched(constants.PlatformChainID, peerNodeID)
  1313  
  1314  	platformHandler.EXPECT().Push(gomock.Any(), myConnectedMsg).Times(1)
  1315  	platformHandler.EXPECT().Push(gomock.Any(), mySubnetConnectedMsg0).Times(1)
  1316  	platformHandler.EXPECT().Push(gomock.Any(), mySubnetConnectedMsg1).Times(1)
  1317  
  1318  	chainRouter.Unbenched(constants.PlatformChainID, myNodeID)
  1319  
  1320  	platformHandler.EXPECT().Push(gomock.Any(), peerConnectedMsg).Times(1)
  1321  	platformHandler.EXPECT().Push(gomock.Any(), peerSubnetConnectedMsg0).Times(1)
  1322  
  1323  	chainRouter.Unbenched(constants.PlatformChainID, peerNodeID)
  1324  
  1325  	platformHandler.EXPECT().Push(gomock.Any(), peerDisconnectedMsg).Times(1)
  1326  	chainRouter.Disconnected(peerNodeID)
  1327  }
  1328  
  1329  func TestValidatorOnlyAllowedNodeMessageDrops(t *testing.T) {
  1330  	require := require.New(t)
  1331  
  1332  	// Create a timeout manager
  1333  	maxTimeout := 25 * time.Millisecond
  1334  	tm, err := timeout.NewManager(
  1335  		&timer.AdaptiveTimeoutConfig{
  1336  			InitialTimeout:     10 * time.Millisecond,
  1337  			MinimumTimeout:     10 * time.Millisecond,
  1338  			MaximumTimeout:     maxTimeout,
  1339  			TimeoutCoefficient: 1,
  1340  			TimeoutHalflife:    5 * time.Minute,
  1341  		},
  1342  		benchlist.NewNoBenchlist(),
  1343  		prometheus.NewRegistry(),
  1344  		prometheus.NewRegistry(),
  1345  	)
  1346  	require.NoError(err)
  1347  
  1348  	go tm.Dispatch()
  1349  	defer tm.Stop()
  1350  
  1351  	// Create a router
  1352  	chainRouter := ChainRouter{}
  1353  	require.NoError(chainRouter.Initialize(
  1354  		ids.EmptyNodeID,
  1355  		logging.NoLog{},
  1356  		tm,
  1357  		time.Millisecond,
  1358  		set.Set[ids.ID]{},
  1359  		true,
  1360  		set.Set[ids.ID]{},
  1361  		nil,
  1362  		HealthConfig{},
  1363  		prometheus.NewRegistry(),
  1364  	))
  1365  	defer chainRouter.Shutdown(context.Background())
  1366  
  1367  	// Create bootstrapper, engine and handler
  1368  	calledF := false
  1369  	wg := sync.WaitGroup{}
  1370  
  1371  	snowCtx := snowtest.Context(t, snowtest.CChainID)
  1372  	ctx := snowtest.ConsensusContext(snowCtx)
  1373  	allowedID := ids.GenerateTestNodeID()
  1374  	allowedSet := set.Of(allowedID)
  1375  	sb := subnets.New(ctx.NodeID, subnets.Config{ValidatorOnly: true, AllowedNodes: allowedSet})
  1376  
  1377  	vdrs := validators.NewManager()
  1378  	vID := ids.GenerateTestNodeID()
  1379  	require.NoError(vdrs.AddStaker(ctx.SubnetID, vID, nil, ids.Empty, 1))
  1380  
  1381  	resourceTracker, err := tracker.NewResourceTracker(
  1382  		prometheus.NewRegistry(),
  1383  		resource.NoUsage,
  1384  		meter.ContinuousFactory{},
  1385  		time.Second,
  1386  	)
  1387  	require.NoError(err)
  1388  
  1389  	p2pTracker, err := p2p.NewPeerTracker(
  1390  		logging.NoLog{},
  1391  		"",
  1392  		prometheus.NewRegistry(),
  1393  		nil,
  1394  		version.CurrentApp,
  1395  	)
  1396  	require.NoError(err)
  1397  
  1398  	h, err := handler.New(
  1399  		ctx,
  1400  		vdrs,
  1401  		nil,
  1402  		time.Second,
  1403  		testThreadPoolSize,
  1404  		resourceTracker,
  1405  		validators.UnhandledSubnetConnector,
  1406  		sb,
  1407  		commontracker.NewPeers(),
  1408  		p2pTracker,
  1409  		prometheus.NewRegistry(),
  1410  	)
  1411  	require.NoError(err)
  1412  
  1413  	bootstrapper := &common.BootstrapperTest{
  1414  		EngineTest: common.EngineTest{
  1415  			T: t,
  1416  		},
  1417  	}
  1418  	bootstrapper.Default(false)
  1419  	bootstrapper.ContextF = func() *snow.ConsensusContext {
  1420  		return ctx
  1421  	}
  1422  	bootstrapper.PullQueryF = func(context.Context, ids.NodeID, uint32, ids.ID, uint64) error {
  1423  		defer wg.Done()
  1424  		calledF = true
  1425  		return nil
  1426  	}
  1427  	ctx.State.Set(snow.EngineState{
  1428  		Type:  engineType,
  1429  		State: snow.Bootstrapping, // assumed bootstrapping is ongoing
  1430  	})
  1431  	engine := &common.EngineTest{T: t}
  1432  	engine.ContextF = func() *snow.ConsensusContext {
  1433  		return ctx
  1434  	}
  1435  	engine.Default(false)
  1436  
  1437  	h.SetEngineManager(&handler.EngineManager{
  1438  		Avalanche: &handler.Engine{
  1439  			Bootstrapper: bootstrapper,
  1440  			Consensus:    engine,
  1441  		},
  1442  	})
  1443  
  1444  	chainRouter.AddChain(context.Background(), h)
  1445  
  1446  	bootstrapper.StartF = func(context.Context, uint32) error {
  1447  		return nil
  1448  	}
  1449  	h.Start(context.Background(), false)
  1450  
  1451  	var inMsg message.InboundMessage
  1452  	dummyContainerID := ids.GenerateTestID()
  1453  	reqID := uint32(0)
  1454  
  1455  	// Non-validator case
  1456  	nID := ids.GenerateTestNodeID()
  1457  
  1458  	calledF = false
  1459  	inMsg = message.InboundPullQuery(
  1460  		ctx.ChainID,
  1461  		reqID,
  1462  		time.Hour,
  1463  		dummyContainerID,
  1464  		0,
  1465  		nID,
  1466  	)
  1467  	chainRouter.HandleInbound(context.Background(), inMsg)
  1468  
  1469  	require.False(calledF) // should not be called for unallowed node ID
  1470  
  1471  	// Allowed NodeID case
  1472  	calledF = false
  1473  	reqID++
  1474  	inMsg = message.InboundPullQuery(
  1475  		ctx.ChainID,
  1476  		reqID,
  1477  		time.Hour,
  1478  		dummyContainerID,
  1479  		0,
  1480  		allowedID,
  1481  	)
  1482  	wg.Add(1)
  1483  	chainRouter.HandleInbound(context.Background(), inMsg)
  1484  
  1485  	wg.Wait()
  1486  	require.True(calledF) // should be called since this is a allowed node request
  1487  
  1488  	// Validator case
  1489  	calledF = false
  1490  	reqID++
  1491  	inMsg = message.InboundPullQuery(
  1492  		ctx.ChainID,
  1493  		reqID,
  1494  		time.Hour,
  1495  		dummyContainerID,
  1496  		0,
  1497  		vID,
  1498  	)
  1499  	wg.Add(1)
  1500  	chainRouter.HandleInbound(context.Background(), inMsg)
  1501  
  1502  	wg.Wait()
  1503  	require.True(calledF) // should be called since this is a validator request
  1504  }
  1505  
  1506  // Tests that a response, peer error, or a timeout clears the timeout and calls
  1507  // the handler
  1508  func TestAppRequest(t *testing.T) {
  1509  	wantRequestID := uint32(123)
  1510  	wantResponse := []byte("response")
  1511  
  1512  	errFoo := common.AppError{
  1513  		Code:    456,
  1514  		Message: "foo",
  1515  	}
  1516  
  1517  	tests := []struct {
  1518  		name       string
  1519  		responseOp message.Op
  1520  		timeoutMsg message.InboundMessage
  1521  		inboundMsg message.InboundMessage
  1522  	}{
  1523  		{
  1524  			name:       "AppRequest - chain response",
  1525  			responseOp: message.AppResponseOp,
  1526  			timeoutMsg: message.InboundAppError(ids.EmptyNodeID, ids.Empty, wantRequestID, errFoo.Code, errFoo.Message),
  1527  			inboundMsg: message.InboundAppResponse(ids.Empty, wantRequestID, wantResponse, ids.EmptyNodeID),
  1528  		},
  1529  		{
  1530  			name:       "AppRequest - chain error",
  1531  			responseOp: message.AppResponseOp,
  1532  			timeoutMsg: message.InboundAppError(ids.EmptyNodeID, ids.Empty, wantRequestID, errFoo.Code, errFoo.Message),
  1533  			inboundMsg: message.InboundAppError(ids.EmptyNodeID, ids.Empty, wantRequestID, errFoo.Code, errFoo.Message),
  1534  		},
  1535  		{
  1536  			name:       "AppRequest - timeout",
  1537  			responseOp: message.AppResponseOp,
  1538  			timeoutMsg: message.InboundAppError(ids.EmptyNodeID, ids.Empty, wantRequestID, errFoo.Code, errFoo.Message),
  1539  		},
  1540  	}
  1541  
  1542  	for _, tt := range tests {
  1543  		t.Run(tt.name, func(t *testing.T) {
  1544  			require := require.New(t)
  1545  
  1546  			wg := &sync.WaitGroup{}
  1547  			chainRouter, engine := newChainRouterTest(t)
  1548  
  1549  			wg.Add(1)
  1550  			if tt.inboundMsg == nil || tt.inboundMsg.Op() == message.AppErrorOp {
  1551  				engine.AppRequestFailedF = func(_ context.Context, nodeID ids.NodeID, requestID uint32, appErr *common.AppError) error {
  1552  					defer wg.Done()
  1553  					chainRouter.lock.Lock()
  1554  					require.Zero(chainRouter.timedRequests.Len())
  1555  					chainRouter.lock.Unlock()
  1556  
  1557  					require.Equal(ids.EmptyNodeID, nodeID)
  1558  					require.Equal(wantRequestID, requestID)
  1559  					require.Equal(errFoo.Code, appErr.Code)
  1560  					require.Equal(errFoo.Message, appErr.Message)
  1561  
  1562  					return nil
  1563  				}
  1564  			} else if tt.inboundMsg.Op() == message.AppResponseOp {
  1565  				engine.AppResponseF = func(_ context.Context, nodeID ids.NodeID, requestID uint32, msg []byte) error {
  1566  					defer wg.Done()
  1567  					chainRouter.lock.Lock()
  1568  					require.Zero(chainRouter.timedRequests.Len())
  1569  					chainRouter.lock.Unlock()
  1570  
  1571  					require.Equal(ids.EmptyNodeID, nodeID)
  1572  					require.Equal(wantRequestID, requestID)
  1573  					require.Equal(wantResponse, msg)
  1574  
  1575  					return nil
  1576  				}
  1577  			}
  1578  
  1579  			ctx := context.Background()
  1580  			chainRouter.RegisterRequest(ctx, ids.EmptyNodeID, ids.Empty, ids.Empty, wantRequestID, tt.responseOp, tt.timeoutMsg, engineType)
  1581  			chainRouter.lock.Lock()
  1582  			require.Equal(1, chainRouter.timedRequests.Len())
  1583  			chainRouter.lock.Unlock()
  1584  
  1585  			if tt.inboundMsg != nil {
  1586  				chainRouter.HandleInbound(ctx, tt.inboundMsg)
  1587  			}
  1588  
  1589  			wg.Wait()
  1590  		})
  1591  	}
  1592  }
  1593  
  1594  // Tests that a response, peer error, or a timeout clears the timeout and calls
  1595  // the handler
  1596  func TestCrossChainAppRequest(t *testing.T) {
  1597  	wantRequestID := uint32(123)
  1598  	wantResponse := []byte("response")
  1599  
  1600  	errFoo := common.AppError{
  1601  		Code:    456,
  1602  		Message: "foo",
  1603  	}
  1604  
  1605  	tests := []struct {
  1606  		name       string
  1607  		responseOp message.Op
  1608  		timeoutMsg message.InboundMessage
  1609  		inboundMsg message.InboundMessage
  1610  	}{
  1611  		{
  1612  			name:       "CrossChainAppRequest - chain response",
  1613  			responseOp: message.CrossChainAppResponseOp,
  1614  			timeoutMsg: message.InternalCrossChainAppError(ids.EmptyNodeID, ids.Empty, ids.Empty, wantRequestID, errFoo.Code, errFoo.Message),
  1615  			inboundMsg: message.InternalCrossChainAppResponse(ids.EmptyNodeID, ids.Empty, ids.Empty, wantRequestID, wantResponse),
  1616  		},
  1617  		{
  1618  			name:       "CrossChainAppRequest - chain error",
  1619  			responseOp: message.CrossChainAppResponseOp,
  1620  			timeoutMsg: message.InternalCrossChainAppError(ids.EmptyNodeID, ids.Empty, ids.Empty, wantRequestID, errFoo.Code, errFoo.Message),
  1621  			inboundMsg: message.InternalCrossChainAppError(ids.EmptyNodeID, ids.Empty, ids.Empty, wantRequestID, errFoo.Code, errFoo.Message),
  1622  		},
  1623  		{
  1624  			name:       "CrossChainAppRequest - timeout",
  1625  			responseOp: message.CrossChainAppResponseOp,
  1626  			timeoutMsg: message.InternalCrossChainAppError(ids.EmptyNodeID, ids.Empty, ids.Empty, wantRequestID, errFoo.Code, errFoo.Message),
  1627  		},
  1628  	}
  1629  
  1630  	for _, tt := range tests {
  1631  		t.Run(tt.name, func(t *testing.T) {
  1632  			require := require.New(t)
  1633  
  1634  			wg := &sync.WaitGroup{}
  1635  			chainRouter, engine := newChainRouterTest(t)
  1636  
  1637  			wg.Add(1)
  1638  			if tt.inboundMsg == nil || tt.inboundMsg.Op() == message.CrossChainAppErrorOp {
  1639  				engine.CrossChainAppRequestFailedF = func(_ context.Context, chainID ids.ID, requestID uint32, appErr *common.AppError) error {
  1640  					defer wg.Done()
  1641  					chainRouter.lock.Lock()
  1642  					require.Zero(chainRouter.timedRequests.Len())
  1643  					chainRouter.lock.Unlock()
  1644  
  1645  					require.Equal(ids.Empty, chainID)
  1646  					require.Equal(wantRequestID, requestID)
  1647  					require.Equal(errFoo.Code, appErr.Code)
  1648  					require.Equal(errFoo.Message, appErr.Message)
  1649  
  1650  					return nil
  1651  				}
  1652  			} else if tt.inboundMsg.Op() == message.CrossChainAppResponseOp {
  1653  				engine.CrossChainAppResponseF = func(_ context.Context, chainID ids.ID, requestID uint32, msg []byte) error {
  1654  					defer wg.Done()
  1655  					chainRouter.lock.Lock()
  1656  					require.Zero(chainRouter.timedRequests.Len())
  1657  					chainRouter.lock.Unlock()
  1658  
  1659  					require.Equal(ids.Empty, chainID)
  1660  					require.Equal(wantRequestID, requestID)
  1661  					require.Equal(wantResponse, msg)
  1662  
  1663  					return nil
  1664  				}
  1665  			}
  1666  
  1667  			ctx := context.Background()
  1668  			chainRouter.RegisterRequest(ctx, ids.EmptyNodeID, ids.Empty, ids.Empty, wantRequestID, tt.responseOp, tt.timeoutMsg, engineType)
  1669  			chainRouter.lock.Lock()
  1670  			require.Equal(1, chainRouter.timedRequests.Len())
  1671  			chainRouter.lock.Unlock()
  1672  
  1673  			if tt.inboundMsg != nil {
  1674  				chainRouter.HandleInbound(ctx, tt.inboundMsg)
  1675  			}
  1676  
  1677  			wg.Wait()
  1678  		})
  1679  	}
  1680  }
  1681  
  1682  func newChainRouterTest(t *testing.T) (*ChainRouter, *common.EngineTest) {
  1683  	// Create a timeout manager
  1684  	tm, err := timeout.NewManager(
  1685  		&timer.AdaptiveTimeoutConfig{
  1686  			InitialTimeout:     3 * time.Second,
  1687  			MinimumTimeout:     3 * time.Second,
  1688  			MaximumTimeout:     5 * time.Minute,
  1689  			TimeoutCoefficient: 1,
  1690  			TimeoutHalflife:    5 * time.Minute,
  1691  		},
  1692  		benchlist.NewNoBenchlist(),
  1693  		prometheus.NewRegistry(),
  1694  		prometheus.NewRegistry(),
  1695  	)
  1696  	require.NoError(t, err)
  1697  
  1698  	go tm.Dispatch()
  1699  
  1700  	// Create a router
  1701  	chainRouter := &ChainRouter{}
  1702  	require.NoError(t, chainRouter.Initialize(
  1703  		ids.EmptyNodeID,
  1704  		logging.NoLog{},
  1705  		tm,
  1706  		time.Millisecond,
  1707  		set.Set[ids.ID]{},
  1708  		true,
  1709  		set.Set[ids.ID]{},
  1710  		nil,
  1711  		HealthConfig{},
  1712  		prometheus.NewRegistry(),
  1713  	))
  1714  
  1715  	// Create bootstrapper, engine and handler
  1716  	snowCtx := snowtest.Context(t, snowtest.PChainID)
  1717  	ctx := snowtest.ConsensusContext(snowCtx)
  1718  	vdrs := validators.NewManager()
  1719  	require.NoError(t, vdrs.AddStaker(ctx.SubnetID, ids.GenerateTestNodeID(), nil, ids.Empty, 1))
  1720  
  1721  	resourceTracker, err := tracker.NewResourceTracker(
  1722  		prometheus.NewRegistry(),
  1723  		resource.NoUsage,
  1724  		meter.ContinuousFactory{},
  1725  		time.Second,
  1726  	)
  1727  	require.NoError(t, err)
  1728  
  1729  	p2pTracker, err := p2p.NewPeerTracker(
  1730  		logging.NoLog{},
  1731  		"",
  1732  		prometheus.NewRegistry(),
  1733  		nil,
  1734  		version.CurrentApp,
  1735  	)
  1736  	require.NoError(t, err)
  1737  
  1738  	h, err := handler.New(
  1739  		ctx,
  1740  		vdrs,
  1741  		nil,
  1742  		time.Second,
  1743  		testThreadPoolSize,
  1744  		resourceTracker,
  1745  		validators.UnhandledSubnetConnector,
  1746  		subnets.New(ctx.NodeID, subnets.Config{}),
  1747  		commontracker.NewPeers(),
  1748  		p2pTracker,
  1749  		prometheus.NewRegistry(),
  1750  	)
  1751  	require.NoError(t, err)
  1752  
  1753  	bootstrapper := &common.BootstrapperTest{
  1754  		EngineTest: common.EngineTest{
  1755  			T: t,
  1756  		},
  1757  	}
  1758  	bootstrapper.Default(false)
  1759  	bootstrapper.ContextF = func() *snow.ConsensusContext {
  1760  		return ctx
  1761  	}
  1762  
  1763  	engine := &common.EngineTest{T: t}
  1764  	engine.Default(false)
  1765  	engine.ContextF = func() *snow.ConsensusContext {
  1766  		return ctx
  1767  	}
  1768  	h.SetEngineManager(&handler.EngineManager{
  1769  		Avalanche: &handler.Engine{
  1770  			StateSyncer:  nil,
  1771  			Bootstrapper: bootstrapper,
  1772  			Consensus:    engine,
  1773  		},
  1774  		Snowman: &handler.Engine{
  1775  			StateSyncer:  nil,
  1776  			Bootstrapper: bootstrapper,
  1777  			Consensus:    engine,
  1778  		},
  1779  	})
  1780  	ctx.State.Set(snow.EngineState{
  1781  		Type:  p2ppb.EngineType_ENGINE_TYPE_SNOWMAN,
  1782  		State: snow.NormalOp, // assumed bootstrapping is done
  1783  	})
  1784  
  1785  	chainRouter.AddChain(context.Background(), h)
  1786  
  1787  	bootstrapper.StartF = func(context.Context, uint32) error {
  1788  		return nil
  1789  	}
  1790  
  1791  	h.Start(context.Background(), false)
  1792  
  1793  	t.Cleanup(func() {
  1794  		tm.Stop()
  1795  		chainRouter.Shutdown(context.Background())
  1796  	})
  1797  
  1798  	return chainRouter, engine
  1799  }