github.com/niedbalski/juju@v0.0.0-20190215020005-8ff100488e47/scripts/ssh-keycheck/ssh_keycheck.go (about)

     1  // Copyright 2017 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package main
     5  
     6  import (
     7  	"fmt"
     8  	"io/ioutil"
     9  	"net"
    10  	"os"
    11  	"os/user"
    12  	"path"
    13  	"strings"
    14  	"time"
    15  
    16  	"github.com/juju/gnuflag"
    17  	"github.com/juju/loggo"
    18  	"golang.org/x/crypto/ssh"
    19  
    20  	"github.com/juju/juju/network"
    21  	jujussh "github.com/juju/juju/network/ssh"
    22  )
    23  
    24  func knownHostFilename() string {
    25  	usr, err := user.Current()
    26  	if err != nil {
    27  		panic(fmt.Sprintf("unable to find current user: %v", err))
    28  	}
    29  	return path.Join(usr.HomeDir, ".ssh", "known_hosts")
    30  }
    31  
    32  // Juju reports the files in /etc/ssh/ssh_host_key_*_key.pub, so they are all
    33  // in AuthorizedKey format.
    34  func getKnownHostKeys(fname string) []string {
    35  	f, err := os.Open(fname)
    36  	if err != nil {
    37  		panic(fmt.Sprintf("unable to read known-hosts file: %q %v", fname, err))
    38  	}
    39  	defer f.Close()
    40  	content, err := ioutil.ReadAll(f)
    41  	if err != nil {
    42  		panic(fmt.Sprintf("failed while reading known-hosts file: %q %v", fname, err))
    43  	}
    44  	pubKeys := make([]string, 0)
    45  	for len(content) > 0 {
    46  		// marker, hosts, pubkey, comment, rest, err
    47  		_, _, pubkey, _, remaining, err := ssh.ParseKnownHosts(content)
    48  		if err != nil {
    49  			panic(fmt.Sprintf("failed while parsing known hosts: %q %v", fname, err))
    50  		}
    51  		content = remaining
    52  		// We convert the "known_hosts" format into AuthorizedKeys format to
    53  		// match what Juju records.
    54  		pubKeys = append(pubKeys, string(ssh.MarshalAuthorizedKey(pubkey)))
    55  	}
    56  	return pubKeys
    57  }
    58  
    59  var logger = loggo.GetLogger("juju.ssh_keyscan")
    60  
    61  func main() {
    62  	var verbose bool
    63  	var dialTimeout int = 500
    64  	var waitTimeout int = 5000
    65  	var hostFile string
    66  	gnuflag.BoolVar(&verbose, "v", false, "dump debugging information")
    67  	gnuflag.IntVar(&dialTimeout, "dial-timeout", 500, "time to try a single connection (in milliseconds)")
    68  	gnuflag.IntVar(&waitTimeout, "wait-timeout", 5000, "overall time to wait for answers (in milliseconds)")
    69  	gnuflag.StringVar(&hostFile, "known-hosts", knownHostFilename(), "point to an alternate known-hosts file")
    70  	gnuflag.Parse(true)
    71  	if verbose {
    72  		loggo.ConfigureLoggers(`<root>=DEBUG`)
    73  	}
    74  	args := gnuflag.Args()
    75  	pubKeys := getKnownHostKeys(hostFile)
    76  	hostPorts := make([]network.HostPort, 0, len(args))
    77  	for _, arg := range args {
    78  		if strings.Index(arg, ":") < 0 {
    79  			// Not valid for IPv6, but good enough for testing
    80  			arg = arg + ":22"
    81  		}
    82  		hp, err := network.ParseHostPort(arg)
    83  		if err != nil {
    84  			fmt.Fprintf(os.Stderr, "invalid host:port value: %v\n%v\n", arg, err)
    85  			return
    86  		}
    87  		hostPorts = append(hostPorts, *hp)
    88  	}
    89  	logger.Infof("host ports: %v\n", hostPorts)
    90  	logger.Infof("found %d known hosts\n", len(pubKeys))
    91  	logger.Debugf("known hosts: %v\n", pubKeys)
    92  	dialer := &net.Dialer{Timeout: time.Duration(dialTimeout) * time.Millisecond}
    93  	checker := jujussh.NewReachableChecker(dialer, time.Duration(waitTimeout)*time.Millisecond)
    94  	found, err := checker.FindHost(hostPorts, pubKeys)
    95  	if err != nil {
    96  		fmt.Fprintf(os.Stderr, "could not find valid host: %v\n", err)
    97  		return
    98  	}
    99  	fmt.Printf("%v\n", found)
   100  }