github.com/google/osv-scalibr@v0.4.1/veles/secrets/common/jwt/jwt_test.go (about)

     1  // Copyright 2025 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package jwt_test
    16  
    17  import (
    18  	"encoding/base64"
    19  	"testing"
    20  
    21  	"github.com/google/go-cmp/cmp"
    22  	"github.com/google/go-cmp/cmp/cmpopts"
    23  	"github.com/google/osv-scalibr/veles/secrets/common/jwt"
    24  )
    25  
    26  // TestExtractTokens_validTokens tests for cases where we expect to
    27  // successfully extract tokens.
    28  func TestExtractTokens_validTokens(t *testing.T) {
    29  	cases := []struct {
    30  		name          string
    31  		input         []byte
    32  		wantRaw       []string
    33  		wantHeader    []map[string]any
    34  		wantPayload   []map[string]any
    35  		wantSignature []string
    36  		wantPos       []int
    37  	}{
    38  		{
    39  			name: "basic_jwt_with_simple_claims",
    40  			input: []byte("eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." +
    41  				"eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.signature"),
    42  			wantRaw: []string{
    43  				"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." +
    44  					"eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.signature",
    45  			},
    46  			wantHeader: []map[string]any{
    47  				{
    48  					"alg": "RS256",
    49  					"typ": "JWT",
    50  				},
    51  			},
    52  			wantPayload: []map[string]any{
    53  				{
    54  					"sub":  "1234567890",
    55  					"name": "John Doe",
    56  					"iat":  float64(1516239022),
    57  				},
    58  			},
    59  
    60  			wantPos: []int{0},
    61  		},
    62  		{
    63  			name: "azure_token",
    64  			input: []byte("prefix eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." +
    65  				"eyJpc3MiOiJodHRwczovL2xvZ2luLm1pY3Jvc29mdG9ubGluZS5jb20vdGVuYW50L3YyLjAiLCJzdWIiOiJ1c2VyMTIzIn0.signature suffix"),
    66  			wantRaw: []string{
    67  				"eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." +
    68  					"eyJpc3MiOiJodHRwczovL2xvZ2luLm1pY3Jvc29mdG9ubGluZS5jb20vdGVuYW50L3YyLjAiLCJzdWIiOiJ1c2VyMTIzIn0.signature",
    69  			},
    70  			wantHeader: []map[string]any{
    71  				{
    72  					"alg": "RS256",
    73  					"typ": "JWT",
    74  				},
    75  			},
    76  			wantPayload: []map[string]any{
    77  				{
    78  					"iss": "https://login.microsoftonline.com/tenant/v2.0",
    79  					"sub": "user123",
    80  				},
    81  			},
    82  			wantSignature: []string{"signature"},
    83  			wantPos:       []int{7},
    84  		},
    85  	}
    86  
    87  	for _, tc := range cases {
    88  		t.Run(tc.name, func(t *testing.T) {
    89  			gotTokens, gotPos := jwt.ExtractTokens(tc.input)
    90  
    91  			if diff := cmp.Diff(tc.wantPos, gotPos); diff != "" {
    92  				t.Errorf("ExtractTokens(): diff position mismatch (-want +got):\n%s", diff)
    93  			}
    94  
    95  			if len(gotTokens) != len(tc.wantRaw) {
    96  				t.Fatalf("ExtractTokens(): diff number of tokens: got %d, want %d", len(gotTokens), len(tc.wantRaw))
    97  			}
    98  
    99  			for i, got := range gotTokens {
   100  				if got.Raw() != tc.wantRaw[i] {
   101  					t.Errorf("ExtractTokens(): diff %d Raw() = %q; want %q", i, got.Raw(), tc.wantRaw[i])
   102  				}
   103  
   104  				if diff := cmp.Diff(tc.wantHeader[i], got.Header(), cmpopts.EquateEmpty()); diff != "" {
   105  					t.Errorf("ExtractTokens(): diff %d Header() mismatch (-want +got):\n%s", i, diff)
   106  				}
   107  
   108  				if diff := cmp.Diff(tc.wantPayload[i], got.Payload(), cmpopts.EquateEmpty()); diff != "" {
   109  					t.Errorf("ExtractTokens(): diff %d Payload() mismatch (-want +got):\n%s", i, diff)
   110  				}
   111  			}
   112  		})
   113  	}
   114  }
   115  
   116  // TestExtractTokens_invalidTokens tests for cases where we expect to return nil.
   117  func TestExtractTokens_invalidTokens(t *testing.T) {
   118  	cases := []struct {
   119  		name  string
   120  		input []byte
   121  	}{
   122  		{
   123  			name:  "empty string",
   124  			input: []byte(""),
   125  		},
   126  		{
   127  			name:  "only one part",
   128  			input: []byte("header"),
   129  		},
   130  		{
   131  			name:  "only two parts",
   132  			input: []byte("header.payload"),
   133  		},
   134  		{
   135  			name:  "too many parts",
   136  			input: []byte("header.payload.signature.extra"),
   137  		},
   138  		{
   139  			name:  "invalid base64 in payload",
   140  			input: []byte("eyJhbGciOiJSUzI1NiJ9.invalid_base64!.signature"),
   141  		},
   142  		{
   143  			name: "payload_is_not_json",
   144  			input: []byte("eyJhbGciOiJSUzI1NiJ9." +
   145  				base64.RawStdEncoding.EncodeToString([]byte("not json")) +
   146  				".signature"),
   147  		},
   148  		{
   149  			name:  "not valid regex",
   150  			input: []byte("eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.e30.signature"),
   151  		},
   152  	}
   153  
   154  	for _, tc := range cases {
   155  		t.Run(tc.name, func(t *testing.T) {
   156  			gotTokens, gotPos := jwt.ExtractTokens(tc.input)
   157  
   158  			if len(gotTokens) != 0 {
   159  				t.Errorf("ExtractTokens(): diff returned %d tokens; want 0", len(gotTokens))
   160  			}
   161  
   162  			if len(gotPos) != 0 {
   163  				t.Errorf("ExtractTokens(): diff returned %d positions; want 0", len(gotPos))
   164  			}
   165  		})
   166  	}
   167  }