github.com/yogeshkumararora/slsa-github-generator@v1.10.1-0.20240520161934-11278bd5afb4/github/oidc_test.go (about) 1 // Copyright 2023 SLSA Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package github 16 17 import ( 18 "context" 19 "encoding/base64" 20 "errors" 21 "fmt" 22 "net/http" 23 "net/http/httptest" 24 "os" 25 "testing" 26 "time" 27 28 "github.com/google/go-cmp/cmp" 29 "github.com/google/go-cmp/cmp/cmpopts" 30 ) 31 32 // tokenEqual returns whether the tokens are functionally equal for the purposes of the test. 33 func tokenEqual(issuer string, wantToken, gotToken *OIDCToken) bool { 34 if wantToken == nil && gotToken == nil { 35 return true 36 } 37 38 if gotToken == nil || wantToken == nil { 39 return false 40 } 41 42 // NOTE: don't check the wantToken issuer because it's not known until the 43 // server is created and we can't use a dummy value because verification checks 44 // it. 45 if want, got := issuer, gotToken.Issuer; want != got { 46 return false 47 } 48 49 if want, got := wantToken.Audience, gotToken.Audience; !compareStringSlice(want, got) { 50 return false 51 } 52 53 if want, got := wantToken.Expiry, gotToken.Expiry; !want.Equal(got) { 54 return false 55 } 56 57 if want, got := wantToken.JobWorkflowRef, gotToken.JobWorkflowRef; want != got { 58 return false 59 } 60 61 return true 62 } 63 64 func TestNewOIDCClient(t *testing.T) { 65 // Tests that NewOIDCClient returns an error when the 66 // ACTIONS_ID_TOKEN_REQUEST_URL env var is empty. 67 t.Run("empty url", func(t *testing.T) { 68 if os.Getenv(requestURLEnvKey) != "" { 69 panic(fmt.Sprintf("expected %v to be empty", requestURLEnvKey)) 70 } 71 72 _, err := NewOIDCClient() 73 if err == nil { 74 t.Fatalf("expected error") 75 } 76 if got, want := err, errURLError; !errors.Is(got, want) { 77 t.Fatalf("unexpected error, got: %#v, want: %#v", got, want) 78 } 79 }) 80 } 81 82 func TestToken(t *testing.T) { 83 now := time.Date(2022, 4, 14, 12, 24, 0, 0, time.UTC) 84 85 errClaimsFunc := func(got error) { 86 want := errClaims 87 if !errors.Is(got, want) { 88 t.Fatalf("unexpected error: %v", cmp.Diff(got, want, cmpopts.EquateErrors())) 89 } 90 } 91 92 errVerifyFunc := func(got error) { 93 want := errVerify 94 if !errors.Is(got, want) { 95 t.Fatalf("unexpected error: %v", cmp.Diff(got, want, cmpopts.EquateErrors())) 96 } 97 } 98 99 errTokenFunc := func(got error) { 100 want := errToken 101 if !errors.Is(got, want) { 102 t.Fatalf("unexpected error: %v", cmp.Diff(got, want, cmpopts.EquateErrors())) 103 } 104 } 105 106 errRequestErrorFunc := func(got error) { 107 want := errRequestError 108 if !errors.Is(got, want) { 109 t.Fatalf("unexpected error: %v", cmp.Diff(got, want, cmpopts.EquateErrors())) 110 } 111 } 112 113 testCases := []struct { 114 name string 115 raw string 116 token *OIDCToken 117 err func(error) 118 audience []string 119 status int 120 }{ 121 { 122 name: "basic token", 123 audience: []string{"hoge"}, 124 token: &OIDCToken{ 125 Audience: []string{"hoge"}, 126 Expiry: now.Add(1 * time.Hour), 127 JobWorkflowRef: "pico", 128 RepositoryID: "1234", 129 RepositoryOwnerID: "4321", 130 ActorID: "4567", 131 }, 132 }, 133 { 134 name: "no repository id claim", 135 audience: []string{"hoge"}, 136 token: &OIDCToken{ 137 Audience: []string{"hoge"}, 138 Expiry: now.Add(1 * time.Hour), 139 JobWorkflowRef: "pico", 140 RepositoryOwnerID: "4321", 141 ActorID: "4567", 142 }, 143 err: errClaimsFunc, 144 }, 145 { 146 name: "no workflow ref claim", 147 audience: []string{"hoge"}, 148 token: &OIDCToken{ 149 Audience: []string{"hoge"}, 150 Expiry: now.Add(1 * time.Hour), 151 RepositoryID: "1234", 152 RepositoryOwnerID: "4321", 153 ActorID: "4567", 154 }, 155 err: errClaimsFunc, 156 }, 157 { 158 name: "no owner id claim", 159 audience: []string{"hoge"}, 160 token: &OIDCToken{ 161 Audience: []string{"hoge"}, 162 Expiry: now.Add(1 * time.Hour), 163 JobWorkflowRef: "pico", 164 RepositoryID: "1234", 165 ActorID: "4567", 166 }, 167 err: errClaimsFunc, 168 }, 169 { 170 name: "no actor id claim", 171 audience: []string{"hoge"}, 172 token: &OIDCToken{ 173 Audience: []string{"hoge"}, 174 Expiry: now.Add(1 * time.Hour), 175 JobWorkflowRef: "pico", 176 RepositoryID: "1234", 177 RepositoryOwnerID: "4321", 178 }, 179 err: errClaimsFunc, 180 }, 181 { 182 name: "expired token", 183 audience: []string{"hoge"}, 184 token: &OIDCToken{ 185 Audience: []string{"hoge"}, 186 Expiry: now.Add(-1 * time.Hour), 187 JobWorkflowRef: "pico", 188 RepositoryID: "1234", 189 RepositoryOwnerID: "4321", 190 ActorID: "4567", 191 }, 192 err: errVerifyFunc, 193 }, 194 { 195 name: "bad audience", 196 audience: []string{"hoge"}, 197 token: &OIDCToken{ 198 Audience: []string{"fuga"}, 199 Expiry: now.Add(1 * time.Hour), 200 JobWorkflowRef: "pico", 201 RepositoryID: "1234", 202 RepositoryOwnerID: "4321", 203 ActorID: "4567", 204 }, 205 err: errVerifyFunc, 206 }, 207 { 208 name: "bad issuer", 209 audience: []string{"hoge"}, 210 token: &OIDCToken{ 211 Issuer: "https://www.google.com/", 212 Audience: []string{"hoge"}, 213 Expiry: now.Add(1 * time.Hour), 214 JobWorkflowRef: "pico", 215 RepositoryID: "1234", 216 RepositoryOwnerID: "4321", 217 ActorID: "4567", 218 }, 219 err: errVerifyFunc, 220 }, 221 { 222 name: "invalid parts", 223 audience: []string{"hoge"}, 224 raw: `{"value": "part1"}`, 225 status: http.StatusOK, 226 err: errVerifyFunc, 227 }, 228 { 229 name: "invalid base64", 230 audience: []string{"hoge"}, 231 raw: `{"value": "part1.part2.part3"}`, 232 status: http.StatusOK, 233 err: errVerifyFunc, 234 }, 235 { 236 name: "invalid json part", 237 audience: []string{"hoge"}, 238 raw: fmt.Sprintf(`{"value": "part1.%s.part3"}`, base64.RawURLEncoding.EncodeToString([]byte("not json"))), 239 status: http.StatusOK, 240 err: errVerifyFunc, 241 }, 242 { 243 name: "invalid response", 244 audience: []string{"hoge"}, 245 raw: `not json`, 246 status: http.StatusOK, 247 err: errTokenFunc, 248 }, 249 { 250 name: "error response", 251 audience: []string{"hoge"}, 252 raw: "", 253 status: http.StatusServiceUnavailable, 254 err: errRequestErrorFunc, 255 }, 256 { 257 name: "redirect response", 258 audience: []string{"hoge"}, 259 raw: "", 260 status: http.StatusFound, 261 err: errRequestErrorFunc, 262 }, 263 } 264 265 for _, tc := range testCases { 266 t.Run(tc.name, func(t *testing.T) { 267 var s *httptest.Server 268 var c *OIDCClient 269 if tc.token != nil { 270 s, c = NewTestOIDCServer(t, now, tc.token) 271 } else { 272 s, c = newRawTestOIDCServer(t, now, tc.status, tc.raw) 273 } 274 defer s.Close() 275 276 token, err := c.Token(context.Background(), tc.audience) 277 if err != nil { 278 if tc.err != nil { 279 tc.err(err) 280 } else { 281 t.Fatalf("unexpected error: %v", cmp.Diff(err, tc.err, cmpopts.EquateErrors())) 282 } 283 } else { 284 if tc.err != nil { 285 tc.err(err) 286 } else { 287 // Successful response, as expected. Check token. 288 if want, got := tc.token, token; !tokenEqual(s.URL, want, got) { 289 t.Errorf("unexpected workflow ref\nwant: %#v\ngot: %#v\ndiff:\n%v", want, got, cmp.Diff(want, got)) 290 } 291 } 292 } 293 }) 294 } 295 } 296 297 func Test_compareStringSlice(t *testing.T) { 298 testCases := []struct { 299 name string 300 left []string 301 right []string 302 expected bool 303 }{ 304 { 305 name: "empty", 306 left: []string{}, 307 right: []string{}, 308 expected: true, 309 }, 310 { 311 name: "nil", 312 left: nil, 313 right: nil, 314 expected: true, 315 }, 316 { 317 name: "left nil, right empty", 318 left: nil, 319 right: []string{}, 320 expected: true, 321 }, 322 { 323 name: "left empty, right nil", 324 left: []string{}, 325 right: nil, 326 expected: true, 327 }, 328 { 329 name: "equal", 330 left: []string{"hoge", "fuga"}, 331 right: []string{"hoge", "fuga"}, 332 expected: true, 333 }, 334 { 335 name: "unsorted", 336 left: []string{"hoge", "fuga"}, 337 right: []string{"fuga", "hoge"}, 338 expected: true, 339 }, 340 { 341 name: "left bigger", 342 left: []string{"hoge", "fuga", "pico"}, 343 right: []string{"fuga", "hoge"}, 344 expected: false, 345 }, 346 { 347 name: "right bigger", 348 left: []string{"hoge", "fuga"}, 349 right: []string{"fuga", "hoge", "pico"}, 350 expected: false, 351 }, 352 { 353 name: "diff value", 354 left: []string{"hoge", "fuga"}, 355 right: []string{"fuga", "pico"}, 356 expected: false, 357 }, 358 { 359 name: "left nil", 360 left: nil, 361 right: []string{"hoge", "fuga"}, 362 expected: false, 363 }, 364 { 365 name: "right nil", 366 left: []string{"hoge", "fuga"}, 367 right: nil, 368 expected: false, 369 }, 370 } 371 372 for _, tc := range testCases { 373 t.Run(tc.name, func(t *testing.T) { 374 if want, got := tc.expected, compareStringSlice(tc.left, tc.right); want != got { 375 t.Errorf("unexpected result, want: %v, got: %v", want, got) 376 } 377 }) 378 } 379 }