github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/provider/common/bootstrap.go (about)

     1  // Copyright 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package common
     5  
     6  import (
     7  	"bufio"
     8  	"context"
     9  	stderrors "errors"
    10  	"fmt"
    11  	"io"
    12  	"net"
    13  	"os"
    14  	"path"
    15  	"strings"
    16  	"sync"
    17  	"time"
    18  
    19  	"github.com/juju/errors"
    20  	"github.com/juju/loggo"
    21  	"github.com/juju/utils/v3"
    22  	"github.com/juju/utils/v3/parallel"
    23  	"github.com/juju/utils/v3/shell"
    24  	"github.com/juju/utils/v3/ssh"
    25  
    26  	"github.com/juju/juju/cloudconfig"
    27  	"github.com/juju/juju/cloudconfig/cloudinit"
    28  	"github.com/juju/juju/cloudconfig/instancecfg"
    29  	"github.com/juju/juju/cloudconfig/sshinit"
    30  	"github.com/juju/juju/controller"
    31  	corebase "github.com/juju/juju/core/base"
    32  	"github.com/juju/juju/core/instance"
    33  	"github.com/juju/juju/core/network"
    34  	"github.com/juju/juju/core/network/firewall"
    35  	"github.com/juju/juju/core/status"
    36  	"github.com/juju/juju/environs"
    37  	"github.com/juju/juju/environs/bootstrap"
    38  	"github.com/juju/juju/environs/config"
    39  	envcontext "github.com/juju/juju/environs/context"
    40  	"github.com/juju/juju/environs/imagemetadata"
    41  	"github.com/juju/juju/environs/instances"
    42  	"github.com/juju/juju/environs/models"
    43  	"github.com/juju/juju/environs/simplestreams"
    44  	pkissh "github.com/juju/juju/pki/ssh"
    45  	"github.com/juju/juju/storage"
    46  	"github.com/juju/juju/storage/poolmanager"
    47  	coretools "github.com/juju/juju/tools"
    48  )
    49  
    50  var logger = loggo.GetLogger("juju.provider.common")
    51  
    52  // Bootstrap is a common implementation of the Bootstrap method defined on
    53  // environs.Environ; we strongly recommend that this implementation be used
    54  // when writing a new provider.
    55  func Bootstrap(
    56  	ctx environs.BootstrapContext,
    57  	env environs.Environ,
    58  	callCtx envcontext.ProviderCallContext,
    59  	args environs.BootstrapParams,
    60  ) (*environs.BootstrapResult, error) {
    61  	result, base, finalizer, err := BootstrapInstance(ctx, env, callCtx, args)
    62  	if err != nil {
    63  		return nil, errors.Trace(err)
    64  	}
    65  
    66  	bsResult := &environs.BootstrapResult{
    67  		Arch:                    *result.Hardware.Arch,
    68  		Base:                    *base,
    69  		CloudBootstrapFinalizer: finalizer,
    70  	}
    71  	return bsResult, nil
    72  }
    73  
    74  // BootstrapInstance creates a new instance with the series of its choice,
    75  // constrained to those of the available tools, and
    76  // returns the instance result, series, and a function that
    77  // must be called to finalize the bootstrap process by transferring
    78  // the tools and installing the initial Juju controller.
    79  // This method is called by Bootstrap above, which implements environs.Bootstrap, but
    80  // is also exported so that providers can manipulate the started instance.
    81  func BootstrapInstance(
    82  	ctx environs.BootstrapContext,
    83  	env environs.Environ,
    84  	callCtx envcontext.ProviderCallContext,
    85  	args environs.BootstrapParams,
    86  ) (_ *environs.StartInstanceResult, resultBase *corebase.Base, _ environs.CloudBootstrapFinalizer, err error) {
    87  	// TODO make safe in the case of racing Bootstraps
    88  	// If two Bootstraps are called concurrently, there's
    89  	// no way to make sure that only one succeeds.
    90  
    91  	// First thing, ensure we have tools otherwise there's no point.
    92  	supportedBootstrapBase := make([]corebase.Base, len(args.SupportedBootstrapSeries))
    93  	for i, b := range args.SupportedBootstrapSeries.SortedValues() {
    94  		sb, err := corebase.GetBaseFromSeries(b)
    95  		if err != nil {
    96  			return nil, nil, nil, errors.Trace(err)
    97  		}
    98  		supportedBootstrapBase[i] = sb
    99  	}
   100  
   101  	var bootstrapBase corebase.Base
   102  	if args.BootstrapSeries != "" {
   103  		b, err := corebase.GetBaseFromSeries(args.BootstrapSeries)
   104  		if err != nil {
   105  			return nil, nil, nil, errors.Trace(err)
   106  		}
   107  		bootstrapBase = b
   108  	}
   109  
   110  	requestedBootstrapBase, err := corebase.ValidateBase(
   111  		supportedBootstrapBase,
   112  		bootstrapBase,
   113  		config.PreferredBase(env.Config()),
   114  	)
   115  	if !args.Force && err != nil {
   116  		// If the base isn't valid (i.e. non-ubuntu) then don't prompt users to use
   117  		// the --force flag.
   118  		if requestedBootstrapBase.OS != corebase.UbuntuOS {
   119  			return nil, nil, nil, errors.NotValidf("non-ubuntu bootstrap base %q", requestedBootstrapBase.String())
   120  		}
   121  		return nil, nil, nil, errors.Annotatef(err, "use --force to override")
   122  	}
   123  	// The base we're attempting to bootstrap is empty, show a friendly
   124  	// error message, rather than the more cryptic error messages that follow
   125  	// onwards.
   126  	if requestedBootstrapBase.Empty() {
   127  		return nil, nil, nil, errors.NotValidf("bootstrap instance base")
   128  	}
   129  	availableTools, err := args.AvailableTools.Match(coretools.Filter{
   130  		OSType: requestedBootstrapBase.OS,
   131  	})
   132  	if err != nil {
   133  		return nil, nil, nil, err
   134  	}
   135  
   136  	// Filter image metadata to the selected base.
   137  	var imageMetadata []*imagemetadata.ImageMetadata
   138  	for _, m := range args.ImageMetadata {
   139  		if m.Version != requestedBootstrapBase.Channel.Track {
   140  			continue
   141  		}
   142  		imageMetadata = append(imageMetadata, m)
   143  	}
   144  
   145  	// Get the bootstrap SSH client. Do this early, so we know
   146  	// not to bother with any of the below if we can't finish the job.
   147  	client := ssh.DefaultClient
   148  	if client == nil {
   149  		// This should never happen: if we don't have OpenSSH, then
   150  		// go.crypto/ssh should be used with an auto-generated key.
   151  		return nil, nil, nil, fmt.Errorf("no SSH client available")
   152  	}
   153  
   154  	publicKey, err := simplestreams.UserPublicSigningKey()
   155  	if err != nil {
   156  		return nil, nil, nil, err
   157  	}
   158  
   159  	instanceConfig, err := instancecfg.NewBootstrapInstanceConfig(
   160  		args.ControllerConfig, args.BootstrapConstraints, args.ModelConstraints, requestedBootstrapBase, publicKey,
   161  		args.ExtraAgentValuesForTesting,
   162  	)
   163  	if err != nil {
   164  		return nil, nil, nil, err
   165  	}
   166  
   167  	envCfg := env.Config()
   168  	instanceConfig.EnableOSRefreshUpdate = envCfg.EnableOSRefreshUpdate()
   169  	instanceConfig.EnableOSUpgrade = envCfg.EnableOSUpgrade()
   170  	instanceConfig.NetBondReconfigureDelay = envCfg.NetBondReconfigureDelay()
   171  	instanceConfig.Tags = instancecfg.InstanceTags(envCfg.UUID(), args.ControllerConfig.ControllerUUID(), envCfg, true)
   172  
   173  	// We're creating a new instance; inject host keys so that we can then
   174  	// make an SSH connection with known keys.
   175  	initialSSHHostKeys, err := generateSSHHostKeys()
   176  	if err != nil {
   177  		return nil, nil, nil, errors.Annotate(err, "generating SSH host keys")
   178  	}
   179  	instanceConfig.Bootstrap.InitialSSHHostKeys = initialSSHHostKeys
   180  
   181  	cloudRegion := args.CloudName
   182  	if args.CloudRegion != "" {
   183  		cloudRegion += "/" + args.CloudRegion
   184  	}
   185  	ctx.Infof("Launching controller instance(s) on %s...", cloudRegion)
   186  	// Print instance status reports status changes during provisioning.
   187  	// Note the carriage returns, meaning subsequent prints are to the same
   188  	// line of stderr, not a new line.
   189  	lastLength := 0
   190  	statusCleanedUp := false
   191  	instanceStatus := func(settableStatus status.Status, info string, data map[string]interface{}) error {
   192  		// The data arg is not expected to be used in this case, but
   193  		// print it, rather than ignore it, if we get something.
   194  		dataString := ""
   195  		if len(data) > 0 {
   196  			dataString = fmt.Sprintf(" %v", data)
   197  		}
   198  		length := len(info) + len(dataString)
   199  		padding := ""
   200  		if lastLength > length {
   201  			padding = strings.Repeat(" ", lastLength-length)
   202  		}
   203  		lastLength = length
   204  		statusCleanedUp = false
   205  		fmt.Fprintf(ctx.GetStderr(), " - %s%s%s\r", info, dataString, padding)
   206  		return nil
   207  	}
   208  	// Likely used after the final instanceStatus call to white-out the
   209  	// current stderr line before the next use, removing any residual status
   210  	// reporting output.
   211  	statusCleanup := func() error {
   212  		if statusCleanedUp {
   213  			return nil
   214  		}
   215  		statusCleanedUp = true
   216  		// The leading spaces account for the leading characters
   217  		// emitted by instanceStatus above.
   218  		padding := strings.Repeat(" ", lastLength)
   219  		fmt.Fprintf(ctx.GetStderr(), "   %s\r", padding)
   220  		return nil
   221  	}
   222  
   223  	var startInstanceArgs = environs.StartInstanceParams{
   224  		ControllerUUID:  args.ControllerConfig.ControllerUUID(),
   225  		Constraints:     args.BootstrapConstraints,
   226  		Tools:           availableTools,
   227  		InstanceConfig:  instanceConfig,
   228  		Placement:       args.Placement,
   229  		ImageMetadata:   imageMetadata,
   230  		StatusCallback:  instanceStatus,
   231  		CleanupCallback: statusCleanup,
   232  	}
   233  
   234  	// If a root disk constraint is specified, see if it matches any
   235  	// storage pools scheduled to be added to the controller model and
   236  	// set up the root disk accordingly.
   237  	if args.BootstrapConstraints.HasRootDiskSource() {
   238  		sp, ok := args.StoragePools[*args.BootstrapConstraints.RootDiskSource]
   239  		if ok {
   240  			pType, _ := sp[poolmanager.Type].(string)
   241  			startInstanceArgs.RootDisk = &storage.VolumeParams{
   242  				Provider:   storage.ProviderType(pType),
   243  				Attributes: sp,
   244  			}
   245  		}
   246  	}
   247  
   248  	zones, err := startInstanceZones(env, callCtx, startInstanceArgs)
   249  	if errors.IsNotImplemented(err) {
   250  		// No zone support, so just call StartInstance with
   251  		// a blank StartInstanceParams.AvailabilityZone.
   252  		zones = []string{""}
   253  		if args.BootstrapConstraints.HasZones() {
   254  			logger.Debugf("environ doesn't support zones: ignoring bootstrap zone constraints")
   255  		}
   256  	} else if err != nil {
   257  		return nil, nil, nil, errors.Annotate(err, "cannot start bootstrap instance")
   258  	} else if args.BootstrapConstraints.HasZones() {
   259  		// TODO(hpidcock): bootstrap and worker/provisioner should probably derive
   260  		// from the same logic regarding placement.
   261  		var filteredZones []string
   262  		for _, zone := range zones {
   263  			for _, zoneConstraint := range *args.BootstrapConstraints.Zones {
   264  				if zone == zoneConstraint {
   265  					filteredZones = append(filteredZones, zone)
   266  					break
   267  				}
   268  			}
   269  		}
   270  		if len(filteredZones) == 0 {
   271  			return nil, nil, nil, errors.Errorf(
   272  				"no available zones (%+q) matching bootstrap zone constraints (%+q)",
   273  				zones,
   274  				*args.BootstrapConstraints.Zones,
   275  			)
   276  		}
   277  		zones = filteredZones
   278  	}
   279  
   280  	var result *environs.StartInstanceResult
   281  	zoneErrors := []error{} // is a collection of errors we encounter for each zone.
   282  	for i, zone := range zones {
   283  		startInstanceArgs.AvailabilityZone = zone
   284  		result, err = env.StartInstance(callCtx, startInstanceArgs)
   285  		if err == nil {
   286  			break
   287  		}
   288  		zoneErrors = append(zoneErrors, fmt.Errorf("starting bootstrap instance in zone %q: %w", zone, err))
   289  
   290  		select {
   291  		case <-ctx.Context().Done():
   292  			return nil, nil, nil, errors.Annotate(err, "starting controller (cancelled)")
   293  		default:
   294  		}
   295  
   296  		if zone == "" || errors.Is(err, environs.ErrAvailabilityZoneIndependent) {
   297  			return nil, nil, nil, errors.Annotate(err, "cannot start bootstrap instance")
   298  		}
   299  
   300  		if i < len(zones)-1 {
   301  			// Try the next zone.
   302  			logger.Debugf("failed to start instance in availability zone %q: %s", zone, err)
   303  			continue
   304  		}
   305  		// This is the last zone in the list, error.
   306  		if len(zones) > 1 {
   307  			return nil, nil, nil, fmt.Errorf(
   308  				"cannot start bootstrap instance in any availability zone (%s):\n%w",
   309  				strings.Join(zones, ", "), stderrors.Join(zoneErrors...),
   310  			)
   311  		}
   312  		return nil, nil, nil, errors.Annotatef(err, "cannot start bootstrap instance in availability zone %q", zone)
   313  	}
   314  	modelFw, ok := env.(models.ModelFirewaller)
   315  	if ok {
   316  		if err := openControllerModelPorts(callCtx, modelFw, args.ControllerConfig, env.Config()); err != nil {
   317  			return nil, nil, nil, errors.Annotate(err, "cannot open SSH")
   318  		}
   319  	}
   320  
   321  	err = statusCleanup()
   322  	if err != nil {
   323  		return nil, nil, nil, errors.Annotate(err, "cleaning up status line")
   324  	}
   325  	msg := fmt.Sprintf(" - %s (%s)", result.Instance.Id(), formatHardware(result.Hardware))
   326  	// We need some padding below to overwrite any previous messages.
   327  	if len(msg) < 40 {
   328  		padding := make([]string, 40-len(msg))
   329  		msg += strings.Join(padding, " ")
   330  	}
   331  	ctx.Infof(msg)
   332  
   333  	finalizer := func(ctx environs.BootstrapContext, icfg *instancecfg.InstanceConfig, opts environs.BootstrapDialOpts) error {
   334  		icfg.Bootstrap.BootstrapMachineInstanceId = result.Instance.Id()
   335  		icfg.Bootstrap.BootstrapMachineDisplayName = result.DisplayName
   336  		icfg.Bootstrap.BootstrapMachineHardwareCharacteristics = result.Hardware
   337  		icfg.Bootstrap.InitialSSHHostKeys = initialSSHHostKeys
   338  		envConfig := env.Config()
   339  		if result.Config != nil {
   340  			updated, err := envConfig.Apply(result.Config.UnknownAttrs())
   341  			if err != nil {
   342  				return errors.Trace(err)
   343  			}
   344  			envConfig = updated
   345  		}
   346  		if err := instancecfg.FinishInstanceConfig(icfg, envConfig); err != nil {
   347  			return err
   348  		}
   349  		return FinishBootstrap(ctx, client, env, callCtx, result.Instance, icfg, opts)
   350  	}
   351  	return result, &requestedBootstrapBase, finalizer, nil
   352  }
   353  
   354  func startInstanceZones(env environs.Environ, ctx envcontext.ProviderCallContext, args environs.StartInstanceParams) ([]string, error) {
   355  	zonedEnviron, ok := env.(ZonedEnviron)
   356  	if !ok {
   357  		return nil, errors.NotImplementedf("ZonedEnviron")
   358  	}
   359  
   360  	// Attempt creating the instance in each of the availability
   361  	// zones, unless the args imply a specific zone.
   362  	zones, err := zonedEnviron.DeriveAvailabilityZones(ctx, args)
   363  	if err != nil {
   364  		return nil, errors.Trace(err)
   365  	}
   366  	if len(zones) > 0 {
   367  		return zones, nil
   368  	}
   369  	allZones, err := zonedEnviron.AvailabilityZones(ctx)
   370  	if err != nil {
   371  		return nil, errors.Trace(err)
   372  	}
   373  	for _, zone := range allZones {
   374  		if !zone.Available() {
   375  			continue
   376  		}
   377  		zones = append(zones, zone.Name())
   378  	}
   379  	if len(zones) == 0 {
   380  		return nil, errors.New("no usable availability zones")
   381  	}
   382  	return zones, nil
   383  }
   384  
   385  // openControllerModelPorts opens port 22 and apiports on the controller to the configured allow list.
   386  // This is all that is required for the bootstrap to continue. Further configured
   387  // rules will be opened by the firewaller, Once it has started
   388  func openControllerModelPorts(callCtx envcontext.ProviderCallContext,
   389  	modelFw models.ModelFirewaller, controllerConfig controller.Config, cfg *config.Config) error {
   390  	rules := firewall.IngressRules{
   391  		firewall.NewIngressRule(network.MustParsePortRange("22"), cfg.SSHAllow()...),
   392  		firewall.NewIngressRule(network.PortRange{
   393  			Protocol: "tcp",
   394  			FromPort: controllerConfig.APIPort(),
   395  			ToPort:   controllerConfig.APIPort(),
   396  		}),
   397  	}
   398  
   399  	if controllerConfig.AutocertDNSName() != "" {
   400  		// Open port 80 as well as it handles Let's Encrypt HTTP challenge.
   401  		rules = append(rules,
   402  			firewall.NewIngressRule(network.PortRange{
   403  				Protocol: "tcp",
   404  				FromPort: 80,
   405  				ToPort:   80,
   406  			}),
   407  		)
   408  	}
   409  
   410  	return modelFw.OpenModelPorts(callCtx, rules)
   411  }
   412  
   413  func formatHardware(hw *instance.HardwareCharacteristics) string {
   414  	if hw == nil {
   415  		return ""
   416  	}
   417  	out := make([]string, 0, 3)
   418  	if hw.Arch != nil && *hw.Arch != "" {
   419  		out = append(out, fmt.Sprintf("arch=%s", *hw.Arch))
   420  	}
   421  	if hw.Mem != nil && *hw.Mem > 0 {
   422  		out = append(out, fmt.Sprintf("mem=%s", formatMemory(*hw.Mem)))
   423  	}
   424  	if hw.CpuCores != nil && *hw.CpuCores > 0 {
   425  		out = append(out, fmt.Sprintf("cores=%d", *hw.CpuCores))
   426  	}
   427  	// If the virt-type is the default, don't print it out, as it's just noise.
   428  	if hw.VirtType != nil && *hw.VirtType != "" && *hw.VirtType != string(instance.DefaultInstanceType) {
   429  		out = append(out, fmt.Sprintf("virt-type=%s", *hw.VirtType))
   430  	}
   431  	return strings.Join(out, " ")
   432  }
   433  
   434  func formatMemory(m uint64) string {
   435  	if m < 1024 {
   436  		return fmt.Sprintf("%dM", m)
   437  	}
   438  	s := fmt.Sprintf("%.1f", float32(m)/1024.0)
   439  	return strings.TrimSuffix(s, ".0") + "G"
   440  }
   441  
   442  // FinishBootstrap completes the bootstrap process by connecting
   443  // to the instance via SSH and carrying out the cloud-config.
   444  //
   445  // Note: FinishBootstrap is exposed so it can be replaced for testing.
   446  var FinishBootstrap = func(
   447  	ctx environs.BootstrapContext,
   448  	client ssh.Client,
   449  	env environs.Environ,
   450  	callCtx envcontext.ProviderCallContext,
   451  	inst instances.Instance,
   452  	instanceConfig *instancecfg.InstanceConfig,
   453  	opts environs.BootstrapDialOpts,
   454  ) error {
   455  	interrupted := make(chan os.Signal, 1)
   456  	ctx.InterruptNotify(interrupted)
   457  	defer ctx.StopInterruptNotify(interrupted)
   458  
   459  	hostSSHOptions := bootstrapSSHOptionsFunc(instanceConfig)
   460  	addr, err := WaitSSH(
   461  		ctx.Context(),
   462  		ctx.GetStderr(),
   463  		client,
   464  		GetCheckNonceCommand(instanceConfig),
   465  		&RefreshableInstance{inst, env},
   466  		callCtx,
   467  		opts,
   468  		hostSSHOptions,
   469  	)
   470  	if err != nil {
   471  		return err
   472  	}
   473  	ctx.Infof("Connected to %v", addr)
   474  
   475  	sshOptions, cleanup, err := hostSSHOptions(addr)
   476  	if err != nil {
   477  		return err
   478  	}
   479  	defer cleanup()
   480  
   481  	return ConfigureMachine(ctx, client, addr, instanceConfig, sshOptions)
   482  }
   483  
   484  func GetCheckNonceCommand(instanceConfig *instancecfg.InstanceConfig) string {
   485  	// Each attempt to connect to an address must verify the machine is the
   486  	// bootstrap machine by checking its nonce file exists and contains the
   487  	// nonce in the InstanceConfig. This also blocks sshinit from proceeding
   488  	// until cloud-init has completed, which is necessary to ensure apt
   489  	// invocations don't trample each other.
   490  	nonceFile := utils.ShQuote(path.Join(instanceConfig.DataDir, cloudconfig.NonceFile))
   491  	checkNonceCommand := fmt.Sprintf(`
   492  	noncefile=%s
   493  	if [ ! -e "$noncefile" ]; then
   494  		echo "$noncefile does not exist" >&2
   495  		exit 1
   496  	fi
   497  	content=$(cat $noncefile)
   498  	if [ "$content" != %s ]; then
   499  		echo "$noncefile contents do not match machine nonce" >&2
   500  		exit 1
   501  	fi
   502  	`, nonceFile, utils.ShQuote(instanceConfig.MachineNonce))
   503  	return checkNonceCommand
   504  }
   505  
   506  func ConfigureMachine(
   507  	ctx environs.BootstrapContext,
   508  	client ssh.Client,
   509  	host string,
   510  	instanceConfig *instancecfg.InstanceConfig,
   511  	sshOptions *ssh.Options,
   512  ) error {
   513  	// Bootstrap is synchronous, and will spawn a subprocess
   514  	// to complete the procedure. If the user hits Ctrl-C,
   515  	// SIGINT is sent to the foreground process attached to
   516  	// the terminal, which will be the ssh subprocess at this
   517  	// point. For that reason, we do not call StopInterruptNotify
   518  	// until this function completes.
   519  	cloudcfg, err := cloudinit.New(instanceConfig.Base.OS)
   520  	if err != nil {
   521  		return errors.Trace(err)
   522  	}
   523  
   524  	// Set packaging update here
   525  	cloudcfg.SetSystemUpdate(instanceConfig.EnableOSRefreshUpdate)
   526  	cloudcfg.SetSystemUpgrade(instanceConfig.EnableOSUpgrade)
   527  
   528  	sshinitConfig := sshinit.ConfigureParams{
   529  		Host:           "ubuntu@" + host,
   530  		Client:         client,
   531  		SSHOptions:     sshOptions,
   532  		Config:         cloudcfg,
   533  		ProgressWriter: ctx.GetStderr(),
   534  	}
   535  
   536  	ft := sshinit.NewFileTransporter(sshinitConfig)
   537  	cloudcfg.SetFileTransporter(ft)
   538  
   539  	udata, err := cloudconfig.NewUserdataConfig(instanceConfig, cloudcfg)
   540  	if err != nil {
   541  		return err
   542  	}
   543  	if err := udata.ConfigureJuju(); err != nil {
   544  		return err
   545  	}
   546  	if err := udata.ConfigureCustomOverrides(); err != nil {
   547  		return err
   548  	}
   549  	configScript, err := cloudcfg.RenderScript()
   550  	if err != nil {
   551  		return err
   552  	}
   553  
   554  	// Wait for the files to be sent to the machine.
   555  	if err := ft.Dispatch(ctx.Context()); err != nil {
   556  		return errors.Annotate(err, "transporting files to machine")
   557  	}
   558  
   559  	script := shell.DumpFileOnErrorScript(instanceConfig.CloudInitOutputLog) + configScript
   560  	ctx.Infof("Running machine configuration script...")
   561  	// TODO(benhoyt) - plumb context through juju/utils/ssh?
   562  	return sshinit.RunConfigureScript(script, sshinitConfig)
   563  }
   564  
   565  // HostSSHOptionsFunc is a function that, given a hostname, returns
   566  // an ssh.Options and a cleanup function, or an error.
   567  type HostSSHOptionsFunc func(host string) (*ssh.Options, func(), error)
   568  
   569  // DefaultHostSSHOptions returns a nil *ssh.Options, which means
   570  // to use the defaults; and a no-op cleanup function.
   571  func DefaultHostSSHOptions(string) (*ssh.Options, func(), error) {
   572  	return nil, func() {}, nil
   573  }
   574  
   575  // bootstrapSSHOptionsFunc that takes a bootstrap machine's InstanceConfig
   576  // and returns a HostSSHOptionsFunc.
   577  func bootstrapSSHOptionsFunc(instanceConfig *instancecfg.InstanceConfig) HostSSHOptionsFunc {
   578  	return func(host string) (*ssh.Options, func(), error) {
   579  		return hostBootstrapSSHOptions(host, instanceConfig)
   580  	}
   581  }
   582  
   583  func hostBootstrapSSHOptions(
   584  	host string,
   585  	instanceConfig *instancecfg.InstanceConfig,
   586  ) (_ *ssh.Options, cleanup func(), err error) {
   587  	cleanup = func() {}
   588  	defer func() {
   589  		if err != nil {
   590  			cleanup()
   591  		}
   592  	}()
   593  
   594  	options := &ssh.Options{}
   595  	options.SetStrictHostKeyChecking(ssh.StrictHostChecksYes)
   596  
   597  	// If any host keys are being injected, we'll set up a
   598  	// known_hosts file with their contents, and accept only
   599  	// them.
   600  	hostKeys := instanceConfig.Bootstrap.InitialSSHHostKeys
   601  	var algos []string
   602  	var pubKeys []string
   603  	for _, hostKey := range hostKeys {
   604  		algos = append(algos, hostKey.PublicKeyAlgorithm)
   605  		pubKeys = append(pubKeys, hostKey.Public)
   606  	}
   607  	if len(pubKeys) == 0 {
   608  		return options, cleanup, nil
   609  	}
   610  
   611  	// Create a temporary known_hosts file.
   612  	f, err := os.CreateTemp("", "juju-known-hosts")
   613  	if err != nil {
   614  		return nil, cleanup, errors.Trace(err)
   615  	}
   616  	cleanup = func() {
   617  		_ = f.Close()
   618  		_ = os.RemoveAll(f.Name())
   619  	}
   620  	w := bufio.NewWriter(f)
   621  	for _, pubKey := range pubKeys {
   622  		fmt.Fprintln(w, host, strings.TrimSpace(pubKey))
   623  	}
   624  	if err := w.Flush(); err != nil {
   625  		return nil, cleanup, errors.Annotate(err, "writing known_hosts")
   626  	}
   627  
   628  	options.SetHostKeyAlgorithms(algos...)
   629  	options.SetKnownHostsFile(f.Name())
   630  	return options, cleanup, nil
   631  }
   632  
   633  // InstanceRefresher is the subet of the Instance interface required
   634  // for waiting for SSH access to become available.
   635  type InstanceRefresher interface {
   636  	// Refresh refreshes the addresses for the instance.
   637  	Refresh(ctx envcontext.ProviderCallContext) error
   638  
   639  	// Addresses returns the addresses for the instance.
   640  	// To ensure that the results are up to date, call
   641  	// Refresh first.
   642  	Addresses(ctx envcontext.ProviderCallContext) (network.ProviderAddresses, error)
   643  
   644  	// Status returns the provider-specific status for the
   645  	// instance.
   646  	Status(ctx envcontext.ProviderCallContext) instance.Status
   647  }
   648  
   649  type RefreshableInstance struct {
   650  	instances.Instance
   651  	Env environs.Environ
   652  }
   653  
   654  // Refresh refreshes the addresses for the instance.
   655  func (i *RefreshableInstance) Refresh(ctx envcontext.ProviderCallContext) error {
   656  	instances, err := i.Env.Instances(ctx, []instance.Id{i.Id()})
   657  	if err != nil {
   658  		return errors.Trace(err)
   659  	}
   660  	i.Instance = instances[0]
   661  	return nil
   662  }
   663  
   664  type hostChecker struct {
   665  	addr           network.ProviderAddress
   666  	client         ssh.Client
   667  	hostSSHOptions HostSSHOptionsFunc
   668  	wg             *sync.WaitGroup
   669  
   670  	// checkDelay is the amount of time to wait between retries.
   671  	checkDelay time.Duration
   672  
   673  	// checkHostScript is executed on the host via SSH.
   674  	// hostChecker.loop will return once the script
   675  	// runs without error.
   676  	checkHostScript string
   677  
   678  	// closed is closed to indicate that the host checker should
   679  	// return, without waiting for the result of any ongoing
   680  	// attempts.
   681  	closed <-chan struct{}
   682  }
   683  
   684  // Close implements io.Closer, as required by parallel.Try.
   685  func (*hostChecker) Close() error {
   686  	return nil
   687  }
   688  
   689  func (hc *hostChecker) loop(dying <-chan struct{}) (io.Closer, error) {
   690  	defer hc.wg.Done()
   691  
   692  	address := hc.addr.Value
   693  	sshOptions, cleanup, err := hc.hostSSHOptions(address)
   694  	if err != nil {
   695  		return nil, err
   696  	}
   697  	defer cleanup()
   698  
   699  	// The value of connectSSH is taken outside the goroutine that may outlive
   700  	// hostChecker.loop, or we evoke the wrath of the race detector.
   701  	connectSSH := connectSSH
   702  	var lastErr error
   703  	done := make(chan error, 1)
   704  	for {
   705  		go func() {
   706  			done <- connectSSH(hc.client, address, hc.checkHostScript, sshOptions)
   707  		}()
   708  		select {
   709  		case <-dying:
   710  			return hc, lastErr
   711  		case lastErr = <-done:
   712  			if lastErr == nil {
   713  				return hc, nil
   714  			}
   715  			logger.Debugf("connection attempt for %s failed: %v", address, lastErr)
   716  		}
   717  		select {
   718  		case <-hc.closed:
   719  			return hc, lastErr
   720  		case <-dying:
   721  		case <-time.After(hc.checkDelay):
   722  		}
   723  	}
   724  }
   725  
   726  type parallelHostChecker struct {
   727  	*parallel.Try
   728  	client         ssh.Client
   729  	hostSSHOptions HostSSHOptionsFunc
   730  	stderr         io.Writer
   731  	wg             sync.WaitGroup
   732  
   733  	// active is a map of addresses to channels for addresses actively
   734  	// being tested. The goroutine testing the address will continue
   735  	// to attempt connecting to the address until it succeeds, the Try
   736  	// is killed, or the corresponding channel in this map is closed.
   737  	active map[network.ProviderAddress]chan struct{}
   738  
   739  	// checkDelay is how long each hostChecker waits between attempts.
   740  	checkDelay time.Duration
   741  
   742  	// checkHostScript is the script to run on each host to check that
   743  	// it is the host we expect.
   744  	checkHostScript string
   745  }
   746  
   747  func (p *parallelHostChecker) UpdateAddresses(addrs []network.ProviderAddress) {
   748  	for _, addr := range addrs {
   749  		if _, ok := p.active[addr]; ok {
   750  			continue
   751  		}
   752  		fmt.Fprintf(p.stderr, "Attempting to connect to %s\n", net.JoinHostPort(addr.Value, "22"))
   753  		closed := make(chan struct{})
   754  		hc := &hostChecker{
   755  			addr:            addr,
   756  			client:          p.client,
   757  			hostSSHOptions:  p.hostSSHOptions,
   758  			checkDelay:      p.checkDelay,
   759  			checkHostScript: p.checkHostScript,
   760  			closed:          closed,
   761  			wg:              &p.wg,
   762  		}
   763  		p.wg.Add(1)
   764  		p.active[addr] = closed
   765  		_ = p.Start(hc.loop)
   766  	}
   767  }
   768  
   769  // Close prevents additional functions from being added to
   770  // the Try, and tells each active hostChecker to exit.
   771  func (p *parallelHostChecker) Close() error {
   772  	// We signal each checker to stop and wait for them
   773  	// each to complete; this allows us to get the error,
   774  	// as opposed to when using try.Kill which does not
   775  	// wait for the functions to complete.
   776  	p.Try.Close()
   777  	for _, ch := range p.active {
   778  		close(ch)
   779  	}
   780  	return nil
   781  }
   782  
   783  // connectSSH is called to connect to the specified host and
   784  // execute the "checkHostScript" bash script on it.
   785  var connectSSH = func(client ssh.Client, host, checkHostScript string, options *ssh.Options) error {
   786  	cmd := client.Command("ubuntu@"+host, []string{"/bin/bash"}, options)
   787  	cmd.Stdin = strings.NewReader(checkHostScript)
   788  	output, err := cmd.CombinedOutput()
   789  	if err != nil && len(output) > 0 {
   790  		err = fmt.Errorf("%s", strings.TrimSpace(string(output)))
   791  	}
   792  	return err
   793  }
   794  
   795  // WaitSSH waits for the instance to be assigned a routable
   796  // address, then waits until we can connect to it via SSH.
   797  //
   798  // waitSSH attempts on all addresses returned by the instance
   799  // in parallel; the first succeeding one wins. We ensure that
   800  // private addresses are for the correct machine by checking
   801  // the presence of a file on the machine that contains the
   802  // machine's nonce. The "checkHostScript" is a bash script
   803  // that performs this file check.
   804  func WaitSSH(
   805  	ctx context.Context,
   806  	stdErr io.Writer,
   807  	client ssh.Client,
   808  	checkHostScript string,
   809  	inst InstanceRefresher,
   810  	callCtx envcontext.ProviderCallContext,
   811  	opts environs.BootstrapDialOpts,
   812  	hostSSHOptions HostSSHOptionsFunc,
   813  ) (addr string, err error) {
   814  	globalTimeout := time.After(opts.Timeout)
   815  	pollAddresses := time.NewTimer(0)
   816  
   817  	// checker checks each address in a loop, in parallel,
   818  	// until one succeeds, the global timeout is reached,
   819  	// or the tomb is killed.
   820  	checker := parallelHostChecker{
   821  		Try:             parallel.NewTry(0, nil),
   822  		client:          client,
   823  		stderr:          stdErr,
   824  		active:          make(map[network.ProviderAddress]chan struct{}),
   825  		checkDelay:      opts.RetryDelay,
   826  		checkHostScript: checkHostScript,
   827  		hostSSHOptions:  hostSSHOptions,
   828  	}
   829  	defer checker.wg.Wait()
   830  	defer checker.Kill()
   831  
   832  	fmt.Fprintln(stdErr, "Waiting for address")
   833  	for {
   834  		select {
   835  		case <-pollAddresses.C:
   836  			pollAddresses.Reset(opts.AddressesDelay)
   837  			if err := inst.Refresh(callCtx); err != nil {
   838  				return "", fmt.Errorf("refreshing addresses: %v", err)
   839  			}
   840  			instanceStatus := inst.Status(callCtx)
   841  			if instanceStatus.Status == status.ProvisioningError {
   842  				if instanceStatus.Message != "" {
   843  					return "", errors.Errorf("instance provisioning failed (%v)", instanceStatus.Message)
   844  				}
   845  				return "", errors.Errorf("instance provisioning failed")
   846  			}
   847  			addresses, err := inst.Addresses(callCtx)
   848  			if err != nil {
   849  				return "", fmt.Errorf("getting addresses: %v", err)
   850  			}
   851  			checker.UpdateAddresses(addresses)
   852  		case <-globalTimeout:
   853  			checker.Close()
   854  			lastErr := checker.Wait()
   855  			format := "waited for %v "
   856  			args := []interface{}{opts.Timeout}
   857  			if len(checker.active) == 0 {
   858  				format += "without getting any addresses"
   859  			} else {
   860  				format += "without being able to connect"
   861  			}
   862  			if lastErr != nil && lastErr != parallel.ErrStopped {
   863  				format += ": %v"
   864  				args = append(args, lastErr)
   865  			}
   866  			return "", fmt.Errorf(format, args...)
   867  		case <-ctx.Done():
   868  			return "", bootstrap.Cancelled()
   869  		case <-checker.Dead():
   870  			result, err := checker.Result()
   871  			if err != nil {
   872  				return "", err
   873  			}
   874  			return result.(*hostChecker).addr.Value, nil
   875  		}
   876  	}
   877  }
   878  
   879  func generateSSHHostKeys() (instancecfg.SSHHostKeys, error) {
   880  	// Generate a single ssh-rsa key. We'll configure the SSH client
   881  	// such that that is the only host key type we'll accept.
   882  	var keys instancecfg.SSHHostKeys
   883  
   884  	hostKeys, err := pkissh.GenerateHostKeys()
   885  	if err != nil {
   886  		return nil, errors.Annotate(err, "generating SSH keys")
   887  	}
   888  
   889  	for i, key := range hostKeys {
   890  		private, public, keyType, err := pkissh.FormatKey(key, fmt.Sprintf("juju-bootstrap-%d", i))
   891  		if err != nil {
   892  			return nil, errors.Annotate(err, "generating SSH key")
   893  		}
   894  
   895  		keys = append(keys, instancecfg.SSHKeyPair{
   896  			Private:            private,
   897  			Public:             public,
   898  			PublicKeyAlgorithm: keyType,
   899  		})
   900  	}
   901  	return keys, nil
   902  }