github.com/grafana/pyroscope@v1.18.0/pkg/querier/worker/scheduler_processor_test.go (about) 1 // SPDX-License-Identifier: AGPL-3.0-only 2 3 package worker 4 5 import ( 6 "context" 7 "testing" 8 "time" 9 10 "github.com/go-kit/log" 11 "github.com/gogo/status" 12 "github.com/grafana/dskit/concurrency" 13 "github.com/stretchr/testify/assert" 14 "github.com/stretchr/testify/mock" 15 "github.com/stretchr/testify/require" 16 "go.uber.org/atomic" 17 "google.golang.org/grpc" 18 "google.golang.org/grpc/codes" 19 "google.golang.org/grpc/metadata" 20 21 "github.com/grafana/pyroscope/pkg/scheduler/schedulerpb" 22 "github.com/grafana/pyroscope/pkg/util/httpgrpc" 23 ) 24 25 func TestSchedulerProcessor_processQueriesOnSingleStream(t *testing.T) { 26 t.Run("should immediately return if worker context is canceled and there's no inflight query", func(t *testing.T) { 27 sp, loopClient, requestHandler := prepareSchedulerProcessor() 28 29 workerCtx, workerCancel := context.WithCancel(context.Background()) 30 31 loopClient.On("Recv").Return(func() (*schedulerpb.SchedulerToQuerier, error) { 32 // Simulate the querier received a SIGTERM while waiting for a query to execute. 33 workerCancel() 34 35 // No query to execute, so wait until terminated. 36 <-loopClient.Context().Done() 37 return nil, loopClient.Context().Err() 38 }) 39 40 requestHandler.On("Handle", mock.Anything, mock.Anything).Return(&httpgrpc.HTTPResponse{}, nil) 41 42 sp.processQueriesOnSingleStream(workerCtx, nil, "127.0.0.1") 43 44 // We expect at this point, the execution context has been canceled too. 45 require.Error(t, loopClient.Context().Err()) 46 47 // We expect Send() has been called only once, to send the querier ID to scheduler. 48 loopClient.AssertNumberOfCalls(t, "Send", 1) 49 loopClient.AssertCalled(t, "Send", &schedulerpb.QuerierToScheduler{QuerierID: "test-querier-id"}) 50 }) 51 52 t.Run("should wait until inflight query execution is completed before returning when worker context is canceled", func(t *testing.T) { 53 sp, loopClient, requestHandler := prepareSchedulerProcessor() 54 55 recvCount := atomic.NewInt64(0) 56 57 loopClient.On("Recv").Return(func() (*schedulerpb.SchedulerToQuerier, error) { 58 switch recvCount.Inc() { 59 case 1: 60 return &schedulerpb.SchedulerToQuerier{ 61 QueryID: 1, 62 HttpRequest: nil, 63 FrontendAddress: "127.0.0.2", 64 UserID: "user-1", 65 }, nil 66 default: 67 // No more messages to process, so waiting until terminated. 68 <-loopClient.Context().Done() 69 return nil, loopClient.Context().Err() 70 } 71 }) 72 73 workerCtx, workerCancel := context.WithCancel(context.Background()) 74 75 requestHandler.On("Handle", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { 76 // Cancel the worker context while the query execution is in progress. 77 workerCancel() 78 79 // Ensure the execution context hasn't been canceled yet. 80 require.Nil(t, loopClient.Context().Err()) 81 82 // Intentionally slow down the query execution, to double check the worker waits until done. 83 time.Sleep(time.Second) 84 }).Return(&httpgrpc.HTTPResponse{}, nil) 85 86 startTime := time.Now() 87 sp.processQueriesOnSingleStream(workerCtx, nil, "127.0.0.1") 88 assert.GreaterOrEqual(t, time.Since(startTime), time.Second) 89 90 // We expect at this point, the execution context has been canceled too. 91 require.Error(t, loopClient.Context().Err()) 92 93 // We expect Send() to be called twice: first to send the querier ID to scheduler 94 // and then to send the query result. 95 loopClient.AssertNumberOfCalls(t, "Send", 2) 96 loopClient.AssertCalled(t, "Send", &schedulerpb.QuerierToScheduler{QuerierID: "test-querier-id"}) 97 }) 98 99 t.Run("should not log an error when the query-scheduler is terminates while waiting for the next query to run", func(t *testing.T) { 100 sp, loopClient, requestHandler := prepareSchedulerProcessor() 101 102 // Override the logger to capture the logs. 103 logs := &concurrency.SyncBuffer{} 104 sp.log = log.NewLogfmtLogger(logs) 105 106 workerCtx, workerCancel := context.WithCancel(context.Background()) 107 108 // As soon as the Recv() is called for the first time, we cancel the worker context and 109 // return the "scheduler not running" error. The reason why we cancel the worker context 110 // is to let processQueriesOnSingleStream() terminate. 111 loopClient.On("Recv").Return(func() (*schedulerpb.SchedulerToQuerier, error) { 112 workerCancel() 113 return nil, status.Error(codes.Unknown, schedulerpb.ErrSchedulerIsNotRunning.Error()) 114 }) 115 116 requestHandler.On("Handle", mock.Anything, mock.Anything).Return(&httpgrpc.HTTPResponse{}, nil) 117 118 sp.processQueriesOnSingleStream(workerCtx, nil, "127.0.0.1") 119 120 // We expect no error in the log. 121 assert.NotContains(t, logs.String(), "error") 122 assert.NotContains(t, logs.String(), schedulerpb.ErrSchedulerIsNotRunning) 123 }) 124 } 125 126 func prepareSchedulerProcessor() (*schedulerProcessor, *querierLoopClientMock, *requestHandlerMock) { 127 var querierLoopCtx context.Context 128 129 loopClient := &querierLoopClientMock{} 130 loopClient.On("Send", mock.Anything).Return(nil) 131 loopClient.On("Context").Return(func() context.Context { 132 return querierLoopCtx 133 }) 134 135 schedulerClient := &schedulerForQuerierClientMock{} 136 schedulerClient.On("QuerierLoop", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { 137 querierLoopCtx = args.Get(0).(context.Context) 138 }).Return(loopClient, nil) 139 140 requestHandler := &requestHandlerMock{} 141 142 sp, _ := newSchedulerProcessor(Config{QuerierID: "test-querier-id"}, requestHandler, log.NewNopLogger(), nil) 143 sp.schedulerClientFactory = func(_ *grpc.ClientConn) schedulerpb.SchedulerForQuerierClient { 144 return schedulerClient 145 } 146 147 return sp, loopClient, requestHandler 148 } 149 150 type schedulerForQuerierClientMock struct { 151 mock.Mock 152 } 153 154 func (m *schedulerForQuerierClientMock) QuerierLoop(ctx context.Context, opts ...grpc.CallOption) (schedulerpb.SchedulerForQuerier_QuerierLoopClient, error) { 155 args := m.Called(ctx, opts) 156 return args.Get(0).(schedulerpb.SchedulerForQuerier_QuerierLoopClient), args.Error(1) 157 } 158 159 func (m *schedulerForQuerierClientMock) NotifyQuerierShutdown(ctx context.Context, in *schedulerpb.NotifyQuerierShutdownRequest, opts ...grpc.CallOption) (*schedulerpb.NotifyQuerierShutdownResponse, error) { 160 args := m.Called(ctx, in, opts) 161 return args.Get(0).(*schedulerpb.NotifyQuerierShutdownResponse), args.Error(1) 162 } 163 164 type querierLoopClientMock struct { 165 mock.Mock 166 } 167 168 func (m *querierLoopClientMock) Send(msg *schedulerpb.QuerierToScheduler) error { 169 args := m.Called(msg) 170 return args.Error(0) 171 } 172 173 func (m *querierLoopClientMock) Recv() (*schedulerpb.SchedulerToQuerier, error) { 174 args := m.Called() 175 176 // Allow to mock the Recv() with a function which is called each time. 177 if fn, ok := args.Get(0).(func() (*schedulerpb.SchedulerToQuerier, error)); ok { 178 return fn() 179 } 180 181 return args.Get(0).(*schedulerpb.SchedulerToQuerier), args.Error(1) 182 } 183 184 func (m *querierLoopClientMock) Header() (metadata.MD, error) { 185 args := m.Called() 186 return args.Get(0).(metadata.MD), args.Error(1) 187 } 188 189 func (m *querierLoopClientMock) Trailer() metadata.MD { 190 args := m.Called() 191 return args.Get(0).(metadata.MD) 192 } 193 194 func (m *querierLoopClientMock) CloseSend() error { 195 args := m.Called() 196 return args.Error(0) 197 } 198 199 func (m *querierLoopClientMock) Context() context.Context { 200 args := m.Called() 201 202 // Allow to mock the Context() with a function which is called each time. 203 if fn, ok := args.Get(0).(func() context.Context); ok { 204 return fn() 205 } 206 207 return args.Get(0).(context.Context) 208 } 209 210 func (m *querierLoopClientMock) SendMsg(msg interface{}) error { 211 args := m.Called(msg) 212 return args.Error(0) 213 } 214 215 func (m *querierLoopClientMock) RecvMsg(msg interface{}) error { 216 args := m.Called(msg) 217 return args.Error(0) 218 } 219 220 type requestHandlerMock struct { 221 mock.Mock 222 } 223 224 func (m *requestHandlerMock) Handle(ctx context.Context, req *httpgrpc.HTTPRequest) (*httpgrpc.HTTPResponse, error) { 225 args := m.Called(ctx, req) 226 return args.Get(0).(*httpgrpc.HTTPResponse), args.Error(1) 227 }