github.com/grafana/pyroscope@v1.18.0/pkg/querybackend/report_aggregator_test.go (about) 1 package querybackend 2 3 import ( 4 "sync" 5 "testing" 6 7 "github.com/stretchr/testify/assert" 8 "github.com/stretchr/testify/require" 9 10 queryv1 "github.com/grafana/pyroscope/api/gen/proto/go/query/v1" 11 ) 12 13 type mockAggregator struct { 14 reports []*queryv1.Report 15 mu sync.Mutex 16 } 17 18 func (m *mockAggregator) aggregate(r *queryv1.Report) error { 19 m.mu.Lock() 20 defer m.mu.Unlock() 21 m.reports = append(m.reports, r) 22 return nil 23 } 24 25 func (m *mockAggregator) build() *queryv1.Report { 26 m.mu.Lock() 27 defer m.mu.Unlock() 28 29 if len(m.reports) == 0 { 30 return &queryv1.Report{} 31 } 32 33 result := &queryv1.Report{ 34 ReportType: m.reports[0].ReportType, 35 } 36 return result 37 } 38 39 func (m *mockAggregator) getReportCount() int { 40 m.mu.Lock() 41 defer m.mu.Unlock() 42 return len(m.reports) 43 } 44 45 func mockAggregatorProvider(req *queryv1.InvokeRequest) aggregator { 46 return &mockAggregator{ 47 reports: make([]*queryv1.Report, 0), 48 } 49 } 50 51 func TestReportAggregator_SingleReport(t *testing.T) { 52 reportType := queryv1.ReportType(999) // use a high number that won't conflict with other registrations 53 registerAggregator(reportType, mockAggregatorProvider, false) 54 defer func() { 55 aggregatorMutex.Lock() 56 delete(aggregators, reportType) 57 aggregatorMutex.Unlock() 58 }() 59 60 request := &queryv1.InvokeRequest{} 61 ra := newAggregator(request) 62 63 report := &queryv1.Report{ReportType: reportType} 64 err := ra.aggregateReport(report) 65 require.NoError(t, err) 66 67 // a single report should be staged and no aggregators should be created 68 assert.Len(t, ra.staged, 1) 69 assert.Len(t, ra.aggregators, 0) 70 assert.Equal(t, report, ra.staged[reportType]) 71 72 // the response should contain the single report 73 resp, err := ra.response() 74 require.NoError(t, err) 75 require.Len(t, resp.Reports, 1) 76 assert.Equal(t, report, resp.Reports[0]) 77 } 78 79 func TestReportAggregator_TwoReports(t *testing.T) { 80 reportType := queryv1.ReportType(999) 81 registerAggregator(reportType, mockAggregatorProvider, false) 82 defer func() { 83 aggregatorMutex.Lock() 84 delete(aggregators, reportType) 85 aggregatorMutex.Unlock() 86 }() 87 88 request := &queryv1.InvokeRequest{} 89 ra := newAggregator(request) 90 91 // the first report should be staged 92 report1 := &queryv1.Report{ReportType: reportType} 93 err := ra.aggregateReport(report1) 94 require.NoError(t, err) 95 assert.Len(t, ra.staged, 1) 96 assert.Len(t, ra.aggregators, 0) 97 98 // the second report should trigger aggregation 99 report2 := &queryv1.Report{ReportType: reportType} 100 err = ra.aggregateReport(report2) 101 require.NoError(t, err) 102 assert.Len(t, ra.aggregators, 1) 103 assert.Nil(t, ra.staged[reportType]) // staged entry should be nil after aggregation 104 agg := ra.aggregators[reportType].(*mockAggregator) 105 assert.Equal(t, 2, agg.getReportCount()) 106 107 // the response should contain the aggregated result 108 resp, err := ra.response() 109 require.NoError(t, err) 110 require.Len(t, resp.Reports, 1) 111 assert.Equal(t, reportType, resp.Reports[0].ReportType) 112 } 113 114 func TestReportAggregator_MultipleTypes(t *testing.T) { 115 type1 := queryv1.ReportType(999) 116 type2 := queryv1.ReportType(998) 117 118 registerAggregator(type1, mockAggregatorProvider, false) 119 registerAggregator(type2, mockAggregatorProvider, false) 120 defer func() { 121 aggregatorMutex.Lock() 122 delete(aggregators, type1) 123 delete(aggregators, type2) 124 aggregatorMutex.Unlock() 125 }() 126 127 request := &queryv1.InvokeRequest{} 128 ra := newAggregator(request) 129 130 report1Type1 := &queryv1.Report{ReportType: type1} 131 report2Type2 := &queryv1.Report{ReportType: type2} 132 report3Type1 := &queryv1.Report{ReportType: type1} 133 134 err := ra.aggregateReport(report1Type1) 135 require.NoError(t, err) 136 err = ra.aggregateReport(report2Type2) 137 require.NoError(t, err) 138 err = ra.aggregateReport(report3Type1) 139 require.NoError(t, err) 140 141 // should have one staged report and one aggregator 142 assert.Equal(t, report2Type2, ra.staged[type2]) 143 assert.Nil(t, ra.staged[type1]) 144 assert.Len(t, ra.aggregators, 1) 145 146 resp, err := ra.response() 147 require.NoError(t, err) 148 require.Len(t, resp.Reports, 2) 149 150 reportTypes := make(map[queryv1.ReportType]bool) 151 for _, r := range resp.Reports { 152 reportTypes[r.ReportType] = true 153 } 154 assert.True(t, reportTypes[type1]) 155 assert.True(t, reportTypes[type2]) 156 } 157 158 func TestReportAggregator_NilReport(t *testing.T) { 159 request := &queryv1.InvokeRequest{} 160 ra := newAggregator(request) 161 162 err := ra.aggregateReport(nil) 163 require.NoError(t, err) 164 assert.Len(t, ra.staged, 0) 165 assert.Len(t, ra.aggregators, 0) 166 } 167 168 func TestReportAggregator_AggregateResponse(t *testing.T) { 169 reportType := queryv1.ReportType(999) 170 registerAggregator(reportType, mockAggregatorProvider, false) 171 defer func() { 172 aggregatorMutex.Lock() 173 delete(aggregators, reportType) 174 aggregatorMutex.Unlock() 175 }() 176 177 request := &queryv1.InvokeRequest{} 178 ra := newAggregator(request) 179 180 resp := &queryv1.InvokeResponse{ 181 Reports: []*queryv1.Report{ 182 {ReportType: reportType}, 183 {ReportType: reportType}, 184 }, 185 } 186 187 err := ra.aggregateResponse(resp, nil) 188 require.NoError(t, err) 189 190 assert.Len(t, ra.aggregators, 1) 191 agg := ra.aggregators[reportType].(*mockAggregator) 192 assert.Equal(t, 2, agg.getReportCount()) 193 } 194 195 func TestReportAggregator_ConcurrentAccess(t *testing.T) { 196 reportType := queryv1.ReportType(999) 197 registerAggregator(reportType, mockAggregatorProvider, false) 198 defer func() { 199 aggregatorMutex.Lock() 200 delete(aggregators, reportType) 201 aggregatorMutex.Unlock() 202 }() 203 204 request := &queryv1.InvokeRequest{} 205 ra := newAggregator(request) 206 207 const numGoroutines = 10 208 const reportsPerGoroutine = 5 209 210 var wg sync.WaitGroup 211 wg.Add(numGoroutines) 212 213 for i := 0; i < numGoroutines; i++ { 214 go func() { 215 defer wg.Done() 216 for j := 0; j < reportsPerGoroutine; j++ { 217 report := &queryv1.Report{ReportType: reportType} 218 err := ra.aggregateReport(report) 219 assert.NoError(t, err) 220 } 221 }() 222 } 223 224 wg.Wait() 225 226 resp, err := ra.response() 227 require.NoError(t, err) 228 assert.Len(t, resp.Reports, 1) 229 } 230 231 func TestGetAggregator(t *testing.T) { 232 reportType := queryv1.ReportType(999) 233 registerAggregator(reportType, mockAggregatorProvider, false) 234 defer func() { 235 aggregatorMutex.Lock() 236 delete(aggregators, reportType) 237 aggregatorMutex.Unlock() 238 }() 239 240 request := &queryv1.InvokeRequest{} 241 report := &queryv1.Report{ReportType: reportType} 242 243 agg, err := getAggregator(request, report) 244 require.NoError(t, err) 245 assert.NotNil(t, agg) 246 } 247 248 func TestGetAggregator_UnknownReportType(t *testing.T) { 249 request := &queryv1.InvokeRequest{} 250 unknownReport := &queryv1.Report{ReportType: queryv1.ReportType(996)} 251 _, err := getAggregator(request, unknownReport) 252 assert.Error(t, err) 253 assert.Contains(t, err.Error(), "unknown build type") 254 } 255 256 func TestRegisterAggregator_Duplicate(t *testing.T) { 257 reportType := queryv1.ReportType(999) 258 259 registerAggregator(reportType, mockAggregatorProvider, false) 260 assert.Panics(t, func() { 261 registerAggregator(reportType, mockAggregatorProvider, false) 262 }) 263 264 aggregatorMutex.Lock() 265 delete(aggregators, reportType) 266 aggregatorMutex.Unlock() 267 } 268 269 func TestQueryReportType(t *testing.T) { 270 queryType := queryv1.QueryType(999) 271 reportType := queryv1.ReportType(999) 272 273 registerQueryReportType(queryType, reportType) 274 defer func() { 275 aggregatorMutex.Lock() 276 delete(queryReportType, queryType) 277 aggregatorMutex.Unlock() 278 }() 279 280 result := QueryReportType(queryType) 281 assert.Equal(t, reportType, result) 282 283 assert.Panics(t, func() { 284 QueryReportType(queryv1.QueryType(889)) // Use an unregistered query type 285 }) 286 } 287 288 func TestRegisterQueryReportType_Duplicate(t *testing.T) { 289 queryType := queryv1.QueryType(999) 290 reportType := queryv1.ReportType(999) 291 292 registerQueryReportType(queryType, reportType) 293 assert.Panics(t, func() { 294 registerQueryReportType(queryType, queryv1.ReportType_REPORT_PPROF) 295 }) 296 297 aggregatorMutex.Lock() 298 delete(queryReportType, queryType) 299 aggregatorMutex.Unlock() 300 }