golang.org/x/build@v0.0.0-20240506185731-218518f32b70/cmd/gomote/ssh.go (about)

     1  // Copyright 2017 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"context"
     9  	"errors"
    10  	"flag"
    11  	"fmt"
    12  	"log"
    13  	"os"
    14  	"os/exec"
    15  	"path/filepath"
    16  	"strings"
    17  
    18  	"golang.org/x/build/internal/gomote/protos"
    19  )
    20  
    21  func ssh(args []string) error {
    22  	if activeGroup != nil {
    23  		return fmt.Errorf("command does not support groups")
    24  	}
    25  
    26  	fs := flag.NewFlagSet("ssh", flag.ContinueOnError)
    27  	fs.Usage = func() {
    28  		fmt.Fprintln(os.Stderr, "ssh usage: gomote ssh <instance>")
    29  		fs.PrintDefaults()
    30  		os.Exit(1)
    31  	}
    32  	fs.Parse(args)
    33  	if fs.NArg() != 1 {
    34  		fs.Usage()
    35  	}
    36  
    37  	name := fs.Arg(0)
    38  	sshKeyDir, err := sshConfigDirectory()
    39  	if err != nil {
    40  		return err
    41  	}
    42  	pubKey, priKey, err := localKeyPair(sshKeyDir)
    43  	if err != nil {
    44  		return err
    45  	}
    46  	pubKeyBytes, err := os.ReadFile(pubKey)
    47  	if err != nil {
    48  		return err
    49  	}
    50  	ctx := context.Background()
    51  	client := gomoteServerClient(ctx)
    52  	resp, err := client.SignSSHKey(ctx, &protos.SignSSHKeyRequest{
    53  		GomoteId:     name,
    54  		PublicSshKey: []byte(pubKeyBytes),
    55  	})
    56  	if err != nil {
    57  		return fmt.Errorf("unable to retrieve SSH certificate: %w", err)
    58  	}
    59  	certPath, err := writeCertificateToDisk(resp.GetSignedPublicSshKey())
    60  	if err != nil {
    61  		return err
    62  	}
    63  	return sshConnect(name, priKey, certPath)
    64  }
    65  
    66  func sshConfigDirectory() (string, error) {
    67  	configDir, err := os.UserConfigDir()
    68  	if err != nil {
    69  		return "", fmt.Errorf("unable to retrieve user configuration directory: %w", err)
    70  	}
    71  	sshConfigDir := filepath.Join(configDir, "gomote", ".ssh")
    72  	err = os.MkdirAll(sshConfigDir, 0700)
    73  	if err != nil {
    74  		return "", fmt.Errorf("unable to create user SSH configuration directory: %w", err)
    75  	}
    76  	return sshConfigDir, nil
    77  }
    78  
    79  func localKeyPair(sshDir string) (string, string, error) {
    80  	priKey := filepath.Join(sshDir, "id_ed25519")
    81  	pubKey := filepath.Join(sshDir, "id_ed25519.pub")
    82  	if !fileExists(priKey) || !fileExists(pubKey) {
    83  		log.Printf("local ssh keys do not exist, attempting to create them")
    84  		if err := createLocalKeyPair(pubKey, priKey); err != nil {
    85  			return "", "", fmt.Errorf("unable to create local SSH key pair: %w", err)
    86  		}
    87  	}
    88  	return pubKey, priKey, nil
    89  }
    90  
    91  func createLocalKeyPair(pubKey, priKey string) error {
    92  	cmd := exec.Command("ssh-keygen", "-o", "-a", "256", "-t", "ed25519", "-f", priKey)
    93  	cmd.Stdout = os.Stdout
    94  	cmd.Stdin = os.Stdin
    95  	cmd.Stderr = os.Stderr
    96  	return cmd.Run()
    97  }
    98  
    99  func writeCertificateToDisk(b []byte) (string, error) {
   100  	tmpDir := filepath.Join(os.TempDir(), ".gomote")
   101  	if err := os.MkdirAll(tmpDir, 0700); err != nil {
   102  		return "", fmt.Errorf("unable to create temp directory for certficates: %w", err)
   103  	}
   104  	tf, err := os.CreateTemp(tmpDir, "id_ed25519-*-cert.pub")
   105  	if err != nil {
   106  		return "", err
   107  	}
   108  	if err := tf.Chmod(0600); err != nil {
   109  		return "", err
   110  	}
   111  	if _, err := tf.Write(b); err != nil {
   112  		return "", err
   113  	}
   114  	return tf.Name(), tf.Close()
   115  }
   116  
   117  func sshConnect(name string, priKey, certPath string) error {
   118  	ssh, err := exec.LookPath("ssh")
   119  	if err != nil {
   120  		return fmt.Errorf("path to ssh not found: %w", err)
   121  	}
   122  	sshServer := "gomotessh.golang.org"
   123  	if luciDisabled() {
   124  		sshServer = "farmer.golang.org"
   125  	}
   126  	cli := []string{"-o", fmt.Sprintf("CertificateFile=%s", certPath), "-i", priKey, "-p", "2222", name + "@" + sshServer}
   127  	fmt.Printf("$ %s %s\n", ssh, strings.Join(cli, " "))
   128  	cmd := exec.Command(ssh, cli...)
   129  	cmd.Stdout = os.Stdout
   130  	cmd.Stdin = os.Stdin
   131  	cmd.Stderr = os.Stderr
   132  	if err := cmd.Run(); err != nil {
   133  		return fmt.Errorf("unable to ssh into instance: %w", err)
   134  	}
   135  	return nil
   136  }
   137  
   138  func fileExists(path string) bool {
   139  	if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) {
   140  		return false
   141  	}
   142  	return true
   143  }