github.com/opentofu/opentofu@v1.7.1/internal/registry/client_test.go (about) 1 // Copyright (c) The OpenTofu Authors 2 // SPDX-License-Identifier: MPL-2.0 3 // Copyright (c) 2023 HashiCorp, Inc. 4 // SPDX-License-Identifier: MPL-2.0 5 6 package registry 7 8 import ( 9 "context" 10 "errors" 11 "io" 12 "net/http" 13 "os" 14 "reflect" 15 "strings" 16 "testing" 17 "time" 18 19 version "github.com/hashicorp/go-version" 20 "github.com/hashicorp/terraform-svchost/disco" 21 "github.com/opentofu/opentofu/internal/httpclient" 22 "github.com/opentofu/opentofu/internal/registry/regsrc" 23 "github.com/opentofu/opentofu/internal/registry/test" 24 tfversion "github.com/opentofu/opentofu/version" 25 ) 26 27 func TestConfigureDiscoveryRetry(t *testing.T) { 28 t.Run("default retry", func(t *testing.T) { 29 if discoveryRetry != defaultRetry { 30 t.Fatalf("expected retry %q, got %q", defaultRetry, discoveryRetry) 31 } 32 33 rc := NewClient(nil, nil) 34 if rc.client.RetryMax != defaultRetry { 35 t.Fatalf("expected client retry %q, got %q", 36 defaultRetry, rc.client.RetryMax) 37 } 38 }) 39 40 t.Run("configured retry", func(t *testing.T) { 41 defer func() { 42 discoveryRetry = defaultRetry 43 }() 44 t.Setenv(registryDiscoveryRetryEnvName, "2") 45 46 configureDiscoveryRetry() 47 expected := 2 48 if discoveryRetry != expected { 49 t.Fatalf("expected retry %q, got %q", 50 expected, discoveryRetry) 51 } 52 53 rc := NewClient(nil, nil) 54 if rc.client.RetryMax != expected { 55 t.Fatalf("expected client retry %q, got %q", 56 expected, rc.client.RetryMax) 57 } 58 }) 59 } 60 61 func TestConfigureRegistryClientTimeout(t *testing.T) { 62 t.Run("default timeout", func(t *testing.T) { 63 if requestTimeout != defaultRequestTimeout { 64 t.Fatalf("expected timeout %q, got %q", 65 defaultRequestTimeout.String(), requestTimeout.String()) 66 } 67 68 rc := NewClient(nil, nil) 69 if rc.client.HTTPClient.Timeout != defaultRequestTimeout { 70 t.Fatalf("expected client timeout %q, got %q", 71 defaultRequestTimeout.String(), rc.client.HTTPClient.Timeout.String()) 72 } 73 }) 74 75 t.Run("configured timeout", func(t *testing.T) { 76 defer func() { 77 requestTimeout = defaultRequestTimeout 78 }() 79 t.Setenv(registryClientTimeoutEnvName, "20") 80 81 configureRequestTimeout() 82 expected := 20 * time.Second 83 if requestTimeout != expected { 84 t.Fatalf("expected timeout %q, got %q", 85 expected, requestTimeout.String()) 86 } 87 88 rc := NewClient(nil, nil) 89 if rc.client.HTTPClient.Timeout != expected { 90 t.Fatalf("expected client timeout %q, got %q", 91 expected, rc.client.HTTPClient.Timeout.String()) 92 } 93 }) 94 } 95 96 func TestLookupModuleVersions(t *testing.T) { 97 server := test.Registry() 98 defer server.Close() 99 100 client := NewClient(test.Disco(server), nil) 101 102 // test with and without a hostname 103 for _, src := range []string{ 104 "example.com/test-versions/name/provider", 105 "test-versions/name/provider", 106 } { 107 modsrc, err := regsrc.ParseModuleSource(src) 108 if err != nil { 109 t.Fatal(err) 110 } 111 112 resp, err := client.ModuleVersions(context.Background(), modsrc) 113 if err != nil { 114 t.Fatal(err) 115 } 116 117 if len(resp.Modules) != 1 { 118 t.Fatal("expected 1 module, got", len(resp.Modules)) 119 } 120 121 mod := resp.Modules[0] 122 name := "test-versions/name/provider" 123 if mod.Source != name { 124 t.Fatalf("expected module name %q, got %q", name, mod.Source) 125 } 126 127 if len(mod.Versions) != 4 { 128 t.Fatal("expected 4 versions, got", len(mod.Versions)) 129 } 130 131 for _, v := range mod.Versions { 132 _, err := version.NewVersion(v.Version) 133 if err != nil { 134 t.Fatalf("invalid version %q: %s", v.Version, err) 135 } 136 } 137 } 138 } 139 140 func TestInvalidRegistry(t *testing.T) { 141 server := test.Registry() 142 defer server.Close() 143 144 client := NewClient(test.Disco(server), nil) 145 146 src := "non-existent.localhost.localdomain/test-versions/name/provider" 147 modsrc, err := regsrc.ParseModuleSource(src) 148 if err != nil { 149 t.Fatal(err) 150 } 151 152 if _, err := client.ModuleVersions(context.Background(), modsrc); err == nil { 153 t.Fatal("expected error") 154 } 155 } 156 157 func TestRegistryAuth(t *testing.T) { 158 server := test.Registry() 159 defer server.Close() 160 161 client := NewClient(test.Disco(server), nil) 162 163 src := "private/name/provider" 164 mod, err := regsrc.ParseModuleSource(src) 165 if err != nil { 166 t.Fatal(err) 167 } 168 169 _, err = client.ModuleVersions(context.Background(), mod) 170 if err != nil { 171 t.Fatal(err) 172 } 173 _, err = client.ModuleLocation(context.Background(), mod, "1.0.0") 174 if err != nil { 175 t.Fatal(err) 176 } 177 178 // Also test without a credentials source 179 client.services.SetCredentialsSource(nil) 180 181 // both should fail without auth 182 _, err = client.ModuleVersions(context.Background(), mod) 183 if err == nil { 184 t.Fatal("expected error") 185 } 186 _, err = client.ModuleLocation(context.Background(), mod, "1.0.0") 187 if err == nil { 188 t.Fatal("expected error") 189 } 190 } 191 192 func TestLookupModuleLocationRelative(t *testing.T) { 193 server := test.Registry() 194 defer server.Close() 195 196 client := NewClient(test.Disco(server), nil) 197 198 src := "relative/foo/bar" 199 mod, err := regsrc.ParseModuleSource(src) 200 if err != nil { 201 t.Fatal(err) 202 } 203 204 got, err := client.ModuleLocation(context.Background(), mod, "0.2.0") 205 if err != nil { 206 t.Fatal(err) 207 } 208 209 want := server.URL + "/relative-path" 210 if got != want { 211 t.Errorf("wrong location %s; want %s", got, want) 212 } 213 } 214 215 func TestAccLookupModuleVersions(t *testing.T) { 216 if os.Getenv("TF_ACC") == "" { 217 t.Skip() 218 } 219 regDisco := disco.New() 220 regDisco.SetUserAgent(httpclient.OpenTofuUserAgent(tfversion.String())) 221 222 // test with and without a hostname 223 for _, src := range []string{ 224 "terraform-aws-modules/vpc/aws", 225 regsrc.PublicRegistryHost.String() + "/terraform-aws-modules/vpc/aws", 226 } { 227 modsrc, err := regsrc.ParseModuleSource(src) 228 if err != nil { 229 t.Fatal(err) 230 } 231 232 s := NewClient(regDisco, nil) 233 resp, err := s.ModuleVersions(context.Background(), modsrc) 234 if err != nil { 235 t.Fatal(err) 236 } 237 238 if len(resp.Modules) != 1 { 239 t.Fatal("expected 1 module, got", len(resp.Modules)) 240 } 241 242 mod := resp.Modules[0] 243 name := "terraform-aws-modules/vpc/aws" 244 if mod.Source != name { 245 t.Fatalf("expected module name %q, got %q", name, mod.Source) 246 } 247 248 if len(mod.Versions) == 0 { 249 t.Fatal("expected multiple versions, got 0") 250 } 251 252 for _, v := range mod.Versions { 253 _, err := version.NewVersion(v.Version) 254 if err != nil { 255 t.Fatalf("invalid version %q: %s", v.Version, err) 256 } 257 } 258 } 259 } 260 261 // the error should reference the config source exactly, not the discovered path. 262 func TestLookupLookupModuleError(t *testing.T) { 263 server := test.Registry() 264 defer server.Close() 265 266 client := NewClient(test.Disco(server), nil) 267 268 // this should not be found in the registry 269 src := "bad/local/path" 270 mod, err := regsrc.ParseModuleSource(src) 271 if err != nil { 272 t.Fatal(err) 273 } 274 275 // Instrument CheckRetry to make sure 404s are not retried 276 retries := 0 277 oldCheck := client.client.CheckRetry 278 client.client.CheckRetry = func(ctx context.Context, resp *http.Response, err error) (bool, error) { 279 if retries > 0 { 280 t.Fatal("retried after module not found") 281 } 282 retries++ 283 return oldCheck(ctx, resp, err) 284 } 285 286 _, err = client.ModuleLocation(context.Background(), mod, "0.2.0") 287 if err == nil { 288 t.Fatal("expected error") 289 } 290 291 // check for the exact quoted string to ensure we didn't prepend a hostname. 292 if !strings.Contains(err.Error(), `"bad/local/path"`) { 293 t.Fatal("error should not include the hostname. got:", err) 294 } 295 } 296 297 func TestLookupModuleRetryError(t *testing.T) { 298 server := test.RegistryRetryableErrorsServer() 299 defer server.Close() 300 301 client := NewClient(test.Disco(server), nil) 302 303 src := "example.com/test-versions/name/provider" 304 modsrc, err := regsrc.ParseModuleSource(src) 305 if err != nil { 306 t.Fatal(err) 307 } 308 resp, err := client.ModuleVersions(context.Background(), modsrc) 309 if err == nil { 310 t.Fatal("expected requests to exceed retry", err) 311 } 312 if resp != nil { 313 t.Fatal("unexpected response", *resp) 314 } 315 316 // verify maxRetryErrorHandler handler returned the error 317 if !strings.Contains(err.Error(), "the request failed after 2 attempts, please try again later") { 318 t.Fatal("unexpected error, got:", err) 319 } 320 } 321 322 func TestLookupModuleNoRetryError(t *testing.T) { 323 // Disable retries 324 discoveryRetry = 0 325 defer configureDiscoveryRetry() 326 327 server := test.RegistryRetryableErrorsServer() 328 defer server.Close() 329 330 client := NewClient(test.Disco(server), nil) 331 332 src := "example.com/test-versions/name/provider" 333 modsrc, err := regsrc.ParseModuleSource(src) 334 if err != nil { 335 t.Fatal(err) 336 } 337 resp, err := client.ModuleVersions(context.Background(), modsrc) 338 if err == nil { 339 t.Fatal("expected request to fail", err) 340 } 341 if resp != nil { 342 t.Fatal("unexpected response", *resp) 343 } 344 345 // verify maxRetryErrorHandler handler returned the error 346 if !strings.Contains(err.Error(), "the request failed, please try again later") { 347 t.Fatal("unexpected error, got:", err) 348 } 349 } 350 351 func TestLookupModuleNetworkError(t *testing.T) { 352 server := test.RegistryRetryableErrorsServer() 353 client := NewClient(test.Disco(server), nil) 354 355 // Shut down the server to simulate network failure 356 server.Close() 357 358 src := "example.com/test-versions/name/provider" 359 modsrc, err := regsrc.ParseModuleSource(src) 360 if err != nil { 361 t.Fatal(err) 362 } 363 resp, err := client.ModuleVersions(context.Background(), modsrc) 364 if err == nil { 365 t.Fatal("expected request to fail", err) 366 } 367 if resp != nil { 368 t.Fatal("unexpected response", *resp) 369 } 370 371 // verify maxRetryErrorHandler handler returned the correct error 372 if !strings.Contains(err.Error(), "the request failed after 2 attempts, please try again later") { 373 t.Fatal("unexpected error, got:", err) 374 } 375 } 376 377 func TestModuleLocation_readRegistryResponse(t *testing.T) { 378 cases := map[string]struct { 379 src string 380 httpClient *http.Client 381 registryFlags []uint8 382 want string 383 wantErrorStr string 384 wantToReadFromHeader bool 385 wantStatusCode int 386 }{ 387 "shall find the module location in the registry response body": { 388 src: "exists-in-registry/identifier/provider", 389 want: "file:///registry/exists", 390 wantStatusCode: http.StatusOK, 391 httpClient: &http.Client{ 392 Transport: &mockRoundTripper{}, 393 }, 394 }, 395 "shall find the module location in the registry response header": { 396 src: "exists-in-registry/identifier/provider", 397 registryFlags: []uint8{test.WithModuleLocationInHeader}, 398 want: "file:///registry/exists", 399 wantToReadFromHeader: true, 400 wantStatusCode: http.StatusNoContent, 401 httpClient: &http.Client{ 402 Transport: &mockRoundTripper{}, 403 }, 404 }, 405 "shall read location from the registry response body even if the header with location address is also set": { 406 src: "exists-in-registry/identifier/provider", 407 want: "file:///registry/exists", 408 wantStatusCode: http.StatusOK, 409 wantToReadFromHeader: false, 410 registryFlags: []uint8{test.WithModuleLocationInBody, test.WithModuleLocationInHeader}, 411 httpClient: &http.Client{ 412 Transport: &mockRoundTripper{}, 413 }, 414 }, 415 "shall fail to find the module": { 416 src: "not-exist/identifier/provider", 417 // note that the version is fixed in the mock 418 // see: /internal/registry/test/mock_registry.go:testMods 419 wantErrorStr: `module "not-exist/identifier/provider" version "0.2.0" not found`, 420 wantStatusCode: http.StatusNotFound, 421 httpClient: &http.Client{ 422 Transport: &mockRoundTripper{}, 423 }, 424 }, 425 "shall fail because of reading response body error": { 426 src: "foo/bar/baz", 427 wantErrorStr: "error reading response body from registry: foo", 428 wantStatusCode: http.StatusOK, 429 httpClient: &http.Client{ 430 Transport: &mockRoundTripper{ 431 forwardResponse: &http.Response{ 432 StatusCode: http.StatusOK, 433 Body: mockErrorReadCloser{err: errors.New("foo")}, 434 }, 435 }, 436 }, 437 }, 438 "shall fail to deserialize JSON response": { 439 src: "foo/bar/baz", 440 wantErrorStr: `module "foo/bar/baz" version "0.2.0" failed to deserialize response body {: unexpected end of JSON input`, 441 wantStatusCode: http.StatusOK, 442 httpClient: &http.Client{ 443 Transport: &mockRoundTripper{ 444 forwardResponse: &http.Response{ 445 StatusCode: http.StatusOK, 446 Body: io.NopCloser(strings.NewReader("{")), 447 }, 448 }, 449 }, 450 }, 451 "shall fail because of unexpected protocol change - 422 http status": { 452 src: "foo/bar/baz", 453 wantErrorStr: `error getting download location for "foo/bar/baz": foo resp:bar`, 454 wantStatusCode: http.StatusUnprocessableEntity, 455 httpClient: &http.Client{ 456 Transport: &mockRoundTripper{ 457 forwardResponse: &http.Response{ 458 StatusCode: http.StatusUnprocessableEntity, 459 Status: "foo", 460 Body: io.NopCloser(strings.NewReader("bar")), 461 }, 462 }, 463 }, 464 }, 465 "shall fail because location is not found in the response": { 466 src: "foo/bar/baz", 467 wantErrorStr: `failed to get download URL for "foo/bar/baz": OK resp:{"foo":"git::https://github.com/foo/terraform-baz-bar?ref=v0.2.0"}`, 468 wantStatusCode: http.StatusOK, 469 httpClient: &http.Client{ 470 Transport: &mockRoundTripper{ 471 forwardResponse: &http.Response{ 472 StatusCode: http.StatusOK, 473 Status: "OK", 474 // note that the response emulates a contract change 475 Body: io.NopCloser(strings.NewReader(`{"foo":"git::https://github.com/foo/terraform-baz-bar?ref=v0.2.0"}`)), 476 }, 477 }, 478 }, 479 }, 480 } 481 482 t.Parallel() 483 for name, tc := range cases { 484 t.Run(name, func(t *testing.T) { 485 server := test.Registry(tc.registryFlags...) 486 defer server.Close() 487 488 client := NewClient(test.Disco(server), tc.httpClient) 489 490 mod, err := regsrc.ParseModuleSource(tc.src) 491 if err != nil { 492 t.Fatal(err) 493 } 494 495 got, err := client.ModuleLocation(context.Background(), mod, "0.2.0") 496 if err != nil && tc.wantErrorStr == "" { 497 t.Fatalf("unexpected error: %v", err) 498 } 499 if err != nil && err.Error() != tc.wantErrorStr { 500 t.Fatalf("unexpected error content: want=%s, got=%v", tc.wantErrorStr, err) 501 } 502 if got != tc.want { 503 t.Fatalf("unexpected location: want=%s, got=%v", tc.want, got) 504 } 505 506 gotStatusCode := tc.httpClient.Transport.(*mockRoundTripper).reverseResponse.StatusCode 507 if tc.wantStatusCode != gotStatusCode { 508 t.Fatalf("unexpected response status code: want=%d, got=%d", tc.wantStatusCode, gotStatusCode) 509 } 510 511 if tc.wantToReadFromHeader { 512 resp := tc.httpClient.Transport.(*mockRoundTripper).reverseResponse 513 if !reflect.DeepEqual(resp.Body, http.NoBody) { 514 t.Fatalf("expected no body") 515 } 516 } 517 }) 518 } 519 } 520 521 type mockRoundTripper struct { 522 // response to return without calling the server 523 // SET TO USE AS A REVERSE PROXY 524 forwardResponse *http.Response 525 // the response from the server will be written here 526 // DO NOT SET 527 reverseResponse *http.Response 528 err error 529 } 530 531 func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { 532 if m.err != nil { 533 return nil, m.err 534 } 535 if m.forwardResponse != nil { 536 m.reverseResponse = m.forwardResponse 537 return m.forwardResponse, nil 538 } 539 resp, err := http.DefaultTransport.RoundTrip(req) 540 m.reverseResponse = resp 541 return resp, err 542 } 543 544 type mockErrorReadCloser struct { 545 err error 546 } 547 548 func (m mockErrorReadCloser) Read(_ []byte) (n int, err error) { 549 return 0, m.err 550 } 551 552 func (m mockErrorReadCloser) Close() error { 553 return m.err 554 }