github.com/opentofu/opentofu@v1.7.1/internal/getproviders/registry_client.go (about)

     1  // Copyright (c) The OpenTofu Authors
     2  // SPDX-License-Identifier: MPL-2.0
     3  // Copyright (c) 2023 HashiCorp, Inc.
     4  // SPDX-License-Identifier: MPL-2.0
     5  
     6  package getproviders
     7  
     8  import (
     9  	"context"
    10  	"crypto/sha256"
    11  	"encoding/hex"
    12  	"encoding/json"
    13  	"errors"
    14  	"fmt"
    15  	"io"
    16  	"log"
    17  	"net/http"
    18  	"net/url"
    19  	"os"
    20  	"path"
    21  	"strconv"
    22  	"time"
    23  
    24  	"github.com/hashicorp/go-retryablehttp"
    25  	svchost "github.com/hashicorp/terraform-svchost"
    26  	svcauth "github.com/hashicorp/terraform-svchost/auth"
    27  
    28  	"github.com/opentofu/opentofu/internal/addrs"
    29  	"github.com/opentofu/opentofu/internal/httpclient"
    30  	"github.com/opentofu/opentofu/internal/logging"
    31  	"github.com/opentofu/opentofu/version"
    32  )
    33  
    34  const (
    35  	terraformVersionHeader = "X-Terraform-Version"
    36  
    37  	// registryDiscoveryRetryEnvName is the name of the environment variable that
    38  	// can be configured to customize number of retries for module and provider
    39  	// discovery requests with the remote registry.
    40  	registryDiscoveryRetryEnvName = "TF_REGISTRY_DISCOVERY_RETRY"
    41  	defaultRetry                  = 1
    42  
    43  	// registryClientTimeoutEnvName is the name of the environment variable that
    44  	// can be configured to customize the timeout duration (seconds) for module
    45  	// and provider discovery with the remote registry.
    46  	registryClientTimeoutEnvName = "TF_REGISTRY_CLIENT_TIMEOUT"
    47  
    48  	// defaultRequestTimeout is the default timeout duration for requests to the
    49  	// remote registry.
    50  	defaultRequestTimeout = 10 * time.Second
    51  )
    52  
    53  var (
    54  	discoveryRetry int
    55  	requestTimeout time.Duration
    56  )
    57  
    58  func init() {
    59  	configureDiscoveryRetry()
    60  	configureRequestTimeout()
    61  }
    62  
    63  var SupportedPluginProtocols = MustParseVersionConstraints(">= 5, <7")
    64  
    65  // registryClient is a client for the provider registry protocol that is
    66  // specialized only for the needs of this package. It's not intended as a
    67  // general registry API client.
    68  type registryClient struct {
    69  	baseURL *url.URL
    70  	creds   svcauth.HostCredentials
    71  
    72  	httpClient *retryablehttp.Client
    73  }
    74  
    75  func newRegistryClient(baseURL *url.URL, creds svcauth.HostCredentials) *registryClient {
    76  	httpClient := httpclient.New()
    77  	httpClient.Timeout = requestTimeout
    78  
    79  	retryableClient := retryablehttp.NewClient()
    80  	retryableClient.HTTPClient = httpClient
    81  	retryableClient.RetryMax = discoveryRetry
    82  	retryableClient.RequestLogHook = requestLogHook
    83  	retryableClient.ErrorHandler = maxRetryErrorHandler
    84  
    85  	retryableClient.Logger = log.New(logging.LogOutput(), "", log.Flags())
    86  
    87  	return &registryClient{
    88  		baseURL:    baseURL,
    89  		creds:      creds,
    90  		httpClient: retryableClient,
    91  	}
    92  }
    93  
    94  // ProviderVersions returns the raw version and protocol strings produced by the
    95  // registry for the given provider.
    96  //
    97  // The returned error will be ErrRegistryProviderNotKnown if the registry responds with
    98  // 404 Not Found to indicate that the namespace or provider type are not known,
    99  // ErrUnauthorized if the registry responds with 401 or 403 status codes, or
   100  // ErrQueryFailed for any other protocol or operational problem.
   101  func (c *registryClient) ProviderVersions(ctx context.Context, addr addrs.Provider) (map[string][]string, []string, error) {
   102  	endpointPath, err := url.Parse(path.Join(addr.Namespace, addr.Type, "versions"))
   103  	if err != nil {
   104  		// Should never happen because we're constructing this from
   105  		// already-validated components.
   106  		return nil, nil, err
   107  	}
   108  	endpointURL := c.baseURL.ResolveReference(endpointPath)
   109  	req, err := retryablehttp.NewRequest("GET", endpointURL.String(), nil)
   110  	if err != nil {
   111  		return nil, nil, err
   112  	}
   113  	req = req.WithContext(ctx)
   114  	c.addHeadersToRequest(req.Request)
   115  
   116  	resp, err := c.httpClient.Do(req)
   117  	if err != nil {
   118  		return nil, nil, c.errQueryFailed(addr, err)
   119  	}
   120  	defer resp.Body.Close()
   121  
   122  	switch resp.StatusCode {
   123  	case http.StatusOK:
   124  		// Great!
   125  	case http.StatusNotFound:
   126  		return nil, nil, ErrRegistryProviderNotKnown{
   127  			Provider: addr,
   128  		}
   129  	case http.StatusUnauthorized, http.StatusForbidden:
   130  		return nil, nil, c.errUnauthorized(addr.Hostname)
   131  	default:
   132  		return nil, nil, c.errQueryFailed(addr, errors.New(resp.Status))
   133  	}
   134  
   135  	// We ignore the platforms portion of the response body, because the
   136  	// installer verifies the platform compatibility after pulling a provider
   137  	// versions' metadata.
   138  	type ResponseBody struct {
   139  		Versions []struct {
   140  			Version   string   `json:"version"`
   141  			Protocols []string `json:"protocols"`
   142  		} `json:"versions"`
   143  		Warnings []string `json:"warnings"`
   144  	}
   145  	var body ResponseBody
   146  
   147  	dec := json.NewDecoder(resp.Body)
   148  	if err := dec.Decode(&body); err != nil {
   149  		return nil, nil, c.errQueryFailed(addr, err)
   150  	}
   151  
   152  	if len(body.Versions) == 0 {
   153  		return nil, body.Warnings, nil
   154  	}
   155  
   156  	ret := make(map[string][]string, len(body.Versions))
   157  	for _, v := range body.Versions {
   158  		ret[v.Version] = v.Protocols
   159  	}
   160  
   161  	return ret, body.Warnings, nil
   162  }
   163  
   164  // PackageMeta returns metadata about a distribution package for a provider.
   165  //
   166  // The returned error will be one of the following:
   167  //
   168  //   - ErrPlatformNotSupported if the registry responds with 404 Not Found,
   169  //     under the assumption that the caller previously checked that the provider
   170  //     and version are valid.
   171  //   - ErrProtocolNotSupported if the requested provider version's protocols are not
   172  //     supported by this version of tofu.
   173  //   - ErrUnauthorized if the registry responds with 401 or 403 status codes
   174  //   - ErrQueryFailed for any other operational problem.
   175  func (c *registryClient) PackageMeta(ctx context.Context, provider addrs.Provider, version Version, target Platform) (PackageMeta, error) {
   176  	endpointPath, err := url.Parse(path.Join(
   177  		provider.Namespace,
   178  		provider.Type,
   179  		version.String(),
   180  		"download",
   181  		target.OS,
   182  		target.Arch,
   183  	))
   184  	if err != nil {
   185  		// Should never happen because we're constructing this from
   186  		// already-validated components.
   187  		return PackageMeta{}, err
   188  	}
   189  	endpointURL := c.baseURL.ResolveReference(endpointPath)
   190  
   191  	req, err := retryablehttp.NewRequest("GET", endpointURL.String(), nil)
   192  	if err != nil {
   193  		return PackageMeta{}, err
   194  	}
   195  	req = req.WithContext(ctx)
   196  	c.addHeadersToRequest(req.Request)
   197  
   198  	resp, err := c.httpClient.Do(req)
   199  	if err != nil {
   200  		return PackageMeta{}, c.errQueryFailed(provider, err)
   201  	}
   202  	defer resp.Body.Close()
   203  
   204  	switch resp.StatusCode {
   205  	case http.StatusOK:
   206  		// Great!
   207  	case http.StatusNotFound:
   208  		return PackageMeta{}, ErrPlatformNotSupported{
   209  			Provider: provider,
   210  			Version:  version,
   211  			Platform: target,
   212  		}
   213  	case http.StatusUnauthorized, http.StatusForbidden:
   214  		return PackageMeta{}, c.errUnauthorized(provider.Hostname)
   215  	default:
   216  		return PackageMeta{}, c.errQueryFailed(provider, errors.New(resp.Status))
   217  	}
   218  
   219  	type SigningKeyList struct {
   220  		GPGPublicKeys []*SigningKey `json:"gpg_public_keys"`
   221  	}
   222  	type ResponseBody struct {
   223  		Protocols   []string `json:"protocols"`
   224  		OS          string   `json:"os"`
   225  		Arch        string   `json:"arch"`
   226  		Filename    string   `json:"filename"`
   227  		DownloadURL string   `json:"download_url"`
   228  		SHA256Sum   string   `json:"shasum"`
   229  
   230  		SHA256SumsURL          string `json:"shasums_url"`
   231  		SHA256SumsSignatureURL string `json:"shasums_signature_url"`
   232  
   233  		SigningKeys SigningKeyList `json:"signing_keys"`
   234  	}
   235  	var body ResponseBody
   236  
   237  	dec := json.NewDecoder(resp.Body)
   238  	if err := dec.Decode(&body); err != nil {
   239  		return PackageMeta{}, c.errQueryFailed(provider, err)
   240  	}
   241  
   242  	var protoVersions VersionList
   243  	for _, versionStr := range body.Protocols {
   244  		v, err := ParseVersion(versionStr)
   245  		if err != nil {
   246  			return PackageMeta{}, c.errQueryFailed(
   247  				provider,
   248  				fmt.Errorf("registry response includes invalid version string %q: %w", versionStr, err),
   249  			)
   250  		}
   251  		protoVersions = append(protoVersions, v)
   252  	}
   253  	protoVersions.Sort()
   254  
   255  	// Verify that this version of tofu supports the providers' protocol
   256  	// version(s)
   257  	if len(protoVersions) > 0 {
   258  		supportedProtos := MeetingConstraints(SupportedPluginProtocols)
   259  		protoErr := ErrProtocolNotSupported{
   260  			Provider: provider,
   261  			Version:  version,
   262  		}
   263  		match := false
   264  		for _, version := range protoVersions {
   265  			if supportedProtos.Has(version) {
   266  				match = true
   267  			}
   268  		}
   269  		if !match {
   270  			// If the protocol version is not supported, try to find the closest
   271  			// matching version.
   272  			closest, err := c.findClosestProtocolCompatibleVersion(ctx, provider, version)
   273  			if err != nil {
   274  				return PackageMeta{}, err
   275  			}
   276  			protoErr.Suggestion = closest
   277  			return PackageMeta{}, protoErr
   278  		}
   279  	}
   280  
   281  	if body.OS != target.OS || body.Arch != target.Arch {
   282  		return PackageMeta{}, fmt.Errorf("registry response to request for %s archive has incorrect target %s", target, Platform{body.OS, body.Arch})
   283  	}
   284  
   285  	downloadURL, err := url.Parse(body.DownloadURL)
   286  	if err != nil {
   287  		return PackageMeta{}, fmt.Errorf("registry response includes invalid download URL: %w", err)
   288  	}
   289  	downloadURL = resp.Request.URL.ResolveReference(downloadURL)
   290  	if downloadURL.Scheme != "http" && downloadURL.Scheme != "https" {
   291  		return PackageMeta{}, fmt.Errorf("registry response includes invalid download URL: must use http or https scheme")
   292  	}
   293  
   294  	ret := PackageMeta{
   295  		Provider:         provider,
   296  		Version:          version,
   297  		ProtocolVersions: protoVersions,
   298  		TargetPlatform: Platform{
   299  			OS:   body.OS,
   300  			Arch: body.Arch,
   301  		},
   302  		Filename: body.Filename,
   303  		Location: PackageHTTPURL(downloadURL.String()),
   304  		// "Authentication" is populated below
   305  	}
   306  
   307  	if len(body.SHA256Sum) != sha256.Size*2 { // *2 because it's hex-encoded
   308  		return PackageMeta{}, c.errQueryFailed(
   309  			provider,
   310  			fmt.Errorf("registry response includes invalid SHA256 hash %q: %w", body.SHA256Sum, err),
   311  		)
   312  	}
   313  
   314  	var checksum [sha256.Size]byte
   315  	_, err = hex.Decode(checksum[:], []byte(body.SHA256Sum))
   316  	if err != nil {
   317  		return PackageMeta{}, c.errQueryFailed(
   318  			provider,
   319  			fmt.Errorf("registry response includes invalid SHA256 hash %q: %w", body.SHA256Sum, err),
   320  		)
   321  	}
   322  
   323  	shasumsURL, err := url.Parse(body.SHA256SumsURL)
   324  	if err != nil {
   325  		return PackageMeta{}, fmt.Errorf("registry response includes invalid SHASUMS URL: %w", err)
   326  	}
   327  	shasumsURL = resp.Request.URL.ResolveReference(shasumsURL)
   328  	if shasumsURL.Scheme != "http" && shasumsURL.Scheme != "https" {
   329  		return PackageMeta{}, fmt.Errorf("registry response includes invalid SHASUMS URL: must use http or https scheme")
   330  	}
   331  	document, err := c.getFile(shasumsURL)
   332  	if err != nil {
   333  		return PackageMeta{}, c.errQueryFailed(
   334  			provider,
   335  			fmt.Errorf("failed to retrieve authentication checksums for provider: %w", err),
   336  		)
   337  	}
   338  	signatureURL, err := url.Parse(body.SHA256SumsSignatureURL)
   339  	if err != nil {
   340  		return PackageMeta{}, fmt.Errorf("registry response includes invalid SHASUMS signature URL: %w", err)
   341  	}
   342  	signatureURL = resp.Request.URL.ResolveReference(signatureURL)
   343  	if signatureURL.Scheme != "http" && signatureURL.Scheme != "https" {
   344  		return PackageMeta{}, fmt.Errorf("registry response includes invalid SHASUMS signature URL: must use http or https scheme")
   345  	}
   346  	signature, err := c.getFile(signatureURL)
   347  	if err != nil {
   348  		return PackageMeta{}, c.errQueryFailed(
   349  			provider,
   350  			fmt.Errorf("failed to retrieve cryptographic signature for provider: %w", err),
   351  		)
   352  	}
   353  
   354  	keys := make([]SigningKey, len(body.SigningKeys.GPGPublicKeys))
   355  	for i, key := range body.SigningKeys.GPGPublicKeys {
   356  		keys[i] = *key
   357  	}
   358  
   359  	ret.Authentication = PackageAuthenticationAll(
   360  		NewMatchingChecksumAuthentication(document, body.Filename, checksum),
   361  		NewArchiveChecksumAuthentication(ret.TargetPlatform, checksum),
   362  		NewSignatureAuthentication(ret, document, signature, keys, &provider),
   363  	)
   364  
   365  	return ret, nil
   366  }
   367  
   368  // findClosestProtocolCompatibleVersion searches for the provider version with the closest protocol match.
   369  func (c *registryClient) findClosestProtocolCompatibleVersion(ctx context.Context, provider addrs.Provider, version Version) (Version, error) {
   370  	var match Version
   371  	available, _, err := c.ProviderVersions(ctx, provider)
   372  	if err != nil {
   373  		return UnspecifiedVersion, err
   374  	}
   375  
   376  	// extract the maps keys so we can make a sorted list of available versions.
   377  	versionList := make(VersionList, 0, len(available))
   378  	for versionStr := range available {
   379  		v, err := ParseVersion(versionStr)
   380  		if err != nil {
   381  			return UnspecifiedVersion, ErrQueryFailed{
   382  				Provider: provider,
   383  				Wrapped:  fmt.Errorf("registry response includes invalid version string %q: %w", versionStr, err),
   384  			}
   385  		}
   386  		versionList = append(versionList, v)
   387  	}
   388  	versionList.Sort() // lowest precedence first, preserving order when equal precedence
   389  
   390  	protoVersions := MeetingConstraints(SupportedPluginProtocols)
   391  FindMatch:
   392  	// put the versions in increasing order of precedence
   393  	for index := len(versionList) - 1; index >= 0; index-- { // walk backwards to consider newer versions first
   394  		for _, protoStr := range available[versionList[index].String()] {
   395  			p, err := ParseVersion(protoStr)
   396  			if err != nil {
   397  				return UnspecifiedVersion, ErrQueryFailed{
   398  					Provider: provider,
   399  					Wrapped:  fmt.Errorf("registry response includes invalid protocol string %q: %w", protoStr, err),
   400  				}
   401  			}
   402  			if protoVersions.Has(p) {
   403  				match = versionList[index]
   404  				break FindMatch
   405  			}
   406  		}
   407  	}
   408  	return match, nil
   409  }
   410  
   411  func (c *registryClient) addHeadersToRequest(req *http.Request) {
   412  	if c.creds != nil {
   413  		c.creds.PrepareRequest(req)
   414  	}
   415  	req.Header.Set(terraformVersionHeader, version.String())
   416  }
   417  
   418  func (c *registryClient) errQueryFailed(provider addrs.Provider, err error) error {
   419  	if err == context.Canceled {
   420  		// This one has a special error type so that callers can
   421  		// handle it in a different way.
   422  		return ErrRequestCanceled{}
   423  	}
   424  	return ErrQueryFailed{
   425  		Provider: provider,
   426  		Wrapped:  err,
   427  	}
   428  }
   429  
   430  func (c *registryClient) errUnauthorized(hostname svchost.Hostname) error {
   431  	return ErrUnauthorized{
   432  		Hostname:        hostname,
   433  		HaveCredentials: c.creds != nil,
   434  	}
   435  }
   436  
   437  func (c *registryClient) getFile(url *url.URL) ([]byte, error) {
   438  	resp, err := c.httpClient.Get(url.String())
   439  	if err != nil {
   440  		return nil, err
   441  	}
   442  	defer resp.Body.Close()
   443  
   444  	if resp.StatusCode != http.StatusOK {
   445  		return nil, fmt.Errorf("%s returned from %s", resp.Status, HostFromRequest(resp.Request))
   446  	}
   447  
   448  	data, err := io.ReadAll(resp.Body)
   449  	if err != nil {
   450  		return data, err
   451  	}
   452  
   453  	return data, nil
   454  }
   455  
   456  // configureDiscoveryRetry configures the number of retries the registry client
   457  // will attempt for requests with retryable errors, like 502 status codes
   458  func configureDiscoveryRetry() {
   459  	discoveryRetry = defaultRetry
   460  
   461  	if v := os.Getenv(registryDiscoveryRetryEnvName); v != "" {
   462  		retry, err := strconv.Atoi(v)
   463  		if err == nil && retry > 0 {
   464  			discoveryRetry = retry
   465  		}
   466  	}
   467  }
   468  
   469  func requestLogHook(logger retryablehttp.Logger, req *http.Request, i int) {
   470  	if i > 0 {
   471  		logger.Printf("[INFO] Previous request to the remote registry failed, attempting retry.")
   472  	}
   473  }
   474  
   475  func maxRetryErrorHandler(resp *http.Response, err error, numTries int) (*http.Response, error) {
   476  	// Close the body per library instructions
   477  	if resp != nil {
   478  		resp.Body.Close()
   479  	}
   480  
   481  	// Additional error detail: if we have a response, use the status code;
   482  	// if we have an error, use that; otherwise nothing. We will never have
   483  	// both response and error.
   484  	var errMsg string
   485  	if resp != nil {
   486  		errMsg = fmt.Sprintf(": %s returned from %s", resp.Status, HostFromRequest(resp.Request))
   487  	} else if err != nil {
   488  		errMsg = fmt.Sprintf(": %s", err)
   489  	}
   490  
   491  	// This function is always called with numTries=RetryMax+1. If we made any
   492  	// retry attempts, include that in the error message.
   493  	if numTries > 1 {
   494  		return resp, fmt.Errorf("the request failed after %d attempts, please try again later%s",
   495  			numTries, errMsg)
   496  	}
   497  	return resp, fmt.Errorf("the request failed, please try again later%s", errMsg)
   498  }
   499  
   500  // HostFromRequest extracts host the same way net/http Request.Write would,
   501  // accounting for empty Request.Host
   502  func HostFromRequest(req *http.Request) string {
   503  	if req.Host != "" {
   504  		return req.Host
   505  	}
   506  	if req.URL != nil {
   507  		return req.URL.Host
   508  	}
   509  
   510  	// this should never happen and if it does
   511  	// it will be handled as part of Request.Write()
   512  	// https://cs.opensource.google/go/go/+/refs/tags/go1.18.4:src/net/http/request.go;l=574
   513  	return ""
   514  }
   515  
   516  // configureRequestTimeout configures the registry client request timeout from
   517  // environment variables
   518  func configureRequestTimeout() {
   519  	requestTimeout = defaultRequestTimeout
   520  
   521  	if v := os.Getenv(registryClientTimeoutEnvName); v != "" {
   522  		timeout, err := strconv.Atoi(v)
   523  		if err == nil && timeout > 0 {
   524  			requestTimeout = time.Duration(timeout) * time.Second
   525  		}
   526  	}
   527  }