github.com/grafana/pyroscope@v1.18.0/pkg/frontend/vcs/token_test.go (about)

     1  package vcs
     2  
     3  import (
     4  	"context"
     5  	"encoding/base64"
     6  	"encoding/json"
     7  	"net/http"
     8  	"net/url"
     9  	"strings"
    10  	"testing"
    11  	"time"
    12  
    13  	"connectrpc.com/connect"
    14  	"github.com/stretchr/testify/require"
    15  	"golang.org/x/oauth2"
    16  
    17  	vcsv1 "github.com/grafana/pyroscope/api/gen/proto/go/vcs/v1"
    18  	"github.com/grafana/pyroscope/pkg/tenant"
    19  )
    20  
    21  func Test_getStringValueFrom(t *testing.T) {
    22  	tests := []struct {
    23  		Name       string
    24  		Query      url.Values
    25  		Key        string
    26  		Want       string
    27  		WantErrMsg string
    28  	}{
    29  		{
    30  			Name: "key exists",
    31  			Query: url.Values{
    32  				"my_key": {"my_value"},
    33  			},
    34  			Key:  "my_key",
    35  			Want: "my_value",
    36  		},
    37  		{
    38  			Name: "key exists with multiple values",
    39  			Query: url.Values{
    40  				"my_key": {"my_value1", "my_value2"},
    41  			},
    42  			Key:  "my_key",
    43  			Want: "my_value1",
    44  		},
    45  		{
    46  			Name: "key is missing",
    47  			Query: url.Values{
    48  				"my_key": {"my_value"},
    49  			},
    50  			Key:        "my_missing_key",
    51  			WantErrMsg: "missing key: my_missing_key",
    52  		},
    53  	}
    54  
    55  	for _, tt := range tests {
    56  		t.Run(tt.Name, func(t *testing.T) {
    57  			got, err := getStringValueFrom(tt.Query, tt.Key)
    58  			if tt.WantErrMsg != "" {
    59  				require.Error(t, err)
    60  				require.EqualError(t, err, tt.WantErrMsg)
    61  			} else {
    62  				require.NoError(t, err)
    63  				require.Equal(t, tt.Want, got)
    64  			}
    65  		})
    66  	}
    67  }
    68  
    69  func Test_getDurationValueFrom(t *testing.T) {
    70  	tests := []struct {
    71  		Name       string
    72  		Query      url.Values
    73  		Key        string
    74  		Scalar     time.Duration
    75  		Want       time.Duration
    76  		WantErrMsg string
    77  	}{
    78  		{
    79  			Name: "key exists",
    80  			Query: url.Values{
    81  				"my_key": {"100"},
    82  			},
    83  			Key:    "my_key",
    84  			Scalar: time.Second,
    85  			Want:   100 * time.Second,
    86  		},
    87  		{
    88  			Name: "key exists with multiple values",
    89  			Query: url.Values{
    90  				"my_key": {"100", "200"},
    91  			},
    92  			Key:    "my_key",
    93  			Scalar: time.Second,
    94  			Want:   100 * time.Second,
    95  		},
    96  		{
    97  			Name: "scalar less than 1",
    98  			Query: url.Values{
    99  				"my_key": {"100"},
   100  			},
   101  			Key:        "my_key",
   102  			Scalar:     0,
   103  			WantErrMsg: "cannot use scalar less than 1",
   104  		},
   105  		{
   106  			Name: "value is not a duration",
   107  			Query: url.Values{
   108  				"my_key": {"not_a_number"},
   109  			},
   110  			Key:        "my_key",
   111  			Scalar:     time.Second,
   112  			WantErrMsg: "failed to parse my_key: strconv.Atoi: parsing \"not_a_number\": invalid syntax",
   113  		},
   114  		{
   115  			Name: "key is missing",
   116  			Query: url.Values{
   117  				"my_key": {"my_value"},
   118  			},
   119  			Scalar:     time.Second,
   120  			Key:        "my_missing_key",
   121  			WantErrMsg: "missing key: my_missing_key",
   122  		},
   123  	}
   124  
   125  	for _, tt := range tests {
   126  		t.Run(tt.Name, func(t *testing.T) {
   127  			got, err := getDurationValueFrom(tt.Query, tt.Key, tt.Scalar)
   128  			if tt.WantErrMsg != "" {
   129  				require.Error(t, err)
   130  				require.EqualError(t, err, tt.WantErrMsg)
   131  			} else {
   132  				require.NoError(t, err)
   133  				require.Equal(t, tt.Want, got)
   134  			}
   135  		})
   136  	}
   137  }
   138  
   139  func Test_tokenFromRequest(t *testing.T) {
   140  	ctx := newTestContext()
   141  
   142  	t.Run("token exists in request", func(t *testing.T) {
   143  		githubSessionSecret = []byte("16_byte_key_XXXX")
   144  
   145  		derivedKey, err := deriveEncryptionKeyForContext(ctx)
   146  		require.NoError(t, err)
   147  
   148  		wantToken := &oauth2.Token{
   149  			AccessToken:  "my_access_token",
   150  			TokenType:    "my_token_type",
   151  			RefreshToken: "my_refresh_token",
   152  			Expiry:       time.Unix(1713298947, 0).UTC(), // 2024-04-16T20:22:27.346Z
   153  		}
   154  
   155  		// The type of request here doesn't matter.
   156  		req := connect.NewRequest(&vcsv1.GetFileRequest{})
   157  		req.Header().Add("Cookie", testEncodeCookie(t, derivedKey, wantToken).String())
   158  
   159  		gotToken, err := tokenFromRequest(ctx, req)
   160  		require.NoError(t, err)
   161  		require.Equal(t, *wantToken, *gotToken)
   162  	})
   163  
   164  	t.Run("legacy token exists in request", func(t *testing.T) {
   165  		githubSessionSecret = []byte("16_byte_key_XXXX")
   166  
   167  		derivedKey, err := deriveEncryptionKeyForContext(ctx)
   168  		require.NoError(t, err)
   169  
   170  		wantToken := &oauth2.Token{
   171  			AccessToken:  "my_access_token",
   172  			TokenType:    "my_token_type",
   173  			RefreshToken: "my_refresh_token",
   174  			Expiry:       time.Unix(1713298947, 0).UTC(), // 2024-04-16T20:22:27.346Z
   175  		}
   176  
   177  		// The type of request here doesn't matter.
   178  		req := connect.NewRequest(&vcsv1.GetFileRequest{})
   179  		req.Header().Add("Cookie", testEncodeLegacyCookie(t, derivedKey, wantToken).String())
   180  
   181  		gotToken, err := tokenFromRequest(ctx, req)
   182  		require.NoError(t, err)
   183  		require.Equal(t, *wantToken, *gotToken)
   184  	})
   185  
   186  	t.Run("token does not exist in request", func(t *testing.T) {
   187  		githubSessionSecret = []byte("16_byte_key_XXXX")
   188  		wantErr := "failed to read cookie pyroscope_git_session: http: named cookie not present"
   189  
   190  		// The type of request here doesn't matter.
   191  		req := connect.NewRequest(&vcsv1.GetFileRequest{})
   192  
   193  		_, err := tokenFromRequest(ctx, req)
   194  		require.Error(t, err)
   195  		require.EqualError(t, err, wantErr)
   196  	})
   197  
   198  	t.Run("token with garbage should fail", func(t *testing.T) {
   199  		githubSessionSecret = []byte("16_byte_key_XXXX")
   200  
   201  		// a cookie with false metadata
   202  		cookieData, err := json.Marshal(map[string]interface{}{
   203  			"metadata": base64.StdEncoding.EncodeToString([]byte(strings.Repeat("x", 128))),
   204  			"expiry":   1234,
   205  		})
   206  		require.NoError(t, err)
   207  
   208  		req := connect.NewRequest(&vcsv1.GetFileRequest{})
   209  		req.Header().Add("Cookie", "pyroscope_git_session="+base64.RawStdEncoding.EncodeToString(cookieData))
   210  
   211  		token, err := tokenFromRequest(ctx, req)
   212  		require.Error(t, err)
   213  		require.EqualError(t, err, "cipher: message authentication failed")
   214  		require.Nil(t, token)
   215  
   216  	})
   217  }
   218  
   219  func Test_encodeTokenInCookie(t *testing.T) {
   220  	githubSessionSecret = []byte("16_byte_key_XXXX")
   221  	ctx := newTestContext()
   222  
   223  	derivedKey, err := deriveEncryptionKeyForContext(ctx)
   224  	require.NoError(t, err)
   225  
   226  	token := &oauth2.Token{
   227  		AccessToken:  "my_access_token",
   228  		TokenType:    "my_token_type",
   229  		RefreshToken: "my_refresh_token",
   230  		Expiry:       time.Unix(1713298947, 0).UTC(), // 2024-04-16T20:22:27.346Z
   231  	}
   232  
   233  	got, err := encodeTokenInCookie(token, derivedKey)
   234  	require.NoError(t, err)
   235  	require.Equal(t, sessionCookieName, got.Name)
   236  	require.NotEmpty(t, got.Value)
   237  	require.NotZero(t, got.Expires)
   238  	require.True(t, got.Secure)
   239  	require.Equal(t, http.SameSiteLaxMode, got.SameSite)
   240  }
   241  
   242  func Test_decodeToken(t *testing.T) {
   243  	githubSessionSecret = []byte("16_byte_key_XXXX")
   244  
   245  	ctx := newTestContext()
   246  	derivedKey, err := deriveEncryptionKeyForContext(ctx)
   247  	require.NoError(t, err)
   248  
   249  	t.Run("valid token", func(t *testing.T) {
   250  		want := &oauth2.Token{
   251  			AccessToken:  "my_access_token",
   252  			TokenType:    "my_token_type",
   253  			RefreshToken: "my_refresh_token",
   254  			Expiry:       time.Unix(1713298947, 0).UTC(), // 2024-04-16T20:22:27.346Z
   255  		}
   256  		cookie := testEncodeCookie(t, derivedKey, want)
   257  
   258  		got, err := decodeToken(cookie.Value, derivedKey)
   259  		require.NoError(t, err)
   260  		require.Equal(t, want, got)
   261  	})
   262  
   263  	t.Run("valid legacy token", func(t *testing.T) {
   264  		want := &oauth2.Token{
   265  			AccessToken:  "my_access_token",
   266  			TokenType:    "my_token_type",
   267  			RefreshToken: "my_refresh_token",
   268  			Expiry:       time.Unix(1713298947, 0).UTC(), // 2024-04-16T20:22:27.346Z
   269  		}
   270  		cookie := testEncodeLegacyCookie(t, derivedKey, want)
   271  
   272  		got, err := decodeToken(cookie.Value, derivedKey)
   273  		require.NoError(t, err)
   274  		require.Equal(t, want, got)
   275  	})
   276  
   277  	t.Run("invalid base64 encoding", func(t *testing.T) {
   278  		illegalBase64Encoding := "xx==="
   279  
   280  		_, err := decodeToken(illegalBase64Encoding, derivedKey)
   281  		require.Error(t, err)
   282  		require.EqualError(t, err, "illegal base64 data at input byte 4")
   283  	})
   284  
   285  	t.Run("invalid json encoding", func(t *testing.T) {
   286  		illegalJSON := base64.StdEncoding.EncodeToString([]byte("illegal json value"))
   287  
   288  		_, err := decodeToken(illegalJSON, derivedKey)
   289  		require.Error(t, err)
   290  		require.EqualError(t, err, "invalid character 'i' looking for beginning of value")
   291  	})
   292  }
   293  
   294  func Test_tenantIsolation(t *testing.T) {
   295  	githubSessionSecret = []byte("16_byte_key_XXXX")
   296  
   297  	var (
   298  		ctxA = newTestContextWithTenantID("tenant_a")
   299  		ctxB = newTestContextWithTenantID("tenant_b")
   300  	)
   301  
   302  	derivedKeyA, err := deriveEncryptionKeyForContext(ctxA)
   303  	require.NoError(t, err)
   304  
   305  	encodedTokenA := testEncodeCookie(t, derivedKeyA, &oauth2.Token{
   306  		AccessToken: "so_secret",
   307  	})
   308  
   309  	req := connect.NewRequest(&vcsv1.GetFileRequest{})
   310  	req.Header().Add("Cookie", encodedTokenA.String())
   311  
   312  	tA, err := tokenFromRequest(ctxA, req)
   313  	require.NoError(t, err)
   314  	require.Equal(t, "so_secret", tA.AccessToken)
   315  
   316  	_, err = tokenFromRequest(ctxB, req)
   317  	require.ErrorContains(t, err, "message authentication failed")
   318  
   319  }
   320  
   321  func Test_StillCompatible(t *testing.T) {
   322  	githubSessionSecret = []byte("16_byte_key_XXXX")
   323  
   324  	ctx := newTestContextWithTenantID("tenant_a")
   325  	req := connect.NewRequest(&vcsv1.GetFileRequest{})
   326  	req.Header().Add("Cookie", "pyroscope_git_session=eyJtZXRhZGF0YSI6Im12N0d1OHlIanZxdWdQMmF5TnJaYXd1SXNyQXFmUUVIMVhGS1RkejVlZWtob1NRV1JUM3hVZGRuMndUemhQZ05oWktRVkpjcVh5SVJDSnFmTTV3WTJyNmR3R21rZkRhL2FORjhRZ0lJcU1oa1hPbGFEdXNwcFE9PSJ9Cg==")
   327  
   328  	realToken, err := tokenFromRequest(ctx, req)
   329  	require.NoError(t, err)
   330  	require.Equal(t, "so_secret", realToken.AccessToken)
   331  }
   332  
   333  func newTestContext() context.Context {
   334  	return newTestContextWithTenantID("test_tenant_id")
   335  }
   336  
   337  func newTestContextWithTenantID(tenantID string) context.Context {
   338  	return tenant.InjectTenantID(context.Background(), tenantID)
   339  }
   340  
   341  func testEncodeCookie(t *testing.T, key []byte, token *oauth2.Token) *http.Cookie {
   342  	t.Helper()
   343  
   344  	encrypted, err := encryptToken(token, key)
   345  	require.NoError(t, err)
   346  
   347  	cookieValue := gitSessionTokenCookie{
   348  		Token: &encrypted,
   349  	}
   350  
   351  	jsonString, err := json.Marshal(cookieValue)
   352  	require.NoError(t, err)
   353  
   354  	encoded := base64.StdEncoding.EncodeToString(jsonString)
   355  	return &http.Cookie{
   356  		Name:     sessionCookieName,
   357  		Value:    encoded,
   358  		Expires:  token.Expiry,
   359  		HttpOnly: false,
   360  		Secure:   true,
   361  		SameSite: http.SameSiteLaxMode,
   362  	}
   363  }
   364  
   365  func testEncodeLegacyCookie(t *testing.T, key []byte, token *oauth2.Token) *http.Cookie {
   366  	t.Helper()
   367  
   368  	encoded, err := encodeTokenInCookie(token, key)
   369  	require.NoError(t, err)
   370  
   371  	return encoded
   372  }