github.com/grafana/pyroscope@v1.18.0/pkg/querier/worker/scheduler_processor_test.go (about)

     1  // SPDX-License-Identifier: AGPL-3.0-only
     2  
     3  package worker
     4  
     5  import (
     6  	"context"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/go-kit/log"
    11  	"github.com/gogo/status"
    12  	"github.com/grafana/dskit/concurrency"
    13  	"github.com/stretchr/testify/assert"
    14  	"github.com/stretchr/testify/mock"
    15  	"github.com/stretchr/testify/require"
    16  	"go.uber.org/atomic"
    17  	"google.golang.org/grpc"
    18  	"google.golang.org/grpc/codes"
    19  	"google.golang.org/grpc/metadata"
    20  
    21  	"github.com/grafana/pyroscope/pkg/scheduler/schedulerpb"
    22  	"github.com/grafana/pyroscope/pkg/util/httpgrpc"
    23  )
    24  
    25  func TestSchedulerProcessor_processQueriesOnSingleStream(t *testing.T) {
    26  	t.Run("should immediately return if worker context is canceled and there's no inflight query", func(t *testing.T) {
    27  		sp, loopClient, requestHandler := prepareSchedulerProcessor()
    28  
    29  		workerCtx, workerCancel := context.WithCancel(context.Background())
    30  
    31  		loopClient.On("Recv").Return(func() (*schedulerpb.SchedulerToQuerier, error) {
    32  			// Simulate the querier received a SIGTERM while waiting for a query to execute.
    33  			workerCancel()
    34  
    35  			// No query to execute, so wait until terminated.
    36  			<-loopClient.Context().Done()
    37  			return nil, loopClient.Context().Err()
    38  		})
    39  
    40  		requestHandler.On("Handle", mock.Anything, mock.Anything).Return(&httpgrpc.HTTPResponse{}, nil)
    41  
    42  		sp.processQueriesOnSingleStream(workerCtx, nil, "127.0.0.1")
    43  
    44  		// We expect at this point, the execution context has been canceled too.
    45  		require.Error(t, loopClient.Context().Err())
    46  
    47  		// We expect Send() has been called only once, to send the querier ID to scheduler.
    48  		loopClient.AssertNumberOfCalls(t, "Send", 1)
    49  		loopClient.AssertCalled(t, "Send", &schedulerpb.QuerierToScheduler{QuerierID: "test-querier-id"})
    50  	})
    51  
    52  	t.Run("should wait until inflight query execution is completed before returning when worker context is canceled", func(t *testing.T) {
    53  		sp, loopClient, requestHandler := prepareSchedulerProcessor()
    54  
    55  		recvCount := atomic.NewInt64(0)
    56  
    57  		loopClient.On("Recv").Return(func() (*schedulerpb.SchedulerToQuerier, error) {
    58  			switch recvCount.Inc() {
    59  			case 1:
    60  				return &schedulerpb.SchedulerToQuerier{
    61  					QueryID:         1,
    62  					HttpRequest:     nil,
    63  					FrontendAddress: "127.0.0.2",
    64  					UserID:          "user-1",
    65  				}, nil
    66  			default:
    67  				// No more messages to process, so waiting until terminated.
    68  				<-loopClient.Context().Done()
    69  				return nil, loopClient.Context().Err()
    70  			}
    71  		})
    72  
    73  		workerCtx, workerCancel := context.WithCancel(context.Background())
    74  
    75  		requestHandler.On("Handle", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
    76  			// Cancel the worker context while the query execution is in progress.
    77  			workerCancel()
    78  
    79  			// Ensure the execution context hasn't been canceled yet.
    80  			require.Nil(t, loopClient.Context().Err())
    81  
    82  			// Intentionally slow down the query execution, to double check the worker waits until done.
    83  			time.Sleep(time.Second)
    84  		}).Return(&httpgrpc.HTTPResponse{}, nil)
    85  
    86  		startTime := time.Now()
    87  		sp.processQueriesOnSingleStream(workerCtx, nil, "127.0.0.1")
    88  		assert.GreaterOrEqual(t, time.Since(startTime), time.Second)
    89  
    90  		// We expect at this point, the execution context has been canceled too.
    91  		require.Error(t, loopClient.Context().Err())
    92  
    93  		// We expect Send() to be called twice: first to send the querier ID to scheduler
    94  		// and then to send the query result.
    95  		loopClient.AssertNumberOfCalls(t, "Send", 2)
    96  		loopClient.AssertCalled(t, "Send", &schedulerpb.QuerierToScheduler{QuerierID: "test-querier-id"})
    97  	})
    98  
    99  	t.Run("should not log an error when the query-scheduler is terminates while waiting for the next query to run", func(t *testing.T) {
   100  		sp, loopClient, requestHandler := prepareSchedulerProcessor()
   101  
   102  		// Override the logger to capture the logs.
   103  		logs := &concurrency.SyncBuffer{}
   104  		sp.log = log.NewLogfmtLogger(logs)
   105  
   106  		workerCtx, workerCancel := context.WithCancel(context.Background())
   107  
   108  		// As soon as the Recv() is called for the first time, we cancel the worker context and
   109  		// return the "scheduler not running" error. The reason why we cancel the worker context
   110  		// is to let processQueriesOnSingleStream() terminate.
   111  		loopClient.On("Recv").Return(func() (*schedulerpb.SchedulerToQuerier, error) {
   112  			workerCancel()
   113  			return nil, status.Error(codes.Unknown, schedulerpb.ErrSchedulerIsNotRunning.Error())
   114  		})
   115  
   116  		requestHandler.On("Handle", mock.Anything, mock.Anything).Return(&httpgrpc.HTTPResponse{}, nil)
   117  
   118  		sp.processQueriesOnSingleStream(workerCtx, nil, "127.0.0.1")
   119  
   120  		// We expect no error in the log.
   121  		assert.NotContains(t, logs.String(), "error")
   122  		assert.NotContains(t, logs.String(), schedulerpb.ErrSchedulerIsNotRunning)
   123  	})
   124  }
   125  
   126  func prepareSchedulerProcessor() (*schedulerProcessor, *querierLoopClientMock, *requestHandlerMock) {
   127  	var querierLoopCtx context.Context
   128  
   129  	loopClient := &querierLoopClientMock{}
   130  	loopClient.On("Send", mock.Anything).Return(nil)
   131  	loopClient.On("Context").Return(func() context.Context {
   132  		return querierLoopCtx
   133  	})
   134  
   135  	schedulerClient := &schedulerForQuerierClientMock{}
   136  	schedulerClient.On("QuerierLoop", mock.Anything, mock.Anything).Run(func(args mock.Arguments) {
   137  		querierLoopCtx = args.Get(0).(context.Context)
   138  	}).Return(loopClient, nil)
   139  
   140  	requestHandler := &requestHandlerMock{}
   141  
   142  	sp, _ := newSchedulerProcessor(Config{QuerierID: "test-querier-id"}, requestHandler, log.NewNopLogger(), nil)
   143  	sp.schedulerClientFactory = func(_ *grpc.ClientConn) schedulerpb.SchedulerForQuerierClient {
   144  		return schedulerClient
   145  	}
   146  
   147  	return sp, loopClient, requestHandler
   148  }
   149  
   150  type schedulerForQuerierClientMock struct {
   151  	mock.Mock
   152  }
   153  
   154  func (m *schedulerForQuerierClientMock) QuerierLoop(ctx context.Context, opts ...grpc.CallOption) (schedulerpb.SchedulerForQuerier_QuerierLoopClient, error) {
   155  	args := m.Called(ctx, opts)
   156  	return args.Get(0).(schedulerpb.SchedulerForQuerier_QuerierLoopClient), args.Error(1)
   157  }
   158  
   159  func (m *schedulerForQuerierClientMock) NotifyQuerierShutdown(ctx context.Context, in *schedulerpb.NotifyQuerierShutdownRequest, opts ...grpc.CallOption) (*schedulerpb.NotifyQuerierShutdownResponse, error) {
   160  	args := m.Called(ctx, in, opts)
   161  	return args.Get(0).(*schedulerpb.NotifyQuerierShutdownResponse), args.Error(1)
   162  }
   163  
   164  type querierLoopClientMock struct {
   165  	mock.Mock
   166  }
   167  
   168  func (m *querierLoopClientMock) Send(msg *schedulerpb.QuerierToScheduler) error {
   169  	args := m.Called(msg)
   170  	return args.Error(0)
   171  }
   172  
   173  func (m *querierLoopClientMock) Recv() (*schedulerpb.SchedulerToQuerier, error) {
   174  	args := m.Called()
   175  
   176  	// Allow to mock the Recv() with a function which is called each time.
   177  	if fn, ok := args.Get(0).(func() (*schedulerpb.SchedulerToQuerier, error)); ok {
   178  		return fn()
   179  	}
   180  
   181  	return args.Get(0).(*schedulerpb.SchedulerToQuerier), args.Error(1)
   182  }
   183  
   184  func (m *querierLoopClientMock) Header() (metadata.MD, error) {
   185  	args := m.Called()
   186  	return args.Get(0).(metadata.MD), args.Error(1)
   187  }
   188  
   189  func (m *querierLoopClientMock) Trailer() metadata.MD {
   190  	args := m.Called()
   191  	return args.Get(0).(metadata.MD)
   192  }
   193  
   194  func (m *querierLoopClientMock) CloseSend() error {
   195  	args := m.Called()
   196  	return args.Error(0)
   197  }
   198  
   199  func (m *querierLoopClientMock) Context() context.Context {
   200  	args := m.Called()
   201  
   202  	// Allow to mock the Context() with a function which is called each time.
   203  	if fn, ok := args.Get(0).(func() context.Context); ok {
   204  		return fn()
   205  	}
   206  
   207  	return args.Get(0).(context.Context)
   208  }
   209  
   210  func (m *querierLoopClientMock) SendMsg(msg interface{}) error {
   211  	args := m.Called(msg)
   212  	return args.Error(0)
   213  }
   214  
   215  func (m *querierLoopClientMock) RecvMsg(msg interface{}) error {
   216  	args := m.Called(msg)
   217  	return args.Error(0)
   218  }
   219  
   220  type requestHandlerMock struct {
   221  	mock.Mock
   222  }
   223  
   224  func (m *requestHandlerMock) Handle(ctx context.Context, req *httpgrpc.HTTPRequest) (*httpgrpc.HTTPResponse, error) {
   225  	args := m.Called(ctx, req)
   226  	return args.Get(0).(*httpgrpc.HTTPResponse), args.Error(1)
   227  }