github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/cmd/grail-ssh/ssh.go (about)

     1  package main
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"io/ioutil"
     7  	"os"
     8  	"os/exec"
     9  	"regexp"
    10  	"strings"
    11  	"syscall"
    12  	"time"
    13  
    14  	"github.com/Schaudge/grailbase/security/ticket"
    15  	sshLib "golang.org/x/crypto/ssh"
    16  	terminal "golang.org/x/crypto/ssh/terminal"
    17  	"v.io/v23/context"
    18  	"v.io/x/lib/cmdline"
    19  	"v.io/x/lib/vlog"
    20  )
    21  
    22  const (
    23  	timeout = 10 * time.Second
    24  )
    25  
    26  func runSsh(ctx *context.T, out io.Writer, env *cmdline.Env, args []string) error {
    27  	if len(args) == 0 {
    28  		return env.UsageErrorf("At least one argument (<ticket>) is required.")
    29  	}
    30  
    31  	ticketPath := args[0]
    32  	args = args[1:] // remove the ticket from the arguments
    33  
    34  	client := ticket.TicketServiceClient(ticketPath)
    35  	ctx, cancel := context.WithTimeout(ctx, timeout)
    36  	defer cancel()
    37  
    38  	// Read in the private key
    39  	privateKey, err := ioutil.ReadFile(idRsaFlag)
    40  	if err != nil {
    41  		return fmt.Errorf("Failed to read private key - %v", err)
    42  	}
    43  
    44  	// Load the private key
    45  	privateSigner, err := sshLib.ParsePrivateKey(privateKey)
    46  	if err != nil {
    47  		switch err.(type) {
    48  		case *sshLib.PassphraseMissingError:
    49  			// try to load the key with a passphrase
    50  			fmt.Print("Enter SSH Key Passphrase: ")
    51  			bytePassword, _ := terminal.ReadPassword(int(syscall.Stdin))
    52  			privateSigner, err = sshLib.ParsePrivateKeyWithPassphrase(privateKey, bytePassword)
    53  			if err != nil {
    54  				return fmt.Errorf("Failed to read private key - %v", err)
    55  			}
    56  			fmt.Println("\nSSH Key decoded")
    57  		default:
    58  			return fmt.Errorf("Failed to parse private key - %v", err)
    59  		}
    60  	}
    61  
    62  	if err != nil {
    63  		return fmt.Errorf("Failed to parse private key - %v", err)
    64  	}
    65  
    66  	var parameters = []ticket.Parameter{
    67  		ticket.Parameter{
    68  			Key:   "PublicKey",
    69  			Value: string(sshLib.MarshalAuthorizedKey(privateSigner.PublicKey())),
    70  		},
    71  	}
    72  
    73  	t, err := client.GetWithParameters(ctx, parameters)
    74  	if err != nil {
    75  		return fmt.Errorf("Failed to communicate with the ticket-server - %v", err)
    76  	}
    77  
    78  	switch t.Index() {
    79  	case (ticket.TicketSshCertificateTicket{}).Index():
    80  		{
    81  			creds := t.(ticket.TicketSshCertificateTicket).Value.Credentials
    82  			// pull the public certificate out and write to the id_rsa cert path location
    83  			if err = ioutil.WriteFile(idRsaFlag+"-cert.pub", []byte(creds.Cert), 0644); err != nil {
    84  				return fmt.Errorf("Failed to write ssh public key "+idRsaFlag+"-cert.pub"+" - %v", err)
    85  			}
    86  		}
    87  	default:
    88  		{
    89  			return fmt.Errorf("Provided ticket is not a SSHCertificateTicket")
    90  		}
    91  	}
    92  
    93  	var computeInstances []ticket.ComputeInstance = t.(ticket.TicketSshCertificateTicket).Value.ComputeInstances
    94  	var username = t.(ticket.TicketSshCertificateTicket).Value.Username
    95  	// Use the environment provided username if specified
    96  	if userFlag != "" {
    97  		username = userFlag
    98  	}
    99  	// Throw an error if no username is set
   100  	if username == "" {
   101  		vlog.Errorf("Username was not provided in ticket or via command line")
   102  		// TODO: return the exit code from the cmd.
   103  		os.Exit(1)
   104  	}
   105  
   106  	var host string
   107  	instanceMatch := regexp.MustCompile("^i-[a-zA-Z0-9]+$")
   108  	// Not the best regex (e.g. doesn't match IPV6) to use here ... better regexs are available at
   109  	// https://stackoverflow.com/questions/106179/regular-expression-to-match-dns-hostname-or-ip-address
   110  	hostIpMatch := regexp.MustCompile("^([a-zA-Z0-9]+\\.)+[a-zA-Z0-9]+$")
   111  	dnsMatch := regexp.MustCompile("^(([a-zA-Z]|[a-zA-Z][a-zA-Z0-9\\-]*[a-zA-Z0-9])\\.)*([A-Za-z]|[A-Za-z][A-Za-z0-9\\-]*[A-Za-z0-9])$")
   112  	stopMatch := regexp.MustCompile("^--$")
   113  
   114  	// Loop through the arguments provided to the CLI tool - and try to match to a hostname or an instanceID.
   115  	// Stop processing if -- is found.
   116  	// host is the last match found.
   117  	for i, arg := range args {
   118  		match := instanceMatch.MatchString(arg)
   119  		if err != nil {
   120  			return fmt.Errorf("Failed to check if input %s matched an instanceId - %v", arg, err)
   121  		}
   122  
   123  		// Find matching instanceId in list
   124  		if match {
   125  			// Remove the matched element from the list
   126  			args = append(args[:i], args[i+1:]...)
   127  			for _, instance := range computeInstances {
   128  				if instance.InstanceId == arg {
   129  					vlog.Errorf("Matched InstanceID %s - %s", instance.InstanceId, instance.PublicIp)
   130  					fmt.Printf("Matched InstanceID %s - %s \n", instance.InstanceId, instance.PublicIp)
   131  					host = instance.PublicIp
   132  					break
   133  				}
   134  			}
   135  			if host == "" {
   136  				return fmt.Errorf("Failed to find a match for InstanceId provided %s", arg)
   137  			}
   138  			break
   139  		}
   140  
   141  		// // check for a dns name to stop processing
   142  		match = dnsMatch.MatchString(arg)
   143  		if err != nil {
   144  			return fmt.Errorf("Failed to check if input %s matched a DNS name' - %v", arg, err)
   145  		}
   146  		if match {
   147  			host = arg
   148  			args = append(args[:i], args[i+1:]...)
   149  			fmt.Printf("Matched DNS %s \n", host)
   150  			break
   151  		}
   152  
   153  		// check for a dns/ip host name to stop processing
   154  		match = hostIpMatch.MatchString(arg)
   155  		if err != nil {
   156  			return fmt.Errorf("Failed to check if input %s matched an '^[a-zA-Z0-9]+\\.[a-zA-Z0-9]+' - %v", arg, err)
   157  		}
   158  		if match {
   159  			host = arg
   160  			args = append(args[:i], args[i+1:]...)
   161  			fmt.Printf("Matched Host IP %s \n", host)
   162  			break
   163  		}
   164  
   165  		// check for a -- to stop processing
   166  		match = stopMatch.MatchString(arg)
   167  		if err != nil {
   168  			return fmt.Errorf("Failed to check if input %s matched an '--' - %v", arg, err)
   169  		}
   170  		if match {
   171  			break
   172  		}
   173  	}
   174  
   175  	// If no host has been found present a list
   176  	if host == "" {
   177  		fmt.Printf("No host or InstanceId provided - please select from list provided by the ticket")
   178  		// prompt for which instance to connect too
   179  		for index, instance := range computeInstances {
   180  			fmt.Printf("[%d] %s:%s - %s\n", index, instance.InstanceId, getTagValueFromKey(instance, "Name"), instance.PublicIp)
   181  		}
   182  		var instanceSelection int = -1 // initialize to negative value
   183  		fmt.Printf("Enter number for corresponding system to connect to?")
   184  		if _, err := fmt.Scanf("%d", &instanceSelection); err != nil {
   185  			return err
   186  		}
   187  
   188  		if instanceSelection < 0 || instanceSelection > len(computeInstances) {
   189  			return fmt.Errorf("Selected index (%d) was not in the list", instanceSelection)
   190  		}
   191  		if computeInstances[instanceSelection].PublicIp != "" {
   192  			host = computeInstances[instanceSelection].PublicIp
   193  		} else {
   194  			host = computeInstances[instanceSelection].PrivateIp
   195  		}
   196  	}
   197  
   198  	if host == "" {
   199  		return fmt.Errorf("Host selection failed - please provide an ip, DNS name, or select host from list with no input")
   200  	}
   201  
   202  	var sshArgs = []string{
   203  		// Forward the ssh agent.
   204  		"-A",
   205  		// Forward the X11 connections.
   206  		"-X",
   207  		// Don't check the identity of the remote host.
   208  		"-o", "StrictHostKeyChecking no",
   209  		// Don't store the identity of the remote host.
   210  		"-o", "UserKnownHostsFile /dev/null",
   211  		// Pass the private key to the ssh command
   212  		"-i", idRsaFlag,
   213  	}
   214  
   215  	// When using MOSH, SSH connection commands need to be passed like
   216  	// $ mosh --ssh="ssh -i ./identity" username@host
   217  	if sshFlag == "mosh" {
   218  		var moshSshArg = strings.Join(sshArgs, " ")
   219  		sshArgs = []string{
   220  			"--ssh", moshSshArg,
   221  		}
   222  	}
   223  
   224  	sshArgs = append(sshArgs,
   225  		username+"@"+host,
   226  	)
   227  
   228  	sshArgs = append(sshArgs, args...)
   229  
   230  	vlog.Infof("exec: %q %q", sshFlag, sshArgs)
   231  	cmd := exec.Command(sshFlag, sshArgs...)
   232  	cmd.Stdin = os.Stdin
   233  	cmd.Stdout = os.Stdout
   234  	cmd.Stderr = os.Stderr
   235  	if err := cmd.Run(); err != nil {
   236  		vlog.Errorf("ssh error: %s", err)
   237  		// TODO: return the exit code from the cmd.
   238  		os.Exit(1)
   239  	}
   240  
   241  	return nil
   242  }
   243  
   244  // Return the key value from the list of Tag Parameters
   245  func getTagValueFromKey(instance ticket.ComputeInstance, key string) string {
   246  	for _, param := range instance.Tags {
   247  		if param.Key == key {
   248  			return param.Value
   249  		}
   250  	}
   251  
   252  	// Throwing a NoSuchKey value is overkill for cases where tag is not added
   253  	return ""
   254  }