github.com/muratcelep/terraform@v1.1.0-beta2-not-internal-4/not-internal/getproviders/registry_client_test.go (about)

     1  package getproviders
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"log"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"os"
    11  	"strings"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/apparentlymart/go-versions/versions"
    16  	"github.com/google/go-cmp/cmp"
    17  	svchost "github.com/hashicorp/terraform-svchost"
    18  	disco "github.com/hashicorp/terraform-svchost/disco"
    19  	"github.com/muratcelep/terraform/not-internal/addrs"
    20  )
    21  
    22  func TestConfigureDiscoveryRetry(t *testing.T) {
    23  	t.Run("default retry", func(t *testing.T) {
    24  		if discoveryRetry != defaultRetry {
    25  			t.Fatalf("expected retry %q, got %q", defaultRetry, discoveryRetry)
    26  		}
    27  
    28  		rc := newRegistryClient(nil, nil)
    29  		if rc.httpClient.RetryMax != defaultRetry {
    30  			t.Fatalf("expected client retry %q, got %q",
    31  				defaultRetry, rc.httpClient.RetryMax)
    32  		}
    33  	})
    34  
    35  	t.Run("configured retry", func(t *testing.T) {
    36  		defer func(retryEnv string) {
    37  			os.Setenv(registryDiscoveryRetryEnvName, retryEnv)
    38  			discoveryRetry = defaultRetry
    39  		}(os.Getenv(registryDiscoveryRetryEnvName))
    40  		os.Setenv(registryDiscoveryRetryEnvName, "2")
    41  
    42  		configureDiscoveryRetry()
    43  		expected := 2
    44  		if discoveryRetry != expected {
    45  			t.Fatalf("expected retry %q, got %q",
    46  				expected, discoveryRetry)
    47  		}
    48  
    49  		rc := newRegistryClient(nil, nil)
    50  		if rc.httpClient.RetryMax != expected {
    51  			t.Fatalf("expected client retry %q, got %q",
    52  				expected, rc.httpClient.RetryMax)
    53  		}
    54  	})
    55  }
    56  
    57  func TestConfigureRegistryClientTimeout(t *testing.T) {
    58  	t.Run("default timeout", func(t *testing.T) {
    59  		if requestTimeout != defaultRequestTimeout {
    60  			t.Fatalf("expected timeout %q, got %q",
    61  				defaultRequestTimeout.String(), requestTimeout.String())
    62  		}
    63  
    64  		rc := newRegistryClient(nil, nil)
    65  		if rc.httpClient.HTTPClient.Timeout != defaultRequestTimeout {
    66  			t.Fatalf("expected client timeout %q, got %q",
    67  				defaultRequestTimeout.String(), rc.httpClient.HTTPClient.Timeout.String())
    68  		}
    69  	})
    70  
    71  	t.Run("configured timeout", func(t *testing.T) {
    72  		defer func(timeoutEnv string) {
    73  			os.Setenv(registryClientTimeoutEnvName, timeoutEnv)
    74  			requestTimeout = defaultRequestTimeout
    75  		}(os.Getenv(registryClientTimeoutEnvName))
    76  		os.Setenv(registryClientTimeoutEnvName, "20")
    77  
    78  		configureRequestTimeout()
    79  		expected := 20 * time.Second
    80  		if requestTimeout != expected {
    81  			t.Fatalf("expected timeout %q, got %q",
    82  				expected, requestTimeout.String())
    83  		}
    84  
    85  		rc := newRegistryClient(nil, nil)
    86  		if rc.httpClient.HTTPClient.Timeout != expected {
    87  			t.Fatalf("expected client timeout %q, got %q",
    88  				expected, rc.httpClient.HTTPClient.Timeout.String())
    89  		}
    90  	})
    91  }
    92  
    93  // testRegistryServices starts up a local HTTP server running a fake provider registry
    94  // service and returns a service discovery object pre-configured to consider
    95  // the host "example.com" to be served by the fake registry service.
    96  //
    97  // The returned discovery object also knows the hostname "not.example.com"
    98  // which does not have a provider registry at all and "too-new.example.com"
    99  // which has a "providers.v99" service that is inoperable but could be useful
   100  // to test the error reporting for detecting an unsupported protocol version.
   101  // It also knows fails.example.com but it refers to an endpoint that doesn't
   102  // correctly speak HTTP, to simulate a protocol error.
   103  //
   104  // The second return value is a function to call at the end of a test function
   105  // to shut down the test server. After you call that function, the discovery
   106  // object becomes useless.
   107  func testRegistryServices(t *testing.T) (services *disco.Disco, baseURL string, cleanup func()) {
   108  	server := httptest.NewServer(http.HandlerFunc(fakeRegistryHandler))
   109  
   110  	services = disco.New()
   111  	services.ForceHostServices(svchost.Hostname("example.com"), map[string]interface{}{
   112  		"providers.v1": server.URL + "/providers/v1/",
   113  	})
   114  	services.ForceHostServices(svchost.Hostname("not.example.com"), map[string]interface{}{})
   115  	services.ForceHostServices(svchost.Hostname("too-new.example.com"), map[string]interface{}{
   116  		// This service doesn't actually work; it's here only to be
   117  		// detected as "too new" by the discovery logic.
   118  		"providers.v99": server.URL + "/providers/v99/",
   119  	})
   120  	services.ForceHostServices(svchost.Hostname("fails.example.com"), map[string]interface{}{
   121  		"providers.v1": server.URL + "/fails-immediately/",
   122  	})
   123  
   124  	// We'll also permit registry.terraform.io here just because it's our
   125  	// default and has some unique features that are not allowed on any other
   126  	// hostname. It behaves the same as example.com, which should be preferred
   127  	// if you're not testing something specific to the default registry in order
   128  	// to ensure that most things are hostname-agnostic.
   129  	services.ForceHostServices(svchost.Hostname("registry.terraform.io"), map[string]interface{}{
   130  		"providers.v1": server.URL + "/providers/v1/",
   131  	})
   132  
   133  	return services, server.URL, func() {
   134  		server.Close()
   135  	}
   136  }
   137  
   138  // testRegistrySource is a wrapper around testServices that uses the created
   139  // discovery object to produce a Source instance that is ready to use with the
   140  // fake registry services.
   141  //
   142  // As with testServices, the second return value is a function to call at the end
   143  // of your test in order to shut down the test server.
   144  func testRegistrySource(t *testing.T) (source *RegistrySource, baseURL string, cleanup func()) {
   145  	services, baseURL, close := testRegistryServices(t)
   146  	source = NewRegistrySource(services)
   147  	return source, baseURL, close
   148  }
   149  
   150  func fakeRegistryHandler(resp http.ResponseWriter, req *http.Request) {
   151  	path := req.URL.EscapedPath()
   152  	if strings.HasPrefix(path, "/fails-immediately/") {
   153  		// Here we take over the socket and just close it immediately, to
   154  		// simulate one possible way a server might not be an HTTP server.
   155  		hijacker, ok := resp.(http.Hijacker)
   156  		if !ok {
   157  			// Not hijackable, so we'll just fail normally.
   158  			// If this happens, tests relying on this will fail.
   159  			resp.WriteHeader(500)
   160  			resp.Write([]byte(`cannot hijack`))
   161  			return
   162  		}
   163  		conn, _, err := hijacker.Hijack()
   164  		if err != nil {
   165  			resp.WriteHeader(500)
   166  			resp.Write([]byte(`hijack failed`))
   167  			return
   168  		}
   169  		conn.Close()
   170  		return
   171  	}
   172  
   173  	if strings.HasPrefix(path, "/pkg/") {
   174  		switch path {
   175  		case "/pkg/awesomesauce/happycloud_1.2.0.zip":
   176  			resp.Write([]byte("some zip file"))
   177  		case "/pkg/awesomesauce/happycloud_1.2.0_SHA256SUMS":
   178  			resp.Write([]byte("000000000000000000000000000000000000000000000000000000000000f00d happycloud_1.2.0.zip\n000000000000000000000000000000000000000000000000000000000000face happycloud_1.2.0_face.zip\n"))
   179  		case "/pkg/awesomesauce/happycloud_1.2.0_SHA256SUMS.sig":
   180  			resp.Write([]byte("GPG signature"))
   181  		default:
   182  			resp.WriteHeader(404)
   183  			resp.Write([]byte("unknown package file download"))
   184  		}
   185  		return
   186  	}
   187  
   188  	if !strings.HasPrefix(path, "/providers/v1/") {
   189  		resp.WriteHeader(404)
   190  		resp.Write([]byte(`not a provider registry endpoint`))
   191  		return
   192  	}
   193  
   194  	pathParts := strings.Split(path, "/")[3:]
   195  	if len(pathParts) < 3 {
   196  		resp.WriteHeader(404)
   197  		resp.Write([]byte(`unexpected number of path parts`))
   198  		return
   199  	}
   200  	log.Printf("[TRACE] fake provider registry request for %#v", pathParts)
   201  
   202  	if pathParts[2] == "versions" {
   203  		if len(pathParts) != 3 {
   204  			resp.WriteHeader(404)
   205  			resp.Write([]byte(`extraneous path parts`))
   206  			return
   207  		}
   208  
   209  		switch pathParts[0] + "/" + pathParts[1] {
   210  		case "awesomesauce/happycloud":
   211  			resp.Header().Set("Content-Type", "application/json")
   212  			resp.WriteHeader(200)
   213  			// Note that these version numbers are intentionally misordered
   214  			// so we can test that the client-side code places them in the
   215  			// correct order (lowest precedence first).
   216  			resp.Write([]byte(`{"versions":[{"version":"0.1.0","protocols":["1.0"]},{"version":"2.0.0","protocols":["99.0"]},{"version":"1.2.0","protocols":["5.0"]}, {"version":"1.0.0","protocols":["5.0"]}]}`))
   217  		case "weaksauce/unsupported-protocol":
   218  			resp.Header().Set("Content-Type", "application/json")
   219  			resp.WriteHeader(200)
   220  			resp.Write([]byte(`{"versions":[{"version":"1.0.0","protocols":["0.1"]}]}`))
   221  		case "weaksauce/protocol-six":
   222  			resp.Header().Set("Content-Type", "application/json")
   223  			resp.WriteHeader(200)
   224  			resp.Write([]byte(`{"versions":[{"version":"1.0.0","protocols":["6.0"]}]}`))
   225  		case "weaksauce/no-versions":
   226  			resp.Header().Set("Content-Type", "application/json")
   227  			resp.WriteHeader(200)
   228  			resp.Write([]byte(`{"versions":[],"warnings":["this provider is weaksauce"]}`))
   229  		case "-/legacy":
   230  			resp.Header().Set("Content-Type", "application/json")
   231  			resp.WriteHeader(200)
   232  			// This response is used for testing LookupLegacyProvider
   233  			resp.Write([]byte(`{"id":"legacycorp/legacy"}`))
   234  		case "-/moved":
   235  			resp.Header().Set("Content-Type", "application/json")
   236  			resp.WriteHeader(200)
   237  			// This response is used for testing LookupLegacyProvider
   238  			resp.Write([]byte(`{"id":"hashicorp/moved","moved_to":"acme/moved"}`))
   239  		case "-/changetype":
   240  			resp.Header().Set("Content-Type", "application/json")
   241  			resp.WriteHeader(200)
   242  			// This (unrealistic) response is used for error handling code coverage
   243  			resp.Write([]byte(`{"id":"legacycorp/newtype"}`))
   244  		case "-/invalid":
   245  			resp.Header().Set("Content-Type", "application/json")
   246  			resp.WriteHeader(200)
   247  			// This (unrealistic) response is used for error handling code coverage
   248  			resp.Write([]byte(`{"id":"some/invalid/id/string"}`))
   249  		default:
   250  			resp.WriteHeader(404)
   251  			resp.Write([]byte(`unknown namespace or provider type`))
   252  		}
   253  		return
   254  	}
   255  
   256  	if len(pathParts) == 6 && pathParts[3] == "download" {
   257  		switch pathParts[0] + "/" + pathParts[1] {
   258  		case "awesomesauce/happycloud":
   259  			if pathParts[4] == "nonexist" {
   260  				resp.WriteHeader(404)
   261  				resp.Write([]byte(`unsupported OS`))
   262  				return
   263  			}
   264  			var protocols []string
   265  			version := pathParts[2]
   266  			switch version {
   267  			case "0.1.0":
   268  				protocols = []string{"1.0"}
   269  			case "2.0.0":
   270  				protocols = []string{"99.0"}
   271  			default:
   272  				protocols = []string{"5.0"}
   273  			}
   274  
   275  			body := map[string]interface{}{
   276  				"protocols":             protocols,
   277  				"os":                    pathParts[4],
   278  				"arch":                  pathParts[5],
   279  				"filename":              "happycloud_" + version + ".zip",
   280  				"shasum":                "000000000000000000000000000000000000000000000000000000000000f00d",
   281  				"download_url":          "/pkg/awesomesauce/happycloud_" + version + ".zip",
   282  				"shasums_url":           "/pkg/awesomesauce/happycloud_" + version + "_SHA256SUMS",
   283  				"shasums_signature_url": "/pkg/awesomesauce/happycloud_" + version + "_SHA256SUMS.sig",
   284  				"signing_keys": map[string]interface{}{
   285  					"gpg_public_keys": []map[string]interface{}{
   286  						{
   287  							"ascii_armor": HashicorpPublicKey,
   288  						},
   289  					},
   290  				},
   291  			}
   292  			enc, err := json.Marshal(body)
   293  			if err != nil {
   294  				resp.WriteHeader(500)
   295  				resp.Write([]byte("failed to encode body"))
   296  			}
   297  			resp.Header().Set("Content-Type", "application/json")
   298  			resp.WriteHeader(200)
   299  			resp.Write(enc)
   300  		default:
   301  			resp.WriteHeader(404)
   302  			resp.Write([]byte(`unknown namespace/provider/version/architecture`))
   303  		}
   304  		return
   305  	}
   306  
   307  	resp.WriteHeader(404)
   308  	resp.Write([]byte(`unrecognized path scheme`))
   309  }
   310  
   311  func TestProviderVersions(t *testing.T) {
   312  	source, _, close := testRegistrySource(t)
   313  	defer close()
   314  
   315  	tests := []struct {
   316  		provider     addrs.Provider
   317  		wantVersions map[string][]string
   318  		wantErr      string
   319  	}{
   320  		{
   321  			addrs.MustParseProviderSourceString("example.com/awesomesauce/happycloud"),
   322  			map[string][]string{
   323  				"0.1.0": {"1.0"},
   324  				"1.0.0": {"5.0"},
   325  				"1.2.0": {"5.0"},
   326  				"2.0.0": {"99.0"},
   327  			},
   328  			``,
   329  		},
   330  		{
   331  			addrs.MustParseProviderSourceString("example.com/weaksauce/no-versions"),
   332  			nil,
   333  			``,
   334  		},
   335  		{
   336  			addrs.MustParseProviderSourceString("example.com/nonexist/nonexist"),
   337  			nil,
   338  			`provider registry example.com does not have a provider named example.com/nonexist/nonexist`,
   339  		},
   340  	}
   341  	for _, test := range tests {
   342  		t.Run(test.provider.String(), func(t *testing.T) {
   343  			client, err := source.registryClient(test.provider.Hostname)
   344  			if err != nil {
   345  				t.Fatal(err)
   346  			}
   347  
   348  			gotVersions, _, err := client.ProviderVersions(context.Background(), test.provider)
   349  
   350  			if err != nil {
   351  				if test.wantErr == "" {
   352  					t.Fatalf("wrong error\ngot:  %s\nwant: <nil>", err.Error())
   353  				}
   354  				if got, want := err.Error(), test.wantErr; got != want {
   355  					t.Fatalf("wrong error\ngot:  %s\nwant: %s", got, want)
   356  				}
   357  				return
   358  			}
   359  
   360  			if test.wantErr != "" {
   361  				t.Fatalf("wrong error\ngot:  <nil>\nwant: %s", test.wantErr)
   362  			}
   363  
   364  			if diff := cmp.Diff(test.wantVersions, gotVersions); diff != "" {
   365  				t.Errorf("wrong result\n%s", diff)
   366  			}
   367  		})
   368  	}
   369  }
   370  
   371  func TestFindClosestProtocolCompatibleVersion(t *testing.T) {
   372  	source, _, close := testRegistrySource(t)
   373  	defer close()
   374  
   375  	tests := map[string]struct {
   376  		provider       addrs.Provider
   377  		version        Version
   378  		wantSuggestion Version
   379  		wantErr        string
   380  	}{
   381  		"pinned version too old": {
   382  			addrs.MustParseProviderSourceString("example.com/awesomesauce/happycloud"),
   383  			MustParseVersion("0.1.0"),
   384  			MustParseVersion("1.2.0"),
   385  			``,
   386  		},
   387  		"pinned version too new": {
   388  			addrs.MustParseProviderSourceString("example.com/awesomesauce/happycloud"),
   389  			MustParseVersion("2.0.0"),
   390  			MustParseVersion("1.2.0"),
   391  			``,
   392  		},
   393  		// This should not actually happen, the function is only meant to be
   394  		// called when the requested provider version is not supported
   395  		"pinned version just right": {
   396  			addrs.MustParseProviderSourceString("example.com/awesomesauce/happycloud"),
   397  			MustParseVersion("1.2.0"),
   398  			MustParseVersion("1.2.0"),
   399  			``,
   400  		},
   401  		"nonexisting provider": {
   402  			addrs.MustParseProviderSourceString("example.com/nonexist/nonexist"),
   403  			MustParseVersion("1.2.0"),
   404  			versions.Unspecified,
   405  			`provider registry example.com does not have a provider named example.com/nonexist/nonexist`,
   406  		},
   407  		"versionless provider": {
   408  			addrs.MustParseProviderSourceString("example.com/weaksauce/no-versions"),
   409  			MustParseVersion("1.2.0"),
   410  			versions.Unspecified,
   411  			``,
   412  		},
   413  		"unsupported provider protocol": {
   414  			addrs.MustParseProviderSourceString("example.com/weaksauce/unsupported-protocol"),
   415  			MustParseVersion("1.0.0"),
   416  			versions.Unspecified,
   417  			``,
   418  		},
   419  		"provider protocol six": {
   420  			addrs.MustParseProviderSourceString("example.com/weaksauce/protocol-six"),
   421  			MustParseVersion("1.0.0"),
   422  			MustParseVersion("1.0.0"),
   423  			``,
   424  		},
   425  	}
   426  	for name, test := range tests {
   427  		t.Run(name, func(t *testing.T) {
   428  			client, err := source.registryClient(test.provider.Hostname)
   429  			if err != nil {
   430  				t.Fatal(err)
   431  			}
   432  
   433  			got, err := client.findClosestProtocolCompatibleVersion(context.Background(), test.provider, test.version)
   434  
   435  			if err != nil {
   436  				if test.wantErr == "" {
   437  					t.Fatalf("wrong error\ngot:  %s\nwant: <nil>", err.Error())
   438  				}
   439  				if got, want := err.Error(), test.wantErr; got != want {
   440  					t.Fatalf("wrong error\ngot:  %s\nwant: %s", got, want)
   441  				}
   442  				return
   443  			}
   444  
   445  			if test.wantErr != "" {
   446  				t.Fatalf("wrong error\ngot:  <nil>\nwant: %s", test.wantErr)
   447  			}
   448  
   449  			fmt.Printf("Got: %s, Want: %s\n", got, test.wantSuggestion)
   450  
   451  			if !got.Same(test.wantSuggestion) {
   452  				t.Fatalf("wrong result\ngot:  %s\nwant: %s", got.String(), test.wantSuggestion.String())
   453  			}
   454  		})
   455  	}
   456  }