github.com/hugorut/terraform@v1.1.3/src/registry/client_test.go (about)

     1  package registry
     2  
     3  import (
     4  	"context"
     5  	"net/http"
     6  	"os"
     7  	"strings"
     8  	"testing"
     9  	"time"
    10  
    11  	version "github.com/hashicorp/go-version"
    12  	"github.com/hashicorp/terraform-svchost/disco"
    13  	"github.com/hugorut/terraform/src/httpclient"
    14  	"github.com/hugorut/terraform/src/registry/regsrc"
    15  	"github.com/hugorut/terraform/src/registry/test"
    16  	tfversion "github.com/hugorut/terraform/version"
    17  )
    18  
    19  func TestConfigureDiscoveryRetry(t *testing.T) {
    20  	t.Run("default retry", func(t *testing.T) {
    21  		if discoveryRetry != defaultRetry {
    22  			t.Fatalf("expected retry %q, got %q", defaultRetry, discoveryRetry)
    23  		}
    24  
    25  		rc := NewClient(nil, nil)
    26  		if rc.client.RetryMax != defaultRetry {
    27  			t.Fatalf("expected client retry %q, got %q",
    28  				defaultRetry, rc.client.RetryMax)
    29  		}
    30  	})
    31  
    32  	t.Run("configured retry", func(t *testing.T) {
    33  		defer func(retryEnv string) {
    34  			os.Setenv(registryDiscoveryRetryEnvName, retryEnv)
    35  			discoveryRetry = defaultRetry
    36  		}(os.Getenv(registryDiscoveryRetryEnvName))
    37  		os.Setenv(registryDiscoveryRetryEnvName, "2")
    38  
    39  		configureDiscoveryRetry()
    40  		expected := 2
    41  		if discoveryRetry != expected {
    42  			t.Fatalf("expected retry %q, got %q",
    43  				expected, discoveryRetry)
    44  		}
    45  
    46  		rc := NewClient(nil, nil)
    47  		if rc.client.RetryMax != expected {
    48  			t.Fatalf("expected client retry %q, got %q",
    49  				expected, rc.client.RetryMax)
    50  		}
    51  	})
    52  }
    53  
    54  func TestConfigureRegistryClientTimeout(t *testing.T) {
    55  	t.Run("default timeout", func(t *testing.T) {
    56  		if requestTimeout != defaultRequestTimeout {
    57  			t.Fatalf("expected timeout %q, got %q",
    58  				defaultRequestTimeout.String(), requestTimeout.String())
    59  		}
    60  
    61  		rc := NewClient(nil, nil)
    62  		if rc.client.HTTPClient.Timeout != defaultRequestTimeout {
    63  			t.Fatalf("expected client timeout %q, got %q",
    64  				defaultRequestTimeout.String(), rc.client.HTTPClient.Timeout.String())
    65  		}
    66  	})
    67  
    68  	t.Run("configured timeout", func(t *testing.T) {
    69  		defer func(timeoutEnv string) {
    70  			os.Setenv(registryClientTimeoutEnvName, timeoutEnv)
    71  			requestTimeout = defaultRequestTimeout
    72  		}(os.Getenv(registryClientTimeoutEnvName))
    73  		os.Setenv(registryClientTimeoutEnvName, "20")
    74  
    75  		configureRequestTimeout()
    76  		expected := 20 * time.Second
    77  		if requestTimeout != expected {
    78  			t.Fatalf("expected timeout %q, got %q",
    79  				expected, requestTimeout.String())
    80  		}
    81  
    82  		rc := NewClient(nil, nil)
    83  		if rc.client.HTTPClient.Timeout != expected {
    84  			t.Fatalf("expected client timeout %q, got %q",
    85  				expected, rc.client.HTTPClient.Timeout.String())
    86  		}
    87  	})
    88  }
    89  
    90  func TestLookupModuleVersions(t *testing.T) {
    91  	server := test.Registry()
    92  	defer server.Close()
    93  
    94  	client := NewClient(test.Disco(server), nil)
    95  
    96  	// test with and without a hostname
    97  	for _, src := range []string{
    98  		"example.com/test-versions/name/provider",
    99  		"test-versions/name/provider",
   100  	} {
   101  		modsrc, err := regsrc.ParseModuleSource(src)
   102  		if err != nil {
   103  			t.Fatal(err)
   104  		}
   105  
   106  		resp, err := client.ModuleVersions(context.Background(), modsrc)
   107  		if err != nil {
   108  			t.Fatal(err)
   109  		}
   110  
   111  		if len(resp.Modules) != 1 {
   112  			t.Fatal("expected 1 module, got", len(resp.Modules))
   113  		}
   114  
   115  		mod := resp.Modules[0]
   116  		name := "test-versions/name/provider"
   117  		if mod.Source != name {
   118  			t.Fatalf("expected module name %q, got %q", name, mod.Source)
   119  		}
   120  
   121  		if len(mod.Versions) != 4 {
   122  			t.Fatal("expected 4 versions, got", len(mod.Versions))
   123  		}
   124  
   125  		for _, v := range mod.Versions {
   126  			_, err := version.NewVersion(v.Version)
   127  			if err != nil {
   128  				t.Fatalf("invalid version %q: %s", v.Version, err)
   129  			}
   130  		}
   131  	}
   132  }
   133  
   134  func TestInvalidRegistry(t *testing.T) {
   135  	server := test.Registry()
   136  	defer server.Close()
   137  
   138  	client := NewClient(test.Disco(server), nil)
   139  
   140  	src := "non-existent.localhost.localdomain/test-versions/name/provider"
   141  	modsrc, err := regsrc.ParseModuleSource(src)
   142  	if err != nil {
   143  		t.Fatal(err)
   144  	}
   145  
   146  	if _, err := client.ModuleVersions(context.Background(), modsrc); err == nil {
   147  		t.Fatal("expected error")
   148  	}
   149  }
   150  
   151  func TestRegistryAuth(t *testing.T) {
   152  	server := test.Registry()
   153  	defer server.Close()
   154  
   155  	client := NewClient(test.Disco(server), nil)
   156  
   157  	src := "private/name/provider"
   158  	mod, err := regsrc.ParseModuleSource(src)
   159  	if err != nil {
   160  		t.Fatal(err)
   161  	}
   162  
   163  	_, err = client.ModuleVersions(context.Background(), mod)
   164  	if err != nil {
   165  		t.Fatal(err)
   166  	}
   167  	_, err = client.ModuleLocation(context.Background(), mod, "1.0.0")
   168  	if err != nil {
   169  		t.Fatal(err)
   170  	}
   171  
   172  	// Also test without a credentials source
   173  	client.services.SetCredentialsSource(nil)
   174  
   175  	// both should fail without auth
   176  	_, err = client.ModuleVersions(context.Background(), mod)
   177  	if err == nil {
   178  		t.Fatal("expected error")
   179  	}
   180  	_, err = client.ModuleLocation(context.Background(), mod, "1.0.0")
   181  	if err == nil {
   182  		t.Fatal("expected error")
   183  	}
   184  }
   185  
   186  func TestLookupModuleLocationRelative(t *testing.T) {
   187  	server := test.Registry()
   188  	defer server.Close()
   189  
   190  	client := NewClient(test.Disco(server), nil)
   191  
   192  	src := "relative/foo/bar"
   193  	mod, err := regsrc.ParseModuleSource(src)
   194  	if err != nil {
   195  		t.Fatal(err)
   196  	}
   197  
   198  	got, err := client.ModuleLocation(context.Background(), mod, "0.2.0")
   199  	if err != nil {
   200  		t.Fatal(err)
   201  	}
   202  
   203  	want := server.URL + "/relative-path"
   204  	if got != want {
   205  		t.Errorf("wrong location %s; want %s", got, want)
   206  	}
   207  }
   208  
   209  func TestAccLookupModuleVersions(t *testing.T) {
   210  	if os.Getenv("TF_ACC") == "" {
   211  		t.Skip()
   212  	}
   213  	regDisco := disco.New()
   214  	regDisco.SetUserAgent(httpclient.TerraformUserAgent(tfversion.String()))
   215  
   216  	// test with and without a hostname
   217  	for _, src := range []string{
   218  		"terraform-aws-modules/vpc/aws",
   219  		regsrc.PublicRegistryHost.String() + "/terraform-aws-modules/vpc/aws",
   220  	} {
   221  		modsrc, err := regsrc.ParseModuleSource(src)
   222  		if err != nil {
   223  			t.Fatal(err)
   224  		}
   225  
   226  		s := NewClient(regDisco, nil)
   227  		resp, err := s.ModuleVersions(context.Background(), modsrc)
   228  		if err != nil {
   229  			t.Fatal(err)
   230  		}
   231  
   232  		if len(resp.Modules) != 1 {
   233  			t.Fatal("expected 1 module, got", len(resp.Modules))
   234  		}
   235  
   236  		mod := resp.Modules[0]
   237  		name := "terraform-aws-modules/vpc/aws"
   238  		if mod.Source != name {
   239  			t.Fatalf("expected module name %q, got %q", name, mod.Source)
   240  		}
   241  
   242  		if len(mod.Versions) == 0 {
   243  			t.Fatal("expected multiple versions, got 0")
   244  		}
   245  
   246  		for _, v := range mod.Versions {
   247  			_, err := version.NewVersion(v.Version)
   248  			if err != nil {
   249  				t.Fatalf("invalid version %q: %s", v.Version, err)
   250  			}
   251  		}
   252  	}
   253  }
   254  
   255  // the error should reference the config source exactly, not the discovered path.
   256  func TestLookupLookupModuleError(t *testing.T) {
   257  	server := test.Registry()
   258  	defer server.Close()
   259  
   260  	client := NewClient(test.Disco(server), nil)
   261  
   262  	// this should not be found in the registry
   263  	src := "bad/local/path"
   264  	mod, err := regsrc.ParseModuleSource(src)
   265  	if err != nil {
   266  		t.Fatal(err)
   267  	}
   268  
   269  	// Instrument CheckRetry to make sure 404s are not retried
   270  	retries := 0
   271  	oldCheck := client.client.CheckRetry
   272  	client.client.CheckRetry = func(ctx context.Context, resp *http.Response, err error) (bool, error) {
   273  		if retries > 0 {
   274  			t.Fatal("retried after module not found")
   275  		}
   276  		retries++
   277  		return oldCheck(ctx, resp, err)
   278  	}
   279  
   280  	_, err = client.ModuleLocation(context.Background(), mod, "0.2.0")
   281  	if err == nil {
   282  		t.Fatal("expected error")
   283  	}
   284  
   285  	// check for the exact quoted string to ensure we didn't prepend a hostname.
   286  	if !strings.Contains(err.Error(), `"bad/local/path"`) {
   287  		t.Fatal("error should not include the hostname. got:", err)
   288  	}
   289  }
   290  
   291  func TestLookupModuleRetryError(t *testing.T) {
   292  	server := test.RegistryRetryableErrorsServer()
   293  	defer server.Close()
   294  
   295  	client := NewClient(test.Disco(server), nil)
   296  
   297  	src := "example.com/test-versions/name/provider"
   298  	modsrc, err := regsrc.ParseModuleSource(src)
   299  	if err != nil {
   300  		t.Fatal(err)
   301  	}
   302  	resp, err := client.ModuleVersions(context.Background(), modsrc)
   303  	if err == nil {
   304  		t.Fatal("expected requests to exceed retry", err)
   305  	}
   306  	if resp != nil {
   307  		t.Fatal("unexpected response", *resp)
   308  	}
   309  
   310  	// verify maxRetryErrorHandler handler returned the error
   311  	if !strings.Contains(err.Error(), "the request failed after 2 attempts, please try again later") {
   312  		t.Fatal("unexpected error, got:", err)
   313  	}
   314  }
   315  
   316  func TestLookupModuleNoRetryError(t *testing.T) {
   317  	// Disable retries
   318  	discoveryRetry = 0
   319  	defer configureDiscoveryRetry()
   320  
   321  	server := test.RegistryRetryableErrorsServer()
   322  	defer server.Close()
   323  
   324  	client := NewClient(test.Disco(server), nil)
   325  
   326  	src := "example.com/test-versions/name/provider"
   327  	modsrc, err := regsrc.ParseModuleSource(src)
   328  	if err != nil {
   329  		t.Fatal(err)
   330  	}
   331  	resp, err := client.ModuleVersions(context.Background(), modsrc)
   332  	if err == nil {
   333  		t.Fatal("expected request to fail", err)
   334  	}
   335  	if resp != nil {
   336  		t.Fatal("unexpected response", *resp)
   337  	}
   338  
   339  	// verify maxRetryErrorHandler handler returned the error
   340  	if !strings.Contains(err.Error(), "the request failed, please try again later") {
   341  		t.Fatal("unexpected error, got:", err)
   342  	}
   343  }
   344  
   345  func TestLookupModuleNetworkError(t *testing.T) {
   346  	server := test.RegistryRetryableErrorsServer()
   347  	client := NewClient(test.Disco(server), nil)
   348  
   349  	// Shut down the server to simulate network failure
   350  	server.Close()
   351  
   352  	src := "example.com/test-versions/name/provider"
   353  	modsrc, err := regsrc.ParseModuleSource(src)
   354  	if err != nil {
   355  		t.Fatal(err)
   356  	}
   357  	resp, err := client.ModuleVersions(context.Background(), modsrc)
   358  	if err == nil {
   359  		t.Fatal("expected request to fail", err)
   360  	}
   361  	if resp != nil {
   362  		t.Fatal("unexpected response", *resp)
   363  	}
   364  
   365  	// verify maxRetryErrorHandler handler returned the correct error
   366  	if !strings.Contains(err.Error(), "the request failed after 2 attempts, please try again later") {
   367  		t.Fatal("unexpected error, got:", err)
   368  	}
   369  }