github.com/grailbio/base@v0.0.11/cmd/grail-access/remote/bless.go (about)

     1  // Copyright 2022 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package remote
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"os/exec"
    11  	"strings"
    12  	"text/template"
    13  	"time"
    14  
    15  	"github.com/aws/aws-sdk-go/aws"
    16  	"github.com/aws/aws-sdk-go/aws/session"
    17  	"github.com/aws/aws-sdk-go/service/ec2"
    18  	"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
    19  	"github.com/aws/aws-sdk-go/service/s3"
    20  	"github.com/grailbio/base/cloud/awssession"
    21  	"github.com/grailbio/base/must"
    22  	v23 "v.io/v23"
    23  	"v.io/v23/context"
    24  	"v.io/v23/security"
    25  )
    26  
    27  const (
    28  	// awsTicketPath is the path of the ticket that provides AWS credentials
    29  	// for querying AWS/EC2 for running instances.
    30  	awsTicketPath = "tickets/eng/dev/aws"
    31  	// blessingsExtension is the extension added to the blessings sent to
    32  	// remotes.
    33  	blessingsExtension = "remote"
    34  
    35  	// remoteExecS3Bucket is the bucket in which the known-compatible
    36  	// grail-access binary installed on remote targets is stored.
    37  	remoteExecS3Bucket = "grail-bin-public"
    38  	// remoteExecS3Key is the object key of the known-compatible grail-access
    39  	// binary installed on remote targets.
    40  	// TODO: Stop assuming single platform (Linux/AMD64) of targets.
    41  	remoteExecS3Key = "linux/amd64/2023-02-10.dev-201357/grail-access"
    42  	// remoteExecExpiry is the expiry of the presigned URL we generate to
    43  	// download (remoteExecS3Bucket, remoteExecS3Key).
    44  	remoteExecExpiry = 15 * time.Minute
    45  	// remoteExecSHA256 is the expected SHA-256 of the executable at
    46  	// (remoteExecS3Bucket, remoteExecS3Key).
    47  	remoteExecSHA256 = "eeede8ad76ee106735867facfe70d5ae917f645de3d7c6a7274cbd25da34460d"
    48  	// remoteExecPath is the path on the remote target at which we install and
    49  	// later invoke the grail-access executable.  This string will be
    50  	// double-quoted in a bash script, so variable expansions can be used.
    51  	//
    52  	// See XDG Base Directory Specification:
    53  	// https://specifications.freedesktop.org/basedir-spec/basedir-spec-latest.html
    54  	remoteExecPath = "${XDG_DATA_HOME:-${HOME}/.local/share}/grail-access/grail-access"
    55  )
    56  
    57  // Bless blesses the principals of targets with unconstrained extensions of
    58  // the default blessings of the principal of ctx.  See package documentation
    59  // (doc.go) for a description of target strings.
    60  func Bless(ctx *context.T, targets []string) error {
    61  	fmt.Println("---------------- Bless Remotes ----------------")
    62  	sess, err := awssession.NewWithTicket(ctx, awsTicketPath)
    63  	if err != nil {
    64  		return fmt.Errorf("creating AWS session: %v", err)
    65  	}
    66  	dests, err := resolveTargets(ctx, sess, targets)
    67  	if err != nil {
    68  		return fmt.Errorf("resolving targets: %v", err)
    69  	}
    70  	p := v23.GetPrincipal(ctx)
    71  	if p == nil {
    72  		return fmt.Errorf("no local principal")
    73  	}
    74  	blessings, _ := p.BlessingStore().Default()
    75  	for i, target := range targets {
    76  		fmt.Printf("%s:\n", target)
    77  		if len(dests[i]) == 0 {
    78  			fmt.Println("  <no matching targets>")
    79  			continue
    80  		}
    81  		for _, d := range dests[i] {
    82  			if !d.running {
    83  				fmt.Printf("  %-60s [ NOT RUNNING ]\n", d.s)
    84  				continue
    85  			}
    86  			if err := blessSSHDest(ctx, sess, p, blessings, d.s); err != nil {
    87  				return fmt.Errorf("blessing %q: %v", d.s, err)
    88  			}
    89  			fmt.Printf("  %-60s [ OK ]\n", d.s)
    90  		}
    91  	}
    92  	return nil
    93  }
    94  
    95  type sshDest struct {
    96  	// s represents this destination.  If running is true, then it is a valid
    97  	// SSH destination, i.e. we can connect to it using SSH.
    98  	s string
    99  	// running is false if we believe that the host is not currently running,
   100  	// e.g. because EC2 tells us so.  Otherwise, it is true.
   101  	running bool
   102  }
   103  
   104  // blessSSHDest uses commands over SSH to bless dest's principal.  p is the
   105  // blesser, and with are the blessings with which to bless dest's principal.
   106  func blessSSHDest(
   107  	ctx *context.T,
   108  	sess *session.Session,
   109  	p security.Principal,
   110  	with security.Blessings,
   111  	dest string,
   112  ) error {
   113  	if err := ensureRemoteExec(ctx, sess, dest); err != nil {
   114  		return fmt.Errorf("ensuring remote executable (grail-access) is available: %v", err)
   115  	}
   116  	key, err := remotePublicKey(ctx, dest)
   117  	if err != nil {
   118  		return fmt.Errorf("getting remote public key: %v", err)
   119  	}
   120  	blessingSelf, err := keysEqual(key, p.PublicKey())
   121  	if err != nil {
   122  		return fmt.Errorf("checking if blessing self: %v", err)
   123  	}
   124  	if blessingSelf {
   125  		return fmt.Errorf("cannot bless self; check that target is a remote machine/principal")
   126  	}
   127  	b, err := p.Bless(key, with, blessingsExtension, security.UnconstrainedUse())
   128  	if err != nil {
   129  		return fmt.Errorf("blessing %v with %v: %v", key, with, err)
   130  	}
   131  	if err := sendBlessings(ctx, b, dest); err != nil {
   132  		return fmt.Errorf("sending blessings to %s: %v", dest, err)
   133  	}
   134  	return nil
   135  }
   136  
   137  func ensureRemoteExec(ctx *context.T, sess *session.Session, dest string) error {
   138  	script, err := makeEnsureRemoteExecScript(sess)
   139  	if err != nil {
   140  		return fmt.Errorf(
   141  			"making script to ensure remote grail-access executable is available: %v",
   142  			err,
   143  		)
   144  	}
   145  	cmd := sshCommand(ctx, dest, "bash -s")
   146  	cmd.Stdin = strings.NewReader(script)
   147  	output, err := cmd.CombinedOutput()
   148  	if err != nil {
   149  		return fmt.Errorf(
   150  			"running installation script on %q: %v"+
   151  				"\n--- std{err,out} ---\n%s",
   152  			dest,
   153  			err,
   154  			output,
   155  		)
   156  	}
   157  	return nil
   158  }
   159  
   160  func makeEnsureRemoteExecScript(sess *session.Session) (string, error) {
   161  	url, err := presignRemoteExecURL(sess)
   162  	if err != nil {
   163  		return "", fmt.Errorf("presigning URL of grail-access executable: %v", err)
   164  	}
   165  	// "Escape" single quotes, as we enclose the URL in single quotes in our
   166  	// generated script.
   167  	url = strings.ReplaceAll(url, "'", "'\\''")
   168  	var b strings.Builder
   169  	ensureRemoteExecTemplate.Execute(&b, map[string]string{
   170  		"url":    url,
   171  		"sha256": remoteExecSHA256,
   172  		"path":   remoteExecPath,
   173  	})
   174  	return b.String(), nil
   175  }
   176  
   177  // ensureRemoteExecTemplate is the template for building the script used to
   178  // ensure that the remote has a compatible grail-access binary installed.  We
   179  // inject the configuration for installation.
   180  var ensureRemoteExecTemplate *template.Template
   181  
   182  func init() {
   183  	must.True(!strings.Contains(remoteExecSHA256, "'"))
   184  	ensureRemoteExecTemplate = template.Must(template.New("script").Parse(`
   185  set -euxo pipefail
   186  
   187  # url is the S3 URL from which to fetch the grail-access binary that will run
   188  # on the target.
   189  url='{{.url}}'
   190  # sha256 is the expected SHA-256 hash of the grail-access binary.
   191  sha256='{{.sha256}}'
   192  
   193  # path is the path at which will we ultimately place the grail-access binary.
   194  path="{{.path}}"
   195  dir="$(dirname "${path}")"
   196  
   197  sha_bad=0
   198  echo "${sha256} ${path}" | sha256sum --check --quiet - || sha_bad=$?
   199  if [ $sha_bad == 0 ]; then
   200  	# We already have the right binary.  Ensure that it is executable.  This
   201  	# should be a no-op unless it was changed externally.
   202  	chmod 700 "${path}"
   203  	exit
   204  fi
   205  
   206  mkdir --mode=700 --parents "${dir}"
   207  chmod 700 "${dir}"
   208  path_download="$(mktemp "${path}.XXXXXXXXXX")"
   209  trap "rm --force -- \"${path_download}\"" EXIT
   210  curl --fail "${url}" --output "${path_download}"
   211  echo "${sha256} ${path_download}" | sha256sum --check --quiet -
   212  chmod 700 "${path_download}"
   213  mv --force "${path_download}" "${path}"
   214  `))
   215  }
   216  
   217  func remotePublicKey(ctx *context.T, dest string) (security.PublicKey, error) {
   218  	var (
   219  		cmd    = remoteExecCommand(ctx, dest, ModePublicKey)
   220  		stderr bytes.Buffer
   221  	)
   222  	cmd.Stderr = &stderr
   223  	output, err := cmd.Output()
   224  	if err != nil {
   225  		return nil, fmt.Errorf(
   226  			"running grail-access(in mode: %s) on remote: %v;"+
   227  				"\n--- stderr ---\n%s",
   228  			ModePublicKey,
   229  			err,
   230  			stderr.String(),
   231  		)
   232  	}
   233  	key, err := decodePublicKey(string(output))
   234  	if err != nil {
   235  		return nil, fmt.Errorf("decoding public key %q: %v", string(output), err)
   236  	}
   237  	return key, nil
   238  }
   239  
   240  func keysEqual(lhs, rhs security.PublicKey) (bool, error) {
   241  	lhsBytes, err := lhs.MarshalBinary()
   242  	if err != nil {
   243  		return false, fmt.Errorf("left-hand side of comparison invalid: %v", err)
   244  	}
   245  	rhsBytes, err := rhs.MarshalBinary()
   246  	if err != nil {
   247  		return false, fmt.Errorf("right-hand side of comparison invalid: %v", err)
   248  	}
   249  	return bytes.Equal(lhsBytes, rhsBytes), nil
   250  }
   251  
   252  func sendBlessings(ctx *context.T, b security.Blessings, dest string) error {
   253  	var (
   254  		cmd                  = remoteExecCommand(ctx, dest, ModeReceive)
   255  		blessingsString, err = encodeBlessings(b)
   256  	)
   257  	if err != nil {
   258  		return fmt.Errorf("encoding blessings: %v", err)
   259  	}
   260  	_ = blessingsString
   261  	cmd.Stdin = strings.NewReader(blessingsString)
   262  	var stderr bytes.Buffer
   263  	cmd.Stderr = &stderr
   264  	if err := cmd.Run(); err != nil {
   265  		return fmt.Errorf(
   266  			"running grail-access(in mode: %s) on remote: %v;"+
   267  				"\n--- stderr ---\n%s",
   268  			ModeReceive,
   269  			err,
   270  			stderr.String(),
   271  		)
   272  	}
   273  	return nil
   274  }
   275  
   276  func remoteExecCommand(ctx *context.T, dest, mode string) *exec.Cmd {
   277  	return sshCommand(
   278  		ctx,
   279  		dest,
   280  		// Set a reasonable value V23_CREDENTIALS in case the target's bash
   281  		// does not configure it (in non-login shells).
   282  		"V23_CREDENTIALS=${HOME}/.v23",
   283  		remoteExecPath, "-"+FlagNameMode+"="+mode,
   284  	)
   285  }
   286  
   287  func sshCommand(ctx *context.T, dest string, args ...string) *exec.Cmd {
   288  	cmdArgs := []string{
   289  		// Use batch mode which prevents prompting for an SSH passphrase.  The
   290  		// prompt is more confusing than failing outright, as we run multiple
   291  		// SSH commands, so even if the user enters the correct passphrase,
   292  		// they will see more prompts.
   293  		"-o", "BatchMode yes",
   294  		// Don't check the identity of the remote host.
   295  		"-o", "StrictHostKeyChecking no",
   296  		// Don't store the identity of the remote host.
   297  		"-o", "UserKnownHostsFile /dev/null",
   298  		dest,
   299  	}
   300  	cmdArgs = append(cmdArgs, args...)
   301  	return exec.CommandContext(ctx, "ssh", cmdArgs...)
   302  }
   303  
   304  // resolveTargets resolves targets into SSH destinations.  Destinations are
   305  // returned as a two-dimensional slice of length len(targets).  Each entry
   306  // corresponds to the input target and is a slice of the matching SSH
   307  // destinations, if any.
   308  //
   309  // Note that for ec2-name targets, we make API calls to EC2 to resolve the
   310  // corresponding hosts.  A single ec2-name target may resolve to multiple (or
   311  // zero) SSH destinations, as names are given as filters.
   312  func resolveTargets(ctx *context.T, sess *session.Session, targets []string) ([][]sshDest, error) {
   313  	var dests = make([][]sshDest, len(targets))
   314  	for i, target := range targets {
   315  		parts := strings.SplitN(target, ":", 2)
   316  		if len(parts) != 2 {
   317  			return nil, fmt.Errorf("target not in \"type:value\" format: %v", target)
   318  		}
   319  		var (
   320  			typ    = parts[0]
   321  			val    = parts[1]
   322  			ec2API = ec2.New(sess)
   323  		)
   324  		switch typ {
   325  		case "ssh":
   326  			dests[i] = append(dests[i], sshDest{s: val, running: true})
   327  		case "ec2-name":
   328  			ec2Dests, err := resolveEC2Target(ctx, ec2API, val)
   329  			if err != nil {
   330  				return nil, fmt.Errorf("resolving EC2 target %v: %v", val, err)
   331  			}
   332  			dests[i] = append(dests[i], ec2Dests...)
   333  		default:
   334  			return nil, fmt.Errorf("invalid target type for %q: %v", target, typ)
   335  		}
   336  	}
   337  	return dests, nil
   338  }
   339  
   340  func resolveEC2Target(ctx *context.T, ec2API ec2iface.EC2API, s string) ([]sshDest, error) {
   341  	var (
   342  		user string
   343  		name string
   344  	)
   345  	parts := strings.SplitN(s, "@", 2)
   346  	switch len(parts) {
   347  	case 1:
   348  		user = "ubuntu"
   349  		name = parts[0]
   350  	case 2:
   351  		user = parts[0]
   352  		name = parts[1]
   353  	default:
   354  		must.Never("SplitN returned invalid result")
   355  	}
   356  	instances, err := findInstances(ctx, ec2API, name)
   357  	if err != nil {
   358  		return nil, fmt.Errorf("finding instances matching %q: %v", name, err)
   359  	}
   360  	var dests []sshDest
   361  	for _, i := range instances {
   362  		if i.InstanceId == nil {
   363  			return nil, fmt.Errorf("instance has no ID: %s", i.String())
   364  		}
   365  		if i.State == nil || i.State.Name == nil {
   366  			return nil, fmt.Errorf("instance has no state: %s", i.String())
   367  		}
   368  		if *i.State.Name != ec2.InstanceStateNameRunning {
   369  			dests = append(dests, sshDest{
   370  				s:       fmt.Sprintf("%s@%s", user, *i.InstanceId),
   371  				running: false,
   372  			})
   373  			continue
   374  		}
   375  		if i.PublicIpAddress == nil {
   376  			return nil, fmt.Errorf("running instance %q has no public IP address", *i.InstanceId)
   377  		}
   378  		dests = append(dests, sshDest{
   379  			s:       fmt.Sprintf("%s@%s", user, *i.PublicIpAddress),
   380  			running: true,
   381  		})
   382  	}
   383  	return dests, nil
   384  }
   385  
   386  func presignRemoteExecURL(sess *session.Session) (string, error) {
   387  	s3API := s3.New(sess)
   388  	req, _ := s3API.GetObjectRequest(&s3.GetObjectInput{
   389  		Bucket: aws.String(remoteExecS3Bucket),
   390  		Key:    aws.String(remoteExecS3Key),
   391  	})
   392  	url, err := req.Presign(remoteExecExpiry)
   393  	if err != nil {
   394  		return "", fmt.Errorf(
   395  			"presigning URL for s3://%s/%s: %v",
   396  			remoteExecS3Bucket,
   397  			remoteExecS3Key,
   398  			err,
   399  		)
   400  	}
   401  	return url, nil
   402  }
   403  
   404  func findInstances(ctx *context.T, api ec2iface.EC2API, name string) ([]*ec2.Instance, error) {
   405  	input := &ec2.DescribeInstancesInput{
   406  		Filters: []*ec2.Filter{
   407  			{
   408  				Name:   aws.String("tag:Name"),
   409  				Values: aws.StringSlice([]string{name}),
   410  			},
   411  		},
   412  	}
   413  	output, err := api.DescribeInstancesWithContext(ctx, input)
   414  	if err != nil {
   415  		return nil, fmt.Errorf(
   416  			"DescribeInstances error:\n%v\nDescribeInstances request:\n%v",
   417  			err,
   418  			input,
   419  		)
   420  	}
   421  	return reservationsInstances(output.Reservations), nil
   422  }
   423  
   424  func reservationsInstances(reservations []*ec2.Reservation) []*ec2.Instance {
   425  	instances := []*ec2.Instance{}
   426  	for _, r := range reservations {
   427  		instances = append(instances, r.Instances...)
   428  	}
   429  	return instances
   430  }