github.com/grailbio/base@v0.0.11/cmd/grail-access/remote/bless.go (about) 1 // Copyright 2022 GRAIL, Inc. All rights reserved. 2 // Use of this source code is governed by the Apache-2.0 3 // license that can be found in the LICENSE file. 4 5 package remote 6 7 import ( 8 "bytes" 9 "fmt" 10 "os/exec" 11 "strings" 12 "text/template" 13 "time" 14 15 "github.com/aws/aws-sdk-go/aws" 16 "github.com/aws/aws-sdk-go/aws/session" 17 "github.com/aws/aws-sdk-go/service/ec2" 18 "github.com/aws/aws-sdk-go/service/ec2/ec2iface" 19 "github.com/aws/aws-sdk-go/service/s3" 20 "github.com/grailbio/base/cloud/awssession" 21 "github.com/grailbio/base/must" 22 v23 "v.io/v23" 23 "v.io/v23/context" 24 "v.io/v23/security" 25 ) 26 27 const ( 28 // awsTicketPath is the path of the ticket that provides AWS credentials 29 // for querying AWS/EC2 for running instances. 30 awsTicketPath = "tickets/eng/dev/aws" 31 // blessingsExtension is the extension added to the blessings sent to 32 // remotes. 33 blessingsExtension = "remote" 34 35 // remoteExecS3Bucket is the bucket in which the known-compatible 36 // grail-access binary installed on remote targets is stored. 37 remoteExecS3Bucket = "grail-bin-public" 38 // remoteExecS3Key is the object key of the known-compatible grail-access 39 // binary installed on remote targets. 40 // TODO: Stop assuming single platform (Linux/AMD64) of targets. 41 remoteExecS3Key = "linux/amd64/2023-02-10.dev-201357/grail-access" 42 // remoteExecExpiry is the expiry of the presigned URL we generate to 43 // download (remoteExecS3Bucket, remoteExecS3Key). 44 remoteExecExpiry = 15 * time.Minute 45 // remoteExecSHA256 is the expected SHA-256 of the executable at 46 // (remoteExecS3Bucket, remoteExecS3Key). 47 remoteExecSHA256 = "eeede8ad76ee106735867facfe70d5ae917f645de3d7c6a7274cbd25da34460d" 48 // remoteExecPath is the path on the remote target at which we install and 49 // later invoke the grail-access executable. This string will be 50 // double-quoted in a bash script, so variable expansions can be used. 51 // 52 // See XDG Base Directory Specification: 53 // https://specifications.freedesktop.org/basedir-spec/basedir-spec-latest.html 54 remoteExecPath = "${XDG_DATA_HOME:-${HOME}/.local/share}/grail-access/grail-access" 55 ) 56 57 // Bless blesses the principals of targets with unconstrained extensions of 58 // the default blessings of the principal of ctx. See package documentation 59 // (doc.go) for a description of target strings. 60 func Bless(ctx *context.T, targets []string) error { 61 fmt.Println("---------------- Bless Remotes ----------------") 62 sess, err := awssession.NewWithTicket(ctx, awsTicketPath) 63 if err != nil { 64 return fmt.Errorf("creating AWS session: %v", err) 65 } 66 dests, err := resolveTargets(ctx, sess, targets) 67 if err != nil { 68 return fmt.Errorf("resolving targets: %v", err) 69 } 70 p := v23.GetPrincipal(ctx) 71 if p == nil { 72 return fmt.Errorf("no local principal") 73 } 74 blessings, _ := p.BlessingStore().Default() 75 for i, target := range targets { 76 fmt.Printf("%s:\n", target) 77 if len(dests[i]) == 0 { 78 fmt.Println(" <no matching targets>") 79 continue 80 } 81 for _, d := range dests[i] { 82 if !d.running { 83 fmt.Printf(" %-60s [ NOT RUNNING ]\n", d.s) 84 continue 85 } 86 if err := blessSSHDest(ctx, sess, p, blessings, d.s); err != nil { 87 return fmt.Errorf("blessing %q: %v", d.s, err) 88 } 89 fmt.Printf(" %-60s [ OK ]\n", d.s) 90 } 91 } 92 return nil 93 } 94 95 type sshDest struct { 96 // s represents this destination. If running is true, then it is a valid 97 // SSH destination, i.e. we can connect to it using SSH. 98 s string 99 // running is false if we believe that the host is not currently running, 100 // e.g. because EC2 tells us so. Otherwise, it is true. 101 running bool 102 } 103 104 // blessSSHDest uses commands over SSH to bless dest's principal. p is the 105 // blesser, and with are the blessings with which to bless dest's principal. 106 func blessSSHDest( 107 ctx *context.T, 108 sess *session.Session, 109 p security.Principal, 110 with security.Blessings, 111 dest string, 112 ) error { 113 if err := ensureRemoteExec(ctx, sess, dest); err != nil { 114 return fmt.Errorf("ensuring remote executable (grail-access) is available: %v", err) 115 } 116 key, err := remotePublicKey(ctx, dest) 117 if err != nil { 118 return fmt.Errorf("getting remote public key: %v", err) 119 } 120 blessingSelf, err := keysEqual(key, p.PublicKey()) 121 if err != nil { 122 return fmt.Errorf("checking if blessing self: %v", err) 123 } 124 if blessingSelf { 125 return fmt.Errorf("cannot bless self; check that target is a remote machine/principal") 126 } 127 b, err := p.Bless(key, with, blessingsExtension, security.UnconstrainedUse()) 128 if err != nil { 129 return fmt.Errorf("blessing %v with %v: %v", key, with, err) 130 } 131 if err := sendBlessings(ctx, b, dest); err != nil { 132 return fmt.Errorf("sending blessings to %s: %v", dest, err) 133 } 134 return nil 135 } 136 137 func ensureRemoteExec(ctx *context.T, sess *session.Session, dest string) error { 138 script, err := makeEnsureRemoteExecScript(sess) 139 if err != nil { 140 return fmt.Errorf( 141 "making script to ensure remote grail-access executable is available: %v", 142 err, 143 ) 144 } 145 cmd := sshCommand(ctx, dest, "bash -s") 146 cmd.Stdin = strings.NewReader(script) 147 output, err := cmd.CombinedOutput() 148 if err != nil { 149 return fmt.Errorf( 150 "running installation script on %q: %v"+ 151 "\n--- std{err,out} ---\n%s", 152 dest, 153 err, 154 output, 155 ) 156 } 157 return nil 158 } 159 160 func makeEnsureRemoteExecScript(sess *session.Session) (string, error) { 161 url, err := presignRemoteExecURL(sess) 162 if err != nil { 163 return "", fmt.Errorf("presigning URL of grail-access executable: %v", err) 164 } 165 // "Escape" single quotes, as we enclose the URL in single quotes in our 166 // generated script. 167 url = strings.ReplaceAll(url, "'", "'\\''") 168 var b strings.Builder 169 ensureRemoteExecTemplate.Execute(&b, map[string]string{ 170 "url": url, 171 "sha256": remoteExecSHA256, 172 "path": remoteExecPath, 173 }) 174 return b.String(), nil 175 } 176 177 // ensureRemoteExecTemplate is the template for building the script used to 178 // ensure that the remote has a compatible grail-access binary installed. We 179 // inject the configuration for installation. 180 var ensureRemoteExecTemplate *template.Template 181 182 func init() { 183 must.True(!strings.Contains(remoteExecSHA256, "'")) 184 ensureRemoteExecTemplate = template.Must(template.New("script").Parse(` 185 set -euxo pipefail 186 187 # url is the S3 URL from which to fetch the grail-access binary that will run 188 # on the target. 189 url='{{.url}}' 190 # sha256 is the expected SHA-256 hash of the grail-access binary. 191 sha256='{{.sha256}}' 192 193 # path is the path at which will we ultimately place the grail-access binary. 194 path="{{.path}}" 195 dir="$(dirname "${path}")" 196 197 sha_bad=0 198 echo "${sha256} ${path}" | sha256sum --check --quiet - || sha_bad=$? 199 if [ $sha_bad == 0 ]; then 200 # We already have the right binary. Ensure that it is executable. This 201 # should be a no-op unless it was changed externally. 202 chmod 700 "${path}" 203 exit 204 fi 205 206 mkdir --mode=700 --parents "${dir}" 207 chmod 700 "${dir}" 208 path_download="$(mktemp "${path}.XXXXXXXXXX")" 209 trap "rm --force -- \"${path_download}\"" EXIT 210 curl --fail "${url}" --output "${path_download}" 211 echo "${sha256} ${path_download}" | sha256sum --check --quiet - 212 chmod 700 "${path_download}" 213 mv --force "${path_download}" "${path}" 214 `)) 215 } 216 217 func remotePublicKey(ctx *context.T, dest string) (security.PublicKey, error) { 218 var ( 219 cmd = remoteExecCommand(ctx, dest, ModePublicKey) 220 stderr bytes.Buffer 221 ) 222 cmd.Stderr = &stderr 223 output, err := cmd.Output() 224 if err != nil { 225 return nil, fmt.Errorf( 226 "running grail-access(in mode: %s) on remote: %v;"+ 227 "\n--- stderr ---\n%s", 228 ModePublicKey, 229 err, 230 stderr.String(), 231 ) 232 } 233 key, err := decodePublicKey(string(output)) 234 if err != nil { 235 return nil, fmt.Errorf("decoding public key %q: %v", string(output), err) 236 } 237 return key, nil 238 } 239 240 func keysEqual(lhs, rhs security.PublicKey) (bool, error) { 241 lhsBytes, err := lhs.MarshalBinary() 242 if err != nil { 243 return false, fmt.Errorf("left-hand side of comparison invalid: %v", err) 244 } 245 rhsBytes, err := rhs.MarshalBinary() 246 if err != nil { 247 return false, fmt.Errorf("right-hand side of comparison invalid: %v", err) 248 } 249 return bytes.Equal(lhsBytes, rhsBytes), nil 250 } 251 252 func sendBlessings(ctx *context.T, b security.Blessings, dest string) error { 253 var ( 254 cmd = remoteExecCommand(ctx, dest, ModeReceive) 255 blessingsString, err = encodeBlessings(b) 256 ) 257 if err != nil { 258 return fmt.Errorf("encoding blessings: %v", err) 259 } 260 _ = blessingsString 261 cmd.Stdin = strings.NewReader(blessingsString) 262 var stderr bytes.Buffer 263 cmd.Stderr = &stderr 264 if err := cmd.Run(); err != nil { 265 return fmt.Errorf( 266 "running grail-access(in mode: %s) on remote: %v;"+ 267 "\n--- stderr ---\n%s", 268 ModeReceive, 269 err, 270 stderr.String(), 271 ) 272 } 273 return nil 274 } 275 276 func remoteExecCommand(ctx *context.T, dest, mode string) *exec.Cmd { 277 return sshCommand( 278 ctx, 279 dest, 280 // Set a reasonable value V23_CREDENTIALS in case the target's bash 281 // does not configure it (in non-login shells). 282 "V23_CREDENTIALS=${HOME}/.v23", 283 remoteExecPath, "-"+FlagNameMode+"="+mode, 284 ) 285 } 286 287 func sshCommand(ctx *context.T, dest string, args ...string) *exec.Cmd { 288 cmdArgs := []string{ 289 // Use batch mode which prevents prompting for an SSH passphrase. The 290 // prompt is more confusing than failing outright, as we run multiple 291 // SSH commands, so even if the user enters the correct passphrase, 292 // they will see more prompts. 293 "-o", "BatchMode yes", 294 // Don't check the identity of the remote host. 295 "-o", "StrictHostKeyChecking no", 296 // Don't store the identity of the remote host. 297 "-o", "UserKnownHostsFile /dev/null", 298 dest, 299 } 300 cmdArgs = append(cmdArgs, args...) 301 return exec.CommandContext(ctx, "ssh", cmdArgs...) 302 } 303 304 // resolveTargets resolves targets into SSH destinations. Destinations are 305 // returned as a two-dimensional slice of length len(targets). Each entry 306 // corresponds to the input target and is a slice of the matching SSH 307 // destinations, if any. 308 // 309 // Note that for ec2-name targets, we make API calls to EC2 to resolve the 310 // corresponding hosts. A single ec2-name target may resolve to multiple (or 311 // zero) SSH destinations, as names are given as filters. 312 func resolveTargets(ctx *context.T, sess *session.Session, targets []string) ([][]sshDest, error) { 313 var dests = make([][]sshDest, len(targets)) 314 for i, target := range targets { 315 parts := strings.SplitN(target, ":", 2) 316 if len(parts) != 2 { 317 return nil, fmt.Errorf("target not in \"type:value\" format: %v", target) 318 } 319 var ( 320 typ = parts[0] 321 val = parts[1] 322 ec2API = ec2.New(sess) 323 ) 324 switch typ { 325 case "ssh": 326 dests[i] = append(dests[i], sshDest{s: val, running: true}) 327 case "ec2-name": 328 ec2Dests, err := resolveEC2Target(ctx, ec2API, val) 329 if err != nil { 330 return nil, fmt.Errorf("resolving EC2 target %v: %v", val, err) 331 } 332 dests[i] = append(dests[i], ec2Dests...) 333 default: 334 return nil, fmt.Errorf("invalid target type for %q: %v", target, typ) 335 } 336 } 337 return dests, nil 338 } 339 340 func resolveEC2Target(ctx *context.T, ec2API ec2iface.EC2API, s string) ([]sshDest, error) { 341 var ( 342 user string 343 name string 344 ) 345 parts := strings.SplitN(s, "@", 2) 346 switch len(parts) { 347 case 1: 348 user = "ubuntu" 349 name = parts[0] 350 case 2: 351 user = parts[0] 352 name = parts[1] 353 default: 354 must.Never("SplitN returned invalid result") 355 } 356 instances, err := findInstances(ctx, ec2API, name) 357 if err != nil { 358 return nil, fmt.Errorf("finding instances matching %q: %v", name, err) 359 } 360 var dests []sshDest 361 for _, i := range instances { 362 if i.InstanceId == nil { 363 return nil, fmt.Errorf("instance has no ID: %s", i.String()) 364 } 365 if i.State == nil || i.State.Name == nil { 366 return nil, fmt.Errorf("instance has no state: %s", i.String()) 367 } 368 if *i.State.Name != ec2.InstanceStateNameRunning { 369 dests = append(dests, sshDest{ 370 s: fmt.Sprintf("%s@%s", user, *i.InstanceId), 371 running: false, 372 }) 373 continue 374 } 375 if i.PublicIpAddress == nil { 376 return nil, fmt.Errorf("running instance %q has no public IP address", *i.InstanceId) 377 } 378 dests = append(dests, sshDest{ 379 s: fmt.Sprintf("%s@%s", user, *i.PublicIpAddress), 380 running: true, 381 }) 382 } 383 return dests, nil 384 } 385 386 func presignRemoteExecURL(sess *session.Session) (string, error) { 387 s3API := s3.New(sess) 388 req, _ := s3API.GetObjectRequest(&s3.GetObjectInput{ 389 Bucket: aws.String(remoteExecS3Bucket), 390 Key: aws.String(remoteExecS3Key), 391 }) 392 url, err := req.Presign(remoteExecExpiry) 393 if err != nil { 394 return "", fmt.Errorf( 395 "presigning URL for s3://%s/%s: %v", 396 remoteExecS3Bucket, 397 remoteExecS3Key, 398 err, 399 ) 400 } 401 return url, nil 402 } 403 404 func findInstances(ctx *context.T, api ec2iface.EC2API, name string) ([]*ec2.Instance, error) { 405 input := &ec2.DescribeInstancesInput{ 406 Filters: []*ec2.Filter{ 407 { 408 Name: aws.String("tag:Name"), 409 Values: aws.StringSlice([]string{name}), 410 }, 411 }, 412 } 413 output, err := api.DescribeInstancesWithContext(ctx, input) 414 if err != nil { 415 return nil, fmt.Errorf( 416 "DescribeInstances error:\n%v\nDescribeInstances request:\n%v", 417 err, 418 input, 419 ) 420 } 421 return reservationsInstances(output.Reservations), nil 422 } 423 424 func reservationsInstances(reservations []*ec2.Reservation) []*ec2.Instance { 425 instances := []*ec2.Instance{} 426 for _, r := range reservations { 427 instances = append(instances, r.Instances...) 428 } 429 return instances 430 }