github.com/rstandt/terraform@v0.12.32-0.20230710220336-b1063613405c/registry/client_test.go (about)

     1  package registry
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net/http"
     7  	"os"
     8  	"strings"
     9  	"testing"
    10  	"time"
    11  
    12  	version "github.com/hashicorp/go-version"
    13  	"github.com/hashicorp/terraform-svchost/disco"
    14  	"github.com/hashicorp/terraform/httpclient"
    15  	"github.com/hashicorp/terraform/registry/regsrc"
    16  	"github.com/hashicorp/terraform/registry/test"
    17  	tfversion "github.com/hashicorp/terraform/version"
    18  )
    19  
    20  func TestConfigureDiscoveryRetry(t *testing.T) {
    21  	t.Run("default retry", func(t *testing.T) {
    22  		if discoveryRetry != defaultRetry {
    23  			t.Fatalf("expected retry %q, got %q", defaultRetry, discoveryRetry)
    24  		}
    25  
    26  		rc := NewClient(nil, nil)
    27  		if rc.client.RetryMax != defaultRetry {
    28  			t.Fatalf("expected client retry %q, got %q",
    29  				defaultRetry, rc.client.RetryMax)
    30  		}
    31  	})
    32  
    33  	t.Run("configured retry", func(t *testing.T) {
    34  		defer func(retryEnv string) {
    35  			os.Setenv(registryDiscoveryRetryEnvName, retryEnv)
    36  			discoveryRetry = defaultRetry
    37  		}(os.Getenv(registryDiscoveryRetryEnvName))
    38  		os.Setenv(registryDiscoveryRetryEnvName, "2")
    39  
    40  		configureDiscoveryRetry()
    41  		expected := 2
    42  		if discoveryRetry != expected {
    43  			t.Fatalf("expected retry %q, got %q",
    44  				expected, discoveryRetry)
    45  		}
    46  
    47  		rc := NewClient(nil, nil)
    48  		if rc.client.RetryMax != expected {
    49  			t.Fatalf("expected client retry %q, got %q",
    50  				expected, rc.client.RetryMax)
    51  		}
    52  	})
    53  }
    54  
    55  func TestConfigureRegistryClientTimeout(t *testing.T) {
    56  	t.Run("default timeout", func(t *testing.T) {
    57  		if requestTimeout != defaultRequestTimeout {
    58  			t.Fatalf("expected timeout %q, got %q",
    59  				defaultRequestTimeout.String(), requestTimeout.String())
    60  		}
    61  
    62  		rc := NewClient(nil, nil)
    63  		if rc.client.HTTPClient.Timeout != defaultRequestTimeout {
    64  			t.Fatalf("expected client timeout %q, got %q",
    65  				defaultRequestTimeout.String(), rc.client.HTTPClient.Timeout.String())
    66  		}
    67  	})
    68  
    69  	t.Run("configured timeout", func(t *testing.T) {
    70  		defer func(timeoutEnv string) {
    71  			os.Setenv(registryClientTimeoutEnvName, timeoutEnv)
    72  			requestTimeout = defaultRequestTimeout
    73  		}(os.Getenv(registryClientTimeoutEnvName))
    74  		os.Setenv(registryClientTimeoutEnvName, "20")
    75  
    76  		configureRequestTimeout()
    77  		expected := 20 * time.Second
    78  		if requestTimeout != expected {
    79  			t.Fatalf("expected timeout %q, got %q",
    80  				expected, requestTimeout.String())
    81  		}
    82  
    83  		rc := NewClient(nil, nil)
    84  		if rc.client.HTTPClient.Timeout != expected {
    85  			t.Fatalf("expected client timeout %q, got %q",
    86  				expected, rc.client.HTTPClient.Timeout.String())
    87  		}
    88  	})
    89  }
    90  
    91  func TestLookupModuleVersions(t *testing.T) {
    92  	server := test.Registry()
    93  	defer server.Close()
    94  
    95  	client := NewClient(test.Disco(server), nil)
    96  
    97  	// test with and without a hostname
    98  	for _, src := range []string{
    99  		"example.com/test-versions/name/provider",
   100  		"test-versions/name/provider",
   101  	} {
   102  		modsrc, err := regsrc.ParseModuleSource(src)
   103  		if err != nil {
   104  			t.Fatal(err)
   105  		}
   106  
   107  		resp, err := client.ModuleVersions(modsrc)
   108  		if err != nil {
   109  			t.Fatal(err)
   110  		}
   111  
   112  		if len(resp.Modules) != 1 {
   113  			t.Fatal("expected 1 module, got", len(resp.Modules))
   114  		}
   115  
   116  		mod := resp.Modules[0]
   117  		name := "test-versions/name/provider"
   118  		if mod.Source != name {
   119  			t.Fatalf("expected module name %q, got %q", name, mod.Source)
   120  		}
   121  
   122  		if len(mod.Versions) != 4 {
   123  			t.Fatal("expected 4 versions, got", len(mod.Versions))
   124  		}
   125  
   126  		for _, v := range mod.Versions {
   127  			_, err := version.NewVersion(v.Version)
   128  			if err != nil {
   129  				t.Fatalf("invalid version %q: %s", v.Version, err)
   130  			}
   131  		}
   132  	}
   133  }
   134  
   135  func TestInvalidRegistry(t *testing.T) {
   136  	server := test.Registry()
   137  	defer server.Close()
   138  
   139  	client := NewClient(test.Disco(server), nil)
   140  
   141  	src := "non-existent.localhost.localdomain/test-versions/name/provider"
   142  	modsrc, err := regsrc.ParseModuleSource(src)
   143  	if err != nil {
   144  		t.Fatal(err)
   145  	}
   146  
   147  	if _, err := client.ModuleVersions(modsrc); err == nil {
   148  		t.Fatal("expected error")
   149  	}
   150  }
   151  
   152  func TestRegistryAuth(t *testing.T) {
   153  	server := test.Registry()
   154  	defer server.Close()
   155  
   156  	client := NewClient(test.Disco(server), nil)
   157  
   158  	src := "private/name/provider"
   159  	mod, err := regsrc.ParseModuleSource(src)
   160  	if err != nil {
   161  		t.Fatal(err)
   162  	}
   163  
   164  	_, err = client.ModuleVersions(mod)
   165  	if err != nil {
   166  		t.Fatal(err)
   167  	}
   168  	_, err = client.ModuleLocation(mod, "1.0.0")
   169  	if err != nil {
   170  		t.Fatal(err)
   171  	}
   172  
   173  	// Also test without a credentials source
   174  	client.services.SetCredentialsSource(nil)
   175  
   176  	// both should fail without auth
   177  	_, err = client.ModuleVersions(mod)
   178  	if err == nil {
   179  		t.Fatal("expected error")
   180  	}
   181  	_, err = client.ModuleLocation(mod, "1.0.0")
   182  	if err == nil {
   183  		t.Fatal("expected error")
   184  	}
   185  }
   186  
   187  func TestLookupModuleLocationRelative(t *testing.T) {
   188  	server := test.Registry()
   189  	defer server.Close()
   190  
   191  	client := NewClient(test.Disco(server), nil)
   192  
   193  	src := "relative/foo/bar"
   194  	mod, err := regsrc.ParseModuleSource(src)
   195  	if err != nil {
   196  		t.Fatal(err)
   197  	}
   198  
   199  	got, err := client.ModuleLocation(mod, "0.2.0")
   200  	if err != nil {
   201  		t.Fatal(err)
   202  	}
   203  
   204  	want := server.URL + "/relative-path"
   205  	if got != want {
   206  		t.Errorf("wrong location %s; want %s", got, want)
   207  	}
   208  }
   209  
   210  func TestAccLookupModuleVersions(t *testing.T) {
   211  	if os.Getenv("TF_ACC") == "" {
   212  		t.Skip()
   213  	}
   214  	regDisco := disco.New()
   215  	regDisco.SetUserAgent(httpclient.TerraformUserAgent(tfversion.String()))
   216  
   217  	// test with and without a hostname
   218  	for _, src := range []string{
   219  		"terraform-aws-modules/vpc/aws",
   220  		regsrc.PublicRegistryHost.String() + "/terraform-aws-modules/vpc/aws",
   221  	} {
   222  		modsrc, err := regsrc.ParseModuleSource(src)
   223  		if err != nil {
   224  			t.Fatal(err)
   225  		}
   226  
   227  		s := NewClient(regDisco, nil)
   228  		resp, err := s.ModuleVersions(modsrc)
   229  		if err != nil {
   230  			t.Fatal(err)
   231  		}
   232  
   233  		if len(resp.Modules) != 1 {
   234  			t.Fatal("expected 1 module, got", len(resp.Modules))
   235  		}
   236  
   237  		mod := resp.Modules[0]
   238  		name := "terraform-aws-modules/vpc/aws"
   239  		if mod.Source != name {
   240  			t.Fatalf("expected module name %q, got %q", name, mod.Source)
   241  		}
   242  
   243  		if len(mod.Versions) == 0 {
   244  			t.Fatal("expected multiple versions, got 0")
   245  		}
   246  
   247  		for _, v := range mod.Versions {
   248  			_, err := version.NewVersion(v.Version)
   249  			if err != nil {
   250  				t.Fatalf("invalid version %q: %s", v.Version, err)
   251  			}
   252  		}
   253  	}
   254  }
   255  
   256  // the error should reference the config source exactly, not the discovered path.
   257  func TestLookupLookupModuleError(t *testing.T) {
   258  	server := test.Registry()
   259  	defer server.Close()
   260  
   261  	client := NewClient(test.Disco(server), nil)
   262  
   263  	// this should not be found in the registry
   264  	src := "bad/local/path"
   265  	mod, err := regsrc.ParseModuleSource(src)
   266  	if err != nil {
   267  		t.Fatal(err)
   268  	}
   269  
   270  	// Instrument CheckRetry to make sure 404s are not retried
   271  	retries := 0
   272  	oldCheck := client.client.CheckRetry
   273  	client.client.CheckRetry = func(ctx context.Context, resp *http.Response, err error) (bool, error) {
   274  		if retries > 0 {
   275  			t.Fatal("retried after module not found")
   276  		}
   277  		retries++
   278  		return oldCheck(ctx, resp, err)
   279  	}
   280  
   281  	_, err = client.ModuleLocation(mod, "0.2.0")
   282  	if err == nil {
   283  		t.Fatal("expected error")
   284  	}
   285  
   286  	// check for the exact quoted string to ensure we didn't prepend a hostname.
   287  	if !strings.Contains(err.Error(), `"bad/local/path"`) {
   288  		t.Fatal("error should not include the hostname. got:", err)
   289  	}
   290  }
   291  
   292  func TestLookupModuleRetryError(t *testing.T) {
   293  	server := test.RegistryRetryableErrorsServer()
   294  	defer server.Close()
   295  
   296  	client := NewClient(test.Disco(server), nil)
   297  
   298  	src := "example.com/test-versions/name/provider"
   299  	modsrc, err := regsrc.ParseModuleSource(src)
   300  	if err != nil {
   301  		t.Fatal(err)
   302  	}
   303  	resp, err := client.ModuleVersions(modsrc)
   304  	if err == nil {
   305  		t.Fatal("expected requests to exceed retry", err)
   306  	}
   307  	if resp != nil {
   308  		t.Fatal("unexpected response", *resp)
   309  	}
   310  
   311  	// verify maxRetryErrorHandler handler returned the error
   312  	if !strings.Contains(err.Error(), "the request failed after 2 attempts, please try again later") {
   313  		t.Fatal("unexpected error, got:", err)
   314  	}
   315  }
   316  
   317  func TestLookupModuleNoRetryError(t *testing.T) {
   318  	// Disable retries
   319  	discoveryRetry = 0
   320  	defer configureDiscoveryRetry()
   321  
   322  	server := test.RegistryRetryableErrorsServer()
   323  	defer server.Close()
   324  
   325  	client := NewClient(test.Disco(server), nil)
   326  
   327  	src := "example.com/test-versions/name/provider"
   328  	modsrc, err := regsrc.ParseModuleSource(src)
   329  	if err != nil {
   330  		t.Fatal(err)
   331  	}
   332  	resp, err := client.ModuleVersions(modsrc)
   333  	if err == nil {
   334  		t.Fatal("expected request to fail", err)
   335  	}
   336  	if resp != nil {
   337  		t.Fatal("unexpected response", *resp)
   338  	}
   339  
   340  	// verify maxRetryErrorHandler handler returned the error
   341  	if !strings.Contains(err.Error(), "the request failed, please try again later") {
   342  		t.Fatal("unexpected error, got:", err)
   343  	}
   344  }
   345  
   346  func TestLookupModuleNetworkError(t *testing.T) {
   347  	server := test.RegistryRetryableErrorsServer()
   348  	client := NewClient(test.Disco(server), nil)
   349  
   350  	// Shut down the server to simulate network failure
   351  	server.Close()
   352  
   353  	src := "example.com/test-versions/name/provider"
   354  	modsrc, err := regsrc.ParseModuleSource(src)
   355  	if err != nil {
   356  		t.Fatal(err)
   357  	}
   358  	resp, err := client.ModuleVersions(modsrc)
   359  	if err == nil {
   360  		t.Fatal("expected request to fail", err)
   361  	}
   362  	if resp != nil {
   363  		t.Fatal("unexpected response", *resp)
   364  	}
   365  
   366  	// verify maxRetryErrorHandler handler returned the correct error
   367  	if !strings.Contains(err.Error(), "the request failed after 2 attempts, please try again later") {
   368  		t.Fatal("unexpected error, got:", err)
   369  	}
   370  }
   371  
   372  func TestLookupProviderVersions(t *testing.T) {
   373  	server := test.Registry()
   374  	defer server.Close()
   375  
   376  	client := NewClient(test.Disco(server), nil)
   377  
   378  	tests := []struct {
   379  		name string
   380  	}{
   381  		{"foo"},
   382  		{"bar"},
   383  	}
   384  	for _, tt := range tests {
   385  		provider := regsrc.NewTerraformProvider(tt.name, "", "")
   386  		resp, err := client.TerraformProviderVersions(provider)
   387  		if err != nil {
   388  			t.Fatal(err)
   389  		}
   390  
   391  		name := fmt.Sprintf("terraform-providers/%s", tt.name)
   392  		if resp.ID != name {
   393  			t.Fatalf("expected provider name %q, got %q", name, resp.ID)
   394  		}
   395  
   396  		if len(resp.Versions) != 2 {
   397  			t.Fatal("expected 2 versions, got", len(resp.Versions))
   398  		}
   399  
   400  		for _, v := range resp.Versions {
   401  			_, err := version.NewVersion(v.Version)
   402  			if err != nil {
   403  				t.Fatalf("invalid version %#v: %v", v, err)
   404  			}
   405  		}
   406  	}
   407  }
   408  
   409  func TestLookupProviderLocation(t *testing.T) {
   410  	server := test.Registry()
   411  	defer server.Close()
   412  
   413  	client := NewClient(test.Disco(server), nil)
   414  
   415  	tests := []struct {
   416  		Name    string
   417  		Version string
   418  		Err     bool
   419  	}{
   420  		{
   421  			"foo",
   422  			"0.2.3",
   423  			false,
   424  		},
   425  		{
   426  			"bar",
   427  			"0.1.1",
   428  			false,
   429  		},
   430  		{
   431  			"baz",
   432  			"0.0.0",
   433  			true,
   434  		},
   435  	}
   436  	for _, tt := range tests {
   437  		// FIXME: the tests are set up to succeed - os/arch is not being validated at this time
   438  		p := regsrc.NewTerraformProvider(tt.Name, "linux", "amd64")
   439  
   440  		locationMetadata, err := client.TerraformProviderLocation(p, tt.Version)
   441  		if tt.Err {
   442  			if err == nil {
   443  				t.Fatal("succeeded; want error")
   444  			}
   445  			return
   446  		} else if err != nil {
   447  			t.Fatalf("unexpected error: %s", err)
   448  		}
   449  
   450  		downloadURL := fmt.Sprintf("https://releases.hashicorp.com/terraform-provider-%s/%s/terraform-provider-%s.zip", tt.Name, tt.Version, tt.Name)
   451  
   452  		if locationMetadata.DownloadURL != downloadURL {
   453  			t.Fatalf("incorrect download URL: expected %q, got %q", downloadURL, locationMetadata.DownloadURL)
   454  		}
   455  	}
   456  
   457  }