github.com/yogeshkumararora/slsa-github-generator@v1.10.1-0.20240520161934-11278bd5afb4/github/oidc_test.go (about)

     1  // Copyright 2023 SLSA Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package github
    16  
    17  import (
    18  	"context"
    19  	"encoding/base64"
    20  	"errors"
    21  	"fmt"
    22  	"net/http"
    23  	"net/http/httptest"
    24  	"os"
    25  	"testing"
    26  	"time"
    27  
    28  	"github.com/google/go-cmp/cmp"
    29  	"github.com/google/go-cmp/cmp/cmpopts"
    30  )
    31  
    32  // tokenEqual returns whether the tokens are functionally equal for the purposes of the test.
    33  func tokenEqual(issuer string, wantToken, gotToken *OIDCToken) bool {
    34  	if wantToken == nil && gotToken == nil {
    35  		return true
    36  	}
    37  
    38  	if gotToken == nil || wantToken == nil {
    39  		return false
    40  	}
    41  
    42  	// NOTE: don't check the wantToken issuer because it's not known until the
    43  	// server is created and we can't use a dummy value because verification checks
    44  	// it.
    45  	if want, got := issuer, gotToken.Issuer; want != got {
    46  		return false
    47  	}
    48  
    49  	if want, got := wantToken.Audience, gotToken.Audience; !compareStringSlice(want, got) {
    50  		return false
    51  	}
    52  
    53  	if want, got := wantToken.Expiry, gotToken.Expiry; !want.Equal(got) {
    54  		return false
    55  	}
    56  
    57  	if want, got := wantToken.JobWorkflowRef, gotToken.JobWorkflowRef; want != got {
    58  		return false
    59  	}
    60  
    61  	return true
    62  }
    63  
    64  func TestNewOIDCClient(t *testing.T) {
    65  	// Tests that NewOIDCClient returns an error when the
    66  	// ACTIONS_ID_TOKEN_REQUEST_URL env var is empty.
    67  	t.Run("empty url", func(t *testing.T) {
    68  		if os.Getenv(requestURLEnvKey) != "" {
    69  			panic(fmt.Sprintf("expected %v to be empty", requestURLEnvKey))
    70  		}
    71  
    72  		_, err := NewOIDCClient()
    73  		if err == nil {
    74  			t.Fatalf("expected error")
    75  		}
    76  		if got, want := err, errURLError; !errors.Is(got, want) {
    77  			t.Fatalf("unexpected error, got: %#v, want: %#v", got, want)
    78  		}
    79  	})
    80  }
    81  
    82  func TestToken(t *testing.T) {
    83  	now := time.Date(2022, 4, 14, 12, 24, 0, 0, time.UTC)
    84  
    85  	errClaimsFunc := func(got error) {
    86  		want := errClaims
    87  		if !errors.Is(got, want) {
    88  			t.Fatalf("unexpected error: %v", cmp.Diff(got, want, cmpopts.EquateErrors()))
    89  		}
    90  	}
    91  
    92  	errVerifyFunc := func(got error) {
    93  		want := errVerify
    94  		if !errors.Is(got, want) {
    95  			t.Fatalf("unexpected error: %v", cmp.Diff(got, want, cmpopts.EquateErrors()))
    96  		}
    97  	}
    98  
    99  	errTokenFunc := func(got error) {
   100  		want := errToken
   101  		if !errors.Is(got, want) {
   102  			t.Fatalf("unexpected error: %v", cmp.Diff(got, want, cmpopts.EquateErrors()))
   103  		}
   104  	}
   105  
   106  	errRequestErrorFunc := func(got error) {
   107  		want := errRequestError
   108  		if !errors.Is(got, want) {
   109  			t.Fatalf("unexpected error: %v", cmp.Diff(got, want, cmpopts.EquateErrors()))
   110  		}
   111  	}
   112  
   113  	testCases := []struct {
   114  		name     string
   115  		raw      string
   116  		token    *OIDCToken
   117  		err      func(error)
   118  		audience []string
   119  		status   int
   120  	}{
   121  		{
   122  			name:     "basic token",
   123  			audience: []string{"hoge"},
   124  			token: &OIDCToken{
   125  				Audience:          []string{"hoge"},
   126  				Expiry:            now.Add(1 * time.Hour),
   127  				JobWorkflowRef:    "pico",
   128  				RepositoryID:      "1234",
   129  				RepositoryOwnerID: "4321",
   130  				ActorID:           "4567",
   131  			},
   132  		},
   133  		{
   134  			name:     "no repository id claim",
   135  			audience: []string{"hoge"},
   136  			token: &OIDCToken{
   137  				Audience:          []string{"hoge"},
   138  				Expiry:            now.Add(1 * time.Hour),
   139  				JobWorkflowRef:    "pico",
   140  				RepositoryOwnerID: "4321",
   141  				ActorID:           "4567",
   142  			},
   143  			err: errClaimsFunc,
   144  		},
   145  		{
   146  			name:     "no workflow ref claim",
   147  			audience: []string{"hoge"},
   148  			token: &OIDCToken{
   149  				Audience:          []string{"hoge"},
   150  				Expiry:            now.Add(1 * time.Hour),
   151  				RepositoryID:      "1234",
   152  				RepositoryOwnerID: "4321",
   153  				ActorID:           "4567",
   154  			},
   155  			err: errClaimsFunc,
   156  		},
   157  		{
   158  			name:     "no owner id claim",
   159  			audience: []string{"hoge"},
   160  			token: &OIDCToken{
   161  				Audience:       []string{"hoge"},
   162  				Expiry:         now.Add(1 * time.Hour),
   163  				JobWorkflowRef: "pico",
   164  				RepositoryID:   "1234",
   165  				ActorID:        "4567",
   166  			},
   167  			err: errClaimsFunc,
   168  		},
   169  		{
   170  			name:     "no actor id claim",
   171  			audience: []string{"hoge"},
   172  			token: &OIDCToken{
   173  				Audience:          []string{"hoge"},
   174  				Expiry:            now.Add(1 * time.Hour),
   175  				JobWorkflowRef:    "pico",
   176  				RepositoryID:      "1234",
   177  				RepositoryOwnerID: "4321",
   178  			},
   179  			err: errClaimsFunc,
   180  		},
   181  		{
   182  			name:     "expired token",
   183  			audience: []string{"hoge"},
   184  			token: &OIDCToken{
   185  				Audience:          []string{"hoge"},
   186  				Expiry:            now.Add(-1 * time.Hour),
   187  				JobWorkflowRef:    "pico",
   188  				RepositoryID:      "1234",
   189  				RepositoryOwnerID: "4321",
   190  				ActorID:           "4567",
   191  			},
   192  			err: errVerifyFunc,
   193  		},
   194  		{
   195  			name:     "bad audience",
   196  			audience: []string{"hoge"},
   197  			token: &OIDCToken{
   198  				Audience:          []string{"fuga"},
   199  				Expiry:            now.Add(1 * time.Hour),
   200  				JobWorkflowRef:    "pico",
   201  				RepositoryID:      "1234",
   202  				RepositoryOwnerID: "4321",
   203  				ActorID:           "4567",
   204  			},
   205  			err: errVerifyFunc,
   206  		},
   207  		{
   208  			name:     "bad issuer",
   209  			audience: []string{"hoge"},
   210  			token: &OIDCToken{
   211  				Issuer:            "https://www.google.com/",
   212  				Audience:          []string{"hoge"},
   213  				Expiry:            now.Add(1 * time.Hour),
   214  				JobWorkflowRef:    "pico",
   215  				RepositoryID:      "1234",
   216  				RepositoryOwnerID: "4321",
   217  				ActorID:           "4567",
   218  			},
   219  			err: errVerifyFunc,
   220  		},
   221  		{
   222  			name:     "invalid parts",
   223  			audience: []string{"hoge"},
   224  			raw:      `{"value": "part1"}`,
   225  			status:   http.StatusOK,
   226  			err:      errVerifyFunc,
   227  		},
   228  		{
   229  			name:     "invalid base64",
   230  			audience: []string{"hoge"},
   231  			raw:      `{"value": "part1.part2.part3"}`,
   232  			status:   http.StatusOK,
   233  			err:      errVerifyFunc,
   234  		},
   235  		{
   236  			name:     "invalid json part",
   237  			audience: []string{"hoge"},
   238  			raw:      fmt.Sprintf(`{"value": "part1.%s.part3"}`, base64.RawURLEncoding.EncodeToString([]byte("not json"))),
   239  			status:   http.StatusOK,
   240  			err:      errVerifyFunc,
   241  		},
   242  		{
   243  			name:     "invalid response",
   244  			audience: []string{"hoge"},
   245  			raw:      `not json`,
   246  			status:   http.StatusOK,
   247  			err:      errTokenFunc,
   248  		},
   249  		{
   250  			name:     "error response",
   251  			audience: []string{"hoge"},
   252  			raw:      "",
   253  			status:   http.StatusServiceUnavailable,
   254  			err:      errRequestErrorFunc,
   255  		},
   256  		{
   257  			name:     "redirect response",
   258  			audience: []string{"hoge"},
   259  			raw:      "",
   260  			status:   http.StatusFound,
   261  			err:      errRequestErrorFunc,
   262  		},
   263  	}
   264  
   265  	for _, tc := range testCases {
   266  		t.Run(tc.name, func(t *testing.T) {
   267  			var s *httptest.Server
   268  			var c *OIDCClient
   269  			if tc.token != nil {
   270  				s, c = NewTestOIDCServer(t, now, tc.token)
   271  			} else {
   272  				s, c = newRawTestOIDCServer(t, now, tc.status, tc.raw)
   273  			}
   274  			defer s.Close()
   275  
   276  			token, err := c.Token(context.Background(), tc.audience)
   277  			if err != nil {
   278  				if tc.err != nil {
   279  					tc.err(err)
   280  				} else {
   281  					t.Fatalf("unexpected error: %v", cmp.Diff(err, tc.err, cmpopts.EquateErrors()))
   282  				}
   283  			} else {
   284  				if tc.err != nil {
   285  					tc.err(err)
   286  				} else {
   287  					// Successful response, as expected. Check token.
   288  					if want, got := tc.token, token; !tokenEqual(s.URL, want, got) {
   289  						t.Errorf("unexpected workflow ref\nwant: %#v\ngot:  %#v\ndiff:\n%v", want, got, cmp.Diff(want, got))
   290  					}
   291  				}
   292  			}
   293  		})
   294  	}
   295  }
   296  
   297  func Test_compareStringSlice(t *testing.T) {
   298  	testCases := []struct {
   299  		name     string
   300  		left     []string
   301  		right    []string
   302  		expected bool
   303  	}{
   304  		{
   305  			name:     "empty",
   306  			left:     []string{},
   307  			right:    []string{},
   308  			expected: true,
   309  		},
   310  		{
   311  			name:     "nil",
   312  			left:     nil,
   313  			right:    nil,
   314  			expected: true,
   315  		},
   316  		{
   317  			name:     "left nil, right empty",
   318  			left:     nil,
   319  			right:    []string{},
   320  			expected: true,
   321  		},
   322  		{
   323  			name:     "left empty, right nil",
   324  			left:     []string{},
   325  			right:    nil,
   326  			expected: true,
   327  		},
   328  		{
   329  			name:     "equal",
   330  			left:     []string{"hoge", "fuga"},
   331  			right:    []string{"hoge", "fuga"},
   332  			expected: true,
   333  		},
   334  		{
   335  			name:     "unsorted",
   336  			left:     []string{"hoge", "fuga"},
   337  			right:    []string{"fuga", "hoge"},
   338  			expected: true,
   339  		},
   340  		{
   341  			name:     "left bigger",
   342  			left:     []string{"hoge", "fuga", "pico"},
   343  			right:    []string{"fuga", "hoge"},
   344  			expected: false,
   345  		},
   346  		{
   347  			name:     "right bigger",
   348  			left:     []string{"hoge", "fuga"},
   349  			right:    []string{"fuga", "hoge", "pico"},
   350  			expected: false,
   351  		},
   352  		{
   353  			name:     "diff value",
   354  			left:     []string{"hoge", "fuga"},
   355  			right:    []string{"fuga", "pico"},
   356  			expected: false,
   357  		},
   358  		{
   359  			name:     "left nil",
   360  			left:     nil,
   361  			right:    []string{"hoge", "fuga"},
   362  			expected: false,
   363  		},
   364  		{
   365  			name:     "right nil",
   366  			left:     []string{"hoge", "fuga"},
   367  			right:    nil,
   368  			expected: false,
   369  		},
   370  	}
   371  
   372  	for _, tc := range testCases {
   373  		t.Run(tc.name, func(t *testing.T) {
   374  			if want, got := tc.expected, compareStringSlice(tc.left, tc.right); want != got {
   375  				t.Errorf("unexpected result, want: %v, got: %v", want, got)
   376  			}
   377  		})
   378  	}
   379  }