github.com/openshift/installer@v1.4.17/pkg/asset/installconfig/azure/client.go (about)

     1  package azure
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"strings"
     8  	"time"
     9  
    10  	azres "github.com/Azure/azure-sdk-for-go/profiles/2018-03-01/resources/mgmt/resources"
    11  	azsubs "github.com/Azure/azure-sdk-for-go/profiles/2018-03-01/resources/mgmt/subscriptions"
    12  	aznetwork "github.com/Azure/azure-sdk-for-go/profiles/2020-09-01/network/mgmt/network"
    13  	azenc "github.com/Azure/azure-sdk-for-go/profiles/latest/compute/mgmt/compute"
    14  	azmarketplace "github.com/Azure/azure-sdk-for-go/profiles/latest/marketplaceordering/mgmt/marketplaceordering"
    15  	"github.com/Azure/go-autorest/autorest/to"
    16  )
    17  
    18  //go:generate mockgen -source=./client.go -destination=mock/azureclient_generated.go -package=mock
    19  
    20  // API represents the calls made to the API.
    21  type API interface {
    22  	GetVirtualNetwork(ctx context.Context, resourceGroupName, virtualNetwork string) (*aznetwork.VirtualNetwork, error)
    23  	GetComputeSubnet(ctx context.Context, resourceGroupName, virtualNetwork, subnet string) (*aznetwork.Subnet, error)
    24  	GetControlPlaneSubnet(ctx context.Context, resourceGroupName, virtualNetwork, subnet string) (*aznetwork.Subnet, error)
    25  	ListLocations(ctx context.Context) (*[]azsubs.Location, error)
    26  	GetResourcesProvider(ctx context.Context, resourceProviderNamespace string) (*azres.Provider, error)
    27  	GetVirtualMachineSku(ctx context.Context, name, region string) (*azenc.ResourceSku, error)
    28  	GetVirtualMachineFamily(ctx context.Context, name, region string) (string, error)
    29  	GetDiskSkus(ctx context.Context, region string) ([]azenc.ResourceSku, error)
    30  	GetGroup(ctx context.Context, groupName string) (*azres.Group, error)
    31  	ListResourceIDsByGroup(ctx context.Context, groupName string) ([]string, error)
    32  	GetStorageEndpointSuffix(ctx context.Context) (string, error)
    33  	GetDiskEncryptionSet(ctx context.Context, subscriptionID, groupName string, diskEncryptionSetName string) (*azenc.DiskEncryptionSet, error)
    34  	GetHyperVGenerationVersion(ctx context.Context, instanceType string, region string, imageHyperVGen string) (string, error)
    35  	GetMarketplaceImage(ctx context.Context, region, publisher, offer, sku, version string) (azenc.VirtualMachineImage, error)
    36  	AreMarketplaceImageTermsAccepted(ctx context.Context, publisher, offer, sku string) (bool, error)
    37  	GetVMCapabilities(ctx context.Context, instanceType, region string) (map[string]string, error)
    38  	GetAvailabilityZones(ctx context.Context, region string, instanceType string) ([]string, error)
    39  	GetLocationInfo(ctx context.Context, region string, instanceType string) (*azenc.ResourceSkuLocationInfo, error)
    40  }
    41  
    42  // Client makes calls to the Azure API.
    43  type Client struct {
    44  	ssn *Session
    45  }
    46  
    47  // NewClient initializes a client with a session.
    48  func NewClient(ssn *Session) *Client {
    49  	client := &Client{
    50  		ssn: ssn,
    51  	}
    52  	return client
    53  }
    54  
    55  // GetVirtualNetwork gets an Azure virtual network by name
    56  func (c *Client) GetVirtualNetwork(ctx context.Context, resourceGroupName, virtualNetwork string) (*aznetwork.VirtualNetwork, error) {
    57  	ctx, cancel := context.WithTimeout(ctx, 1*time.Minute)
    58  	defer cancel()
    59  
    60  	vnetClient, err := c.getVirtualNetworksClient(ctx)
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  
    65  	vnet, err := vnetClient.Get(ctx, resourceGroupName, virtualNetwork, "")
    66  	if err != nil {
    67  		return nil, fmt.Errorf("failed to get virtual network %s: %w", virtualNetwork, err)
    68  	}
    69  
    70  	return &vnet, nil
    71  }
    72  
    73  // getSubnet gets an Azure subnet by name
    74  func (c *Client) getSubnet(ctx context.Context, resourceGroupName, virtualNetwork, subNetwork string) (*aznetwork.Subnet, error) {
    75  	ctx, cancel := context.WithTimeout(ctx, 1*time.Minute)
    76  	defer cancel()
    77  
    78  	subnetsClient, err := c.getSubnetsClient(ctx)
    79  	if err != nil {
    80  		return nil, err
    81  	}
    82  
    83  	subnet, err := subnetsClient.Get(ctx, resourceGroupName, virtualNetwork, subNetwork, "")
    84  	if err != nil {
    85  		return nil, fmt.Errorf("failed to get subnet %s: %w", subNetwork, err)
    86  	}
    87  
    88  	return &subnet, nil
    89  }
    90  
    91  // GetComputeSubnet gets the Azure compute subnet
    92  func (c *Client) GetComputeSubnet(ctx context.Context, resourceGroupName, virtualNetwork, subNetwork string) (*aznetwork.Subnet, error) {
    93  	return c.getSubnet(ctx, resourceGroupName, virtualNetwork, subNetwork)
    94  }
    95  
    96  // GetControlPlaneSubnet gets the Azure control plane subnet
    97  func (c *Client) GetControlPlaneSubnet(ctx context.Context, resourceGroupName, virtualNetwork, subNetwork string) (*aznetwork.Subnet, error) {
    98  	return c.getSubnet(ctx, resourceGroupName, virtualNetwork, subNetwork)
    99  }
   100  
   101  // getVnetsClient sets up a new client to retrieve vnets
   102  func (c *Client) getVirtualNetworksClient(ctx context.Context) (*aznetwork.VirtualNetworksClient, error) {
   103  	vnetsClient := aznetwork.NewVirtualNetworksClientWithBaseURI(c.ssn.Environment.ResourceManagerEndpoint, c.ssn.Credentials.SubscriptionID)
   104  	vnetsClient.Authorizer = c.ssn.Authorizer
   105  	return &vnetsClient, nil
   106  }
   107  
   108  // GetStorageEndpointSuffix retrieves the StorageEndpointSuffix from the
   109  // session environment
   110  func (c *Client) GetStorageEndpointSuffix(ctx context.Context) (string, error) {
   111  	return c.ssn.Environment.StorageEndpointSuffix, nil
   112  }
   113  
   114  // getSubnetsClient sets up a new client to retrieve a subnet
   115  func (c *Client) getSubnetsClient(ctx context.Context) (*aznetwork.SubnetsClient, error) {
   116  	subnetClient := aznetwork.NewSubnetsClientWithBaseURI(c.ssn.Environment.ResourceManagerEndpoint, c.ssn.Credentials.SubscriptionID)
   117  	subnetClient.Authorizer = c.ssn.Authorizer
   118  	return &subnetClient, nil
   119  }
   120  
   121  // ListLocations lists the Azure regions dir the given subscription
   122  func (c *Client) ListLocations(ctx context.Context) (*[]azsubs.Location, error) {
   123  	ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
   124  	defer cancel()
   125  
   126  	subsClient, err := c.getSubscriptionsClient(ctx)
   127  	if err != nil {
   128  		return nil, err
   129  	}
   130  
   131  	locations, err := subsClient.ListLocations(ctx, c.ssn.Credentials.SubscriptionID)
   132  	if err != nil {
   133  		return nil, fmt.Errorf("failed to list locations: %w", err)
   134  	}
   135  
   136  	return locations.Value, nil
   137  }
   138  
   139  // getSubscriptionsClient sets up a new client to retrieve subscription data
   140  func (c *Client) getSubscriptionsClient(ctx context.Context) (azsubs.Client, error) {
   141  	client := azsubs.NewClientWithBaseURI(c.ssn.Environment.ResourceManagerEndpoint)
   142  	client.Authorizer = c.ssn.Authorizer
   143  	return client, nil
   144  }
   145  
   146  // GetResourcesProvider gets the Azure resource provider
   147  func (c *Client) GetResourcesProvider(ctx context.Context, resourceProviderNamespace string) (*azres.Provider, error) {
   148  	ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
   149  	defer cancel()
   150  
   151  	providersClient, err := c.getProvidersClient(ctx)
   152  	if err != nil {
   153  		return nil, err
   154  	}
   155  
   156  	provider, err := providersClient.Get(ctx, resourceProviderNamespace, "")
   157  	if err != nil {
   158  		return nil, fmt.Errorf("failed to get resource provider %s: %w", resourceProviderNamespace, err)
   159  	}
   160  
   161  	return &provider, nil
   162  }
   163  
   164  // getProvidersClient sets up a new client to retrieve providers data
   165  func (c *Client) getProvidersClient(ctx context.Context) (azres.ProvidersClient, error) {
   166  	client := azres.NewProvidersClientWithBaseURI(c.ssn.Environment.ResourceManagerEndpoint, c.ssn.Credentials.SubscriptionID)
   167  	client.Authorizer = c.ssn.Authorizer
   168  	return client, nil
   169  }
   170  
   171  // GetDiskSkus returns all the disk SKU pages for a given region.
   172  func (c *Client) GetDiskSkus(ctx context.Context, region string) ([]azenc.ResourceSku, error) {
   173  	client := azenc.NewResourceSkusClientWithBaseURI(c.ssn.Environment.ResourceManagerEndpoint, c.ssn.Credentials.SubscriptionID)
   174  	client.Authorizer = c.ssn.Authorizer
   175  	// See https://issues.redhat.com/browse/OCPBUGS-29469 before changing this timeout
   176  	ctx, cancel := context.WithTimeout(ctx, 2*time.Minute)
   177  	defer cancel()
   178  
   179  	var sku []azenc.ResourceSku
   180  	filter := fmt.Sprintf("location eq '%s'", region)
   181  	// This has to be initialized outside the `for` because we need access to
   182  	// `err`. If initialized in the loop and the API call fails right away,
   183  	// `page.NotDone()` will return `false` and we'll never check for the error
   184  	skuPage, err := client.List(ctx, filter, "false")
   185  	if err != nil {
   186  		return nil, fmt.Errorf("failed to list SKUs: %w", err)
   187  	}
   188  	for ; skuPage.NotDone(); err = skuPage.NextWithContext(ctx) {
   189  		if err != nil {
   190  			return nil, fmt.Errorf("error fetching SKU pages: %w", err)
   191  		}
   192  		for _, page := range skuPage.Values() {
   193  			for _, diskRegion := range to.StringSlice(page.Locations) {
   194  				if strings.EqualFold(diskRegion, region) {
   195  					sku = append(sku, page)
   196  				}
   197  			}
   198  		}
   199  	}
   200  
   201  	if len(sku) != 0 {
   202  		return sku, nil
   203  	}
   204  
   205  	return nil, fmt.Errorf("no disks for specified subscription in region %s", region)
   206  }
   207  
   208  // GetGroup returns resource group for the groupName.
   209  func (c *Client) GetGroup(ctx context.Context, groupName string) (*azres.Group, error) {
   210  	client := azres.NewGroupsClientWithBaseURI(c.ssn.Environment.ResourceManagerEndpoint, c.ssn.Credentials.SubscriptionID)
   211  	client.Authorizer = c.ssn.Authorizer
   212  	ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
   213  	defer cancel()
   214  
   215  	res, err := client.Get(ctx, groupName)
   216  	if err != nil {
   217  		return nil, fmt.Errorf("failed to get resource group: %w", err)
   218  	}
   219  	return &res, nil
   220  }
   221  
   222  // ListResourceIDsByGroup returns a list of resource IDs for resource group groupName.
   223  func (c *Client) ListResourceIDsByGroup(ctx context.Context, groupName string) ([]string, error) {
   224  	client := azres.NewClientWithBaseURI(c.ssn.Environment.ResourceManagerEndpoint, c.ssn.Credentials.SubscriptionID)
   225  	client.Authorizer = c.ssn.Authorizer
   226  	ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
   227  	defer cancel()
   228  
   229  	var res []string
   230  	resPage, err := client.ListByResourceGroup(ctx, groupName, "", "", nil)
   231  	if err != nil {
   232  		return nil, fmt.Errorf("failed to list resources: %w", err)
   233  	}
   234  	for ; resPage.NotDone(); err = resPage.NextWithContext(ctx) {
   235  		if err != nil {
   236  			return nil, fmt.Errorf("error fetching resource pages: %w", err)
   237  		}
   238  		for _, page := range resPage.Values() {
   239  			res = append(res, to.String(page.ID))
   240  		}
   241  	}
   242  	return res, nil
   243  }
   244  
   245  // GetVirtualMachineSku retrieves the resource SKU of a specified virtual machine SKU in the specified region.
   246  func (c *Client) GetVirtualMachineSku(ctx context.Context, name, region string) (*azenc.ResourceSku, error) {
   247  	client := azenc.NewResourceSkusClientWithBaseURI(c.ssn.Environment.ResourceManagerEndpoint, c.ssn.Credentials.SubscriptionID)
   248  	client.Authorizer = c.ssn.Authorizer
   249  
   250  	// See https://issues.redhat.com/browse/OCPBUGS-29469 before chaging this timeout
   251  	ctx, cancel := context.WithTimeout(ctx, 2*time.Minute)
   252  	defer cancel()
   253  
   254  	filter := fmt.Sprintf("location eq '%s'", region)
   255  	// This has to be initialized outside the `for` because we need access to
   256  	// `err`. If initialized in the loop and the API call fails right away,
   257  	// `page.NotDone()` will return `false` and we'll never check for the error
   258  	page, err := client.List(ctx, filter, "false")
   259  	if err != nil {
   260  		return nil, fmt.Errorf("failed to list SKUs: %w", err)
   261  	}
   262  	for ; page.NotDone(); err = page.NextWithContext(ctx) {
   263  		if err != nil {
   264  			return nil, fmt.Errorf("error fetching SKU pages: %w", err)
   265  		}
   266  		for _, sku := range page.Values() {
   267  			// Filter out resources that are not virtualMachines
   268  			if !strings.EqualFold("virtualMachines", *sku.ResourceType) {
   269  				continue
   270  			}
   271  			// Filter out resources that do not match the provided name
   272  			if !strings.EqualFold(name, *sku.Name) {
   273  				continue
   274  			}
   275  			// Return the resource from the provided region
   276  			for _, location := range to.StringSlice(sku.Locations) {
   277  				if strings.EqualFold(location, region) {
   278  					return &sku, nil
   279  				}
   280  			}
   281  		}
   282  	}
   283  
   284  	return nil, nil
   285  }
   286  
   287  // GetDiskEncryptionSet retrieves the specified disk encryption set.
   288  func (c *Client) GetDiskEncryptionSet(ctx context.Context, subscriptionID, groupName, diskEncryptionSetName string) (*azenc.DiskEncryptionSet, error) {
   289  	client := azenc.NewDiskEncryptionSetsClientWithBaseURI(c.ssn.Environment.ResourceManagerEndpoint, subscriptionID)
   290  	client.Authorizer = c.ssn.Authorizer
   291  	ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
   292  	defer cancel()
   293  
   294  	diskEncryptionSet, err := client.Get(ctx, groupName, diskEncryptionSetName)
   295  	if err != nil {
   296  		return nil, fmt.Errorf("failed to get disk encryption set: %w", err)
   297  	}
   298  
   299  	return &diskEncryptionSet, nil
   300  }
   301  
   302  // GetVirtualMachineFamily retrieves the VM family of an instance type.
   303  func (c *Client) GetVirtualMachineFamily(ctx context.Context, name, region string) (string, error) {
   304  	typeMeta, err := c.GetVirtualMachineSku(ctx, name, region)
   305  	if err != nil {
   306  		return "", fmt.Errorf("error connecting to Azure client: %w", err)
   307  	}
   308  	if typeMeta == nil {
   309  		return "", fmt.Errorf("not found in region %s", region)
   310  	}
   311  	if typeMeta.Family == nil {
   312  		return "", fmt.Errorf("error getting resource family")
   313  	}
   314  
   315  	return to.String(typeMeta.Family), nil
   316  }
   317  
   318  // GetVMCapabilities retrieves the capabilities of an instant type in a specific region. Returns these values
   319  // in a map with the capability name as the key and the corresponding value.
   320  func (c *Client) GetVMCapabilities(ctx context.Context, instanceType, region string) (map[string]string, error) {
   321  	typeMeta, err := c.GetVirtualMachineSku(ctx, instanceType, region)
   322  	if err != nil {
   323  		return nil, fmt.Errorf("error connecting to Azure client: %w", err)
   324  	}
   325  	if typeMeta == nil {
   326  		return nil, fmt.Errorf("not found in region %s", region)
   327  	}
   328  	capabilities := make(map[string]string)
   329  	for _, capability := range *typeMeta.Capabilities {
   330  		capabilities[to.String(capability.Name)] = to.String(capability.Value)
   331  	}
   332  
   333  	return capabilities, nil
   334  }
   335  
   336  // GetHyperVGenerationVersion gets the HyperVGeneration version for the given instance type and marketplace image version, if specified. Defaults to V2 if either V1 or V2
   337  // available.
   338  func (c *Client) GetHyperVGenerationVersion(ctx context.Context, instanceType string, region string, imageHyperVGen string) (version string, err error) {
   339  	capabilities, err := c.GetVMCapabilities(ctx, instanceType, region)
   340  	if err != nil {
   341  		return "", err
   342  	}
   343  
   344  	return GetHyperVGenerationVersion(capabilities, imageHyperVGen)
   345  }
   346  
   347  // GetMarketplaceImage get the specified marketplace VM image.
   348  func (c *Client) GetMarketplaceImage(ctx context.Context, region, publisher, offer, sku, version string) (azenc.VirtualMachineImage, error) {
   349  	client := azenc.NewVirtualMachineImagesClientWithBaseURI(c.ssn.Environment.ResourceManagerEndpoint, c.ssn.Credentials.SubscriptionID)
   350  	client.Authorizer = c.ssn.Authorizer
   351  	ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
   352  	defer cancel()
   353  
   354  	image, err := client.Get(ctx, region, publisher, offer, sku, version)
   355  	if err != nil {
   356  		return image, fmt.Errorf("could not get marketplace image: %w", err)
   357  	}
   358  	return image, nil
   359  }
   360  
   361  // AreMarketplaceImageTermsAccepted tests whether the terms have been accepted for the specified marketplace VM image.
   362  func (c *Client) AreMarketplaceImageTermsAccepted(ctx context.Context, publisher, offer, sku string) (bool, error) {
   363  	client := azmarketplace.NewMarketplaceAgreementsClientWithBaseURI(c.ssn.Environment.ResourceManagerEndpoint, c.ssn.Credentials.SubscriptionID)
   364  	client.Authorizer = c.ssn.Authorizer
   365  	ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
   366  	defer cancel()
   367  
   368  	terms, err := client.Get(ctx, publisher, offer, sku)
   369  	if err != nil {
   370  		return false, err
   371  	}
   372  
   373  	if terms.AgreementProperties == nil {
   374  		return false, errors.New("no agreement properties for image")
   375  	}
   376  
   377  	return terms.AgreementProperties.Accepted != nil && *terms.AgreementProperties.Accepted, nil
   378  }
   379  
   380  // GetAvailabilityZones retrieves a list of availability zones for the given region, and instance type.
   381  func (c *Client) GetAvailabilityZones(ctx context.Context, region string, instanceType string) ([]string, error) {
   382  	locationInfo, err := c.GetLocationInfo(ctx, region, instanceType)
   383  	if err != nil {
   384  		return nil, err
   385  	}
   386  	if locationInfo != nil {
   387  		return to.StringSlice(locationInfo.Zones), nil
   388  	}
   389  
   390  	return nil, fmt.Errorf("error retrieving availability zones for %s in %s", instanceType, region)
   391  }
   392  
   393  // GetLocationInfo retrieves the location info associated with the instance type in region
   394  func (c *Client) GetLocationInfo(ctx context.Context, region string, instanceType string) (*azenc.ResourceSkuLocationInfo, error) {
   395  	client := azenc.NewResourceSkusClientWithBaseURI(c.ssn.Environment.ResourceManagerEndpoint, c.ssn.Credentials.SubscriptionID)
   396  	client.Authorizer = c.ssn.Authorizer
   397  
   398  	// Only supported filter atm is `location`
   399  	filter := fmt.Sprintf("location eq '%s'", region)
   400  	res, err := client.List(ctx, filter, "false")
   401  	if err != nil {
   402  		return nil, fmt.Errorf("failed to list SKUs: %w", err)
   403  	}
   404  	for ; res.NotDone(); err = res.NextWithContext(ctx) {
   405  		if err != nil {
   406  			return nil, err
   407  		}
   408  
   409  		for _, resSku := range res.Values() {
   410  			if !strings.EqualFold(to.String(resSku.ResourceType), "virtualMachines") {
   411  				continue
   412  			}
   413  			if strings.EqualFold(to.String(resSku.Name), instanceType) {
   414  				for _, locationInfo := range *resSku.LocationInfo {
   415  					return &locationInfo, nil
   416  				}
   417  			}
   418  		}
   419  	}
   420  
   421  	return nil, fmt.Errorf("location information not found for %s in %s", instanceType, region)
   422  }