github.com/grafana/pyroscope@v1.18.0/pkg/querybackend/report_aggregator_test.go (about)

     1  package querybackend
     2  
     3  import (
     4  	"sync"
     5  	"testing"
     6  
     7  	"github.com/stretchr/testify/assert"
     8  	"github.com/stretchr/testify/require"
     9  
    10  	queryv1 "github.com/grafana/pyroscope/api/gen/proto/go/query/v1"
    11  )
    12  
    13  type mockAggregator struct {
    14  	reports []*queryv1.Report
    15  	mu      sync.Mutex
    16  }
    17  
    18  func (m *mockAggregator) aggregate(r *queryv1.Report) error {
    19  	m.mu.Lock()
    20  	defer m.mu.Unlock()
    21  	m.reports = append(m.reports, r)
    22  	return nil
    23  }
    24  
    25  func (m *mockAggregator) build() *queryv1.Report {
    26  	m.mu.Lock()
    27  	defer m.mu.Unlock()
    28  
    29  	if len(m.reports) == 0 {
    30  		return &queryv1.Report{}
    31  	}
    32  
    33  	result := &queryv1.Report{
    34  		ReportType: m.reports[0].ReportType,
    35  	}
    36  	return result
    37  }
    38  
    39  func (m *mockAggregator) getReportCount() int {
    40  	m.mu.Lock()
    41  	defer m.mu.Unlock()
    42  	return len(m.reports)
    43  }
    44  
    45  func mockAggregatorProvider(req *queryv1.InvokeRequest) aggregator {
    46  	return &mockAggregator{
    47  		reports: make([]*queryv1.Report, 0),
    48  	}
    49  }
    50  
    51  func TestReportAggregator_SingleReport(t *testing.T) {
    52  	reportType := queryv1.ReportType(999) // use a high number that won't conflict with other registrations
    53  	registerAggregator(reportType, mockAggregatorProvider, false)
    54  	defer func() {
    55  		aggregatorMutex.Lock()
    56  		delete(aggregators, reportType)
    57  		aggregatorMutex.Unlock()
    58  	}()
    59  
    60  	request := &queryv1.InvokeRequest{}
    61  	ra := newAggregator(request)
    62  
    63  	report := &queryv1.Report{ReportType: reportType}
    64  	err := ra.aggregateReport(report)
    65  	require.NoError(t, err)
    66  
    67  	// a single report should be staged and no aggregators should be created
    68  	assert.Len(t, ra.staged, 1)
    69  	assert.Len(t, ra.aggregators, 0)
    70  	assert.Equal(t, report, ra.staged[reportType])
    71  
    72  	// the response should contain the single report
    73  	resp, err := ra.response()
    74  	require.NoError(t, err)
    75  	require.Len(t, resp.Reports, 1)
    76  	assert.Equal(t, report, resp.Reports[0])
    77  }
    78  
    79  func TestReportAggregator_TwoReports(t *testing.T) {
    80  	reportType := queryv1.ReportType(999)
    81  	registerAggregator(reportType, mockAggregatorProvider, false)
    82  	defer func() {
    83  		aggregatorMutex.Lock()
    84  		delete(aggregators, reportType)
    85  		aggregatorMutex.Unlock()
    86  	}()
    87  
    88  	request := &queryv1.InvokeRequest{}
    89  	ra := newAggregator(request)
    90  
    91  	// the first report should be staged
    92  	report1 := &queryv1.Report{ReportType: reportType}
    93  	err := ra.aggregateReport(report1)
    94  	require.NoError(t, err)
    95  	assert.Len(t, ra.staged, 1)
    96  	assert.Len(t, ra.aggregators, 0)
    97  
    98  	// the second report should trigger aggregation
    99  	report2 := &queryv1.Report{ReportType: reportType}
   100  	err = ra.aggregateReport(report2)
   101  	require.NoError(t, err)
   102  	assert.Len(t, ra.aggregators, 1)
   103  	assert.Nil(t, ra.staged[reportType]) // staged entry should be nil after aggregation
   104  	agg := ra.aggregators[reportType].(*mockAggregator)
   105  	assert.Equal(t, 2, agg.getReportCount())
   106  
   107  	// the response should contain the aggregated result
   108  	resp, err := ra.response()
   109  	require.NoError(t, err)
   110  	require.Len(t, resp.Reports, 1)
   111  	assert.Equal(t, reportType, resp.Reports[0].ReportType)
   112  }
   113  
   114  func TestReportAggregator_MultipleTypes(t *testing.T) {
   115  	type1 := queryv1.ReportType(999)
   116  	type2 := queryv1.ReportType(998)
   117  
   118  	registerAggregator(type1, mockAggregatorProvider, false)
   119  	registerAggregator(type2, mockAggregatorProvider, false)
   120  	defer func() {
   121  		aggregatorMutex.Lock()
   122  		delete(aggregators, type1)
   123  		delete(aggregators, type2)
   124  		aggregatorMutex.Unlock()
   125  	}()
   126  
   127  	request := &queryv1.InvokeRequest{}
   128  	ra := newAggregator(request)
   129  
   130  	report1Type1 := &queryv1.Report{ReportType: type1}
   131  	report2Type2 := &queryv1.Report{ReportType: type2}
   132  	report3Type1 := &queryv1.Report{ReportType: type1}
   133  
   134  	err := ra.aggregateReport(report1Type1)
   135  	require.NoError(t, err)
   136  	err = ra.aggregateReport(report2Type2)
   137  	require.NoError(t, err)
   138  	err = ra.aggregateReport(report3Type1)
   139  	require.NoError(t, err)
   140  
   141  	// should have one staged report and one aggregator
   142  	assert.Equal(t, report2Type2, ra.staged[type2])
   143  	assert.Nil(t, ra.staged[type1])
   144  	assert.Len(t, ra.aggregators, 1)
   145  
   146  	resp, err := ra.response()
   147  	require.NoError(t, err)
   148  	require.Len(t, resp.Reports, 2)
   149  
   150  	reportTypes := make(map[queryv1.ReportType]bool)
   151  	for _, r := range resp.Reports {
   152  		reportTypes[r.ReportType] = true
   153  	}
   154  	assert.True(t, reportTypes[type1])
   155  	assert.True(t, reportTypes[type2])
   156  }
   157  
   158  func TestReportAggregator_NilReport(t *testing.T) {
   159  	request := &queryv1.InvokeRequest{}
   160  	ra := newAggregator(request)
   161  
   162  	err := ra.aggregateReport(nil)
   163  	require.NoError(t, err)
   164  	assert.Len(t, ra.staged, 0)
   165  	assert.Len(t, ra.aggregators, 0)
   166  }
   167  
   168  func TestReportAggregator_AggregateResponse(t *testing.T) {
   169  	reportType := queryv1.ReportType(999)
   170  	registerAggregator(reportType, mockAggregatorProvider, false)
   171  	defer func() {
   172  		aggregatorMutex.Lock()
   173  		delete(aggregators, reportType)
   174  		aggregatorMutex.Unlock()
   175  	}()
   176  
   177  	request := &queryv1.InvokeRequest{}
   178  	ra := newAggregator(request)
   179  
   180  	resp := &queryv1.InvokeResponse{
   181  		Reports: []*queryv1.Report{
   182  			{ReportType: reportType},
   183  			{ReportType: reportType},
   184  		},
   185  	}
   186  
   187  	err := ra.aggregateResponse(resp, nil)
   188  	require.NoError(t, err)
   189  
   190  	assert.Len(t, ra.aggregators, 1)
   191  	agg := ra.aggregators[reportType].(*mockAggregator)
   192  	assert.Equal(t, 2, agg.getReportCount())
   193  }
   194  
   195  func TestReportAggregator_ConcurrentAccess(t *testing.T) {
   196  	reportType := queryv1.ReportType(999)
   197  	registerAggregator(reportType, mockAggregatorProvider, false)
   198  	defer func() {
   199  		aggregatorMutex.Lock()
   200  		delete(aggregators, reportType)
   201  		aggregatorMutex.Unlock()
   202  	}()
   203  
   204  	request := &queryv1.InvokeRequest{}
   205  	ra := newAggregator(request)
   206  
   207  	const numGoroutines = 10
   208  	const reportsPerGoroutine = 5
   209  
   210  	var wg sync.WaitGroup
   211  	wg.Add(numGoroutines)
   212  
   213  	for i := 0; i < numGoroutines; i++ {
   214  		go func() {
   215  			defer wg.Done()
   216  			for j := 0; j < reportsPerGoroutine; j++ {
   217  				report := &queryv1.Report{ReportType: reportType}
   218  				err := ra.aggregateReport(report)
   219  				assert.NoError(t, err)
   220  			}
   221  		}()
   222  	}
   223  
   224  	wg.Wait()
   225  
   226  	resp, err := ra.response()
   227  	require.NoError(t, err)
   228  	assert.Len(t, resp.Reports, 1)
   229  }
   230  
   231  func TestGetAggregator(t *testing.T) {
   232  	reportType := queryv1.ReportType(999)
   233  	registerAggregator(reportType, mockAggregatorProvider, false)
   234  	defer func() {
   235  		aggregatorMutex.Lock()
   236  		delete(aggregators, reportType)
   237  		aggregatorMutex.Unlock()
   238  	}()
   239  
   240  	request := &queryv1.InvokeRequest{}
   241  	report := &queryv1.Report{ReportType: reportType}
   242  
   243  	agg, err := getAggregator(request, report)
   244  	require.NoError(t, err)
   245  	assert.NotNil(t, agg)
   246  }
   247  
   248  func TestGetAggregator_UnknownReportType(t *testing.T) {
   249  	request := &queryv1.InvokeRequest{}
   250  	unknownReport := &queryv1.Report{ReportType: queryv1.ReportType(996)}
   251  	_, err := getAggregator(request, unknownReport)
   252  	assert.Error(t, err)
   253  	assert.Contains(t, err.Error(), "unknown build type")
   254  }
   255  
   256  func TestRegisterAggregator_Duplicate(t *testing.T) {
   257  	reportType := queryv1.ReportType(999)
   258  
   259  	registerAggregator(reportType, mockAggregatorProvider, false)
   260  	assert.Panics(t, func() {
   261  		registerAggregator(reportType, mockAggregatorProvider, false)
   262  	})
   263  
   264  	aggregatorMutex.Lock()
   265  	delete(aggregators, reportType)
   266  	aggregatorMutex.Unlock()
   267  }
   268  
   269  func TestQueryReportType(t *testing.T) {
   270  	queryType := queryv1.QueryType(999)
   271  	reportType := queryv1.ReportType(999)
   272  
   273  	registerQueryReportType(queryType, reportType)
   274  	defer func() {
   275  		aggregatorMutex.Lock()
   276  		delete(queryReportType, queryType)
   277  		aggregatorMutex.Unlock()
   278  	}()
   279  
   280  	result := QueryReportType(queryType)
   281  	assert.Equal(t, reportType, result)
   282  
   283  	assert.Panics(t, func() {
   284  		QueryReportType(queryv1.QueryType(889)) // Use an unregistered query type
   285  	})
   286  }
   287  
   288  func TestRegisterQueryReportType_Duplicate(t *testing.T) {
   289  	queryType := queryv1.QueryType(999)
   290  	reportType := queryv1.ReportType(999)
   291  
   292  	registerQueryReportType(queryType, reportType)
   293  	assert.Panics(t, func() {
   294  		registerQueryReportType(queryType, queryv1.ReportType_REPORT_PPROF)
   295  	})
   296  
   297  	aggregatorMutex.Lock()
   298  	delete(queryReportType, queryType)
   299  	aggregatorMutex.Unlock()
   300  }