github.com/graywolf-at-work-2/terraform-vendor@v1.4.5/internal/getproviders/registry_client_test.go (about) 1 package getproviders 2 3 import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "log" 8 "net/http" 9 "net/http/httptest" 10 "os" 11 "strings" 12 "testing" 13 "time" 14 15 "github.com/apparentlymart/go-versions/versions" 16 "github.com/google/go-cmp/cmp" 17 svchost "github.com/hashicorp/terraform-svchost" 18 disco "github.com/hashicorp/terraform-svchost/disco" 19 "github.com/hashicorp/terraform/internal/addrs" 20 ) 21 22 func TestConfigureDiscoveryRetry(t *testing.T) { 23 t.Run("default retry", func(t *testing.T) { 24 if discoveryRetry != defaultRetry { 25 t.Fatalf("expected retry %q, got %q", defaultRetry, discoveryRetry) 26 } 27 28 rc := newRegistryClient(nil, nil) 29 if rc.httpClient.RetryMax != defaultRetry { 30 t.Fatalf("expected client retry %q, got %q", 31 defaultRetry, rc.httpClient.RetryMax) 32 } 33 }) 34 35 t.Run("configured retry", func(t *testing.T) { 36 defer func(retryEnv string) { 37 os.Setenv(registryDiscoveryRetryEnvName, retryEnv) 38 discoveryRetry = defaultRetry 39 }(os.Getenv(registryDiscoveryRetryEnvName)) 40 os.Setenv(registryDiscoveryRetryEnvName, "2") 41 42 configureDiscoveryRetry() 43 expected := 2 44 if discoveryRetry != expected { 45 t.Fatalf("expected retry %q, got %q", 46 expected, discoveryRetry) 47 } 48 49 rc := newRegistryClient(nil, nil) 50 if rc.httpClient.RetryMax != expected { 51 t.Fatalf("expected client retry %q, got %q", 52 expected, rc.httpClient.RetryMax) 53 } 54 }) 55 } 56 57 func TestConfigureRegistryClientTimeout(t *testing.T) { 58 t.Run("default timeout", func(t *testing.T) { 59 if requestTimeout != defaultRequestTimeout { 60 t.Fatalf("expected timeout %q, got %q", 61 defaultRequestTimeout.String(), requestTimeout.String()) 62 } 63 64 rc := newRegistryClient(nil, nil) 65 if rc.httpClient.HTTPClient.Timeout != defaultRequestTimeout { 66 t.Fatalf("expected client timeout %q, got %q", 67 defaultRequestTimeout.String(), rc.httpClient.HTTPClient.Timeout.String()) 68 } 69 }) 70 71 t.Run("configured timeout", func(t *testing.T) { 72 defer func(timeoutEnv string) { 73 os.Setenv(registryClientTimeoutEnvName, timeoutEnv) 74 requestTimeout = defaultRequestTimeout 75 }(os.Getenv(registryClientTimeoutEnvName)) 76 os.Setenv(registryClientTimeoutEnvName, "20") 77 78 configureRequestTimeout() 79 expected := 20 * time.Second 80 if requestTimeout != expected { 81 t.Fatalf("expected timeout %q, got %q", 82 expected, requestTimeout.String()) 83 } 84 85 rc := newRegistryClient(nil, nil) 86 if rc.httpClient.HTTPClient.Timeout != expected { 87 t.Fatalf("expected client timeout %q, got %q", 88 expected, rc.httpClient.HTTPClient.Timeout.String()) 89 } 90 }) 91 } 92 93 // testRegistryServices starts up a local HTTP server running a fake provider registry 94 // service and returns a service discovery object pre-configured to consider 95 // the host "example.com" to be served by the fake registry service. 96 // 97 // The returned discovery object also knows the hostname "not.example.com" 98 // which does not have a provider registry at all and "too-new.example.com" 99 // which has a "providers.v99" service that is inoperable but could be useful 100 // to test the error reporting for detecting an unsupported protocol version. 101 // It also knows fails.example.com but it refers to an endpoint that doesn't 102 // correctly speak HTTP, to simulate a protocol error. 103 // 104 // The second return value is a function to call at the end of a test function 105 // to shut down the test server. After you call that function, the discovery 106 // object becomes useless. 107 func testRegistryServices(t *testing.T) (services *disco.Disco, baseURL string, cleanup func()) { 108 server := httptest.NewServer(http.HandlerFunc(fakeRegistryHandler)) 109 110 services = disco.New() 111 services.ForceHostServices(svchost.Hostname("example.com"), map[string]interface{}{ 112 "providers.v1": server.URL + "/providers/v1/", 113 }) 114 services.ForceHostServices(svchost.Hostname("not.example.com"), map[string]interface{}{}) 115 services.ForceHostServices(svchost.Hostname("too-new.example.com"), map[string]interface{}{ 116 // This service doesn't actually work; it's here only to be 117 // detected as "too new" by the discovery logic. 118 "providers.v99": server.URL + "/providers/v99/", 119 }) 120 services.ForceHostServices(svchost.Hostname("fails.example.com"), map[string]interface{}{ 121 "providers.v1": server.URL + "/fails-immediately/", 122 }) 123 124 // We'll also permit registry.terraform.io here just because it's our 125 // default and has some unique features that are not allowed on any other 126 // hostname. It behaves the same as example.com, which should be preferred 127 // if you're not testing something specific to the default registry in order 128 // to ensure that most things are hostname-agnostic. 129 services.ForceHostServices(svchost.Hostname("registry.terraform.io"), map[string]interface{}{ 130 "providers.v1": server.URL + "/providers/v1/", 131 }) 132 133 return services, server.URL, func() { 134 server.Close() 135 } 136 } 137 138 // testRegistrySource is a wrapper around testServices that uses the created 139 // discovery object to produce a Source instance that is ready to use with the 140 // fake registry services. 141 // 142 // As with testServices, the second return value is a function to call at the end 143 // of your test in order to shut down the test server. 144 func testRegistrySource(t *testing.T) (source *RegistrySource, baseURL string, cleanup func()) { 145 services, baseURL, close := testRegistryServices(t) 146 source = NewRegistrySource(services) 147 return source, baseURL, close 148 } 149 150 func fakeRegistryHandler(resp http.ResponseWriter, req *http.Request) { 151 path := req.URL.EscapedPath() 152 if strings.HasPrefix(path, "/fails-immediately/") { 153 // Here we take over the socket and just close it immediately, to 154 // simulate one possible way a server might not be an HTTP server. 155 hijacker, ok := resp.(http.Hijacker) 156 if !ok { 157 // Not hijackable, so we'll just fail normally. 158 // If this happens, tests relying on this will fail. 159 resp.WriteHeader(500) 160 resp.Write([]byte(`cannot hijack`)) 161 return 162 } 163 conn, _, err := hijacker.Hijack() 164 if err != nil { 165 resp.WriteHeader(500) 166 resp.Write([]byte(`hijack failed`)) 167 return 168 } 169 conn.Close() 170 return 171 } 172 173 if strings.HasPrefix(path, "/pkg/") { 174 switch path { 175 case "/pkg/awesomesauce/happycloud_1.2.0.zip": 176 resp.Write([]byte("some zip file")) 177 case "/pkg/awesomesauce/happycloud_1.2.0_SHA256SUMS": 178 resp.Write([]byte("000000000000000000000000000000000000000000000000000000000000f00d happycloud_1.2.0.zip\n000000000000000000000000000000000000000000000000000000000000face happycloud_1.2.0_face.zip\n")) 179 case "/pkg/awesomesauce/happycloud_1.2.0_SHA256SUMS.sig": 180 resp.Write([]byte("GPG signature")) 181 default: 182 resp.WriteHeader(404) 183 resp.Write([]byte("unknown package file download")) 184 } 185 return 186 } 187 188 if !strings.HasPrefix(path, "/providers/v1/") { 189 resp.WriteHeader(404) 190 resp.Write([]byte(`not a provider registry endpoint`)) 191 return 192 } 193 194 pathParts := strings.Split(path, "/")[3:] 195 if len(pathParts) < 3 { 196 resp.WriteHeader(404) 197 resp.Write([]byte(`unexpected number of path parts`)) 198 return 199 } 200 log.Printf("[TRACE] fake provider registry request for %#v", pathParts) 201 202 if pathParts[2] == "versions" { 203 if len(pathParts) != 3 { 204 resp.WriteHeader(404) 205 resp.Write([]byte(`extraneous path parts`)) 206 return 207 } 208 209 switch pathParts[0] + "/" + pathParts[1] { 210 case "awesomesauce/happycloud": 211 resp.Header().Set("Content-Type", "application/json") 212 resp.WriteHeader(200) 213 // Note that these version numbers are intentionally misordered 214 // so we can test that the client-side code places them in the 215 // correct order (lowest precedence first). 216 resp.Write([]byte(`{"versions":[{"version":"0.1.0","protocols":["1.0"]},{"version":"2.0.0","protocols":["99.0"]},{"version":"1.2.0","protocols":["5.0"]}, {"version":"1.0.0","protocols":["5.0"]}]}`)) 217 case "weaksauce/unsupported-protocol": 218 resp.Header().Set("Content-Type", "application/json") 219 resp.WriteHeader(200) 220 resp.Write([]byte(`{"versions":[{"version":"1.0.0","protocols":["0.1"]}]}`)) 221 case "weaksauce/protocol-six": 222 resp.Header().Set("Content-Type", "application/json") 223 resp.WriteHeader(200) 224 resp.Write([]byte(`{"versions":[{"version":"1.0.0","protocols":["6.0"]}]}`)) 225 case "weaksauce/no-versions": 226 resp.Header().Set("Content-Type", "application/json") 227 resp.WriteHeader(200) 228 resp.Write([]byte(`{"versions":[],"warnings":["this provider is weaksauce"]}`)) 229 case "-/legacy": 230 resp.Header().Set("Content-Type", "application/json") 231 resp.WriteHeader(200) 232 // This response is used for testing LookupLegacyProvider 233 resp.Write([]byte(`{"id":"legacycorp/legacy"}`)) 234 case "-/moved": 235 resp.Header().Set("Content-Type", "application/json") 236 resp.WriteHeader(200) 237 // This response is used for testing LookupLegacyProvider 238 resp.Write([]byte(`{"id":"hashicorp/moved","moved_to":"acme/moved"}`)) 239 case "-/changetype": 240 resp.Header().Set("Content-Type", "application/json") 241 resp.WriteHeader(200) 242 // This (unrealistic) response is used for error handling code coverage 243 resp.Write([]byte(`{"id":"legacycorp/newtype"}`)) 244 case "-/invalid": 245 resp.Header().Set("Content-Type", "application/json") 246 resp.WriteHeader(200) 247 // This (unrealistic) response is used for error handling code coverage 248 resp.Write([]byte(`{"id":"some/invalid/id/string"}`)) 249 default: 250 resp.WriteHeader(404) 251 resp.Write([]byte(`unknown namespace or provider type`)) 252 } 253 return 254 } 255 256 if len(pathParts) == 6 && pathParts[3] == "download" { 257 switch pathParts[0] + "/" + pathParts[1] { 258 case "awesomesauce/happycloud": 259 if pathParts[4] == "nonexist" { 260 resp.WriteHeader(404) 261 resp.Write([]byte(`unsupported OS`)) 262 return 263 } 264 var protocols []string 265 version := pathParts[2] 266 switch version { 267 case "0.1.0": 268 protocols = []string{"1.0"} 269 case "2.0.0": 270 protocols = []string{"99.0"} 271 default: 272 protocols = []string{"5.0"} 273 } 274 275 body := map[string]interface{}{ 276 "protocols": protocols, 277 "os": pathParts[4], 278 "arch": pathParts[5], 279 "filename": "happycloud_" + version + ".zip", 280 "shasum": "000000000000000000000000000000000000000000000000000000000000f00d", 281 "download_url": "/pkg/awesomesauce/happycloud_" + version + ".zip", 282 "shasums_url": "/pkg/awesomesauce/happycloud_" + version + "_SHA256SUMS", 283 "shasums_signature_url": "/pkg/awesomesauce/happycloud_" + version + "_SHA256SUMS.sig", 284 "signing_keys": map[string]interface{}{ 285 "gpg_public_keys": []map[string]interface{}{ 286 { 287 "ascii_armor": HashicorpPublicKey, 288 }, 289 }, 290 }, 291 } 292 enc, err := json.Marshal(body) 293 if err != nil { 294 resp.WriteHeader(500) 295 resp.Write([]byte("failed to encode body")) 296 } 297 resp.Header().Set("Content-Type", "application/json") 298 resp.WriteHeader(200) 299 resp.Write(enc) 300 default: 301 resp.WriteHeader(404) 302 resp.Write([]byte(`unknown namespace/provider/version/architecture`)) 303 } 304 return 305 } 306 307 resp.WriteHeader(404) 308 resp.Write([]byte(`unrecognized path scheme`)) 309 } 310 311 func TestProviderVersions(t *testing.T) { 312 source, _, close := testRegistrySource(t) 313 defer close() 314 315 tests := []struct { 316 provider addrs.Provider 317 wantVersions map[string][]string 318 wantErr string 319 }{ 320 { 321 addrs.MustParseProviderSourceString("example.com/awesomesauce/happycloud"), 322 map[string][]string{ 323 "0.1.0": {"1.0"}, 324 "1.0.0": {"5.0"}, 325 "1.2.0": {"5.0"}, 326 "2.0.0": {"99.0"}, 327 }, 328 ``, 329 }, 330 { 331 addrs.MustParseProviderSourceString("example.com/weaksauce/no-versions"), 332 nil, 333 ``, 334 }, 335 { 336 addrs.MustParseProviderSourceString("example.com/nonexist/nonexist"), 337 nil, 338 `provider registry example.com does not have a provider named example.com/nonexist/nonexist`, 339 }, 340 } 341 for _, test := range tests { 342 t.Run(test.provider.String(), func(t *testing.T) { 343 client, err := source.registryClient(test.provider.Hostname) 344 if err != nil { 345 t.Fatal(err) 346 } 347 348 gotVersions, _, err := client.ProviderVersions(context.Background(), test.provider) 349 350 if err != nil { 351 if test.wantErr == "" { 352 t.Fatalf("wrong error\ngot: %s\nwant: <nil>", err.Error()) 353 } 354 if got, want := err.Error(), test.wantErr; got != want { 355 t.Fatalf("wrong error\ngot: %s\nwant: %s", got, want) 356 } 357 return 358 } 359 360 if test.wantErr != "" { 361 t.Fatalf("wrong error\ngot: <nil>\nwant: %s", test.wantErr) 362 } 363 364 if diff := cmp.Diff(test.wantVersions, gotVersions); diff != "" { 365 t.Errorf("wrong result\n%s", diff) 366 } 367 }) 368 } 369 } 370 371 func TestFindClosestProtocolCompatibleVersion(t *testing.T) { 372 source, _, close := testRegistrySource(t) 373 defer close() 374 375 tests := map[string]struct { 376 provider addrs.Provider 377 version Version 378 wantSuggestion Version 379 wantErr string 380 }{ 381 "pinned version too old": { 382 addrs.MustParseProviderSourceString("example.com/awesomesauce/happycloud"), 383 MustParseVersion("0.1.0"), 384 MustParseVersion("1.2.0"), 385 ``, 386 }, 387 "pinned version too new": { 388 addrs.MustParseProviderSourceString("example.com/awesomesauce/happycloud"), 389 MustParseVersion("2.0.0"), 390 MustParseVersion("1.2.0"), 391 ``, 392 }, 393 // This should not actually happen, the function is only meant to be 394 // called when the requested provider version is not supported 395 "pinned version just right": { 396 addrs.MustParseProviderSourceString("example.com/awesomesauce/happycloud"), 397 MustParseVersion("1.2.0"), 398 MustParseVersion("1.2.0"), 399 ``, 400 }, 401 "nonexisting provider": { 402 addrs.MustParseProviderSourceString("example.com/nonexist/nonexist"), 403 MustParseVersion("1.2.0"), 404 versions.Unspecified, 405 `provider registry example.com does not have a provider named example.com/nonexist/nonexist`, 406 }, 407 "versionless provider": { 408 addrs.MustParseProviderSourceString("example.com/weaksauce/no-versions"), 409 MustParseVersion("1.2.0"), 410 versions.Unspecified, 411 ``, 412 }, 413 "unsupported provider protocol": { 414 addrs.MustParseProviderSourceString("example.com/weaksauce/unsupported-protocol"), 415 MustParseVersion("1.0.0"), 416 versions.Unspecified, 417 ``, 418 }, 419 "provider protocol six": { 420 addrs.MustParseProviderSourceString("example.com/weaksauce/protocol-six"), 421 MustParseVersion("1.0.0"), 422 MustParseVersion("1.0.0"), 423 ``, 424 }, 425 } 426 for name, test := range tests { 427 t.Run(name, func(t *testing.T) { 428 client, err := source.registryClient(test.provider.Hostname) 429 if err != nil { 430 t.Fatal(err) 431 } 432 433 got, err := client.findClosestProtocolCompatibleVersion(context.Background(), test.provider, test.version) 434 435 if err != nil { 436 if test.wantErr == "" { 437 t.Fatalf("wrong error\ngot: %s\nwant: <nil>", err.Error()) 438 } 439 if got, want := err.Error(), test.wantErr; got != want { 440 t.Fatalf("wrong error\ngot: %s\nwant: %s", got, want) 441 } 442 return 443 } 444 445 if test.wantErr != "" { 446 t.Fatalf("wrong error\ngot: <nil>\nwant: %s", test.wantErr) 447 } 448 449 fmt.Printf("Got: %s, Want: %s\n", got, test.wantSuggestion) 450 451 if !got.Same(test.wantSuggestion) { 452 t.Fatalf("wrong result\ngot: %s\nwant: %s", got.String(), test.wantSuggestion.String()) 453 } 454 }) 455 } 456 }