github.com/grafana/pyroscope@v1.18.0/pkg/scheduler/queue/queue_test.go (about)

     1  // SPDX-License-Identifier: AGPL-3.0-only
     2  // Provenance-includes-location: https://github.com/cortexproject/cortex/blob/master/pkg/scheduler/queue/queue_test.go
     3  // Provenance-includes-license: Apache-2.0
     4  // Provenance-includes-copyright: The Cortex Authors.
     5  
     6  package queue
     7  
     8  import (
     9  	"context"
    10  	"fmt"
    11  	"strconv"
    12  	"sync"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/grafana/dskit/services"
    17  	"github.com/prometheus/client_golang/prometheus"
    18  	"github.com/prometheus/client_golang/prometheus/promauto"
    19  	"github.com/stretchr/testify/assert"
    20  	"github.com/stretchr/testify/require"
    21  )
    22  
    23  func BenchmarkGetNextRequest(b *testing.B) {
    24  	const maxOutstandingPerTenant = 2
    25  	const numTenants = 50
    26  	const queriers = 5
    27  
    28  	queues := make([]*RequestQueue, 0, b.N)
    29  
    30  	for n := 0; n < b.N; n++ {
    31  		queue := NewRequestQueue(maxOutstandingPerTenant, 0,
    32  			promauto.With(nil).NewGaugeVec(prometheus.GaugeOpts{}, []string{"tenant"}),
    33  			promauto.With(nil).NewCounterVec(prometheus.CounterOpts{}, []string{"tenant"}),
    34  		)
    35  		queues = append(queues, queue)
    36  
    37  		for ix := 0; ix < queriers; ix++ {
    38  			queue.RegisterQuerierConnection(fmt.Sprintf("querier-%d", ix))
    39  		}
    40  
    41  		for i := 0; i < maxOutstandingPerTenant; i++ {
    42  			for j := 0; j < numTenants; j++ {
    43  				tenantID := strconv.Itoa(j)
    44  
    45  				err := queue.EnqueueRequest(tenantID, "request", 0, nil)
    46  				if err != nil {
    47  					b.Fatal(err)
    48  				}
    49  			}
    50  		}
    51  	}
    52  
    53  	ctx := context.Background()
    54  	b.ResetTimer()
    55  
    56  	for i := 0; i < b.N; i++ {
    57  		idx := FirstUser()
    58  		for j := 0; j < maxOutstandingPerTenant*numTenants; j++ {
    59  			querier := ""
    60  		b:
    61  			// Find querier with at least one request to avoid blocking in getNextRequestForQuerier.
    62  			for _, q := range queues[i].queues.userQueues {
    63  				for qid := range q.queriers {
    64  					querier = qid
    65  					break b
    66  				}
    67  			}
    68  
    69  			_, nidx, err := queues[i].GetNextRequestForQuerier(ctx, idx, querier)
    70  			if err != nil {
    71  				b.Fatal(err)
    72  			}
    73  			idx = nidx
    74  		}
    75  	}
    76  }
    77  
    78  func BenchmarkQueueRequest(b *testing.B) {
    79  	const maxOutstandingPerTenant = 2
    80  	const numTenants = 50
    81  	const queriers = 5
    82  
    83  	queues := make([]*RequestQueue, 0, b.N)
    84  	users := make([]string, 0, numTenants)
    85  	requests := make([]string, 0, numTenants)
    86  
    87  	for n := 0; n < b.N; n++ {
    88  		q := NewRequestQueue(maxOutstandingPerTenant, 0,
    89  			promauto.With(nil).NewGaugeVec(prometheus.GaugeOpts{}, []string{"tenant"}),
    90  			promauto.With(nil).NewCounterVec(prometheus.CounterOpts{}, []string{"user"}),
    91  		)
    92  
    93  		for ix := 0; ix < queriers; ix++ {
    94  			q.RegisterQuerierConnection(fmt.Sprintf("querier-%d", ix))
    95  		}
    96  
    97  		queues = append(queues, q)
    98  
    99  		for j := 0; j < numTenants; j++ {
   100  			requests = append(requests, fmt.Sprintf("%d-%d", n, j))
   101  			users = append(users, strconv.Itoa(j))
   102  		}
   103  	}
   104  
   105  	b.ResetTimer()
   106  	for n := 0; n < b.N; n++ {
   107  		for i := 0; i < maxOutstandingPerTenant; i++ {
   108  			for j := 0; j < numTenants; j++ {
   109  				err := queues[n].EnqueueRequest(users[j], requests[j], 0, nil)
   110  				if err != nil {
   111  					b.Fatal(err)
   112  				}
   113  			}
   114  		}
   115  	}
   116  }
   117  
   118  func TestRequestQueue_GetNextRequestForQuerier_ShouldGetRequestAfterReshardingBecauseQuerierHasBeenForgotten(t *testing.T) {
   119  	const forgetDelay = 3 * time.Second
   120  
   121  	queue := NewRequestQueue(1, forgetDelay,
   122  		promauto.With(nil).NewGaugeVec(prometheus.GaugeOpts{}, []string{"user"}),
   123  		promauto.With(nil).NewCounterVec(prometheus.CounterOpts{}, []string{"user"}))
   124  
   125  	// Start the queue service.
   126  	ctx := context.Background()
   127  	require.NoError(t, services.StartAndAwaitRunning(ctx, queue))
   128  	t.Cleanup(func() {
   129  		require.NoError(t, services.StopAndAwaitTerminated(ctx, queue))
   130  	})
   131  
   132  	// Two queriers connect.
   133  	queue.RegisterQuerierConnection("querier-1")
   134  	queue.RegisterQuerierConnection("querier-2")
   135  
   136  	// Querier-2 waits for a new request.
   137  	querier2wg := sync.WaitGroup{}
   138  	querier2wg.Add(1)
   139  	go func() {
   140  		defer querier2wg.Done()
   141  		_, _, err := queue.GetNextRequestForQuerier(ctx, FirstUser(), "querier-2")
   142  		require.NoError(t, err)
   143  	}()
   144  
   145  	// Querier-1 crashes (no graceful shutdown notification).
   146  	queue.UnregisterQuerierConnection("querier-1")
   147  
   148  	// Enqueue a request from an user which would be assigned to querier-1.
   149  	// NOTE: "user-1" hash falls in the querier-1 shard.
   150  	require.NoError(t, queue.EnqueueRequest("user-1", "request", 1, nil))
   151  
   152  	startTime := time.Now()
   153  	querier2wg.Wait()
   154  	waitTime := time.Since(startTime)
   155  
   156  	// We expect that querier-2 got the request only after querier-1 forget delay is passed.
   157  	assert.GreaterOrEqual(t, waitTime.Milliseconds(), forgetDelay.Milliseconds())
   158  }
   159  
   160  func TestContextCond(t *testing.T) {
   161  	t.Run("wait until broadcast", func(t *testing.T) {
   162  		t.Parallel()
   163  		mtx := &sync.Mutex{}
   164  		cond := contextCond{Cond: sync.NewCond(mtx)}
   165  
   166  		doneWaiting := make(chan struct{})
   167  
   168  		mtx.Lock()
   169  		go func() {
   170  			cond.Wait(context.Background())
   171  			mtx.Unlock()
   172  			close(doneWaiting)
   173  		}()
   174  
   175  		assertChanNotReceived(t, doneWaiting, 100*time.Millisecond, "cond.Wait returned, but it should not because we did not broadcast yet")
   176  
   177  		cond.Broadcast()
   178  		assertChanReceived(t, doneWaiting, 250*time.Millisecond, "cond.Wait did not return after broadcast")
   179  	})
   180  
   181  	t.Run("wait until context deadline", func(t *testing.T) {
   182  		t.Parallel()
   183  		mtx := &sync.Mutex{}
   184  		cond := contextCond{Cond: sync.NewCond(mtx)}
   185  		doneWaiting := make(chan struct{})
   186  
   187  		ctx, cancel := context.WithCancel(context.Background())
   188  		defer cancel()
   189  
   190  		mtx.Lock()
   191  		go func() {
   192  			cond.Wait(ctx)
   193  			mtx.Unlock()
   194  			close(doneWaiting)
   195  		}()
   196  
   197  		assertChanNotReceived(t, doneWaiting, 100*time.Millisecond, "cond.Wait returned, but it should not because we did not broadcast yet and didn't cancel the context")
   198  
   199  		cancel()
   200  		assertChanReceived(t, doneWaiting, 250*time.Millisecond, "cond.Wait did not return after cancelling the context")
   201  	})
   202  
   203  	t.Run("wait on already canceled context", func(t *testing.T) {
   204  		// This test represents the racy real world scenario,
   205  		// we don't know whether it's going to wait before the broadcast triggered by the context cancellation.
   206  		t.Parallel()
   207  		mtx := &sync.Mutex{}
   208  		cond := contextCond{Cond: sync.NewCond(mtx)}
   209  		doneWaiting := make(chan struct{})
   210  
   211  		alreadyCanceledContext, cancel := context.WithCancel(context.Background())
   212  		cancel()
   213  
   214  		mtx.Lock()
   215  		go func() {
   216  			cond.Wait(alreadyCanceledContext)
   217  			mtx.Unlock()
   218  			close(doneWaiting)
   219  		}()
   220  
   221  		assertChanReceived(t, doneWaiting, 250*time.Millisecond, "cond.Wait did not return after cancelling the context")
   222  	})
   223  
   224  	t.Run("wait on already canceled context, but it takes a while to wait", func(t *testing.T) {
   225  		t.Parallel()
   226  		mtx := &sync.Mutex{}
   227  		cond := contextCond{
   228  			Cond: sync.NewCond(mtx),
   229  			testHookBeforeWaiting: func() {
   230  				// This makes the waiting goroutine so slow that out Wait(ctx) will need to broadcast once it sees it waiting.
   231  				time.Sleep(250 * time.Millisecond)
   232  			},
   233  		}
   234  		doneWaiting := make(chan struct{})
   235  
   236  		alreadyCanceledContext, cancel := context.WithCancel(context.Background())
   237  		cancel()
   238  
   239  		mtx.Lock()
   240  		go func() {
   241  			cond.Wait(alreadyCanceledContext)
   242  			mtx.Unlock()
   243  			close(doneWaiting)
   244  		}()
   245  
   246  		assertChanReceived(t, doneWaiting, 500*time.Millisecond, "cond.Wait did not return after 500ms")
   247  	})
   248  
   249  	t.Run("lots of goroutines waiting at the same time, none of them misses it's broadcast from cancel", func(t *testing.T) {
   250  		t.Parallel()
   251  		mtx := &sync.Mutex{}
   252  		cond := contextCond{
   253  			Cond: sync.NewCond(mtx),
   254  			testHookBeforeWaiting: func() {
   255  				// Make every goroutine a little bit more racy by introducing a delay before its inner Wait call.
   256  				time.Sleep(time.Millisecond)
   257  			},
   258  		}
   259  		const goroutines = 100
   260  
   261  		doneWaiting := make(chan struct{}, goroutines)
   262  		release := make(chan struct{})
   263  
   264  		ctx, cancel := context.WithCancel(context.Background())
   265  		defer cancel()
   266  
   267  		for i := 0; i < goroutines; i++ {
   268  			go func() {
   269  				<-release
   270  
   271  				mtx.Lock()
   272  				cond.Wait(ctx)
   273  				mtx.Unlock()
   274  
   275  				doneWaiting <- struct{}{}
   276  			}()
   277  		}
   278  		go func() {
   279  			<-release
   280  			cancel()
   281  		}()
   282  
   283  		close(release)
   284  
   285  		assert.Eventually(t, func() bool {
   286  			return len(doneWaiting) == goroutines
   287  		}, time.Second, 10*time.Millisecond)
   288  	})
   289  }
   290  
   291  func assertChanReceived(t *testing.T, c chan struct{}, timeout time.Duration, msg string) {
   292  	t.Helper()
   293  
   294  	select {
   295  	case <-c:
   296  	case <-time.After(timeout):
   297  		t.Fatalf("%s", msg)
   298  	}
   299  }
   300  
   301  func assertChanNotReceived(t *testing.T, c chan struct{}, wait time.Duration, msg string) {
   302  	t.Helper()
   303  
   304  	select {
   305  	case <-c:
   306  		t.Fatal(msg)
   307  	case <-time.After(wait):
   308  		// OK!
   309  	}
   310  }