github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/cmd/grail-ssh/ssh.go (about) 1 package main 2 3 import ( 4 "fmt" 5 "io" 6 "io/ioutil" 7 "os" 8 "os/exec" 9 "regexp" 10 "strings" 11 "syscall" 12 "time" 13 14 "github.com/Schaudge/grailbase/security/ticket" 15 sshLib "golang.org/x/crypto/ssh" 16 terminal "golang.org/x/crypto/ssh/terminal" 17 "v.io/v23/context" 18 "v.io/x/lib/cmdline" 19 "v.io/x/lib/vlog" 20 ) 21 22 const ( 23 timeout = 10 * time.Second 24 ) 25 26 func runSsh(ctx *context.T, out io.Writer, env *cmdline.Env, args []string) error { 27 if len(args) == 0 { 28 return env.UsageErrorf("At least one argument (<ticket>) is required.") 29 } 30 31 ticketPath := args[0] 32 args = args[1:] // remove the ticket from the arguments 33 34 client := ticket.TicketServiceClient(ticketPath) 35 ctx, cancel := context.WithTimeout(ctx, timeout) 36 defer cancel() 37 38 // Read in the private key 39 privateKey, err := ioutil.ReadFile(idRsaFlag) 40 if err != nil { 41 return fmt.Errorf("Failed to read private key - %v", err) 42 } 43 44 // Load the private key 45 privateSigner, err := sshLib.ParsePrivateKey(privateKey) 46 if err != nil { 47 switch err.(type) { 48 case *sshLib.PassphraseMissingError: 49 // try to load the key with a passphrase 50 fmt.Print("Enter SSH Key Passphrase: ") 51 bytePassword, _ := terminal.ReadPassword(int(syscall.Stdin)) 52 privateSigner, err = sshLib.ParsePrivateKeyWithPassphrase(privateKey, bytePassword) 53 if err != nil { 54 return fmt.Errorf("Failed to read private key - %v", err) 55 } 56 fmt.Println("\nSSH Key decoded") 57 default: 58 return fmt.Errorf("Failed to parse private key - %v", err) 59 } 60 } 61 62 if err != nil { 63 return fmt.Errorf("Failed to parse private key - %v", err) 64 } 65 66 var parameters = []ticket.Parameter{ 67 ticket.Parameter{ 68 Key: "PublicKey", 69 Value: string(sshLib.MarshalAuthorizedKey(privateSigner.PublicKey())), 70 }, 71 } 72 73 t, err := client.GetWithParameters(ctx, parameters) 74 if err != nil { 75 return fmt.Errorf("Failed to communicate with the ticket-server - %v", err) 76 } 77 78 switch t.Index() { 79 case (ticket.TicketSshCertificateTicket{}).Index(): 80 { 81 creds := t.(ticket.TicketSshCertificateTicket).Value.Credentials 82 // pull the public certificate out and write to the id_rsa cert path location 83 if err = ioutil.WriteFile(idRsaFlag+"-cert.pub", []byte(creds.Cert), 0644); err != nil { 84 return fmt.Errorf("Failed to write ssh public key "+idRsaFlag+"-cert.pub"+" - %v", err) 85 } 86 } 87 default: 88 { 89 return fmt.Errorf("Provided ticket is not a SSHCertificateTicket") 90 } 91 } 92 93 var computeInstances []ticket.ComputeInstance = t.(ticket.TicketSshCertificateTicket).Value.ComputeInstances 94 var username = t.(ticket.TicketSshCertificateTicket).Value.Username 95 // Use the environment provided username if specified 96 if userFlag != "" { 97 username = userFlag 98 } 99 // Throw an error if no username is set 100 if username == "" { 101 vlog.Errorf("Username was not provided in ticket or via command line") 102 // TODO: return the exit code from the cmd. 103 os.Exit(1) 104 } 105 106 var host string 107 instanceMatch := regexp.MustCompile("^i-[a-zA-Z0-9]+$") 108 // Not the best regex (e.g. doesn't match IPV6) to use here ... better regexs are available at 109 // https://stackoverflow.com/questions/106179/regular-expression-to-match-dns-hostname-or-ip-address 110 hostIpMatch := regexp.MustCompile("^([a-zA-Z0-9]+\\.)+[a-zA-Z0-9]+$") 111 dnsMatch := regexp.MustCompile("^(([a-zA-Z]|[a-zA-Z][a-zA-Z0-9\\-]*[a-zA-Z0-9])\\.)*([A-Za-z]|[A-Za-z][A-Za-z0-9\\-]*[A-Za-z0-9])$") 112 stopMatch := regexp.MustCompile("^--$") 113 114 // Loop through the arguments provided to the CLI tool - and try to match to a hostname or an instanceID. 115 // Stop processing if -- is found. 116 // host is the last match found. 117 for i, arg := range args { 118 match := instanceMatch.MatchString(arg) 119 if err != nil { 120 return fmt.Errorf("Failed to check if input %s matched an instanceId - %v", arg, err) 121 } 122 123 // Find matching instanceId in list 124 if match { 125 // Remove the matched element from the list 126 args = append(args[:i], args[i+1:]...) 127 for _, instance := range computeInstances { 128 if instance.InstanceId == arg { 129 vlog.Errorf("Matched InstanceID %s - %s", instance.InstanceId, instance.PublicIp) 130 fmt.Printf("Matched InstanceID %s - %s \n", instance.InstanceId, instance.PublicIp) 131 host = instance.PublicIp 132 break 133 } 134 } 135 if host == "" { 136 return fmt.Errorf("Failed to find a match for InstanceId provided %s", arg) 137 } 138 break 139 } 140 141 // // check for a dns name to stop processing 142 match = dnsMatch.MatchString(arg) 143 if err != nil { 144 return fmt.Errorf("Failed to check if input %s matched a DNS name' - %v", arg, err) 145 } 146 if match { 147 host = arg 148 args = append(args[:i], args[i+1:]...) 149 fmt.Printf("Matched DNS %s \n", host) 150 break 151 } 152 153 // check for a dns/ip host name to stop processing 154 match = hostIpMatch.MatchString(arg) 155 if err != nil { 156 return fmt.Errorf("Failed to check if input %s matched an '^[a-zA-Z0-9]+\\.[a-zA-Z0-9]+' - %v", arg, err) 157 } 158 if match { 159 host = arg 160 args = append(args[:i], args[i+1:]...) 161 fmt.Printf("Matched Host IP %s \n", host) 162 break 163 } 164 165 // check for a -- to stop processing 166 match = stopMatch.MatchString(arg) 167 if err != nil { 168 return fmt.Errorf("Failed to check if input %s matched an '--' - %v", arg, err) 169 } 170 if match { 171 break 172 } 173 } 174 175 // If no host has been found present a list 176 if host == "" { 177 fmt.Printf("No host or InstanceId provided - please select from list provided by the ticket") 178 // prompt for which instance to connect too 179 for index, instance := range computeInstances { 180 fmt.Printf("[%d] %s:%s - %s\n", index, instance.InstanceId, getTagValueFromKey(instance, "Name"), instance.PublicIp) 181 } 182 var instanceSelection int = -1 // initialize to negative value 183 fmt.Printf("Enter number for corresponding system to connect to?") 184 if _, err := fmt.Scanf("%d", &instanceSelection); err != nil { 185 return err 186 } 187 188 if instanceSelection < 0 || instanceSelection > len(computeInstances) { 189 return fmt.Errorf("Selected index (%d) was not in the list", instanceSelection) 190 } 191 if computeInstances[instanceSelection].PublicIp != "" { 192 host = computeInstances[instanceSelection].PublicIp 193 } else { 194 host = computeInstances[instanceSelection].PrivateIp 195 } 196 } 197 198 if host == "" { 199 return fmt.Errorf("Host selection failed - please provide an ip, DNS name, or select host from list with no input") 200 } 201 202 var sshArgs = []string{ 203 // Forward the ssh agent. 204 "-A", 205 // Forward the X11 connections. 206 "-X", 207 // Don't check the identity of the remote host. 208 "-o", "StrictHostKeyChecking no", 209 // Don't store the identity of the remote host. 210 "-o", "UserKnownHostsFile /dev/null", 211 // Pass the private key to the ssh command 212 "-i", idRsaFlag, 213 } 214 215 // When using MOSH, SSH connection commands need to be passed like 216 // $ mosh --ssh="ssh -i ./identity" username@host 217 if sshFlag == "mosh" { 218 var moshSshArg = strings.Join(sshArgs, " ") 219 sshArgs = []string{ 220 "--ssh", moshSshArg, 221 } 222 } 223 224 sshArgs = append(sshArgs, 225 username+"@"+host, 226 ) 227 228 sshArgs = append(sshArgs, args...) 229 230 vlog.Infof("exec: %q %q", sshFlag, sshArgs) 231 cmd := exec.Command(sshFlag, sshArgs...) 232 cmd.Stdin = os.Stdin 233 cmd.Stdout = os.Stdout 234 cmd.Stderr = os.Stderr 235 if err := cmd.Run(); err != nil { 236 vlog.Errorf("ssh error: %s", err) 237 // TODO: return the exit code from the cmd. 238 os.Exit(1) 239 } 240 241 return nil 242 } 243 244 // Return the key value from the list of Tag Parameters 245 func getTagValueFromKey(instance ticket.ComputeInstance, key string) string { 246 for _, param := range instance.Tags { 247 if param.Key == key { 248 return param.Value 249 } 250 } 251 252 // Throwing a NoSuchKey value is overkill for cases where tag is not added 253 return "" 254 }