github.com/mvdan/u-root-coreutils@v0.0.0-20230122170626-c2eef2898555/cmds/exp/ssh/main.go (about) 1 // Copyright 2022 the u-root Authors. All rights reserved 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // SSH client. 6 // 7 // Synopsis: 8 // 9 // ssh OPTIONS [DEST] 10 // 11 // Description: 12 // 13 // Connects to the specified destination. 14 // 15 // Options: 16 // 17 // Destination format: 18 // 19 // [user@]hostname or ssh://[user@]hostname[:port] 20 package main 21 22 import ( 23 "errors" 24 "flag" 25 "fmt" 26 "io" 27 "io/ioutil" 28 "log" 29 "net" 30 "os" 31 guser "os/user" 32 "path/filepath" 33 "strings" 34 35 config "github.com/kevinburke/ssh_config" 36 sshconfig "github.com/kevinburke/ssh_config" 37 "golang.org/x/crypto/ssh" 38 "golang.org/x/crypto/ssh/knownhosts" 39 "golang.org/x/term" 40 ) 41 42 var ( 43 flags = flag.NewFlagSet(os.Args[0], flag.ExitOnError) 44 45 debug = flags.Bool("d", false, "enable debug prints") 46 keyFile = flags.String("i", "", "key file") 47 configFile = flags.String("F", defaultConfigFile, "config file") 48 49 v = func(string, ...interface{}) {} 50 51 // ssh config file 52 cfg *sshconfig.Config 53 54 errInvalidArgs = errors.New("Invalid command-line arguments") 55 ) 56 57 // loadConfig loads the SSH config file 58 func loadConfig(path string) (err error) { 59 var f *os.File 60 if f, err = os.Open(path); err != nil { 61 if os.IsNotExist(err) { 62 err = nil 63 cfg = &config.Config{} 64 } 65 return 66 } 67 cfg, err = config.Decode(f) 68 return 69 } 70 71 func main() { 72 if err := run(os.Args, os.Stdin, os.Stdout, os.Stderr); err != nil { 73 log.Fatalf("%v", err) 74 } 75 } 76 77 func knownHosts() (ssh.HostKeyCallback, error) { 78 etc, err := filepath.Glob("/etc/*/ssh_known_hosts") 79 if err != nil { 80 return nil, err 81 } 82 if home, ok := os.LookupEnv("HOME"); ok { 83 etc = append(etc, filepath.Join(home, ".ssh", "known_hosts")) 84 } 85 return knownhosts.New(etc...) 86 } 87 88 // we demand that stdin be a proper os.File because we need to be able to put it in raw mode 89 func run(osArgs []string, stdin *os.File, stdout io.Writer, stderr io.Writer) error { 90 flags.SetOutput(stderr) 91 flags.Parse(osArgs[1:]) 92 if *debug { 93 v = log.Printf 94 } 95 defer cleanup(stdin) 96 97 // Check if they're given appropriate arguments 98 args := flags.Args() 99 var dest string 100 if len(args) >= 1 { 101 dest = args[0] 102 args = args[1:] 103 } else { 104 fmt.Fprintf(stderr, "usage: %v [flags] [user@]dest[:port] [command]\n", osArgs[0]) 105 flags.PrintDefaults() 106 return errInvalidArgs 107 } 108 109 // Read the config file (if any) 110 if err := loadConfig(*configFile); err != nil { 111 return fmt.Errorf("config parse failed: %v", err) 112 } 113 114 // Parse out the destination 115 user, host, port, err := parseDest(dest) 116 if err != nil { 117 return fmt.Errorf("destination parse failed: %v", err) 118 } 119 120 cb, err := knownHosts() 121 if err != nil { 122 return fmt.Errorf("known hosts:%v", err) 123 } 124 // Build a client config with appropriate auth methods 125 config := &ssh.ClientConfig{ 126 User: user, 127 HostKeyCallback: cb, 128 } 129 // Figure out if there's a keyfile or not 130 kf := getKeyFile(host, *keyFile) 131 key, err := ioutil.ReadFile(kf) 132 if err == nil { 133 // The key exists 134 signer, err := ssh.ParsePrivateKey(key) 135 if err != nil { 136 return fmt.Errorf("ParsePrivateKey %v: %v", kf, err) 137 } 138 config.Auth = []ssh.AuthMethod{ssh.PublicKeys(signer)} 139 } else if err != nil && *keyFile != "" { 140 return fmt.Errorf("Could not read user-specified keyfile %v: %v", kf, err) 141 } 142 v("Config: %+v\n", config) 143 if term.IsTerminal(int(stdin.Fd())) { 144 pwReader := func() (string, error) { 145 return readPassword(stdin, stdout) 146 } 147 config.Auth = append(config.Auth, ssh.PasswordCallback(pwReader)) 148 } 149 150 // Now connect to the server 151 conn, err := ssh.Dial("tcp", net.JoinHostPort(host, port), config) 152 if err != nil { 153 return fmt.Errorf("unable to connect: %v", err) 154 } 155 defer conn.Close() 156 // Create a session on that connection 157 session, err := conn.NewSession() 158 if err != nil { 159 return fmt.Errorf("unable to create session: %v", err) 160 } 161 session.Stdin = stdin 162 session.Stdout = stdout 163 session.Stderr = stderr 164 defer session.Close() 165 166 if len(args) > 0 { 167 // run the command 168 if err := session.Run(strings.Join(args, " ")); err != nil { 169 return fmt.Errorf("Failed to run command: %v", err) 170 } 171 } else { 172 // Set up the terminal 173 modes := ssh.TerminalModes{ 174 ssh.ECHO: 1, // disable echoing 175 ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud 176 ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud 177 } 178 if term.IsTerminal(int(stdin.Fd())) { 179 if err := raw(stdin); err != nil { 180 // throw a notice but continue 181 log.Printf("failed to set raw mode: %v", err) 182 } 183 // Try to figure out the terminal size 184 width, height, err := getSize(stdin) 185 if err != nil { 186 return fmt.Errorf("failed to get terminal size: %v", err) 187 } 188 // Request pseudo terminal - "xterm" for now, may make this adjustable later. 189 if err := session.RequestPty("xterm", height, width, modes); err != nil { 190 log.Print("request for pseudo terminal failed: ", err) 191 } 192 } 193 // Start shell on remote system 194 if err := session.Shell(); err != nil { 195 log.Fatal("failed to start shell: ", err) 196 } 197 // Wait for the session to complete 198 session.Wait() 199 } 200 return nil 201 } 202 203 // parseDest splits an ssh destination spec into separate user, host, and port fields. 204 // Example specs: 205 // 206 // ssh://user@host:port 207 // user@host:port 208 // user@host 209 // host 210 func parseDest(input string) (user, host, port string, err error) { 211 // strip off any preceding ssh:// 212 input = strings.TrimPrefix(input, "ssh://") 213 // try to find user 214 i := strings.LastIndex(input, "@") 215 if i < 0 { 216 var u *guser.User 217 u, err = guser.Current() 218 if err != nil { 219 return 220 } 221 user = u.Username 222 } else { 223 user = input[0:i] 224 input = input[i+1:] 225 } 226 if host, port, err = net.SplitHostPort(input); err != nil { 227 err = nil 228 host = input 229 port = "22" 230 } 231 if host == "" { 232 err = errors.New("No host specified") 233 } 234 return 235 } 236 237 // getKeyFile picks a keyfile if none has been set. 238 // It will use sshconfig, else use a default. 239 // The kf parameter is a user-specified key file. We pass it 240 // here so it can be re-written if it contains a ~ 241 func getKeyFile(host, kf string) string { 242 v("getKeyFile for %q", kf) 243 if len(kf) == 0 { 244 var err error 245 kf, err = cfg.Get(host, "IdentityFile") 246 v("key file from config is %q", kf) 247 if len(kf) == 0 || err != nil { 248 kf = defaultKeyFile 249 } 250 } 251 // The kf will always be non-zero at this point. 252 if strings.HasPrefix(kf, "~") { 253 kf = filepath.Join(os.Getenv("HOME"), kf[1:]) 254 } 255 v("getKeyFile returns %q", kf) 256 // this is a tad annoying, but the config package doesn't handle ~. 257 return kf 258 }