github.com/mvdan/u-root-coreutils@v0.0.0-20230122170626-c2eef2898555/cmds/exp/ssh/main.go (about)

     1  // Copyright 2022 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // SSH client.
     6  //
     7  // Synopsis:
     8  //
     9  //	ssh OPTIONS [DEST]
    10  //
    11  // Description:
    12  //
    13  //	Connects to the specified destination.
    14  //
    15  // Options:
    16  //
    17  // Destination format:
    18  //
    19  //	[user@]hostname or ssh://[user@]hostname[:port]
    20  package main
    21  
    22  import (
    23  	"errors"
    24  	"flag"
    25  	"fmt"
    26  	"io"
    27  	"io/ioutil"
    28  	"log"
    29  	"net"
    30  	"os"
    31  	guser "os/user"
    32  	"path/filepath"
    33  	"strings"
    34  
    35  	config "github.com/kevinburke/ssh_config"
    36  	sshconfig "github.com/kevinburke/ssh_config"
    37  	"golang.org/x/crypto/ssh"
    38  	"golang.org/x/crypto/ssh/knownhosts"
    39  	"golang.org/x/term"
    40  )
    41  
    42  var (
    43  	flags = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
    44  
    45  	debug      = flags.Bool("d", false, "enable debug prints")
    46  	keyFile    = flags.String("i", "", "key file")
    47  	configFile = flags.String("F", defaultConfigFile, "config file")
    48  
    49  	v = func(string, ...interface{}) {}
    50  
    51  	// ssh config file
    52  	cfg *sshconfig.Config
    53  
    54  	errInvalidArgs = errors.New("Invalid command-line arguments")
    55  )
    56  
    57  // loadConfig loads the SSH config file
    58  func loadConfig(path string) (err error) {
    59  	var f *os.File
    60  	if f, err = os.Open(path); err != nil {
    61  		if os.IsNotExist(err) {
    62  			err = nil
    63  			cfg = &config.Config{}
    64  		}
    65  		return
    66  	}
    67  	cfg, err = config.Decode(f)
    68  	return
    69  }
    70  
    71  func main() {
    72  	if err := run(os.Args, os.Stdin, os.Stdout, os.Stderr); err != nil {
    73  		log.Fatalf("%v", err)
    74  	}
    75  }
    76  
    77  func knownHosts() (ssh.HostKeyCallback, error) {
    78  	etc, err := filepath.Glob("/etc/*/ssh_known_hosts")
    79  	if err != nil {
    80  		return nil, err
    81  	}
    82  	if home, ok := os.LookupEnv("HOME"); ok {
    83  		etc = append(etc, filepath.Join(home, ".ssh", "known_hosts"))
    84  	}
    85  	return knownhosts.New(etc...)
    86  }
    87  
    88  // we demand that stdin be a proper os.File because we need to be able to put it in raw mode
    89  func run(osArgs []string, stdin *os.File, stdout io.Writer, stderr io.Writer) error {
    90  	flags.SetOutput(stderr)
    91  	flags.Parse(osArgs[1:])
    92  	if *debug {
    93  		v = log.Printf
    94  	}
    95  	defer cleanup(stdin)
    96  
    97  	// Check if they're given appropriate arguments
    98  	args := flags.Args()
    99  	var dest string
   100  	if len(args) >= 1 {
   101  		dest = args[0]
   102  		args = args[1:]
   103  	} else {
   104  		fmt.Fprintf(stderr, "usage: %v [flags] [user@]dest[:port] [command]\n", osArgs[0])
   105  		flags.PrintDefaults()
   106  		return errInvalidArgs
   107  	}
   108  
   109  	// Read the config file (if any)
   110  	if err := loadConfig(*configFile); err != nil {
   111  		return fmt.Errorf("config parse failed: %v", err)
   112  	}
   113  
   114  	// Parse out the destination
   115  	user, host, port, err := parseDest(dest)
   116  	if err != nil {
   117  		return fmt.Errorf("destination parse failed: %v", err)
   118  	}
   119  
   120  	cb, err := knownHosts()
   121  	if err != nil {
   122  		return fmt.Errorf("known hosts:%v", err)
   123  	}
   124  	// Build a client config with appropriate auth methods
   125  	config := &ssh.ClientConfig{
   126  		User:            user,
   127  		HostKeyCallback: cb,
   128  	}
   129  	// Figure out if there's a keyfile or not
   130  	kf := getKeyFile(host, *keyFile)
   131  	key, err := ioutil.ReadFile(kf)
   132  	if err == nil {
   133  		// The key exists
   134  		signer, err := ssh.ParsePrivateKey(key)
   135  		if err != nil {
   136  			return fmt.Errorf("ParsePrivateKey %v: %v", kf, err)
   137  		}
   138  		config.Auth = []ssh.AuthMethod{ssh.PublicKeys(signer)}
   139  	} else if err != nil && *keyFile != "" {
   140  		return fmt.Errorf("Could not read user-specified keyfile %v: %v", kf, err)
   141  	}
   142  	v("Config: %+v\n", config)
   143  	if term.IsTerminal(int(stdin.Fd())) {
   144  		pwReader := func() (string, error) {
   145  			return readPassword(stdin, stdout)
   146  		}
   147  		config.Auth = append(config.Auth, ssh.PasswordCallback(pwReader))
   148  	}
   149  
   150  	// Now connect to the server
   151  	conn, err := ssh.Dial("tcp", net.JoinHostPort(host, port), config)
   152  	if err != nil {
   153  		return fmt.Errorf("unable to connect: %v", err)
   154  	}
   155  	defer conn.Close()
   156  	// Create a session on that connection
   157  	session, err := conn.NewSession()
   158  	if err != nil {
   159  		return fmt.Errorf("unable to create session: %v", err)
   160  	}
   161  	session.Stdin = stdin
   162  	session.Stdout = stdout
   163  	session.Stderr = stderr
   164  	defer session.Close()
   165  
   166  	if len(args) > 0 {
   167  		// run the command
   168  		if err := session.Run(strings.Join(args, " ")); err != nil {
   169  			return fmt.Errorf("Failed to run command: %v", err)
   170  		}
   171  	} else {
   172  		// Set up the terminal
   173  		modes := ssh.TerminalModes{
   174  			ssh.ECHO:          1,     // disable echoing
   175  			ssh.TTY_OP_ISPEED: 14400, // input speed = 14.4kbaud
   176  			ssh.TTY_OP_OSPEED: 14400, // output speed = 14.4kbaud
   177  		}
   178  		if term.IsTerminal(int(stdin.Fd())) {
   179  			if err := raw(stdin); err != nil {
   180  				// throw a notice but continue
   181  				log.Printf("failed to set raw mode: %v", err)
   182  			}
   183  			// Try to figure out the terminal size
   184  			width, height, err := getSize(stdin)
   185  			if err != nil {
   186  				return fmt.Errorf("failed to get terminal size: %v", err)
   187  			}
   188  			// Request pseudo terminal - "xterm" for now, may make this adjustable later.
   189  			if err := session.RequestPty("xterm", height, width, modes); err != nil {
   190  				log.Print("request for pseudo terminal failed: ", err)
   191  			}
   192  		}
   193  		// Start shell on remote system
   194  		if err := session.Shell(); err != nil {
   195  			log.Fatal("failed to start shell: ", err)
   196  		}
   197  		// Wait for the session to complete
   198  		session.Wait()
   199  	}
   200  	return nil
   201  }
   202  
   203  // parseDest splits an ssh destination spec into separate user, host, and port fields.
   204  // Example specs:
   205  //
   206  //	ssh://user@host:port
   207  //	user@host:port
   208  //	user@host
   209  //	host
   210  func parseDest(input string) (user, host, port string, err error) {
   211  	// strip off any preceding ssh://
   212  	input = strings.TrimPrefix(input, "ssh://")
   213  	// try to find user
   214  	i := strings.LastIndex(input, "@")
   215  	if i < 0 {
   216  		var u *guser.User
   217  		u, err = guser.Current()
   218  		if err != nil {
   219  			return
   220  		}
   221  		user = u.Username
   222  	} else {
   223  		user = input[0:i]
   224  		input = input[i+1:]
   225  	}
   226  	if host, port, err = net.SplitHostPort(input); err != nil {
   227  		err = nil
   228  		host = input
   229  		port = "22"
   230  	}
   231  	if host == "" {
   232  		err = errors.New("No host specified")
   233  	}
   234  	return
   235  }
   236  
   237  // getKeyFile picks a keyfile if none has been set.
   238  // It will use sshconfig, else use a default.
   239  // The kf parameter is a user-specified key file. We pass it
   240  // here so it can be re-written if it contains a ~
   241  func getKeyFile(host, kf string) string {
   242  	v("getKeyFile for %q", kf)
   243  	if len(kf) == 0 {
   244  		var err error
   245  		kf, err = cfg.Get(host, "IdentityFile")
   246  		v("key file from config is %q", kf)
   247  		if len(kf) == 0 || err != nil {
   248  			kf = defaultKeyFile
   249  		}
   250  	}
   251  	// The kf will always be non-zero at this point.
   252  	if strings.HasPrefix(kf, "~") {
   253  		kf = filepath.Join(os.Getenv("HOME"), kf[1:])
   254  	}
   255  	v("getKeyFile returns %q", kf)
   256  	// this is a tad annoying, but the config package doesn't handle ~.
   257  	return kf
   258  }