github.com/grafana/pyroscope@v1.18.0/pkg/symbolizer/debuginfod_client_test.go (about)

     1  package symbolizer
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"io"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"strings"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/go-kit/log"
    14  	"github.com/prometheus/client_golang/prometheus"
    15  	"github.com/stretchr/testify/assert"
    16  	"github.com/stretchr/testify/require"
    17  
    18  	"github.com/grafana/pyroscope/pkg/tenant"
    19  	"github.com/grafana/pyroscope/pkg/validation"
    20  )
    21  
    22  func TestDebuginfodClient(t *testing.T) {
    23  	// Create a test server that returns different responses based on the request
    24  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    25  		buildID := r.URL.Path[len("/buildid/"):]
    26  		buildID = buildID[:len(buildID)-len("/debuginfo")]
    27  
    28  		switch buildID {
    29  		case "valid-build-id":
    30  			w.WriteHeader(http.StatusOK)
    31  			_, _ = w.Write([]byte("mock debug info"))
    32  		case "not-found":
    33  			w.WriteHeader(http.StatusNotFound)
    34  		case "server-error":
    35  			w.WriteHeader(http.StatusInternalServerError)
    36  		case "rate-limited":
    37  			w.WriteHeader(http.StatusTooManyRequests)
    38  		default:
    39  			w.WriteHeader(http.StatusBadRequest)
    40  		}
    41  	}))
    42  	defer server.Close()
    43  
    44  	limits := validation.MockOverrides(func(defaults *validation.Limits, tenantLimits map[string]*validation.Limits) {
    45  		l := validation.MockDefaultLimits()
    46  		l.Symbolizer.MaxSymbolSizeBytes = 4
    47  		tenantLimits["tenant-limited"] = l
    48  	})
    49  
    50  	// Create a client with the test server URL
    51  	metrics := newMetrics(prometheus.NewRegistry())
    52  	client, err := NewDebuginfodClient(log.NewNopLogger(), server.URL, metrics, limits)
    53  	require.NoError(t, err)
    54  
    55  	// Test cases
    56  	tests := []struct {
    57  		name          string
    58  		buildID       string
    59  		tenantID      string
    60  		expectedError bool
    61  		expectedData  string
    62  		errorCheck    func(error) bool
    63  	}{
    64  		{
    65  			name:          "valid build ID",
    66  			buildID:       "valid-build-id",
    67  			expectedError: false,
    68  			expectedData:  "mock debug info",
    69  		},
    70  		{
    71  			name:          "not found",
    72  			buildID:       "not-found",
    73  			expectedError: true,
    74  			errorCheck: func(err error) bool {
    75  				var notFoundErr buildIDNotFoundError
    76  				return errors.As(err, &notFoundErr)
    77  			},
    78  		},
    79  		{
    80  			name:          "server error",
    81  			buildID:       "server-error",
    82  			expectedError: true,
    83  			errorCheck: func(err error) bool {
    84  				return err != nil && err.Error() != "" &&
    85  					(err.Error() == "HTTP error 500" ||
    86  						err.Error() == "failed to fetch debuginfo after 3 attempts: HTTP error 500")
    87  			},
    88  		},
    89  		{
    90  			name:          "rate limited",
    91  			buildID:       "rate-limited",
    92  			expectedError: true,
    93  			errorCheck: func(err error) bool {
    94  				return err != nil && err.Error() != "" &&
    95  					(err.Error() == "HTTP error 429" ||
    96  						err.Error() == "failed to fetch debuginfo after 3 attempts: HTTP error 429")
    97  			},
    98  		},
    99  		{
   100  			name:          "invalid build ID",
   101  			buildID:       "invalid/build/id",
   102  			expectedError: true,
   103  			errorCheck: func(err error) bool {
   104  				return isInvalidBuildIDError(err)
   105  			},
   106  		},
   107  		{
   108  			name:          "size limit",
   109  			buildID:       "valid-build-id",
   110  			tenantID:      "tenant-limited",
   111  			expectedError: true,
   112  			errorCheck: func(err error) bool {
   113  				return err != nil && strings.Contains(err.Error(), "symbol size exceeds maximum allowed size of 4 bytes")
   114  			},
   115  		},
   116  	}
   117  
   118  	for _, tc := range tests {
   119  		t.Run(tc.name, func(t *testing.T) {
   120  			// Fetch debug info
   121  			tenantID := "tenant"
   122  			if tc.tenantID != "" {
   123  				tenantID = tc.tenantID
   124  			}
   125  			ctx := tenant.InjectTenantID(context.Background(), tenantID)
   126  			reader, err := client.FetchDebuginfo(ctx, tc.buildID)
   127  
   128  			// Check error
   129  			if tc.expectedError {
   130  				assert.Error(t, err)
   131  				if tc.errorCheck != nil {
   132  					assert.True(t, tc.errorCheck(err), "Error type check failed: %v", err)
   133  				}
   134  				return
   135  			}
   136  
   137  			// Check success case
   138  			require.NoError(t, err)
   139  			defer reader.Close()
   140  
   141  			// Read the data
   142  			data, err := io.ReadAll(reader)
   143  			require.NoError(t, err)
   144  			assert.Equal(t, tc.expectedData, string(data))
   145  		})
   146  	}
   147  }
   148  
   149  func TestDebuginfodClientSingleflight(t *testing.T) {
   150  	// Create a test server that sleeps to simulate a slow response
   151  	requestCount := 0
   152  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   153  		requestCount++
   154  		time.Sleep(100 * time.Millisecond)
   155  		w.WriteHeader(http.StatusOK)
   156  		_, _ = w.Write([]byte("mock debug info"))
   157  	}))
   158  	defer server.Close()
   159  
   160  	// Create a client with the test server URL
   161  	metrics := newMetrics(prometheus.NewRegistry())
   162  	client, err := NewDebuginfodClient(log.NewNopLogger(), server.URL, metrics, validation.MockDefaultOverrides())
   163  	require.NoError(t, err)
   164  
   165  	// Make concurrent requests with the same build ID
   166  	buildID := "singleflight-test-id"
   167  	ctx := tenant.InjectTenantID(context.Background(), "tenant")
   168  
   169  	// Channel to synchronize goroutines
   170  	done := make(chan struct{})
   171  	results := make(chan []byte, 3)
   172  	errors := make(chan error, 3)
   173  
   174  	// Start 3 concurrent requests
   175  	for i := 0; i < 3; i++ {
   176  		go func() {
   177  			reader, err := client.FetchDebuginfo(ctx, buildID)
   178  			if err != nil {
   179  				errors <- err
   180  				done <- struct{}{}
   181  				return
   182  			}
   183  			data, err := io.ReadAll(reader)
   184  			reader.Close()
   185  			if err != nil {
   186  				errors <- err
   187  			} else {
   188  				results <- data
   189  			}
   190  			done <- struct{}{}
   191  		}()
   192  	}
   193  
   194  	// Wait for all requests to complete
   195  	for i := 0; i < 3; i++ {
   196  		<-done
   197  	}
   198  
   199  	// Check results
   200  	close(results)
   201  	close(errors)
   202  
   203  	// Should have no errors
   204  	for err := range errors {
   205  		t.Errorf("Unexpected error: %v", err)
   206  	}
   207  
   208  	// All results should be the same
   209  	var data []byte
   210  	for result := range results {
   211  		if data == nil {
   212  			data = result
   213  		} else {
   214  			assert.Equal(t, data, result)
   215  		}
   216  	}
   217  
   218  	// Should have made only one HTTP request
   219  	assert.Equal(t, 1, requestCount, "Expected only one HTTP request")
   220  }
   221  
   222  func TestSanitizeBuildID(t *testing.T) {
   223  	tests := []struct {
   224  		name        string
   225  		buildID     string
   226  		expected    string
   227  		expectError bool
   228  	}{
   229  		{
   230  			name:        "valid build ID",
   231  			buildID:     "abcdef1234567890",
   232  			expected:    "abcdef1234567890",
   233  			expectError: false,
   234  		},
   235  		{
   236  			name:        "valid build ID with dashes and underscores",
   237  			buildID:     "abcdef-1234_7890",
   238  			expected:    "abcdef-1234_7890",
   239  			expectError: false,
   240  		},
   241  		{
   242  			name:        "invalid build ID with slashes",
   243  			buildID:     "abcdef/1234",
   244  			expected:    "",
   245  			expectError: true,
   246  		},
   247  		{
   248  			name:        "invalid build ID with spaces",
   249  			buildID:     "abcdef 1234",
   250  			expected:    "",
   251  			expectError: true,
   252  		},
   253  		{
   254  			name:        "invalid build ID with special characters",
   255  			buildID:     "abcdef#1234",
   256  			expected:    "",
   257  			expectError: true,
   258  		},
   259  	}
   260  
   261  	for _, tc := range tests {
   262  		t.Run(tc.name, func(t *testing.T) {
   263  			result, err := sanitizeBuildID(tc.buildID)
   264  			if tc.expectError {
   265  				assert.Error(t, err)
   266  			} else {
   267  				assert.NoError(t, err)
   268  				assert.Equal(t, tc.expected, result)
   269  			}
   270  		})
   271  	}
   272  }
   273  
   274  func TestIsRetryableError(t *testing.T) {
   275  	tests := []struct {
   276  		name     string
   277  		err      error
   278  		expected bool
   279  	}{
   280  		{
   281  			name:     "nil error",
   282  			err:      nil,
   283  			expected: false,
   284  		},
   285  		{
   286  			name:     "context canceled",
   287  			err:      context.Canceled,
   288  			expected: false,
   289  		},
   290  		{
   291  			name:     "context deadline exceeded",
   292  			err:      context.DeadlineExceeded,
   293  			expected: false,
   294  		},
   295  		{
   296  			name:     "invalid build ID",
   297  			err:      invalidBuildIDError{buildID: "invalid"},
   298  			expected: false,
   299  		},
   300  		{
   301  			name:     "build ID not found",
   302  			err:      buildIDNotFoundError{buildID: "not-found"},
   303  			expected: false,
   304  		},
   305  		{
   306  			name:     "HTTP 404",
   307  			err:      httpStatusError{statusCode: http.StatusNotFound},
   308  			expected: false,
   309  		},
   310  		{
   311  			name:     "HTTP 429",
   312  			err:      httpStatusError{statusCode: http.StatusTooManyRequests},
   313  			expected: true,
   314  		},
   315  		{
   316  			name:     "HTTP 500",
   317  			err:      httpStatusError{statusCode: http.StatusInternalServerError},
   318  			expected: true,
   319  		},
   320  	}
   321  
   322  	for _, tc := range tests {
   323  		t.Run(tc.name, func(t *testing.T) {
   324  			result := isRetryableError(tc.err)
   325  			assert.Equal(t, tc.expected, result)
   326  		})
   327  	}
   328  }
   329  
   330  func TestDebuginfodClientNotFoundCache(t *testing.T) {
   331  	requestCount := 0
   332  
   333  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   334  		requestCount++
   335  		buildID := r.URL.Path[len("/buildid/"):]
   336  		buildID = buildID[:len(buildID)-len("/debuginfo")]
   337  		if buildID == "not-found-cached" {
   338  			w.WriteHeader(http.StatusNotFound)
   339  			return
   340  		}
   341  		w.WriteHeader(http.StatusOK)
   342  		_, _ = w.Write([]byte("mock debug info"))
   343  	}))
   344  	defer server.Close()
   345  
   346  	client, err := NewDebuginfodClientWithConfig(log.NewNopLogger(), DebuginfodClientConfig{
   347  		BaseURL:               server.URL,
   348  		NotFoundCacheMaxItems: 100,
   349  		NotFoundCacheTTL:      10 * time.Second,
   350  	}, newMetrics(nil), validation.MockDefaultOverrides())
   351  	require.NoError(t, err)
   352  
   353  	ctx := tenant.InjectTenantID(context.Background(), "tenant")
   354  	buildID := "not-found-cached"
   355  
   356  	// First request should hit the server and get a 404
   357  	reader, err := client.FetchDebuginfo(ctx, buildID)
   358  	assert.Error(t, err)
   359  	assert.Nil(t, reader)
   360  
   361  	var notFoundErr buildIDNotFoundError
   362  	assert.True(t, errors.As(err, &notFoundErr))
   363  	assert.Equal(t, 1, requestCount)
   364  
   365  	client.notFoundCache.Wait()
   366  
   367  	// Second request should get 404 from cache without hitting server
   368  	reader, err = client.FetchDebuginfo(ctx, buildID)
   369  	assert.Error(t, err)
   370  	assert.Nil(t, reader)
   371  	assert.True(t, errors.As(err, &notFoundErr))
   372  	assert.Equal(t, 1, requestCount)
   373  
   374  	// Third request should hit the server
   375  	reader, err = client.FetchDebuginfo(ctx, "valid-id")
   376  	assert.NoError(t, err)
   377  	require.NotNil(t, reader)
   378  
   379  	data, err := io.ReadAll(reader)
   380  	require.NoError(t, err)
   381  	reader.Close()
   382  	assert.Equal(t, "mock debug info", string(data))
   383  
   384  	assert.Equal(t, 2, requestCount)
   385  }