github.com/opentofu/opentofu@v1.7.1/internal/registry/client_test.go (about)

     1  // Copyright (c) The OpenTofu Authors
     2  // SPDX-License-Identifier: MPL-2.0
     3  // Copyright (c) 2023 HashiCorp, Inc.
     4  // SPDX-License-Identifier: MPL-2.0
     5  
     6  package registry
     7  
     8  import (
     9  	"context"
    10  	"errors"
    11  	"io"
    12  	"net/http"
    13  	"os"
    14  	"reflect"
    15  	"strings"
    16  	"testing"
    17  	"time"
    18  
    19  	version "github.com/hashicorp/go-version"
    20  	"github.com/hashicorp/terraform-svchost/disco"
    21  	"github.com/opentofu/opentofu/internal/httpclient"
    22  	"github.com/opentofu/opentofu/internal/registry/regsrc"
    23  	"github.com/opentofu/opentofu/internal/registry/test"
    24  	tfversion "github.com/opentofu/opentofu/version"
    25  )
    26  
    27  func TestConfigureDiscoveryRetry(t *testing.T) {
    28  	t.Run("default retry", func(t *testing.T) {
    29  		if discoveryRetry != defaultRetry {
    30  			t.Fatalf("expected retry %q, got %q", defaultRetry, discoveryRetry)
    31  		}
    32  
    33  		rc := NewClient(nil, nil)
    34  		if rc.client.RetryMax != defaultRetry {
    35  			t.Fatalf("expected client retry %q, got %q",
    36  				defaultRetry, rc.client.RetryMax)
    37  		}
    38  	})
    39  
    40  	t.Run("configured retry", func(t *testing.T) {
    41  		defer func() {
    42  			discoveryRetry = defaultRetry
    43  		}()
    44  		t.Setenv(registryDiscoveryRetryEnvName, "2")
    45  
    46  		configureDiscoveryRetry()
    47  		expected := 2
    48  		if discoveryRetry != expected {
    49  			t.Fatalf("expected retry %q, got %q",
    50  				expected, discoveryRetry)
    51  		}
    52  
    53  		rc := NewClient(nil, nil)
    54  		if rc.client.RetryMax != expected {
    55  			t.Fatalf("expected client retry %q, got %q",
    56  				expected, rc.client.RetryMax)
    57  		}
    58  	})
    59  }
    60  
    61  func TestConfigureRegistryClientTimeout(t *testing.T) {
    62  	t.Run("default timeout", func(t *testing.T) {
    63  		if requestTimeout != defaultRequestTimeout {
    64  			t.Fatalf("expected timeout %q, got %q",
    65  				defaultRequestTimeout.String(), requestTimeout.String())
    66  		}
    67  
    68  		rc := NewClient(nil, nil)
    69  		if rc.client.HTTPClient.Timeout != defaultRequestTimeout {
    70  			t.Fatalf("expected client timeout %q, got %q",
    71  				defaultRequestTimeout.String(), rc.client.HTTPClient.Timeout.String())
    72  		}
    73  	})
    74  
    75  	t.Run("configured timeout", func(t *testing.T) {
    76  		defer func() {
    77  			requestTimeout = defaultRequestTimeout
    78  		}()
    79  		t.Setenv(registryClientTimeoutEnvName, "20")
    80  
    81  		configureRequestTimeout()
    82  		expected := 20 * time.Second
    83  		if requestTimeout != expected {
    84  			t.Fatalf("expected timeout %q, got %q",
    85  				expected, requestTimeout.String())
    86  		}
    87  
    88  		rc := NewClient(nil, nil)
    89  		if rc.client.HTTPClient.Timeout != expected {
    90  			t.Fatalf("expected client timeout %q, got %q",
    91  				expected, rc.client.HTTPClient.Timeout.String())
    92  		}
    93  	})
    94  }
    95  
    96  func TestLookupModuleVersions(t *testing.T) {
    97  	server := test.Registry()
    98  	defer server.Close()
    99  
   100  	client := NewClient(test.Disco(server), nil)
   101  
   102  	// test with and without a hostname
   103  	for _, src := range []string{
   104  		"example.com/test-versions/name/provider",
   105  		"test-versions/name/provider",
   106  	} {
   107  		modsrc, err := regsrc.ParseModuleSource(src)
   108  		if err != nil {
   109  			t.Fatal(err)
   110  		}
   111  
   112  		resp, err := client.ModuleVersions(context.Background(), modsrc)
   113  		if err != nil {
   114  			t.Fatal(err)
   115  		}
   116  
   117  		if len(resp.Modules) != 1 {
   118  			t.Fatal("expected 1 module, got", len(resp.Modules))
   119  		}
   120  
   121  		mod := resp.Modules[0]
   122  		name := "test-versions/name/provider"
   123  		if mod.Source != name {
   124  			t.Fatalf("expected module name %q, got %q", name, mod.Source)
   125  		}
   126  
   127  		if len(mod.Versions) != 4 {
   128  			t.Fatal("expected 4 versions, got", len(mod.Versions))
   129  		}
   130  
   131  		for _, v := range mod.Versions {
   132  			_, err := version.NewVersion(v.Version)
   133  			if err != nil {
   134  				t.Fatalf("invalid version %q: %s", v.Version, err)
   135  			}
   136  		}
   137  	}
   138  }
   139  
   140  func TestInvalidRegistry(t *testing.T) {
   141  	server := test.Registry()
   142  	defer server.Close()
   143  
   144  	client := NewClient(test.Disco(server), nil)
   145  
   146  	src := "non-existent.localhost.localdomain/test-versions/name/provider"
   147  	modsrc, err := regsrc.ParseModuleSource(src)
   148  	if err != nil {
   149  		t.Fatal(err)
   150  	}
   151  
   152  	if _, err := client.ModuleVersions(context.Background(), modsrc); err == nil {
   153  		t.Fatal("expected error")
   154  	}
   155  }
   156  
   157  func TestRegistryAuth(t *testing.T) {
   158  	server := test.Registry()
   159  	defer server.Close()
   160  
   161  	client := NewClient(test.Disco(server), nil)
   162  
   163  	src := "private/name/provider"
   164  	mod, err := regsrc.ParseModuleSource(src)
   165  	if err != nil {
   166  		t.Fatal(err)
   167  	}
   168  
   169  	_, err = client.ModuleVersions(context.Background(), mod)
   170  	if err != nil {
   171  		t.Fatal(err)
   172  	}
   173  	_, err = client.ModuleLocation(context.Background(), mod, "1.0.0")
   174  	if err != nil {
   175  		t.Fatal(err)
   176  	}
   177  
   178  	// Also test without a credentials source
   179  	client.services.SetCredentialsSource(nil)
   180  
   181  	// both should fail without auth
   182  	_, err = client.ModuleVersions(context.Background(), mod)
   183  	if err == nil {
   184  		t.Fatal("expected error")
   185  	}
   186  	_, err = client.ModuleLocation(context.Background(), mod, "1.0.0")
   187  	if err == nil {
   188  		t.Fatal("expected error")
   189  	}
   190  }
   191  
   192  func TestLookupModuleLocationRelative(t *testing.T) {
   193  	server := test.Registry()
   194  	defer server.Close()
   195  
   196  	client := NewClient(test.Disco(server), nil)
   197  
   198  	src := "relative/foo/bar"
   199  	mod, err := regsrc.ParseModuleSource(src)
   200  	if err != nil {
   201  		t.Fatal(err)
   202  	}
   203  
   204  	got, err := client.ModuleLocation(context.Background(), mod, "0.2.0")
   205  	if err != nil {
   206  		t.Fatal(err)
   207  	}
   208  
   209  	want := server.URL + "/relative-path"
   210  	if got != want {
   211  		t.Errorf("wrong location %s; want %s", got, want)
   212  	}
   213  }
   214  
   215  func TestAccLookupModuleVersions(t *testing.T) {
   216  	if os.Getenv("TF_ACC") == "" {
   217  		t.Skip()
   218  	}
   219  	regDisco := disco.New()
   220  	regDisco.SetUserAgent(httpclient.OpenTofuUserAgent(tfversion.String()))
   221  
   222  	// test with and without a hostname
   223  	for _, src := range []string{
   224  		"terraform-aws-modules/vpc/aws",
   225  		regsrc.PublicRegistryHost.String() + "/terraform-aws-modules/vpc/aws",
   226  	} {
   227  		modsrc, err := regsrc.ParseModuleSource(src)
   228  		if err != nil {
   229  			t.Fatal(err)
   230  		}
   231  
   232  		s := NewClient(regDisco, nil)
   233  		resp, err := s.ModuleVersions(context.Background(), modsrc)
   234  		if err != nil {
   235  			t.Fatal(err)
   236  		}
   237  
   238  		if len(resp.Modules) != 1 {
   239  			t.Fatal("expected 1 module, got", len(resp.Modules))
   240  		}
   241  
   242  		mod := resp.Modules[0]
   243  		name := "terraform-aws-modules/vpc/aws"
   244  		if mod.Source != name {
   245  			t.Fatalf("expected module name %q, got %q", name, mod.Source)
   246  		}
   247  
   248  		if len(mod.Versions) == 0 {
   249  			t.Fatal("expected multiple versions, got 0")
   250  		}
   251  
   252  		for _, v := range mod.Versions {
   253  			_, err := version.NewVersion(v.Version)
   254  			if err != nil {
   255  				t.Fatalf("invalid version %q: %s", v.Version, err)
   256  			}
   257  		}
   258  	}
   259  }
   260  
   261  // the error should reference the config source exactly, not the discovered path.
   262  func TestLookupLookupModuleError(t *testing.T) {
   263  	server := test.Registry()
   264  	defer server.Close()
   265  
   266  	client := NewClient(test.Disco(server), nil)
   267  
   268  	// this should not be found in the registry
   269  	src := "bad/local/path"
   270  	mod, err := regsrc.ParseModuleSource(src)
   271  	if err != nil {
   272  		t.Fatal(err)
   273  	}
   274  
   275  	// Instrument CheckRetry to make sure 404s are not retried
   276  	retries := 0
   277  	oldCheck := client.client.CheckRetry
   278  	client.client.CheckRetry = func(ctx context.Context, resp *http.Response, err error) (bool, error) {
   279  		if retries > 0 {
   280  			t.Fatal("retried after module not found")
   281  		}
   282  		retries++
   283  		return oldCheck(ctx, resp, err)
   284  	}
   285  
   286  	_, err = client.ModuleLocation(context.Background(), mod, "0.2.0")
   287  	if err == nil {
   288  		t.Fatal("expected error")
   289  	}
   290  
   291  	// check for the exact quoted string to ensure we didn't prepend a hostname.
   292  	if !strings.Contains(err.Error(), `"bad/local/path"`) {
   293  		t.Fatal("error should not include the hostname. got:", err)
   294  	}
   295  }
   296  
   297  func TestLookupModuleRetryError(t *testing.T) {
   298  	server := test.RegistryRetryableErrorsServer()
   299  	defer server.Close()
   300  
   301  	client := NewClient(test.Disco(server), nil)
   302  
   303  	src := "example.com/test-versions/name/provider"
   304  	modsrc, err := regsrc.ParseModuleSource(src)
   305  	if err != nil {
   306  		t.Fatal(err)
   307  	}
   308  	resp, err := client.ModuleVersions(context.Background(), modsrc)
   309  	if err == nil {
   310  		t.Fatal("expected requests to exceed retry", err)
   311  	}
   312  	if resp != nil {
   313  		t.Fatal("unexpected response", *resp)
   314  	}
   315  
   316  	// verify maxRetryErrorHandler handler returned the error
   317  	if !strings.Contains(err.Error(), "the request failed after 2 attempts, please try again later") {
   318  		t.Fatal("unexpected error, got:", err)
   319  	}
   320  }
   321  
   322  func TestLookupModuleNoRetryError(t *testing.T) {
   323  	// Disable retries
   324  	discoveryRetry = 0
   325  	defer configureDiscoveryRetry()
   326  
   327  	server := test.RegistryRetryableErrorsServer()
   328  	defer server.Close()
   329  
   330  	client := NewClient(test.Disco(server), nil)
   331  
   332  	src := "example.com/test-versions/name/provider"
   333  	modsrc, err := regsrc.ParseModuleSource(src)
   334  	if err != nil {
   335  		t.Fatal(err)
   336  	}
   337  	resp, err := client.ModuleVersions(context.Background(), modsrc)
   338  	if err == nil {
   339  		t.Fatal("expected request to fail", err)
   340  	}
   341  	if resp != nil {
   342  		t.Fatal("unexpected response", *resp)
   343  	}
   344  
   345  	// verify maxRetryErrorHandler handler returned the error
   346  	if !strings.Contains(err.Error(), "the request failed, please try again later") {
   347  		t.Fatal("unexpected error, got:", err)
   348  	}
   349  }
   350  
   351  func TestLookupModuleNetworkError(t *testing.T) {
   352  	server := test.RegistryRetryableErrorsServer()
   353  	client := NewClient(test.Disco(server), nil)
   354  
   355  	// Shut down the server to simulate network failure
   356  	server.Close()
   357  
   358  	src := "example.com/test-versions/name/provider"
   359  	modsrc, err := regsrc.ParseModuleSource(src)
   360  	if err != nil {
   361  		t.Fatal(err)
   362  	}
   363  	resp, err := client.ModuleVersions(context.Background(), modsrc)
   364  	if err == nil {
   365  		t.Fatal("expected request to fail", err)
   366  	}
   367  	if resp != nil {
   368  		t.Fatal("unexpected response", *resp)
   369  	}
   370  
   371  	// verify maxRetryErrorHandler handler returned the correct error
   372  	if !strings.Contains(err.Error(), "the request failed after 2 attempts, please try again later") {
   373  		t.Fatal("unexpected error, got:", err)
   374  	}
   375  }
   376  
   377  func TestModuleLocation_readRegistryResponse(t *testing.T) {
   378  	cases := map[string]struct {
   379  		src                  string
   380  		httpClient           *http.Client
   381  		registryFlags        []uint8
   382  		want                 string
   383  		wantErrorStr         string
   384  		wantToReadFromHeader bool
   385  		wantStatusCode       int
   386  	}{
   387  		"shall find the module location in the registry response body": {
   388  			src:            "exists-in-registry/identifier/provider",
   389  			want:           "file:///registry/exists",
   390  			wantStatusCode: http.StatusOK,
   391  			httpClient: &http.Client{
   392  				Transport: &mockRoundTripper{},
   393  			},
   394  		},
   395  		"shall find the module location in the registry response header": {
   396  			src:                  "exists-in-registry/identifier/provider",
   397  			registryFlags:        []uint8{test.WithModuleLocationInHeader},
   398  			want:                 "file:///registry/exists",
   399  			wantToReadFromHeader: true,
   400  			wantStatusCode:       http.StatusNoContent,
   401  			httpClient: &http.Client{
   402  				Transport: &mockRoundTripper{},
   403  			},
   404  		},
   405  		"shall read location from the registry response body even if the header with location address is also set": {
   406  			src:                  "exists-in-registry/identifier/provider",
   407  			want:                 "file:///registry/exists",
   408  			wantStatusCode:       http.StatusOK,
   409  			wantToReadFromHeader: false,
   410  			registryFlags:        []uint8{test.WithModuleLocationInBody, test.WithModuleLocationInHeader},
   411  			httpClient: &http.Client{
   412  				Transport: &mockRoundTripper{},
   413  			},
   414  		},
   415  		"shall fail to find the module": {
   416  			src: "not-exist/identifier/provider",
   417  			// note that the version is fixed in the mock
   418  			// see: /internal/registry/test/mock_registry.go:testMods
   419  			wantErrorStr:   `module "not-exist/identifier/provider" version "0.2.0" not found`,
   420  			wantStatusCode: http.StatusNotFound,
   421  			httpClient: &http.Client{
   422  				Transport: &mockRoundTripper{},
   423  			},
   424  		},
   425  		"shall fail because of reading response body error": {
   426  			src:            "foo/bar/baz",
   427  			wantErrorStr:   "error reading response body from registry: foo",
   428  			wantStatusCode: http.StatusOK,
   429  			httpClient: &http.Client{
   430  				Transport: &mockRoundTripper{
   431  					forwardResponse: &http.Response{
   432  						StatusCode: http.StatusOK,
   433  						Body:       mockErrorReadCloser{err: errors.New("foo")},
   434  					},
   435  				},
   436  			},
   437  		},
   438  		"shall fail to deserialize JSON response": {
   439  			src:            "foo/bar/baz",
   440  			wantErrorStr:   `module "foo/bar/baz" version "0.2.0" failed to deserialize response body {: unexpected end of JSON input`,
   441  			wantStatusCode: http.StatusOK,
   442  			httpClient: &http.Client{
   443  				Transport: &mockRoundTripper{
   444  					forwardResponse: &http.Response{
   445  						StatusCode: http.StatusOK,
   446  						Body:       io.NopCloser(strings.NewReader("{")),
   447  					},
   448  				},
   449  			},
   450  		},
   451  		"shall fail because of unexpected protocol change - 422 http status": {
   452  			src:            "foo/bar/baz",
   453  			wantErrorStr:   `error getting download location for "foo/bar/baz": foo resp:bar`,
   454  			wantStatusCode: http.StatusUnprocessableEntity,
   455  			httpClient: &http.Client{
   456  				Transport: &mockRoundTripper{
   457  					forwardResponse: &http.Response{
   458  						StatusCode: http.StatusUnprocessableEntity,
   459  						Status:     "foo",
   460  						Body:       io.NopCloser(strings.NewReader("bar")),
   461  					},
   462  				},
   463  			},
   464  		},
   465  		"shall fail because location is not found in the response": {
   466  			src:            "foo/bar/baz",
   467  			wantErrorStr:   `failed to get download URL for "foo/bar/baz": OK resp:{"foo":"git::https://github.com/foo/terraform-baz-bar?ref=v0.2.0"}`,
   468  			wantStatusCode: http.StatusOK,
   469  			httpClient: &http.Client{
   470  				Transport: &mockRoundTripper{
   471  					forwardResponse: &http.Response{
   472  						StatusCode: http.StatusOK,
   473  						Status:     "OK",
   474  						// note that the response emulates a contract change
   475  						Body: io.NopCloser(strings.NewReader(`{"foo":"git::https://github.com/foo/terraform-baz-bar?ref=v0.2.0"}`)),
   476  					},
   477  				},
   478  			},
   479  		},
   480  	}
   481  
   482  	t.Parallel()
   483  	for name, tc := range cases {
   484  		t.Run(name, func(t *testing.T) {
   485  			server := test.Registry(tc.registryFlags...)
   486  			defer server.Close()
   487  
   488  			client := NewClient(test.Disco(server), tc.httpClient)
   489  
   490  			mod, err := regsrc.ParseModuleSource(tc.src)
   491  			if err != nil {
   492  				t.Fatal(err)
   493  			}
   494  
   495  			got, err := client.ModuleLocation(context.Background(), mod, "0.2.0")
   496  			if err != nil && tc.wantErrorStr == "" {
   497  				t.Fatalf("unexpected error: %v", err)
   498  			}
   499  			if err != nil && err.Error() != tc.wantErrorStr {
   500  				t.Fatalf("unexpected error content: want=%s, got=%v", tc.wantErrorStr, err)
   501  			}
   502  			if got != tc.want {
   503  				t.Fatalf("unexpected location: want=%s, got=%v", tc.want, got)
   504  			}
   505  
   506  			gotStatusCode := tc.httpClient.Transport.(*mockRoundTripper).reverseResponse.StatusCode
   507  			if tc.wantStatusCode != gotStatusCode {
   508  				t.Fatalf("unexpected response status code: want=%d, got=%d", tc.wantStatusCode, gotStatusCode)
   509  			}
   510  
   511  			if tc.wantToReadFromHeader {
   512  				resp := tc.httpClient.Transport.(*mockRoundTripper).reverseResponse
   513  				if !reflect.DeepEqual(resp.Body, http.NoBody) {
   514  					t.Fatalf("expected no body")
   515  				}
   516  			}
   517  		})
   518  	}
   519  }
   520  
   521  type mockRoundTripper struct {
   522  	// response to return without calling the server
   523  	// SET TO USE AS A REVERSE PROXY
   524  	forwardResponse *http.Response
   525  	// the response from the server will be written here
   526  	// DO NOT SET
   527  	reverseResponse *http.Response
   528  	err             error
   529  }
   530  
   531  func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
   532  	if m.err != nil {
   533  		return nil, m.err
   534  	}
   535  	if m.forwardResponse != nil {
   536  		m.reverseResponse = m.forwardResponse
   537  		return m.forwardResponse, nil
   538  	}
   539  	resp, err := http.DefaultTransport.RoundTrip(req)
   540  	m.reverseResponse = resp
   541  	return resp, err
   542  }
   543  
   544  type mockErrorReadCloser struct {
   545  	err error
   546  }
   547  
   548  func (m mockErrorReadCloser) Read(_ []byte) (n int, err error) {
   549  	return 0, m.err
   550  }
   551  
   552  func (m mockErrorReadCloser) Close() error {
   553  	return m.err
   554  }