github.com/grafana/pyroscope@v1.18.0/pkg/frontend/frontend_test.go (about)

     1  // SPDX-License-Identifier: AGPL-3.0-only
     2  // Provenance-includes-location: https://github.com/cortexproject/cortex/blob/master/pkg/frontend/v2/frontend_test.go
     3  // Provenance-includes-license: Apache-2.0
     4  // Provenance-includes-copyright: The Cortex Authors.
     5  
     6  package frontend
     7  
     8  import (
     9  	"bytes"
    10  	"context"
    11  	"fmt"
    12  	"net/http"
    13  	"net/http/httptest"
    14  	"net/url"
    15  	"os"
    16  	"runtime"
    17  	"strconv"
    18  	"strings"
    19  	"sync"
    20  	"testing"
    21  	"time"
    22  
    23  	"connectrpc.com/connect"
    24  	"github.com/go-kit/log"
    25  	"github.com/gorilla/mux"
    26  	"github.com/grafana/dskit/flagext"
    27  	"github.com/grafana/dskit/services"
    28  	"github.com/grafana/dskit/test"
    29  	"github.com/grafana/dskit/user"
    30  	"github.com/opentracing/opentracing-go"
    31  	"github.com/prometheus/client_golang/prometheus"
    32  	"github.com/prometheus/client_golang/prometheus/testutil"
    33  	"github.com/stretchr/testify/assert"
    34  	"github.com/stretchr/testify/require"
    35  	"go.uber.org/atomic"
    36  	"golang.org/x/net/http2"
    37  	"golang.org/x/net/http2/h2c"
    38  
    39  	"github.com/grafana/pyroscope/api/gen/proto/go/querier/v1/querierv1connect"
    40  	typesv1 "github.com/grafana/pyroscope/api/gen/proto/go/types/v1"
    41  	connectapi "github.com/grafana/pyroscope/pkg/api/connect"
    42  	"github.com/grafana/pyroscope/pkg/frontend/frontendpb"
    43  	"github.com/grafana/pyroscope/pkg/frontend/frontendpb/frontendpbconnect"
    44  	"github.com/grafana/pyroscope/pkg/querier/stats"
    45  	"github.com/grafana/pyroscope/pkg/querier/worker"
    46  	"github.com/grafana/pyroscope/pkg/scheduler"
    47  	"github.com/grafana/pyroscope/pkg/scheduler/schedulerdiscovery"
    48  	"github.com/grafana/pyroscope/pkg/scheduler/schedulerpb"
    49  	"github.com/grafana/pyroscope/pkg/scheduler/schedulerpb/schedulerpbconnect"
    50  	"github.com/grafana/pyroscope/pkg/util/connectgrpc"
    51  	"github.com/grafana/pyroscope/pkg/util/httpgrpc"
    52  	"github.com/grafana/pyroscope/pkg/util/servicediscovery"
    53  	"github.com/grafana/pyroscope/pkg/validation"
    54  )
    55  
    56  const testFrontendWorkerConcurrency = 5
    57  
    58  func setupFrontend(t *testing.T, reg prometheus.Registerer, schedulerReplyFunc func(f *Frontend, msg *schedulerpb.FrontendToScheduler) *schedulerpb.SchedulerToFrontend) (*Frontend, *mockScheduler) {
    59  	return setupFrontendWithConcurrencyAndServerOptions(t, reg, schedulerReplyFunc, testFrontendWorkerConcurrency)
    60  }
    61  
    62  func cfgFromURL(t *testing.T, urlS string) Config {
    63  	u, err := url.Parse(urlS)
    64  	require.NoError(t, err)
    65  
    66  	port, err := strconv.Atoi(u.Port())
    67  	require.NoError(t, err)
    68  
    69  	cfg := Config{}
    70  	flagext.DefaultValues(&cfg)
    71  	cfg.SchedulerAddress = u.Hostname() + ":" + u.Port()
    72  	cfg.Addr = u.Hostname()
    73  	cfg.Port = port
    74  	return cfg
    75  }
    76  
    77  func setupFrontendWithConcurrencyAndServerOptions(t *testing.T, reg prometheus.Registerer, schedulerReplyFunc func(f *Frontend, msg *schedulerpb.FrontendToScheduler) *schedulerpb.SchedulerToFrontend, concurrency int) (*Frontend, *mockScheduler) {
    78  	s := httptest.NewUnstartedServer(nil)
    79  	mux := mux.NewRouter()
    80  	s.Config.Handler = h2c.NewHandler(mux, &http2.Server{})
    81  
    82  	s.Start()
    83  
    84  	cfg := cfgFromURL(t, s.URL)
    85  	cfg.WorkerConcurrency = concurrency
    86  
    87  	logger := log.NewLogfmtLogger(os.Stdout)
    88  	f, err := NewFrontend(cfg, validation.MockLimits{MaxQueryParallelismValue: 1}, logger, reg)
    89  	require.NoError(t, err)
    90  
    91  	frontendpbconnect.RegisterFrontendForQuerierHandler(mux, f)
    92  
    93  	ms := newMockScheduler(t, f, schedulerReplyFunc)
    94  
    95  	schedulerpbconnect.RegisterSchedulerForFrontendHandler(mux, ms)
    96  
    97  	t.Cleanup(func() {
    98  		s.Close()
    99  	})
   100  
   101  	require.NoError(t, services.StartAndAwaitRunning(context.Background(), f))
   102  	t.Cleanup(func() {
   103  		_ = services.StopAndAwaitTerminated(context.Background(), f)
   104  	})
   105  
   106  	// Wait for frontend to connect to scheduler.
   107  	test.Poll(t, 1*time.Second, 1, func() interface{} {
   108  		ms.mu.Lock()
   109  		defer ms.mu.Unlock()
   110  
   111  		return len(ms.frontendAddr)
   112  	})
   113  
   114  	return f, ms
   115  }
   116  
   117  func sendResponseWithDelay(f *Frontend, delay time.Duration, userID string, queryID uint64, resp *httpgrpc.HTTPResponse) {
   118  	if delay > 0 {
   119  		time.Sleep(delay)
   120  	}
   121  
   122  	ctx := user.InjectOrgID(context.Background(), userID)
   123  	_, _ = f.QueryResult(ctx, connect.NewRequest(&frontendpb.QueryResultRequest{
   124  		QueryID:      queryID,
   125  		HttpResponse: resp,
   126  		Stats:        &stats.Stats{},
   127  	}))
   128  }
   129  
   130  func TestFrontendBasicWorkflow(t *testing.T) {
   131  	const (
   132  		body   = "all fine here"
   133  		userID = "test"
   134  	)
   135  
   136  	f, _ := setupFrontend(t, nil, func(f *Frontend, msg *schedulerpb.FrontendToScheduler) *schedulerpb.SchedulerToFrontend {
   137  		// We cannot call QueryResult directly, as Frontend is not yet waiting for the response.
   138  		// It first needs to be told that enqueuing has succeeded.
   139  		go sendResponseWithDelay(f, 100*time.Millisecond, userID, msg.QueryID, &httpgrpc.HTTPResponse{
   140  			Code: 200,
   141  			Body: []byte(body),
   142  		})
   143  
   144  		return &schedulerpb.SchedulerToFrontend{Status: schedulerpb.SchedulerToFrontendStatus_OK}
   145  	})
   146  
   147  	resp, err := f.RoundTripGRPC(user.InjectOrgID(context.Background(), userID), &httpgrpc.HTTPRequest{})
   148  	require.NoError(t, err)
   149  	require.Equal(t, int32(200), resp.Code)
   150  	require.Equal(t, []byte(body), resp.Body)
   151  }
   152  
   153  func TestFrontendRequestsPerWorkerMetric(t *testing.T) {
   154  	const (
   155  		body   = "all fine here"
   156  		userID = "test"
   157  	)
   158  
   159  	reg := prometheus.NewRegistry()
   160  
   161  	f, _ := setupFrontend(t, reg, func(f *Frontend, msg *schedulerpb.FrontendToScheduler) *schedulerpb.SchedulerToFrontend {
   162  		// We cannot call QueryResult directly, as Frontend is not yet waiting for the response.
   163  		// It first needs to be told that enqueuing has succeeded.
   164  		go sendResponseWithDelay(f, 100*time.Millisecond, userID, msg.QueryID, &httpgrpc.HTTPResponse{
   165  			Code: 200,
   166  			Body: []byte(body),
   167  		})
   168  
   169  		return &schedulerpb.SchedulerToFrontend{Status: schedulerpb.SchedulerToFrontendStatus_OK}
   170  	})
   171  
   172  	expectedMetrics := fmt.Sprintf(`
   173  		# HELP pyroscope_query_frontend_workers_enqueued_requests_total Total number of requests enqueued by each query frontend worker (regardless of the result), labeled by scheduler address.
   174  		# TYPE pyroscope_query_frontend_workers_enqueued_requests_total counter
   175  		pyroscope_query_frontend_workers_enqueued_requests_total{scheduler_address="%s"} 0
   176  	`, f.cfg.SchedulerAddress)
   177  	require.NoError(t, testutil.GatherAndCompare(reg, strings.NewReader(expectedMetrics), "pyroscope_query_frontend_workers_enqueued_requests_total"))
   178  
   179  	resp, err := f.RoundTripGRPC(user.InjectOrgID(context.Background(), userID), &httpgrpc.HTTPRequest{})
   180  	require.NoError(t, err)
   181  	require.Equal(t, int32(200), resp.Code)
   182  	require.Equal(t, []byte(body), resp.Body)
   183  
   184  	expectedMetrics = fmt.Sprintf(`
   185  		# HELP pyroscope_query_frontend_workers_enqueued_requests_total Total number of requests enqueued by each query frontend worker (regardless of the result), labeled by scheduler address.
   186  		# TYPE pyroscope_query_frontend_workers_enqueued_requests_total counter
   187  		pyroscope_query_frontend_workers_enqueued_requests_total{scheduler_address="%s"} 1
   188  	`, f.cfg.SchedulerAddress)
   189  	require.NoError(t, testutil.GatherAndCompare(reg, strings.NewReader(expectedMetrics), "pyroscope_query_frontend_workers_enqueued_requests_total"))
   190  
   191  	// Manually remove the address, check that label is removed.
   192  	f.schedulerWorkers.InstanceRemoved(servicediscovery.Instance{Address: f.cfg.SchedulerAddress, InUse: true})
   193  	expectedMetrics = ``
   194  	require.NoError(t, testutil.GatherAndCompare(reg, strings.NewReader(expectedMetrics), "pyroscope_query_frontend_workers_enqueued_requests_total"))
   195  }
   196  
   197  func newFakeQuerierGRPCHandler() connectgrpc.GRPCHandler {
   198  	q := &fakeQuerier{}
   199  	mux := http.NewServeMux()
   200  	mux.Handle(querierv1connect.NewQuerierServiceHandler(q, connectapi.DefaultHandlerOptions()...))
   201  	return connectgrpc.NewHandler(mux)
   202  }
   203  
   204  type fakeQuerier struct {
   205  	querierv1connect.QuerierServiceHandler
   206  }
   207  
   208  func (f *fakeQuerier) LabelNames(ctx context.Context, req *connect.Request[typesv1.LabelNamesRequest]) (*connect.Response[typesv1.LabelNamesResponse], error) {
   209  	return connect.NewResponse(&typesv1.LabelNamesResponse{
   210  		Names: []string{"i", "have", "labels"},
   211  	}), nil
   212  }
   213  
   214  func headerToSlice(t testing.TB, header http.Header) []string {
   215  	buf := new(bytes.Buffer)
   216  	excludeHeaders := map[string]bool{"Content-Length": true, "Date": true}
   217  	require.NoError(t, header.WriteSubset(buf, excludeHeaders))
   218  	sl := strings.Split(strings.ReplaceAll(buf.String(), "\r\n", "\n"), "\n")
   219  	if len(sl) > 0 && sl[len(sl)-1] == "" {
   220  		sl = sl[:len(sl)-1]
   221  	}
   222  	return sl
   223  }
   224  
   225  // TestFrontendFullRoundtrip tests the full roundtrip of a request from the frontend to a fake querier and back, with using an actual scheduler.
   226  func TestFrontendFullRoundtrip(t *testing.T) {
   227  	var (
   228  		logger = log.NewNopLogger()
   229  		reg    = prometheus.NewRegistry()
   230  		tenant = "tenant-a"
   231  	)
   232  	if testing.Verbose() {
   233  		logger = log.NewLogfmtLogger(os.Stderr)
   234  	}
   235  
   236  	// create server for frontend and scheduler
   237  	mux := mux.NewRouter()
   238  	// inject a span/tenant into the context
   239  	mux.Use(func(next http.Handler) http.Handler {
   240  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   241  			ctx := user.InjectOrgID(r.Context(), tenant)
   242  			_, ctx = opentracing.StartSpanFromContext(ctx, "test")
   243  			next.ServeHTTP(w, r.WithContext(ctx))
   244  		})
   245  	})
   246  	s := httptest.NewServer(h2c.NewHandler(mux, &http2.Server{}))
   247  	defer s.Close()
   248  
   249  	// initialize the scheduler
   250  	schedCfg := scheduler.Config{}
   251  	flagext.DefaultValues(&schedCfg)
   252  	sched, err := scheduler.NewScheduler(schedCfg, validation.MockLimits{}, logger, reg)
   253  	require.NoError(t, err)
   254  	schedulerpbconnect.RegisterSchedulerForFrontendHandler(mux, sched)
   255  	schedulerpbconnect.RegisterSchedulerForQuerierHandler(mux, sched)
   256  
   257  	// initialize the frontend
   258  	fCfg := cfgFromURL(t, s.URL)
   259  	f, err := NewFrontend(fCfg, validation.MockLimits{MaxQueryParallelismValue: 1}, logger, reg)
   260  	require.NoError(t, err)
   261  	frontendpbconnect.RegisterFrontendForQuerierHandler(mux, f) // probably not needed
   262  	querierv1connect.RegisterQuerierServiceHandler(mux, f)
   263  
   264  	// create a querier worker
   265  	qWorkerCfg := worker.Config{}
   266  	flagext.DefaultValues(&qWorkerCfg)
   267  	qWorkerCfg.SchedulerAddress = fCfg.SchedulerAddress
   268  	qWorker, err := worker.NewQuerierWorker(qWorkerCfg, newFakeQuerierGRPCHandler(), log.NewLogfmtLogger(os.Stderr), prometheus.NewRegistry())
   269  	require.NoError(t, err)
   270  
   271  	// start services
   272  	svc, err := services.NewManager(sched, f, qWorker)
   273  	require.NoError(t, err)
   274  	require.NoError(t, svc.StartAsync(context.Background()))
   275  	require.NoError(t, svc.AwaitHealthy(context.Background()))
   276  	defer func() {
   277  		svc.StopAsync()
   278  		require.NoError(t, svc.AwaitStopped(context.Background()))
   279  	}()
   280  
   281  	t.Run("using protocol grpc", func(t *testing.T) {
   282  		client := querierv1connect.NewQuerierServiceClient(http.DefaultClient, s.URL, connect.WithGRPC())
   283  
   284  		resp, err := client.LabelNames(context.Background(), connect.NewRequest(&typesv1.LabelNamesRequest{}))
   285  		require.NoError(t, err)
   286  
   287  		require.Equal(t, []string{"i", "have", "labels"}, resp.Msg.Names)
   288  
   289  		assert.Equal(t, []string{
   290  			"Content-Type: application/grpc",
   291  			"Grpc-Accept-Encoding: gzip",
   292  			"Grpc-Encoding: gzip",
   293  		}, headerToSlice(t, resp.Header()))
   294  	})
   295  
   296  	t.Run("using protocol grpc-web", func(t *testing.T) {
   297  		client := querierv1connect.NewQuerierServiceClient(http.DefaultClient, s.URL, connect.WithGRPCWeb())
   298  
   299  		resp, err := client.LabelNames(context.Background(), connect.NewRequest(&typesv1.LabelNamesRequest{}))
   300  		require.NoError(t, err)
   301  
   302  		require.Equal(t, []string{"i", "have", "labels"}, resp.Msg.Names)
   303  
   304  		assert.Equal(t, []string{
   305  			"Content-Type: application/grpc-web+proto",
   306  			"Grpc-Accept-Encoding: gzip",
   307  			"Grpc-Encoding: gzip",
   308  		}, headerToSlice(t, resp.Header()))
   309  	})
   310  
   311  	t.Run("using protocol json", func(t *testing.T) {
   312  		client := querierv1connect.NewQuerierServiceClient(http.DefaultClient, s.URL, connect.WithProtoJSON())
   313  
   314  		resp, err := client.LabelNames(context.Background(), connect.NewRequest(&typesv1.LabelNamesRequest{}))
   315  		require.NoError(t, err)
   316  
   317  		require.Equal(t, []string{"i", "have", "labels"}, resp.Msg.Names)
   318  
   319  		assert.Equal(t, []string{
   320  			"Accept-Encoding: gzip",
   321  			"Content-Encoding: gzip",
   322  			"Content-Type: application/json",
   323  		}, headerToSlice(t, resp.Header()))
   324  	})
   325  
   326  }
   327  
   328  func TestFrontendRetryEnqueue(t *testing.T) {
   329  	// Frontend uses worker concurrency to compute number of retries. We use one less failure.
   330  	failures := atomic.NewInt64(testFrontendWorkerConcurrency - 1)
   331  	const (
   332  		body   = "hello world"
   333  		userID = "test"
   334  	)
   335  
   336  	f, _ := setupFrontend(t, nil, func(f *Frontend, msg *schedulerpb.FrontendToScheduler) *schedulerpb.SchedulerToFrontend {
   337  		fail := failures.Dec()
   338  		if fail >= 0 {
   339  			return &schedulerpb.SchedulerToFrontend{Status: schedulerpb.SchedulerToFrontendStatus_SHUTTING_DOWN}
   340  		}
   341  
   342  		go sendResponseWithDelay(f, 100*time.Millisecond, userID, msg.QueryID, &httpgrpc.HTTPResponse{
   343  			Code: 200,
   344  			Body: []byte(body),
   345  		})
   346  
   347  		return &schedulerpb.SchedulerToFrontend{Status: schedulerpb.SchedulerToFrontendStatus_OK}
   348  	})
   349  	_, err := f.RoundTripGRPC(user.InjectOrgID(context.Background(), userID), &httpgrpc.HTTPRequest{})
   350  	require.NoError(t, err)
   351  }
   352  
   353  func TestFrontendTooManyRequests(t *testing.T) {
   354  	f, _ := setupFrontend(t, nil, func(f *Frontend, msg *schedulerpb.FrontendToScheduler) *schedulerpb.SchedulerToFrontend {
   355  		return &schedulerpb.SchedulerToFrontend{Status: schedulerpb.SchedulerToFrontendStatus_TOO_MANY_REQUESTS_PER_TENANT}
   356  	})
   357  
   358  	resp, err := f.RoundTripGRPC(user.InjectOrgID(context.Background(), "test"), &httpgrpc.HTTPRequest{})
   359  	require.NoError(t, err)
   360  	require.Equal(t, int32(http.StatusTooManyRequests), resp.Code)
   361  }
   362  
   363  func TestFrontendEnqueueFailure(t *testing.T) {
   364  	f, _ := setupFrontend(t, nil, func(f *Frontend, msg *schedulerpb.FrontendToScheduler) *schedulerpb.SchedulerToFrontend {
   365  		return &schedulerpb.SchedulerToFrontend{Status: schedulerpb.SchedulerToFrontendStatus_SHUTTING_DOWN}
   366  	})
   367  
   368  	_, err := f.RoundTripGRPC(user.InjectOrgID(context.Background(), "test"), &httpgrpc.HTTPRequest{})
   369  	require.Error(t, err)
   370  	require.True(t, strings.Contains(err.Error(), "failed to enqueue request"))
   371  }
   372  
   373  func TestFrontendCancellation(t *testing.T) {
   374  	f, ms := setupFrontend(t, nil, nil)
   375  
   376  	ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
   377  	defer cancel()
   378  
   379  	resp, err := f.RoundTripGRPC(user.InjectOrgID(ctx, "test"), &httpgrpc.HTTPRequest{})
   380  	require.EqualError(t, err, context.DeadlineExceeded.Error())
   381  	require.Nil(t, resp)
   382  
   383  	// We wait a bit to make sure scheduler receives the cancellation request.
   384  	test.Poll(t, time.Second, 2, func() interface{} {
   385  		ms.mu.Lock()
   386  		defer ms.mu.Unlock()
   387  
   388  		return len(ms.msgs)
   389  	})
   390  
   391  	ms.checkWithLock(func() {
   392  		require.Equal(t, 2, len(ms.msgs))
   393  		require.True(t, ms.msgs[0].Type == schedulerpb.FrontendToSchedulerType_ENQUEUE)
   394  		require.True(t, ms.msgs[1].Type == schedulerpb.FrontendToSchedulerType_CANCEL)
   395  		require.True(t, ms.msgs[0].QueryID == ms.msgs[1].QueryID)
   396  	})
   397  }
   398  
   399  // When frontendWorker that processed the request is busy (processing a new request or cancelling a previous one)
   400  // we still need to make sure that the cancellation reach the scheduler at some point.
   401  // Issue: https://github.com/grafana/mimir/issues/740
   402  func TestFrontendWorkerCancellation(t *testing.T) {
   403  	f, ms := setupFrontend(t, nil, nil)
   404  
   405  	ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
   406  	defer cancel()
   407  
   408  	// send multiple requests > maxconcurrency of scheduler. So that it keeps all the frontend worker busy in serving requests.
   409  	reqCount := testFrontendWorkerConcurrency + 5
   410  	var wg sync.WaitGroup
   411  	for i := 0; i < reqCount; i++ {
   412  		wg.Add(1)
   413  		go func() {
   414  			defer wg.Done()
   415  			resp, err := f.RoundTripGRPC(user.InjectOrgID(ctx, "test"), &httpgrpc.HTTPRequest{})
   416  			require.EqualError(t, err, context.DeadlineExceeded.Error())
   417  			require.Nil(t, resp)
   418  		}()
   419  	}
   420  
   421  	wg.Wait()
   422  
   423  	// We wait a bit to make sure scheduler receives the cancellation request.
   424  	// 2 * reqCount because for every request, should also be corresponding cancel request
   425  	test.Poll(t, 5*time.Second, 2*reqCount, func() interface{} {
   426  		ms.mu.Lock()
   427  		defer ms.mu.Unlock()
   428  
   429  		return len(ms.msgs)
   430  	})
   431  
   432  	ms.checkWithLock(func() {
   433  		require.Equal(t, 2*reqCount, len(ms.msgs))
   434  		msgTypeCounts := map[schedulerpb.FrontendToSchedulerType]int{}
   435  		for _, msg := range ms.msgs {
   436  			msgTypeCounts[msg.Type]++
   437  		}
   438  		expectedMsgTypeCounts := map[schedulerpb.FrontendToSchedulerType]int{
   439  			schedulerpb.FrontendToSchedulerType_ENQUEUE: reqCount,
   440  			schedulerpb.FrontendToSchedulerType_CANCEL:  reqCount,
   441  		}
   442  		require.Equalf(t, expectedMsgTypeCounts, msgTypeCounts,
   443  			"Should receive %d enqueue (%d) requests, and %d cancel (%d) requests.", reqCount, schedulerpb.FrontendToSchedulerType_ENQUEUE, reqCount, schedulerpb.FrontendToSchedulerType_CANCEL,
   444  		)
   445  	})
   446  }
   447  
   448  func TestFrontendFailedCancellation(t *testing.T) {
   449  	f, ms := setupFrontend(t, nil, nil)
   450  
   451  	ctx, cancel := context.WithCancel(context.Background())
   452  	defer cancel()
   453  
   454  	go func() {
   455  		time.Sleep(100 * time.Millisecond)
   456  
   457  		// stop scheduler workers
   458  		addr := ""
   459  		f.schedulerWorkers.mu.Lock()
   460  		for k := range f.schedulerWorkers.workers {
   461  			addr = k
   462  			break
   463  		}
   464  		f.schedulerWorkers.mu.Unlock()
   465  
   466  		f.schedulerWorkers.InstanceRemoved(servicediscovery.Instance{Address: addr, InUse: true})
   467  
   468  		// Wait for worker goroutines to stop.
   469  		time.Sleep(100 * time.Millisecond)
   470  
   471  		// Cancel request. Frontend will try to send cancellation to scheduler, but that will fail (not visible to user).
   472  		// Everything else should still work fine.
   473  		cancel()
   474  	}()
   475  
   476  	// send request
   477  	resp, err := f.RoundTripGRPC(user.InjectOrgID(ctx, "test"), &httpgrpc.HTTPRequest{})
   478  	require.EqualError(t, err, context.Canceled.Error())
   479  	require.Nil(t, resp)
   480  
   481  	ms.checkWithLock(func() {
   482  		require.Equal(t, 1, len(ms.msgs))
   483  	})
   484  }
   485  
   486  type mockScheduler struct {
   487  	t *testing.T
   488  	f *Frontend
   489  
   490  	replyFunc func(f *Frontend, msg *schedulerpb.FrontendToScheduler) *schedulerpb.SchedulerToFrontend
   491  
   492  	mu           sync.Mutex
   493  	frontendAddr map[string]int
   494  	msgs         []*schedulerpb.FrontendToScheduler
   495  
   496  	schedulerpb.UnimplementedSchedulerForFrontendServer
   497  }
   498  
   499  func newMockScheduler(t *testing.T, f *Frontend, replyFunc func(f *Frontend, msg *schedulerpb.FrontendToScheduler) *schedulerpb.SchedulerToFrontend) *mockScheduler {
   500  	return &mockScheduler{t: t, f: f, frontendAddr: map[string]int{}, replyFunc: replyFunc}
   501  }
   502  
   503  func (m *mockScheduler) checkWithLock(fn func()) {
   504  	m.mu.Lock()
   505  	defer m.mu.Unlock()
   506  
   507  	fn()
   508  }
   509  
   510  func (m *mockScheduler) FrontendLoop(ctx context.Context, frontend *connect.BidiStream[schedulerpb.FrontendToScheduler, schedulerpb.SchedulerToFrontend]) error {
   511  	init, err := frontend.Receive()
   512  	if err != nil {
   513  		return err
   514  	}
   515  
   516  	m.mu.Lock()
   517  	m.frontendAddr[init.FrontendAddress]++
   518  	m.mu.Unlock()
   519  
   520  	// Ack INIT from frontend.
   521  	if err := frontend.Send(&schedulerpb.SchedulerToFrontend{Status: schedulerpb.SchedulerToFrontendStatus_OK}); err != nil {
   522  		return err
   523  	}
   524  
   525  	for {
   526  		msg, err := frontend.Receive()
   527  		if err != nil {
   528  			return err
   529  		}
   530  
   531  		m.mu.Lock()
   532  		m.msgs = append(m.msgs, msg)
   533  		m.mu.Unlock()
   534  
   535  		reply := &schedulerpb.SchedulerToFrontend{Status: schedulerpb.SchedulerToFrontendStatus_OK}
   536  		if m.replyFunc != nil {
   537  			reply = m.replyFunc(m.f, msg)
   538  		}
   539  
   540  		if err := frontend.Send(reply); err != nil {
   541  			return err
   542  		}
   543  	}
   544  }
   545  
   546  func TestConfig_Validate(t *testing.T) {
   547  	tests := map[string]struct {
   548  		setup       func(cfg *Config)
   549  		expectedErr string
   550  	}{
   551  		"should pass with default config": {
   552  			setup: func(cfg *Config) {},
   553  		},
   554  		"should pass if scheduler address is configured, and query-scheduler discovery mode is the default one": {
   555  			setup: func(cfg *Config) {
   556  				cfg.SchedulerAddress = "localhost:9095"
   557  			},
   558  		},
   559  		"should fail if query-scheduler service discovery is set to ring, and scheduler address is configured": {
   560  			setup: func(cfg *Config) {
   561  				cfg.QuerySchedulerDiscovery.Mode = schedulerdiscovery.ModeRing
   562  				cfg.SchedulerAddress = "localhost:9095"
   563  			},
   564  			expectedErr: `scheduler address cannot be specified when query-scheduler service discovery mode is set to 'ring'`,
   565  		},
   566  	}
   567  
   568  	for testName, testData := range tests {
   569  		t.Run(testName, func(t *testing.T) {
   570  			cfg := Config{}
   571  			flagext.DefaultValues(&cfg)
   572  			testData.setup(&cfg)
   573  
   574  			actualErr := cfg.Validate()
   575  			if testData.expectedErr == "" {
   576  				require.NoError(t, actualErr)
   577  			} else {
   578  				require.Error(t, actualErr)
   579  				assert.ErrorContains(t, actualErr, testData.expectedErr)
   580  			}
   581  		})
   582  	}
   583  }
   584  
   585  func TestWithClosingGrpcServer(t *testing.T) {
   586  	// This test is easier with single frontend worker.
   587  	const frontendConcurrency = 1
   588  	const userID = "test"
   589  
   590  	f, _ := setupFrontendWithConcurrencyAndServerOptions(t, nil, func(f *Frontend, msg *schedulerpb.FrontendToScheduler) *schedulerpb.SchedulerToFrontend {
   591  		return &schedulerpb.SchedulerToFrontend{Status: schedulerpb.SchedulerToFrontendStatus_TOO_MANY_REQUESTS_PER_TENANT}
   592  	}, frontendConcurrency)
   593  
   594  	// Connection will be established on the first roundtrip.
   595  	resp, err := f.RoundTripGRPC(user.InjectOrgID(context.Background(), userID), &httpgrpc.HTTPRequest{})
   596  	require.NoError(t, err)
   597  	require.Equal(t, int(resp.Code), http.StatusTooManyRequests)
   598  
   599  	// Verify that there is one stream open.
   600  	require.Equal(t, 1, checkStreamGoroutines())
   601  
   602  	// Wait a bit, to make sure that server closes connection.
   603  	time.Sleep(1 * time.Second)
   604  
   605  	// Despite server closing connections, stream-related goroutines still exist.
   606  	require.Equal(t, 1, checkStreamGoroutines())
   607  
   608  	// Another request will work as before, because worker will recreate connection.
   609  	resp, err = f.RoundTripGRPC(user.InjectOrgID(context.Background(), userID), &httpgrpc.HTTPRequest{})
   610  	require.NoError(t, err)
   611  	require.Equal(t, int(resp.Code), http.StatusTooManyRequests)
   612  
   613  	// There should still be only one stream open, and one goroutine created for it.
   614  	// Previously frontend leaked goroutine because stream that received "EOF" due to server closing the connection, never stopped its goroutine.
   615  	require.Equal(t, 1, checkStreamGoroutines())
   616  }
   617  
   618  func checkStreamGoroutines() int {
   619  	const streamGoroutineStackFrameTrailer = "created by google.golang.org/grpc.newClientStreamWithParams"
   620  
   621  	buf := make([]byte, 1000000)
   622  	stacklen := runtime.Stack(buf, true)
   623  
   624  	goroutineStacks := string(buf[:stacklen])
   625  	return strings.Count(goroutineStacks, streamGoroutineStackFrameTrailer)
   626  }