github.com/rstandt/terraform@v0.12.32-0.20230710220336-b1063613405c/registry/test/mock_registry.go (about)

     1  package test
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"io"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"os"
    10  	"regexp"
    11  	"sort"
    12  	"strings"
    13  
    14  	version "github.com/hashicorp/go-version"
    15  	svchost "github.com/hashicorp/terraform-svchost"
    16  	"github.com/hashicorp/terraform-svchost/auth"
    17  	"github.com/hashicorp/terraform-svchost/disco"
    18  	"github.com/hashicorp/terraform/httpclient"
    19  	"github.com/hashicorp/terraform/registry/regsrc"
    20  	"github.com/hashicorp/terraform/registry/response"
    21  	tfversion "github.com/hashicorp/terraform/version"
    22  )
    23  
    24  // Disco return a *disco.Disco mapping registry.terraform.io, localhost,
    25  // localhost.localdomain, and example.com to the test server.
    26  func Disco(s *httptest.Server) *disco.Disco {
    27  	services := map[string]interface{}{
    28  		// Note that both with and without trailing slashes are supported behaviours
    29  		// TODO: add specific tests to enumerate both possibilities.
    30  		"modules.v1":   fmt.Sprintf("%s/v1/modules", s.URL),
    31  		"providers.v1": fmt.Sprintf("%s/v1/providers", s.URL),
    32  	}
    33  	d := disco.NewWithCredentialsSource(credsSrc)
    34  	d.SetUserAgent(httpclient.TerraformUserAgent(tfversion.String()))
    35  
    36  	d.ForceHostServices(svchost.Hostname("registry.terraform.io"), services)
    37  	d.ForceHostServices(svchost.Hostname("localhost"), services)
    38  	d.ForceHostServices(svchost.Hostname("localhost.localdomain"), services)
    39  	d.ForceHostServices(svchost.Hostname("example.com"), services)
    40  	return d
    41  }
    42  
    43  // Map of module names and location of test modules.
    44  // Only one version for now, as we only lookup latest from the registry.
    45  type testMod struct {
    46  	location string
    47  	version  string
    48  }
    49  
    50  // Map of provider names and location of test providers.
    51  // Only one version for now, as we only lookup latest from the registry.
    52  type testProvider struct {
    53  	version string
    54  	os      string
    55  	arch    string
    56  	url     string
    57  }
    58  
    59  const (
    60  	testCred = "test-auth-token"
    61  )
    62  
    63  var (
    64  	regHost  = svchost.Hostname(regsrc.PublicRegistryHost.Normalized())
    65  	credsSrc = auth.StaticCredentialsSource(map[svchost.Hostname]map[string]interface{}{
    66  		regHost: {"token": testCred},
    67  	})
    68  )
    69  
    70  // All the locationes from the mockRegistry start with a file:// scheme. If
    71  // the the location string here doesn't have a scheme, the mockRegistry will
    72  // find the absolute path and return a complete URL.
    73  var testMods = map[string][]testMod{
    74  	"registry/foo/bar": {{
    75  		location: "file:///download/registry/foo/bar/0.2.3//*?archive=tar.gz",
    76  		version:  "0.2.3",
    77  	}},
    78  	"registry/foo/baz": {{
    79  		location: "file:///download/registry/foo/baz/1.10.0//*?archive=tar.gz",
    80  		version:  "1.10.0",
    81  	}},
    82  	"registry/local/sub": {{
    83  		location: "testdata/registry-tar-subdir/foo.tgz//*?archive=tar.gz",
    84  		version:  "0.1.2",
    85  	}},
    86  	"exists-in-registry/identifier/provider": {{
    87  		location: "file:///registry/exists",
    88  		version:  "0.2.0",
    89  	}},
    90  	"relative/foo/bar": {{ // There is an exception for the "relative/" prefix in the test registry server
    91  		location: "/relative-path",
    92  		version:  "0.2.0",
    93  	}},
    94  	"test-versions/name/provider": {
    95  		{version: "2.2.0"},
    96  		{version: "2.1.1"},
    97  		{version: "1.2.2"},
    98  		{version: "1.2.1"},
    99  	},
   100  	"private/name/provider": {
   101  		{version: "1.0.0"},
   102  	},
   103  }
   104  
   105  var testProviders = map[string][]testProvider{
   106  	"-/foo": {
   107  		{
   108  			version: "0.2.3",
   109  			url:     "https://releases.hashicorp.com/terraform-provider-foo/0.2.3/terraform-provider-foo.zip",
   110  		},
   111  		{version: "0.3.0"},
   112  	},
   113  	"-/bar": {
   114  		{
   115  			version: "0.1.1",
   116  			url:     "https://releases.hashicorp.com/terraform-provider-bar/0.1.1/terraform-provider-bar.zip",
   117  		},
   118  		{version: "0.1.2"},
   119  	},
   120  }
   121  
   122  func providerAlias(provider string) string {
   123  	re := regexp.MustCompile("^-/")
   124  	if re.MatchString(provider) {
   125  		return re.ReplaceAllString(provider, "terraform-providers/")
   126  	}
   127  	return provider
   128  }
   129  
   130  func init() {
   131  	// Add provider aliases
   132  	for provider, info := range testProviders {
   133  		alias := providerAlias(provider)
   134  		testProviders[alias] = info
   135  	}
   136  }
   137  
   138  func latestVersion(versions []string) string {
   139  	var col version.Collection
   140  	for _, v := range versions {
   141  		ver, err := version.NewVersion(v)
   142  		if err != nil {
   143  			panic(err)
   144  		}
   145  		col = append(col, ver)
   146  	}
   147  
   148  	sort.Sort(col)
   149  	return col[len(col)-1].String()
   150  }
   151  
   152  func mockRegHandler() http.Handler {
   153  	mux := http.NewServeMux()
   154  
   155  	moduleDownload := func(w http.ResponseWriter, r *http.Request) {
   156  		p := strings.TrimLeft(r.URL.Path, "/")
   157  		// handle download request
   158  		re := regexp.MustCompile(`^([-a-z]+/\w+/\w+).*/download$`)
   159  		// download lookup
   160  		matches := re.FindStringSubmatch(p)
   161  		if len(matches) != 2 {
   162  			w.WriteHeader(http.StatusBadRequest)
   163  			return
   164  		}
   165  
   166  		// check for auth
   167  		if strings.Contains(matches[0], "private/") {
   168  			if !strings.Contains(r.Header.Get("Authorization"), testCred) {
   169  				http.Error(w, "", http.StatusForbidden)
   170  				return
   171  			}
   172  		}
   173  
   174  		versions, ok := testMods[matches[1]]
   175  		if !ok {
   176  			http.NotFound(w, r)
   177  			return
   178  		}
   179  		mod := versions[0]
   180  
   181  		location := mod.location
   182  		if !strings.HasPrefix(matches[0], "relative/") && !strings.HasPrefix(location, "file:///") {
   183  			// we can't use filepath.Abs because it will clean `//`
   184  			wd, _ := os.Getwd()
   185  			location = fmt.Sprintf("file://%s/%s", wd, location)
   186  		}
   187  
   188  		w.Header().Set("X-Terraform-Get", location)
   189  		w.WriteHeader(http.StatusNoContent)
   190  		// no body
   191  		return
   192  	}
   193  
   194  	moduleVersions := func(w http.ResponseWriter, r *http.Request) {
   195  		p := strings.TrimLeft(r.URL.Path, "/")
   196  		re := regexp.MustCompile(`^([-a-z]+/\w+/\w+)/versions$`)
   197  		matches := re.FindStringSubmatch(p)
   198  		if len(matches) != 2 {
   199  			w.WriteHeader(http.StatusBadRequest)
   200  			return
   201  		}
   202  
   203  		// check for auth
   204  		if strings.Contains(matches[1], "private/") {
   205  			if !strings.Contains(r.Header.Get("Authorization"), testCred) {
   206  				http.Error(w, "", http.StatusForbidden)
   207  			}
   208  		}
   209  
   210  		name := matches[1]
   211  		versions, ok := testMods[name]
   212  		if !ok {
   213  			http.NotFound(w, r)
   214  			return
   215  		}
   216  
   217  		// only adding the single requested module for now
   218  		// this is the minimal that any regisry is epected to support
   219  		mpvs := &response.ModuleProviderVersions{
   220  			Source: name,
   221  		}
   222  
   223  		for _, v := range versions {
   224  			mv := &response.ModuleVersion{
   225  				Version: v.version,
   226  			}
   227  			mpvs.Versions = append(mpvs.Versions, mv)
   228  		}
   229  
   230  		resp := response.ModuleVersions{
   231  			Modules: []*response.ModuleProviderVersions{mpvs},
   232  		}
   233  
   234  		js, err := json.Marshal(resp)
   235  		if err != nil {
   236  			http.Error(w, err.Error(), http.StatusInternalServerError)
   237  			return
   238  		}
   239  		w.Header().Set("Content-Type", "application/json")
   240  		w.Write(js)
   241  	}
   242  
   243  	mux.Handle("/v1/modules/",
   244  		http.StripPrefix("/v1/modules/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   245  			if strings.HasSuffix(r.URL.Path, "/download") {
   246  				moduleDownload(w, r)
   247  				return
   248  			}
   249  
   250  			if strings.HasSuffix(r.URL.Path, "/versions") {
   251  				moduleVersions(w, r)
   252  				return
   253  			}
   254  
   255  			http.NotFound(w, r)
   256  		})),
   257  	)
   258  
   259  	providerDownload := func(w http.ResponseWriter, r *http.Request) {
   260  		p := strings.TrimLeft(r.URL.Path, "/")
   261  		v := strings.Split(string(p), "/")
   262  
   263  		if len(v) != 6 {
   264  			w.WriteHeader(http.StatusBadRequest)
   265  			return
   266  		}
   267  
   268  		name := fmt.Sprintf("%s/%s", v[0], v[1])
   269  
   270  		providers, ok := testProviders[name]
   271  		if !ok {
   272  			http.NotFound(w, r)
   273  			return
   274  		}
   275  
   276  		// for this test / moment we will only return the one provider
   277  		loc := response.TerraformProviderPlatformLocation{
   278  			DownloadURL: providers[0].url,
   279  		}
   280  
   281  		js, err := json.Marshal(loc)
   282  		if err != nil {
   283  			http.Error(w, err.Error(), http.StatusInternalServerError)
   284  			return
   285  		}
   286  
   287  		w.Header().Set("Content-Type", "application/json")
   288  		w.Write(js)
   289  
   290  	}
   291  
   292  	providerVersions := func(w http.ResponseWriter, r *http.Request) {
   293  		p := strings.TrimLeft(r.URL.Path, "/")
   294  		re := regexp.MustCompile(`^([-a-z]+/\w+)/versions$`)
   295  		matches := re.FindStringSubmatch(p)
   296  
   297  		if len(matches) != 2 {
   298  			w.WriteHeader(http.StatusBadRequest)
   299  			return
   300  		}
   301  
   302  		// check for auth
   303  		if strings.Contains(matches[1], "private/") {
   304  			if !strings.Contains(r.Header.Get("Authorization"), testCred) {
   305  				http.Error(w, "", http.StatusForbidden)
   306  			}
   307  		}
   308  
   309  		name := providerAlias(fmt.Sprintf("%s", matches[1]))
   310  		versions, ok := testProviders[name]
   311  		if !ok {
   312  			http.NotFound(w, r)
   313  			return
   314  		}
   315  
   316  		// only adding the single requested provider for now
   317  		// this is the minimal that any registry is expected to support
   318  		pvs := &response.TerraformProviderVersions{
   319  			ID: name,
   320  		}
   321  
   322  		for _, v := range versions {
   323  			pv := &response.TerraformProviderVersion{
   324  				Version: v.version,
   325  			}
   326  			pvs.Versions = append(pvs.Versions, pv)
   327  		}
   328  
   329  		js, err := json.Marshal(pvs)
   330  		if err != nil {
   331  			http.Error(w, err.Error(), http.StatusInternalServerError)
   332  			return
   333  		}
   334  
   335  		w.Header().Set("Content-Type", "application/json")
   336  		w.Write(js)
   337  	}
   338  
   339  	mux.Handle("/v1/providers/",
   340  		http.StripPrefix("/v1/providers/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   341  			if strings.Contains(r.URL.Path, "/download") {
   342  				providerDownload(w, r)
   343  				return
   344  			}
   345  
   346  			if strings.HasSuffix(r.URL.Path, "/versions") {
   347  				providerVersions(w, r)
   348  				return
   349  			}
   350  
   351  			http.NotFound(w, r)
   352  		})),
   353  	)
   354  
   355  	mux.HandleFunc("/.well-known/terraform.json", func(w http.ResponseWriter, r *http.Request) {
   356  		w.Header().Set("Content-Type", "application/json")
   357  		io.WriteString(w, `{"modules.v1":"http://localhost/v1/modules/", "providers.v1":"http://localhost/v1/providers/"}`)
   358  	})
   359  	return mux
   360  }
   361  
   362  // Registry returns an httptest server that mocks out some registry functionality.
   363  func Registry() *httptest.Server {
   364  	return httptest.NewServer(mockRegHandler())
   365  }
   366  
   367  // RegistryRetryableErrorsServer returns an httptest server that mocks out the
   368  // registry API to return 502 errors.
   369  func RegistryRetryableErrorsServer() *httptest.Server {
   370  	mux := http.NewServeMux()
   371  	mux.HandleFunc("/v1/modules/", func(w http.ResponseWriter, r *http.Request) {
   372  		http.Error(w, "mocked server error", http.StatusBadGateway)
   373  	})
   374  	mux.HandleFunc("/v1/providers/", func(w http.ResponseWriter, r *http.Request) {
   375  		http.Error(w, "mocked server error", http.StatusBadGateway)
   376  	})
   377  	return httptest.NewServer(mux)
   378  }