github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/cmd/roachprod/vm/azure/azure.go (about)

     1  // Copyright 2019 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package azure
    12  
    13  import (
    14  	"context"
    15  	"encoding/base64"
    16  	"encoding/json"
    17  	"fmt"
    18  	"io/ioutil"
    19  	"log"
    20  	"os"
    21  	"strconv"
    22  	"strings"
    23  	"time"
    24  
    25  	"github.com/Azure/azure-sdk-for-go/profiles/latest/compute/mgmt/compute"
    26  	"github.com/Azure/azure-sdk-for-go/profiles/latest/network/mgmt/network"
    27  	"github.com/Azure/azure-sdk-for-go/profiles/latest/resources/mgmt/resources"
    28  	"github.com/Azure/azure-sdk-for-go/profiles/latest/resources/mgmt/subscriptions"
    29  	"github.com/Azure/go-autorest/autorest"
    30  	"github.com/Azure/go-autorest/autorest/to"
    31  	"github.com/cockroachdb/cockroach/pkg/cmd/roachprod/vm"
    32  	"github.com/cockroachdb/cockroach/pkg/cmd/roachprod/vm/flagstub"
    33  	"github.com/cockroachdb/cockroach/pkg/util/syncutil"
    34  	"github.com/cockroachdb/cockroach/pkg/util/timeutil"
    35  	"github.com/cockroachdb/errors"
    36  	"golang.org/x/sync/errgroup"
    37  )
    38  
    39  const (
    40  	// ProviderName is "azure".
    41  	ProviderName = "azure"
    42  	remoteUser   = "ubuntu"
    43  	tagCluster   = "cluster"
    44  	tagComment   = "comment"
    45  	// RFC3339-formatted timestamp.
    46  	tagCreated   = "created"
    47  	tagLifetime  = "lifetime"
    48  	tagRoachprod = "roachprod"
    49  	tagSubnet    = "subnetPrefix"
    50  )
    51  
    52  // init registers Provider with the top-level vm package.
    53  func init() {
    54  	const unimplemented = "please install the Azure CLI utilities +" +
    55  		"(https://docs.microsoft.com/en-us/cli/azure/install-azure-cli)"
    56  
    57  	p := New()
    58  
    59  	if _, err := p.getAuthToken(); err == nil {
    60  		vm.Providers[ProviderName] = p
    61  	} else {
    62  		vm.Providers[ProviderName] = flagstub.New(p, unimplemented)
    63  	}
    64  }
    65  
    66  // Provider implements the vm.Provider interface for the Microsoft Azure
    67  // cloud.
    68  type Provider struct {
    69  	opts providerOpts
    70  	mu   struct {
    71  		syncutil.Mutex
    72  
    73  		authorizer     autorest.Authorizer
    74  		subscription   subscriptions.Subscription
    75  		resourceGroups map[string]resources.Group
    76  		subnets        map[string]network.Subnet
    77  	}
    78  }
    79  
    80  // New constructs a new Provider instance.
    81  func New() *Provider {
    82  	p := &Provider{}
    83  	p.mu.resourceGroups = make(map[string]resources.Group)
    84  	p.mu.subnets = make(map[string]network.Subnet)
    85  	return p
    86  }
    87  
    88  // Active implements vm.Provider and always returns true.
    89  func (p *Provider) Active() bool {
    90  	return true
    91  }
    92  
    93  // CleanSSH implements vm.Provider, is a no-op, and returns nil.
    94  func (p *Provider) CleanSSH() error {
    95  	return nil
    96  }
    97  
    98  // ConfigSSH implements vm.Provider, is a no-op, and returns nil.
    99  // On Azure, the SSH public key is set as part of VM instance creation.
   100  func (p *Provider) ConfigSSH() error {
   101  	return nil
   102  }
   103  
   104  // Create implements vm.Provider.
   105  func (p *Provider) Create(names []string, opts vm.CreateOpts) error {
   106  	// Load the user's SSH public key to configure the resulting VMs.
   107  	var sshKey string
   108  	sshFile := os.ExpandEnv("${HOME}/.ssh/id_rsa.pub")
   109  	if _, err := os.Stat(sshFile); err == nil {
   110  		if bytes, err := ioutil.ReadFile(sshFile); err == nil {
   111  			sshKey = string(bytes)
   112  		} else {
   113  			return errors.Wrapf(err, "could not read SSH public key file")
   114  		}
   115  	} else {
   116  		return errors.Wrapf(err, "could not find SSH public key file")
   117  	}
   118  
   119  	ctx, cancel := context.WithTimeout(context.Background(), p.opts.operationTimeout)
   120  	defer cancel()
   121  
   122  	if len(p.opts.locations) == 0 {
   123  		if opts.GeoDistributed {
   124  			p.opts.locations = defaultLocations
   125  		} else {
   126  			p.opts.locations = []string{defaultLocations[0]}
   127  		}
   128  	}
   129  
   130  	if _, err := p.createVNets(ctx, p.opts.locations); err != nil {
   131  		return err
   132  	}
   133  
   134  	// Effectively a map of node number to location.
   135  	nodeLocations := vm.ZonePlacement(len(p.opts.locations), len(names))
   136  	// Invert it.
   137  	nodesByLocIdx := make(map[int][]int, len(p.opts.locations))
   138  	for nodeIdx, locIdx := range nodeLocations {
   139  		nodesByLocIdx[locIdx] = append(nodesByLocIdx[locIdx], nodeIdx)
   140  	}
   141  
   142  	errs, _ := errgroup.WithContext(ctx)
   143  	for locIdx, nodes := range nodesByLocIdx {
   144  		// Shadow variables for closure.
   145  		locIdx := locIdx
   146  		nodes := nodes
   147  		errs.Go(func() error {
   148  			location := p.opts.locations[locIdx]
   149  
   150  			// Create a resource group within the location.
   151  			group, err := p.getResourceGroup(ctx, opts.ClusterName, location, opts)
   152  			if err != nil {
   153  				return err
   154  			}
   155  
   156  			p.mu.Lock()
   157  			subnet, ok := p.mu.subnets[location]
   158  			p.mu.Unlock()
   159  			if !ok {
   160  				return errors.Errorf("missing subnet for location %q", location)
   161  			}
   162  
   163  			for _, nodeIdx := range nodes {
   164  				name := names[nodeIdx]
   165  				errs.Go(func() error {
   166  					_, err := p.createVM(ctx, group, subnet, name, sshKey, opts)
   167  					err = errors.Wrapf(err, "creating VM %s", name)
   168  					if err == nil {
   169  						log.Printf("created VM %s", name)
   170  					}
   171  					return err
   172  				})
   173  			}
   174  			return nil
   175  		})
   176  	}
   177  	return errs.Wait()
   178  }
   179  
   180  // Delete implements the vm.Provider interface.
   181  func (p *Provider) Delete(vms vm.List) error {
   182  	ctx, cancel := context.WithTimeout(context.Background(), p.opts.operationTimeout)
   183  	defer cancel()
   184  
   185  	sub, err := p.getSubscription(ctx)
   186  	if err != nil {
   187  		return err
   188  	}
   189  	client := compute.NewVirtualMachinesClient(*sub.ID)
   190  	if client.Authorizer, err = p.getAuthorizer(); err != nil {
   191  		return err
   192  	}
   193  
   194  	var futures []compute.VirtualMachinesDeleteFuture
   195  	for _, vm := range vms {
   196  		parts, err := parseAzureID(vm.ProviderID)
   197  		if err != nil {
   198  			return err
   199  		}
   200  		future, err := client.Delete(ctx, parts.resourceGroup, parts.resourceName)
   201  		if err != nil {
   202  			return errors.Wrapf(err, "could not delete %s", vm.ProviderID)
   203  		}
   204  		futures = append(futures, future)
   205  	}
   206  
   207  	if !p.opts.syncDelete {
   208  		return nil
   209  	}
   210  
   211  	for _, future := range futures {
   212  		if err := future.WaitForCompletionRef(ctx, client.Client); err != nil {
   213  			return err
   214  		}
   215  		if _, err := future.Result(client); err != nil {
   216  			return err
   217  		}
   218  	}
   219  	return nil
   220  }
   221  
   222  // DeleteCluster implements the vm.DeleteCluster interface, providing
   223  // a fast-path to tear down all resources associated with a cluster.
   224  func (p *Provider) DeleteCluster(name string) error {
   225  	ctx, cancel := context.WithTimeout(context.Background(), p.opts.operationTimeout)
   226  	defer cancel()
   227  
   228  	sub, err := p.getSubscription(ctx)
   229  	if err != nil {
   230  		return err
   231  	}
   232  	client := resources.NewGroupsClient(*sub.SubscriptionID)
   233  	if client.Authorizer, err = p.getAuthorizer(); err != nil {
   234  		return err
   235  	}
   236  
   237  	filter := fmt.Sprintf("tagName eq '%s' and tagValue eq '%s'", tagCluster, name)
   238  	it, err := client.ListComplete(ctx, filter, nil /* limit */)
   239  	if err != nil {
   240  		return err
   241  	}
   242  
   243  	var futures []resources.GroupsDeleteFuture
   244  	for it.NotDone() {
   245  		group := it.Value()
   246  		// Don't bother waiting for the cluster to get torn down.
   247  		future, err := client.Delete(ctx, *group.Name)
   248  		if err != nil {
   249  			return err
   250  		}
   251  		log.Printf("marked Azure resource group %s for deletion\n", *group.Name)
   252  		futures = append(futures, future)
   253  
   254  		if err := it.NextWithContext(ctx); err != nil {
   255  			return err
   256  		}
   257  	}
   258  
   259  	if !p.opts.syncDelete {
   260  		return nil
   261  	}
   262  
   263  	for _, future := range futures {
   264  		if err := future.WaitForCompletionRef(ctx, client.Client); err != nil {
   265  			return err
   266  		}
   267  		if _, err := future.Result(client); err != nil {
   268  			return err
   269  		}
   270  	}
   271  	return nil
   272  }
   273  
   274  // Extend implements the vm.Provider interface.
   275  func (p *Provider) Extend(vms vm.List, lifetime time.Duration) error {
   276  	ctx, cancel := context.WithTimeout(context.Background(), p.opts.operationTimeout)
   277  	defer cancel()
   278  
   279  	sub, err := p.getSubscription(ctx)
   280  	if err != nil {
   281  		return err
   282  	}
   283  	client := compute.NewVirtualMachinesClient(*sub.ID)
   284  	if client.Authorizer, err = p.getAuthorizer(); err != nil {
   285  		return err
   286  	}
   287  
   288  	futures := make([]compute.VirtualMachinesUpdateFuture, len(vms))
   289  	for idx, vm := range vms {
   290  		vmParts, err := parseAzureID(vm.ProviderID)
   291  		if err != nil {
   292  			return err
   293  		}
   294  		update := compute.VirtualMachineUpdate{
   295  			Tags: map[string]*string{
   296  				tagLifetime: to.StringPtr(lifetime.String()),
   297  			},
   298  		}
   299  		futures[idx], err = client.Update(ctx, vmParts.resourceGroup, vmParts.resourceName, update)
   300  		if err != nil {
   301  			return err
   302  		}
   303  	}
   304  
   305  	for _, future := range futures {
   306  		if err := future.WaitForCompletionRef(ctx, client.Client); err != nil {
   307  			return err
   308  		}
   309  		if _, err := future.Result(client); err != nil {
   310  			return err
   311  		}
   312  	}
   313  	return nil
   314  }
   315  
   316  // FindActiveAccount implements vm.Provider.
   317  func (p *Provider) FindActiveAccount() (string, error) {
   318  	// It's a JSON Web Token, so we'll just dissect it enough to get the
   319  	// data that we want. There are three base64-encoded segments
   320  	// separated by periods. The second segment is the "claims" JSON
   321  	// object.
   322  	token, err := p.getAuthToken()
   323  	if err != nil {
   324  		return "", err
   325  	}
   326  
   327  	parts := strings.Split(token, ".")
   328  	if len(parts) != 3 {
   329  		return "", errors.Errorf("unexpected number of segments; expected 3, had %d", len(parts))
   330  	}
   331  
   332  	a := base64.NewDecoder(base64.RawStdEncoding, strings.NewReader(parts[1]))
   333  	var data struct {
   334  		Username string `json:"upn"`
   335  	}
   336  	if err := json.NewDecoder(a).Decode(&data); err != nil {
   337  		return "", errors.Wrapf(err, "could not decode JWT claims segment")
   338  	}
   339  
   340  	// This is in an email address format, we just want the username.
   341  	return data.Username[:strings.Index(data.Username, "@")], nil
   342  }
   343  
   344  // Flags implements the vm.Provider interface.
   345  func (p *Provider) Flags() vm.ProviderFlags {
   346  	return &p.opts
   347  }
   348  
   349  // List implements the vm.Provider interface. This will query all
   350  // Azure VMs in the subscription and select those with a roachprod tag.
   351  func (p *Provider) List() (vm.List, error) {
   352  	ctx, cancel := context.WithTimeout(context.Background(), p.opts.operationTimeout)
   353  	defer cancel()
   354  
   355  	sub, err := p.getSubscription(ctx)
   356  	if err != nil {
   357  		return nil, err
   358  	}
   359  
   360  	// We're just going to list all VMs and filter.
   361  	client := compute.NewVirtualMachinesClient(*sub.SubscriptionID)
   362  	if client.Authorizer, err = p.getAuthorizer(); err != nil {
   363  		return nil, err
   364  	}
   365  
   366  	it, err := client.ListAllComplete(ctx)
   367  	if err != nil {
   368  		return nil, err
   369  	}
   370  
   371  	var ret vm.List
   372  	for it.NotDone() {
   373  		found := it.Value()
   374  
   375  		if _, ok := found.Tags[tagRoachprod]; !ok {
   376  			if err := it.NextWithContext(ctx); err != nil {
   377  				return nil, err
   378  			}
   379  			continue
   380  		}
   381  
   382  		m := vm.VM{
   383  			Name:        *found.Name,
   384  			Provider:    ProviderName,
   385  			ProviderID:  *found.ID,
   386  			RemoteUser:  remoteUser,
   387  			VPC:         "global",
   388  			MachineType: string(found.HardwareProfile.VMSize),
   389  			// We add a fake availability-zone suffix since other roachprod
   390  			// code assumes particular formats. For example, "eastus2z".
   391  			Zone: *found.Location + "z",
   392  		}
   393  
   394  		if createdPtr := found.Tags[tagCreated]; createdPtr == nil {
   395  			m.Errors = append(m.Errors, vm.ErrNoExpiration)
   396  		} else if parsed, err := time.Parse(time.RFC3339, *createdPtr); err == nil {
   397  			m.CreatedAt = parsed
   398  		} else {
   399  			m.Errors = append(m.Errors, vm.ErrNoExpiration)
   400  		}
   401  
   402  		if lifetimePtr := found.Tags[tagLifetime]; lifetimePtr == nil {
   403  			m.Errors = append(m.Errors, vm.ErrNoExpiration)
   404  		} else if parsed, err := time.ParseDuration(*lifetimePtr); err == nil {
   405  			m.Lifetime = parsed
   406  		} else {
   407  			m.Errors = append(m.Errors, vm.ErrNoExpiration)
   408  		}
   409  
   410  		// The network info needs a separate request.
   411  		nicID, err := parseAzureID(*(*found.NetworkProfile.NetworkInterfaces)[0].ID)
   412  		if err != nil {
   413  			return nil, err
   414  		}
   415  		if err := p.fillNetworkDetails(ctx, &m, nicID); errors.Is(err, vm.ErrBadNetwork) {
   416  			m.Errors = append(m.Errors, err)
   417  		} else if err != nil {
   418  			return nil, err
   419  		}
   420  
   421  		ret = append(ret, m)
   422  
   423  		if err := it.NextWithContext(ctx); err != nil {
   424  			return nil, err
   425  		}
   426  
   427  	}
   428  
   429  	return ret, nil
   430  }
   431  
   432  // Name implements vm.Provider.
   433  func (p *Provider) Name() string {
   434  	return ProviderName
   435  }
   436  
   437  func (p *Provider) createVM(
   438  	ctx context.Context,
   439  	group resources.Group,
   440  	subnet network.Subnet,
   441  	name, sshKey string,
   442  	opts vm.CreateOpts,
   443  ) (vm compute.VirtualMachine, err error) {
   444  	// We can inject a cloud-init script into the VM creation to perform
   445  	// the necessary pre-flight configuration. By default, a
   446  	// locally-attached SSD is available at /mnt, so we just need to
   447  	// create the necessary directory and preflight.
   448  	//
   449  	// https://cloudinit.readthedocs.io/en/latest/
   450  	cloudConfig := `#cloud-config
   451  final_message: "roachprod init completed"
   452  `
   453  
   454  	var cmds []string
   455  	if opts.SSDOpts.UseLocalSSD {
   456  		cmds = []string{
   457  			"mkdir -p /mnt/data1",
   458  			"touch /mnt/data1/.roachprod-initialized",
   459  			fmt.Sprintf("chown -R %s /data1", remoteUser),
   460  		}
   461  		if opts.SSDOpts.NoExt4Barrier {
   462  			cmds = append(cmds, "mount -o remount,nobarrier,discard /mnt/data")
   463  		}
   464  	} else {
   465  		// We define lun42 explicitly in the data disk request below.
   466  		cloudConfig += `
   467  disk_setup:
   468    /dev/disk/azure/scsi1/lun42:
   469      table_type: gpt
   470      layout: True
   471      overwrite: True
   472  
   473  fs_setup:
   474    - device: /dev/disk/azure/scsi1/lun42
   475      partition: 1
   476      filesystem: ext4
   477  
   478  mounts:
   479    - ["/dev/disk/azure/scsi1/lun42-part1", "/data1", "auto", "defaults"]
   480  `
   481  		cmds = []string{
   482  			"ln -s /data1 /mnt/data1",
   483  			"touch /data1/.roachprod-initialized",
   484  			fmt.Sprintf("chown -R %s /data1", remoteUser),
   485  		}
   486  	}
   487  
   488  	cloudConfig += "runcmd:\n"
   489  	for _, cmd := range cmds {
   490  		cloudConfig += fmt.Sprintf(" - %q\n", cmd)
   491  	}
   492  
   493  	sub, err := p.getSubscription(ctx)
   494  	if err != nil {
   495  		return
   496  	}
   497  
   498  	client := compute.NewVirtualMachinesClient(*sub.SubscriptionID)
   499  	if client.Authorizer, err = p.getAuthorizer(); err != nil {
   500  		return
   501  	}
   502  
   503  	// We first need to allocate a NIC to give the VM network access
   504  	ip, err := p.createIP(ctx, group, name)
   505  	if err != nil {
   506  		return
   507  	}
   508  	nic, err := p.createNIC(ctx, group, ip, subnet)
   509  	if err != nil {
   510  		return
   511  	}
   512  
   513  	tags := make(map[string]*string)
   514  	tags[tagCreated] = to.StringPtr(timeutil.Now().Format(time.RFC3339))
   515  	tags[tagLifetime] = to.StringPtr(opts.Lifetime.String())
   516  	tags[tagRoachprod] = to.StringPtr("true")
   517  
   518  	// Derived from
   519  	// https://github.com/Azure-Samples/azure-sdk-for-go-samples/blob/79e3f3af791c3873d810efe094f9d61e93a6ccaa/compute/vm.go#L41
   520  	vm = compute.VirtualMachine{
   521  		Location: group.Location,
   522  		Tags:     tags,
   523  		VirtualMachineProperties: &compute.VirtualMachineProperties{
   524  			HardwareProfile: &compute.HardwareProfile{
   525  				VMSize: compute.VirtualMachineSizeTypes(p.opts.machineType),
   526  			},
   527  			StorageProfile: &compute.StorageProfile{
   528  				ImageReference: &compute.ImageReference{
   529  					Publisher: to.StringPtr("Canonical"),
   530  					Offer:     to.StringPtr("UbuntuServer"),
   531  					Sku:       to.StringPtr("18.04-LTS"),
   532  					Version:   to.StringPtr("latest"),
   533  				},
   534  				OsDisk: &compute.OSDisk{
   535  					CreateOption: compute.DiskCreateOptionTypesFromImage,
   536  					ManagedDisk: &compute.ManagedDiskParameters{
   537  						StorageAccountType: compute.StorageAccountTypesStandardSSDLRS,
   538  					},
   539  				},
   540  			},
   541  			OsProfile: &compute.OSProfile{
   542  				ComputerName:  to.StringPtr(name),
   543  				AdminUsername: to.StringPtr(remoteUser),
   544  				// Per the docs, the cloud-init script should be uploaded already
   545  				// base64-encoded.
   546  				CustomData: to.StringPtr(base64.StdEncoding.EncodeToString([]byte(cloudConfig))),
   547  				LinuxConfiguration: &compute.LinuxConfiguration{
   548  					SSH: &compute.SSHConfiguration{
   549  						PublicKeys: &[]compute.SSHPublicKey{
   550  							{
   551  								Path:    to.StringPtr(fmt.Sprintf("/home/%s/.ssh/authorized_keys", remoteUser)),
   552  								KeyData: to.StringPtr(sshKey),
   553  							},
   554  						},
   555  					},
   556  				},
   557  			},
   558  			NetworkProfile: &compute.NetworkProfile{
   559  				NetworkInterfaces: &[]compute.NetworkInterfaceReference{
   560  					{
   561  						ID: nic.ID,
   562  						NetworkInterfaceReferenceProperties: &compute.NetworkInterfaceReferenceProperties{
   563  							Primary: to.BoolPtr(true),
   564  						},
   565  					},
   566  				},
   567  			},
   568  		},
   569  	}
   570  	if !opts.SSDOpts.UseLocalSSD {
   571  		vm.VirtualMachineProperties.StorageProfile.DataDisks = &[]compute.DataDisk{
   572  			{
   573  				CreateOption: compute.DiskCreateOptionTypesEmpty,
   574  				DiskSizeGB:   to.Int32Ptr(100),
   575  				Lun:          to.Int32Ptr(42),
   576  				ManagedDisk: &compute.ManagedDiskParameters{
   577  					StorageAccountType: compute.StorageAccountTypesPremiumLRS,
   578  				},
   579  			},
   580  		}
   581  	}
   582  	future, err := client.CreateOrUpdate(ctx, *group.Name, name, vm)
   583  	if err != nil {
   584  		return
   585  	}
   586  	if err = future.WaitForCompletionRef(ctx, client.Client); err != nil {
   587  		return
   588  	}
   589  	return future.Result(client)
   590  }
   591  
   592  // createNIC creates a network adapter that is bound to the given public IP address.
   593  func (p *Provider) createNIC(
   594  	ctx context.Context, group resources.Group, ip network.PublicIPAddress, subnet network.Subnet,
   595  ) (iface network.Interface, err error) {
   596  	sub, err := p.getSubscription(ctx)
   597  	if err != nil {
   598  		return
   599  	}
   600  	client := network.NewInterfacesClient(*sub.SubscriptionID)
   601  	if client.Authorizer, err = p.getAuthorizer(); err != nil {
   602  		return
   603  	}
   604  
   605  	future, err := client.CreateOrUpdate(ctx, *group.Name, *ip.Name, network.Interface{
   606  		Name:     ip.Name,
   607  		Location: group.Location,
   608  		InterfacePropertiesFormat: &network.InterfacePropertiesFormat{
   609  			IPConfigurations: &[]network.InterfaceIPConfiguration{
   610  				{
   611  					Name: to.StringPtr("ipConfig"),
   612  					InterfaceIPConfigurationPropertiesFormat: &network.InterfaceIPConfigurationPropertiesFormat{
   613  						Subnet:                    &subnet,
   614  						PrivateIPAllocationMethod: network.Dynamic,
   615  						PublicIPAddress:           &ip,
   616  					},
   617  				},
   618  			},
   619  		},
   620  	})
   621  	if err != nil {
   622  		return
   623  	}
   624  	if err = future.WaitForCompletionRef(ctx, client.Client); err != nil {
   625  		return
   626  	}
   627  	iface, err = future.Result(client)
   628  	if err == nil {
   629  		log.Printf("created NIC %s %s", *iface.Name, *(*iface.IPConfigurations)[0].PrivateIPAddress)
   630  	}
   631  	return
   632  }
   633  
   634  // createVNets will create a VNet in each of the given locations to be
   635  // shared across roachprod clusters. Thus, all roachprod clusters will
   636  // be able to communicate with one another, although this is scoped by
   637  // the value of the vnet-name flag.
   638  func (p *Provider) createVNets(
   639  	ctx context.Context, locations []string,
   640  ) (map[string]network.VirtualNetwork, error) {
   641  	sub, err := p.getSubscription(ctx)
   642  	if err != nil {
   643  		return nil, err
   644  	}
   645  
   646  	groupsClient := resources.NewGroupsClient(*sub.SubscriptionID)
   647  	if groupsClient.Authorizer, err = p.getAuthorizer(); err != nil {
   648  		return nil, err
   649  	}
   650  
   651  	vnetGroupName := func(location string) string {
   652  		return fmt.Sprintf("roachprod-vnets-%s", location)
   653  	}
   654  
   655  	// Supporting local functions to make the logic below easier to read.
   656  	createVNetGroup := func(location string) (resources.Group, error) {
   657  		return groupsClient.CreateOrUpdate(ctx, vnetGroupName(location), resources.Group{
   658  			Location: to.StringPtr(location),
   659  			Tags: map[string]*string{
   660  				tagComment:   to.StringPtr("DO NOT DELETE: Used by all roachprod clusters"),
   661  				tagRoachprod: to.StringPtr("true"),
   662  			},
   663  		})
   664  	}
   665  
   666  	getVNetGroup := func(location string) (resources.Group, bool, error) {
   667  		group, err := groupsClient.Get(ctx, vnetGroupName(location))
   668  		if err == nil {
   669  			return group, true, nil
   670  		}
   671  		var detail autorest.DetailedError
   672  		if errors.As(err, &detail) {
   673  			if code, ok := detail.StatusCode.(int); ok {
   674  				if code == 404 {
   675  					return resources.Group{}, false, nil
   676  				}
   677  			}
   678  		}
   679  		return resources.Group{}, false, err
   680  	}
   681  
   682  	setVNetSubnetPrefix := func(group resources.Group, subnet int) (resources.Group, error) {
   683  		return groupsClient.Update(ctx, *group.Name, resources.GroupPatchable{
   684  			Tags: map[string]*string{
   685  				tagSubnet: to.StringPtr(strconv.Itoa(subnet)),
   686  			},
   687  		})
   688  	}
   689  
   690  	// First, find or create a resource group for roachprod to create the
   691  	// VNets in. We need one per location.
   692  	groupsByLocation := make(map[string]resources.Group)
   693  	for _, location := range locations {
   694  		group, found, err := getVNetGroup(location)
   695  		if err == nil && !found {
   696  			group, err = createVNetGroup(location)
   697  		}
   698  		if err != nil {
   699  			return nil, errors.Wrapf(err, "for location %q", location)
   700  		}
   701  		groupsByLocation[location] = group
   702  	}
   703  
   704  	// In order to prevent overlapping subnets, we want to associate each
   705  	// of the roachprod-owned RG's with a network prefix. We're going to
   706  	// make an assumption that it's very unlikely that two users will try
   707  	// to allocate new subnets at the same time. If this happens, it can
   708  	// be easily fixed by deleting one of resource groups and re-running
   709  	// roachprod to select a new network prefix.
   710  	prefixesByLocation := make(map[string]int)
   711  	activePrefixes := make(map[int]bool)
   712  	var locationsWithoutSubnet []string
   713  	for location, group := range groupsByLocation {
   714  		if prefixString := group.Tags[tagSubnet]; prefixString != nil {
   715  			prefix, err := strconv.Atoi(*prefixString)
   716  			if err != nil {
   717  				return nil, errors.Wrapf(err, "for location %q", location)
   718  			}
   719  			activePrefixes[prefix] = true
   720  			prefixesByLocation[location] = prefix
   721  		} else {
   722  			locationsWithoutSubnet = append(locationsWithoutSubnet, location)
   723  		}
   724  	}
   725  
   726  	prefix := 1
   727  	for _, location := range locationsWithoutSubnet {
   728  		for activePrefixes[prefix] {
   729  			prefix++
   730  		}
   731  		activePrefixes[prefix] = true
   732  		prefixesByLocation[location] = prefix
   733  		group := groupsByLocation[location]
   734  		if groupsByLocation[location], err = setVNetSubnetPrefix(group, prefix); err != nil {
   735  			return nil, errors.Wrapf(err, "for location %q", location)
   736  		}
   737  	}
   738  
   739  	// Now, we can ensure that the VNet exists with the requested subnet.
   740  	ret := make(map[string]network.VirtualNetwork)
   741  	vnets := make([]network.VirtualNetwork, len(ret))
   742  	for location, prefix := range prefixesByLocation {
   743  		group := groupsByLocation[location]
   744  		if vnet, _, err := p.createVNet(ctx, group, prefix); err == nil {
   745  			ret[location] = vnet
   746  			vnets = append(vnets, vnet)
   747  		} else {
   748  			return nil, errors.Wrapf(err, "for location %q", location)
   749  		}
   750  	}
   751  
   752  	// We only need to create peerings if there are new subnets.
   753  	if locationsWithoutSubnet != nil {
   754  		return ret, p.createVNetPeerings(ctx, vnets)
   755  	}
   756  
   757  	return ret, nil
   758  }
   759  
   760  // createVNet creates or retrieves a named VNet object using the 10.<offset>/16 prefix.
   761  // A single /18 subnet will be created within the VNet.
   762  // The results  will be memoized in the Provider.
   763  func (p *Provider) createVNet(
   764  	ctx context.Context, group resources.Group, prefix int,
   765  ) (vnet network.VirtualNetwork, subnet network.Subnet, err error) {
   766  	vnetName := p.opts.vnetName
   767  
   768  	sub, err := p.getSubscription(ctx)
   769  	if err != nil {
   770  		return
   771  	}
   772  	client := network.NewVirtualNetworksClient(*sub.SubscriptionID)
   773  	if client.Authorizer, err = p.getAuthorizer(); err != nil {
   774  		return
   775  	}
   776  	vnet = network.VirtualNetwork{
   777  		Name:     group.Name,
   778  		Location: group.Location,
   779  		VirtualNetworkPropertiesFormat: &network.VirtualNetworkPropertiesFormat{
   780  			AddressSpace: &network.AddressSpace{
   781  				AddressPrefixes: &[]string{fmt.Sprintf("10.%d.0.0/16", prefix)},
   782  			},
   783  			Subnets: &[]network.Subnet{
   784  				{
   785  					Name: group.Name,
   786  					SubnetPropertiesFormat: &network.SubnetPropertiesFormat{
   787  						AddressPrefix: to.StringPtr(fmt.Sprintf("10.%d.0.0/18", prefix)),
   788  					},
   789  				},
   790  			},
   791  		},
   792  	}
   793  	future, err := client.CreateOrUpdate(ctx, *group.Name, *group.Name, vnet)
   794  	if err != nil {
   795  		err = errors.Wrapf(err, "creating Azure VNet %q in %q", vnetName, *group.Name)
   796  		return
   797  	}
   798  	if err = future.WaitForCompletionRef(ctx, client.Client); err != nil {
   799  		err = errors.Wrapf(err, "creating Azure VNet %q in %q", vnetName, *group.Name)
   800  		return
   801  	}
   802  	vnet, err = future.Result(client)
   803  	err = errors.Wrapf(err, "creating Azure VNet %q in %q", vnetName, *group.Name)
   804  	if err == nil {
   805  		subnet = (*vnet.Subnets)[0]
   806  		p.mu.Lock()
   807  		p.mu.subnets[*group.Location] = subnet
   808  		p.mu.Unlock()
   809  		log.Printf("created Azure VNet %q in %q with prefix %d", vnetName, *group.Name, prefix)
   810  	}
   811  	return
   812  }
   813  
   814  // createVNetPeerings creates a fully-connected graph of peerings
   815  // between the provided vnets.
   816  func (p *Provider) createVNetPeerings(ctx context.Context, vnets []network.VirtualNetwork) error {
   817  	sub, err := p.getSubscription(ctx)
   818  	if err != nil {
   819  		return err
   820  	}
   821  	client := network.NewVirtualNetworkPeeringsClient(*sub.SubscriptionID)
   822  	if client.Authorizer, err = p.getAuthorizer(); err != nil {
   823  		return err
   824  	}
   825  
   826  	// Create cross-product of vnets.
   827  	futures := make(map[string]network.VirtualNetworkPeeringsCreateOrUpdateFuture)
   828  	for _, outer := range vnets {
   829  		for _, inner := range vnets {
   830  			if *outer.ID == *inner.ID {
   831  				continue
   832  			}
   833  
   834  			linkName := fmt.Sprintf("%s-%s", *outer.Name, *inner.Name)
   835  			peering := network.VirtualNetworkPeering{
   836  				Name: to.StringPtr(linkName),
   837  				VirtualNetworkPeeringPropertiesFormat: &network.VirtualNetworkPeeringPropertiesFormat{
   838  					AllowForwardedTraffic:     to.BoolPtr(true),
   839  					AllowVirtualNetworkAccess: to.BoolPtr(true),
   840  					RemoteAddressSpace:        inner.AddressSpace,
   841  					RemoteVirtualNetwork: &network.SubResource{
   842  						ID: inner.ID,
   843  					},
   844  				},
   845  			}
   846  
   847  			outerParts, err := parseAzureID(*outer.ID)
   848  			if err != nil {
   849  				return err
   850  			}
   851  
   852  			future, err := client.CreateOrUpdate(ctx, outerParts.resourceGroup, *outer.Name, linkName, peering)
   853  			if err != nil {
   854  				return errors.Wrapf(err, "creating vnet peering %s", linkName)
   855  			}
   856  			futures[linkName] = future
   857  		}
   858  	}
   859  
   860  	for name, future := range futures {
   861  		if err := future.WaitForCompletionRef(ctx, client.Client); err != nil {
   862  			return errors.Wrapf(err, "creating vnet peering %s", name)
   863  		}
   864  		peering, err := future.Result(client)
   865  		if err != nil {
   866  			return errors.Wrapf(err, "creating vnet peering %s", name)
   867  		}
   868  		log.Printf("created vnet peering %s", *peering.Name)
   869  	}
   870  
   871  	return nil
   872  }
   873  
   874  // createIP allocates an IP address that will later be bound to a NIC.
   875  func (p *Provider) createIP(
   876  	ctx context.Context, group resources.Group, name string,
   877  ) (ip network.PublicIPAddress, err error) {
   878  	sub, err := p.getSubscription(ctx)
   879  	if err != nil {
   880  		return
   881  	}
   882  	ipc := network.NewPublicIPAddressesClient(*sub.SubscriptionID)
   883  	if ipc.Authorizer, err = p.getAuthorizer(); err != nil {
   884  		return
   885  	}
   886  	future, err := ipc.CreateOrUpdate(ctx, *group.Name, name,
   887  		network.PublicIPAddress{
   888  			Name:     to.StringPtr(name),
   889  			Location: group.Location,
   890  			PublicIPAddressPropertiesFormat: &network.PublicIPAddressPropertiesFormat{
   891  				PublicIPAddressVersion:   network.IPv4,
   892  				PublicIPAllocationMethod: network.Dynamic,
   893  			},
   894  		})
   895  	if err != nil {
   896  		err = errors.Wrapf(err, "creating IP %s", name)
   897  		return
   898  	}
   899  	err = future.WaitForCompletionRef(ctx, ipc.Client)
   900  	if err != nil {
   901  		err = errors.Wrapf(err, "creating IP %s", name)
   902  		return
   903  	}
   904  	if ip, err = future.Result(ipc); err == nil {
   905  		log.Printf("created Azure IP %s", *ip.Name)
   906  	} else {
   907  		err = errors.Wrapf(err, "creating IP %s", name)
   908  	}
   909  
   910  	return
   911  }
   912  
   913  // fillNetworkDetails makes some secondary requests to the Azure
   914  // API in order to populate the VM details. This will return
   915  // vm.ErrBadNetwork if the response payload is not of the expected form.
   916  func (p *Provider) fillNetworkDetails(ctx context.Context, m *vm.VM, nicID azureID) error {
   917  	sub, err := p.getSubscription(ctx)
   918  	if err != nil {
   919  		return err
   920  	}
   921  
   922  	nicClient := network.NewInterfacesClient(*sub.SubscriptionID)
   923  	if nicClient.Authorizer, err = p.getAuthorizer(); err != nil {
   924  		return err
   925  	}
   926  
   927  	ipClient := network.NewPublicIPAddressesClient(*sub.SubscriptionID)
   928  	ipClient.Authorizer = nicClient.Authorizer
   929  
   930  	iface, err := nicClient.Get(ctx, nicID.resourceGroup, nicID.resourceName, "" /*expand*/)
   931  	if err != nil {
   932  		return err
   933  	}
   934  	if iface.IPConfigurations == nil {
   935  		return vm.ErrBadNetwork
   936  	}
   937  	cfg := (*iface.IPConfigurations)[0]
   938  	if cfg.PrivateIPAddress == nil {
   939  		return vm.ErrBadNetwork
   940  	}
   941  	m.PrivateIP = *cfg.PrivateIPAddress
   942  	m.DNS = m.PrivateIP
   943  	if cfg.PublicIPAddress == nil || cfg.PublicIPAddress.ID == nil {
   944  		return vm.ErrBadNetwork
   945  	}
   946  	ipID, err := parseAzureID(*cfg.PublicIPAddress.ID)
   947  	if err != nil {
   948  		return vm.ErrBadNetwork
   949  	}
   950  
   951  	ip, err := ipClient.Get(ctx, ipID.resourceGroup, ipID.resourceName, "" /*expand*/)
   952  	if err != nil {
   953  		return vm.ErrBadNetwork
   954  	}
   955  	if ip.IPAddress == nil {
   956  		return vm.ErrBadNetwork
   957  	}
   958  	m.PublicIP = *ip.IPAddress
   959  
   960  	return nil
   961  }
   962  
   963  // getResourceGroup creates or retrieves a resource group within the
   964  // specified location. The base name will be combined with the location,
   965  // to allow for easy tear-down of multi-region clusters. Results are
   966  // memoized within the Provider instance.
   967  func (p *Provider) getResourceGroup(
   968  	ctx context.Context, cluster, location string, opts vm.CreateOpts,
   969  ) (group resources.Group, err error) {
   970  	groupName := fmt.Sprintf("%s-%s", cluster, location)
   971  
   972  	p.mu.Lock()
   973  	group, ok := p.mu.resourceGroups[groupName]
   974  	p.mu.Unlock()
   975  	if ok {
   976  		return
   977  	}
   978  
   979  	sub, err := p.getSubscription(ctx)
   980  	if err != nil {
   981  		return
   982  	}
   983  
   984  	client := resources.NewGroupsClient(*sub.SubscriptionID)
   985  	if client.Authorizer, err = p.getAuthorizer(); err != nil {
   986  		return
   987  	}
   988  
   989  	tags := make(map[string]*string)
   990  	tags[tagCluster] = to.StringPtr(cluster)
   991  	tags[tagCreated] = to.StringPtr(timeutil.Now().Format(time.RFC3339))
   992  	tags[tagLifetime] = to.StringPtr(opts.Lifetime.String())
   993  	tags[tagRoachprod] = to.StringPtr("true")
   994  
   995  	group, err = client.CreateOrUpdate(ctx, groupName,
   996  		resources.Group{
   997  			Location: to.StringPtr(location),
   998  			Tags:     tags,
   999  		})
  1000  	if err == nil {
  1001  		p.mu.Lock()
  1002  		p.mu.resourceGroups[groupName] = group
  1003  		p.mu.Unlock()
  1004  	}
  1005  	return
  1006  }
  1007  
  1008  // getSubscription chooses the first available subscription. The value
  1009  // is memoized in the Provider instance.
  1010  func (p *Provider) getSubscription(
  1011  	ctx context.Context,
  1012  ) (sub subscriptions.Subscription, err error) {
  1013  	p.mu.Lock()
  1014  	sub = p.mu.subscription
  1015  	p.mu.Unlock()
  1016  
  1017  	if sub.SubscriptionID != nil {
  1018  		return
  1019  	}
  1020  
  1021  	sc := subscriptions.NewClient()
  1022  	if sc.Authorizer, err = p.getAuthorizer(); err != nil {
  1023  		return
  1024  	}
  1025  
  1026  	if page, err := sc.List(ctx); err == nil {
  1027  		if len(page.Values()) == 0 {
  1028  			err = errors.New("did not find Azure subscription")
  1029  			return sub, err
  1030  		}
  1031  		sub = page.Values()[0]
  1032  
  1033  		p.mu.Lock()
  1034  		p.mu.subscription = page.Values()[0]
  1035  		p.mu.Unlock()
  1036  	}
  1037  	return
  1038  }