github.com/opentofu/opentofu@v1.7.1/internal/registry/client.go (about)

     1  // Copyright (c) The OpenTofu Authors
     2  // SPDX-License-Identifier: MPL-2.0
     3  // Copyright (c) 2023 HashiCorp, Inc.
     4  // SPDX-License-Identifier: MPL-2.0
     5  
     6  package registry
     7  
     8  import (
     9  	"context"
    10  	"encoding/json"
    11  	"fmt"
    12  	"io"
    13  	"log"
    14  	"net/http"
    15  	"net/url"
    16  	"os"
    17  	"path"
    18  	"strconv"
    19  	"strings"
    20  	"time"
    21  
    22  	"github.com/hashicorp/go-retryablehttp"
    23  	svchost "github.com/hashicorp/terraform-svchost"
    24  	"github.com/hashicorp/terraform-svchost/disco"
    25  	"github.com/opentofu/opentofu/internal/httpclient"
    26  	"github.com/opentofu/opentofu/internal/logging"
    27  	"github.com/opentofu/opentofu/internal/registry/regsrc"
    28  	"github.com/opentofu/opentofu/internal/registry/response"
    29  	"github.com/opentofu/opentofu/version"
    30  )
    31  
    32  const (
    33  	xTerraformGet      = "X-Terraform-Get"
    34  	xTerraformVersion  = "X-Terraform-Version"
    35  	modulesServiceID   = "modules.v1"
    36  	providersServiceID = "providers.v1"
    37  
    38  	// registryDiscoveryRetryEnvName is the name of the environment variable that
    39  	// can be configured to customize number of retries for module and provider
    40  	// discovery requests with the remote registry.
    41  	registryDiscoveryRetryEnvName = "TF_REGISTRY_DISCOVERY_RETRY"
    42  	defaultRetry                  = 1
    43  
    44  	// registryClientTimeoutEnvName is the name of the environment variable that
    45  	// can be configured to customize the timeout duration (seconds) for module
    46  	// and provider discovery with the remote registry.
    47  	registryClientTimeoutEnvName = "TF_REGISTRY_CLIENT_TIMEOUT"
    48  
    49  	// defaultRequestTimeout is the default timeout duration for requests to the
    50  	// remote registry.
    51  	defaultRequestTimeout = 10 * time.Second
    52  )
    53  
    54  var (
    55  	tfVersion = version.String()
    56  
    57  	discoveryRetry int
    58  	requestTimeout time.Duration
    59  )
    60  
    61  func init() {
    62  	configureDiscoveryRetry()
    63  	configureRequestTimeout()
    64  }
    65  
    66  // Client provides methods to query OpenTofu Registries.
    67  type Client struct {
    68  	// this is the client to be used for all requests.
    69  	client *retryablehttp.Client
    70  
    71  	// services is a required *disco.Disco, which may have services and
    72  	// credentials pre-loaded.
    73  	services *disco.Disco
    74  }
    75  
    76  // NewClient returns a new initialized registry client.
    77  func NewClient(services *disco.Disco, client *http.Client) *Client {
    78  	if services == nil {
    79  		services = disco.New()
    80  	}
    81  
    82  	if client == nil {
    83  		client = httpclient.New()
    84  		client.Timeout = requestTimeout
    85  	}
    86  	retryableClient := retryablehttp.NewClient()
    87  	retryableClient.HTTPClient = client
    88  	retryableClient.RetryMax = discoveryRetry
    89  	retryableClient.RequestLogHook = requestLogHook
    90  	retryableClient.ErrorHandler = maxRetryErrorHandler
    91  
    92  	logOutput := logging.LogOutput()
    93  	retryableClient.Logger = log.New(logOutput, "", log.Flags())
    94  
    95  	services.Transport = retryableClient.HTTPClient.Transport
    96  
    97  	services.SetUserAgent(httpclient.OpenTofuUserAgent(version.String()))
    98  
    99  	return &Client{
   100  		client:   retryableClient,
   101  		services: services,
   102  	}
   103  }
   104  
   105  // Discover queries the host, and returns the url for the registry.
   106  func (c *Client) Discover(host svchost.Hostname, serviceID string) (*url.URL, error) {
   107  	service, err := c.services.DiscoverServiceURL(host, serviceID)
   108  	if err != nil {
   109  		return nil, &ServiceUnreachableError{err}
   110  	}
   111  	if !strings.HasSuffix(service.Path, "/") {
   112  		service.Path += "/"
   113  	}
   114  	return service, nil
   115  }
   116  
   117  // ModuleVersions queries the registry for a module, and returns the available versions.
   118  func (c *Client) ModuleVersions(ctx context.Context, module *regsrc.Module) (*response.ModuleVersions, error) {
   119  	host, err := module.SvcHost()
   120  	if err != nil {
   121  		return nil, err
   122  	}
   123  
   124  	service, err := c.Discover(host, modulesServiceID)
   125  	if err != nil {
   126  		return nil, err
   127  	}
   128  
   129  	p, err := url.Parse(path.Join(module.Module(), "versions"))
   130  	if err != nil {
   131  		return nil, err
   132  	}
   133  
   134  	service = service.ResolveReference(p)
   135  
   136  	log.Printf("[DEBUG] fetching module versions from %q", service)
   137  
   138  	req, err := retryablehttp.NewRequest("GET", service.String(), nil)
   139  	if err != nil {
   140  		return nil, err
   141  	}
   142  	req = req.WithContext(ctx)
   143  
   144  	c.addRequestCreds(host, req.Request)
   145  	req.Header.Set(xTerraformVersion, tfVersion)
   146  
   147  	resp, err := c.client.Do(req)
   148  	if err != nil {
   149  		return nil, err
   150  	}
   151  	defer resp.Body.Close()
   152  
   153  	switch resp.StatusCode {
   154  	case http.StatusOK:
   155  		// OK
   156  	case http.StatusNotFound:
   157  		return nil, &errModuleNotFound{addr: module}
   158  	default:
   159  		return nil, fmt.Errorf("error looking up module versions: %s", resp.Status)
   160  	}
   161  
   162  	var versions response.ModuleVersions
   163  
   164  	dec := json.NewDecoder(resp.Body)
   165  	if err := dec.Decode(&versions); err != nil {
   166  		return nil, err
   167  	}
   168  
   169  	for _, mod := range versions.Modules {
   170  		for _, v := range mod.Versions {
   171  			log.Printf("[DEBUG] found available version %q for %s", v.Version, mod.Source)
   172  		}
   173  	}
   174  
   175  	return &versions, nil
   176  }
   177  
   178  func (c *Client) addRequestCreds(host svchost.Hostname, req *http.Request) {
   179  	creds, err := c.services.CredentialsForHost(host)
   180  	if err != nil {
   181  		log.Printf("[WARN] Failed to get credentials for %s: %s (ignoring)", host, err)
   182  		return
   183  	}
   184  
   185  	if creds != nil {
   186  		creds.PrepareRequest(req)
   187  	}
   188  }
   189  
   190  // ModuleLocation find the download location for a specific version module.
   191  // This returns a string, because the final location may contain special go-getter syntax.
   192  func (c *Client) ModuleLocation(ctx context.Context, module *regsrc.Module, version string) (string, error) {
   193  	host, err := module.SvcHost()
   194  	if err != nil {
   195  		return "", err
   196  	}
   197  
   198  	service, err := c.Discover(host, modulesServiceID)
   199  	if err != nil {
   200  		return "", err
   201  	}
   202  
   203  	var p *url.URL
   204  	if version == "" {
   205  		p, err = url.Parse(path.Join(module.Module(), "download"))
   206  	} else {
   207  		p, err = url.Parse(path.Join(module.Module(), version, "download"))
   208  	}
   209  	if err != nil {
   210  		return "", err
   211  	}
   212  	download := service.ResolveReference(p)
   213  
   214  	log.Printf("[DEBUG] looking up module location from %q", download)
   215  
   216  	req, err := retryablehttp.NewRequest("GET", download.String(), nil)
   217  	if err != nil {
   218  		return "", err
   219  	}
   220  
   221  	req = req.WithContext(ctx)
   222  
   223  	c.addRequestCreds(host, req.Request)
   224  	req.Header.Set(xTerraformVersion, tfVersion)
   225  
   226  	resp, err := c.client.Do(req)
   227  	if err != nil {
   228  		return "", err
   229  	}
   230  	defer resp.Body.Close()
   231  
   232  	body, err := io.ReadAll(resp.Body)
   233  	if err != nil {
   234  		return "", fmt.Errorf("error reading response body from registry: %w", err)
   235  	}
   236  
   237  	var location string
   238  
   239  	switch resp.StatusCode {
   240  	case http.StatusOK:
   241  		var v response.ModuleLocationRegistryResp
   242  		if err := json.Unmarshal(body, &v); err != nil {
   243  			return "", fmt.Errorf("module %q version %q failed to deserialize response body %s: %w",
   244  				module, version, body, err)
   245  		}
   246  
   247  		location = v.Location
   248  
   249  	case http.StatusNoContent:
   250  		// FALLBACK: set the found location from the header
   251  		location = resp.Header.Get(xTerraformGet)
   252  
   253  	case http.StatusNotFound:
   254  		return "", fmt.Errorf("module %q version %q not found", module, version)
   255  
   256  	default:
   257  		// anything else is an error:
   258  		return "", fmt.Errorf("error getting download location for %q: %s resp:%s", module, resp.Status, body)
   259  	}
   260  
   261  	if location == "" {
   262  		return "", fmt.Errorf("failed to get download URL for %q: %s resp:%s", module, resp.Status, body)
   263  	}
   264  
   265  	// If location looks like it's trying to be a relative URL, treat it as
   266  	// one.
   267  	//
   268  	// We don't do this for just _any_ location, since the X-Terraform-Get
   269  	// header is a go-getter location rather than a URL, and so not all
   270  	// possible values will parse reasonably as URLs.)
   271  	//
   272  	// When used in conjunction with go-getter we normally require this header
   273  	// to be an absolute URL, but we are more liberal here because third-party
   274  	// registry implementations may not "know" their own absolute URLs if
   275  	// e.g. they are running behind a reverse proxy frontend, or such.
   276  	if strings.HasPrefix(location, "/") || strings.HasPrefix(location, "./") || strings.HasPrefix(location, "../") {
   277  		locationURL, err := url.Parse(location)
   278  		if err != nil {
   279  			return "", fmt.Errorf("invalid relative URL for %q: %w", module, err)
   280  		}
   281  		locationURL = download.ResolveReference(locationURL)
   282  		location = locationURL.String()
   283  	}
   284  
   285  	return location, nil
   286  }
   287  
   288  // configureDiscoveryRetry configures the number of retries the registry client
   289  // will attempt for requests with retryable errors, like 502 status codes
   290  func configureDiscoveryRetry() {
   291  	discoveryRetry = defaultRetry
   292  
   293  	if v := os.Getenv(registryDiscoveryRetryEnvName); v != "" {
   294  		retry, err := strconv.Atoi(v)
   295  		if err == nil && retry > 0 {
   296  			discoveryRetry = retry
   297  		}
   298  	}
   299  }
   300  
   301  func requestLogHook(logger retryablehttp.Logger, req *http.Request, i int) {
   302  	if i > 0 {
   303  		logger.Printf("[INFO] Previous request to the remote registry failed, attempting retry.")
   304  	}
   305  }
   306  
   307  func maxRetryErrorHandler(resp *http.Response, err error, numTries int) (*http.Response, error) {
   308  	// Close the body per library instructions
   309  	if resp != nil {
   310  		resp.Body.Close()
   311  	}
   312  
   313  	// Additional error detail: if we have a response, use the status code;
   314  	// if we have an error, use that; otherwise nothing. We will never have
   315  	// both response and error.
   316  	var errMsg string
   317  	if resp != nil {
   318  		errMsg = fmt.Sprintf(": %s returned from %s", resp.Status, resp.Request.URL)
   319  	} else if err != nil {
   320  		errMsg = fmt.Sprintf(": %s", err)
   321  	}
   322  
   323  	// This function is always called with numTries=RetryMax+1. If we made any
   324  	// retry attempts, include that in the error message.
   325  	if numTries > 1 {
   326  		return resp, fmt.Errorf("the request failed after %d attempts, please try again later%s",
   327  			numTries, errMsg)
   328  	}
   329  	return resp, fmt.Errorf("the request failed, please try again later%s", errMsg)
   330  }
   331  
   332  // configureRequestTimeout configures the registry client request timeout from
   333  // environment variables
   334  func configureRequestTimeout() {
   335  	requestTimeout = defaultRequestTimeout
   336  
   337  	if v := os.Getenv(registryClientTimeoutEnvName); v != "" {
   338  		timeout, err := strconv.Atoi(v)
   339  		if err == nil && timeout > 0 {
   340  			requestTimeout = time.Duration(timeout) * time.Second
   341  		}
   342  	}
   343  }