github.com/sberex/go-sberex@v1.8.2-0.20181113200658-ed96ac38f7d7/cmd/puppeth/ssh.go (about)

     1  // This file is part of the go-sberex library. The go-sberex library is 
     2  // free software: you can redistribute it and/or modify it under the terms 
     3  // of the GNU Lesser General Public License as published by the Free 
     4  // Software Foundation, either version 3 of the License, or (at your option)
     5  // any later version.
     6  //
     7  // The go-sberex library is distributed in the hope that it will be useful, 
     8  // but WITHOUT ANY WARRANTY; without even the implied warranty of
     9  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser 
    10  // General Public License <http://www.gnu.org/licenses/> for more details.
    11  
    12  package main
    13  
    14  import (
    15  	"bufio"
    16  	"bytes"
    17  	"errors"
    18  	"fmt"
    19  	"io/ioutil"
    20  	"net"
    21  	"os"
    22  	"os/user"
    23  	"path/filepath"
    24  	"strings"
    25  
    26  	"github.com/Sberex/go-sberex/log"
    27  	"golang.org/x/crypto/ssh"
    28  	"golang.org/x/crypto/ssh/terminal"
    29  )
    30  
    31  // sshClient is a small wrapper around Go's SSH client with a few utility methods
    32  // implemented on top.
    33  type sshClient struct {
    34  	server  string // Server name or IP without port number
    35  	address string // IP address of the remote server
    36  	pubkey  []byte // RSA public key to authenticate the server
    37  	client  *ssh.Client
    38  	logger  log.Logger
    39  }
    40  
    41  // dial establishes an SSH connection to a remote node using the current user and
    42  // the user's configured private RSA key. If that fails, password authentication
    43  // is fallen back to. The caller may override the login user via user@server:port.
    44  func dial(server string, pubkey []byte) (*sshClient, error) {
    45  	// Figure out a label for the server and a logger
    46  	label := server
    47  	if strings.Contains(label, ":") {
    48  		label = label[:strings.Index(label, ":")]
    49  	}
    50  	login := ""
    51  	if strings.Contains(server, "@") {
    52  		login = label[:strings.Index(label, "@")]
    53  		label = label[strings.Index(label, "@")+1:]
    54  		server = server[strings.Index(server, "@")+1:]
    55  	}
    56  	logger := log.New("server", label)
    57  	logger.Debug("Attempting to establish SSH connection")
    58  
    59  	user, err := user.Current()
    60  	if err != nil {
    61  		return nil, err
    62  	}
    63  	if login == "" {
    64  		login = user.Username
    65  	}
    66  	// Configure the supported authentication methods (private key and password)
    67  	var auths []ssh.AuthMethod
    68  
    69  	path := filepath.Join(user.HomeDir, ".ssh", "id_rsa")
    70  	if buf, err := ioutil.ReadFile(path); err != nil {
    71  		log.Warn("No SSH key, falling back to passwords", "path", path, "err", err)
    72  	} else {
    73  		key, err := ssh.ParsePrivateKey(buf)
    74  		if err != nil {
    75  			fmt.Printf("What's the decryption password for %s? (won't be echoed)\n>", path)
    76  			blob, err := terminal.ReadPassword(int(os.Stdin.Fd()))
    77  			fmt.Println()
    78  			if err != nil {
    79  				log.Warn("Couldn't read password", "err", err)
    80  			}
    81  			key, err := ssh.ParsePrivateKeyWithPassphrase(buf, blob)
    82  			if err != nil {
    83  				log.Warn("Failed to decrypt SSH key, falling back to passwords", "path", path, "err", err)
    84  			} else {
    85  				auths = append(auths, ssh.PublicKeys(key))
    86  			}
    87  		} else {
    88  			auths = append(auths, ssh.PublicKeys(key))
    89  		}
    90  	}
    91  	auths = append(auths, ssh.PasswordCallback(func() (string, error) {
    92  		fmt.Printf("What's the login password for %s at %s? (won't be echoed)\n> ", login, server)
    93  		blob, err := terminal.ReadPassword(int(os.Stdin.Fd()))
    94  
    95  		fmt.Println()
    96  		return string(blob), err
    97  	}))
    98  	// Resolve the IP address of the remote server
    99  	addr, err := net.LookupHost(label)
   100  	if err != nil {
   101  		return nil, err
   102  	}
   103  	if len(addr) == 0 {
   104  		return nil, errors.New("no IPs associated with domain")
   105  	}
   106  	// Try to dial in to the remote server
   107  	logger.Trace("Dialing remote SSH server", "user", login)
   108  	if !strings.Contains(server, ":") {
   109  		server += ":22"
   110  	}
   111  	keycheck := func(hostname string, remote net.Addr, key ssh.PublicKey) error {
   112  		// If no public key is known for SSH, ask the user to confirm
   113  		if pubkey == nil {
   114  			fmt.Println()
   115  			fmt.Printf("The authenticity of host '%s (%s)' can't be established.\n", hostname, remote)
   116  			fmt.Printf("SSH key fingerprint is %s [MD5]\n", ssh.FingerprintLegacyMD5(key))
   117  			fmt.Printf("Are you sure you want to continue connecting (yes/no)? ")
   118  
   119  			text, err := bufio.NewReader(os.Stdin).ReadString('\n')
   120  			switch {
   121  			case err != nil:
   122  				return err
   123  			case strings.TrimSpace(text) == "yes":
   124  				pubkey = key.Marshal()
   125  				return nil
   126  			default:
   127  				return fmt.Errorf("unknown auth choice: %v", text)
   128  			}
   129  		}
   130  		// If a public key exists for this SSH server, check that it matches
   131  		if bytes.Equal(pubkey, key.Marshal()) {
   132  			return nil
   133  		}
   134  		// We have a mismatch, forbid connecting
   135  		return errors.New("ssh key mismatch, readd the machine to update")
   136  	}
   137  	client, err := ssh.Dial("tcp", server, &ssh.ClientConfig{User: login, Auth: auths, HostKeyCallback: keycheck})
   138  	if err != nil {
   139  		return nil, err
   140  	}
   141  	// Connection established, return our utility wrapper
   142  	c := &sshClient{
   143  		server:  label,
   144  		address: addr[0],
   145  		pubkey:  pubkey,
   146  		client:  client,
   147  		logger:  logger,
   148  	}
   149  	if err := c.init(); err != nil {
   150  		client.Close()
   151  		return nil, err
   152  	}
   153  	return c, nil
   154  }
   155  
   156  // init runs some initialization commands on the remote server to ensure it's
   157  // capable of acting as puppeth target.
   158  func (client *sshClient) init() error {
   159  	client.logger.Debug("Verifying if docker is available")
   160  	if out, err := client.Run("docker version"); err != nil {
   161  		if len(out) == 0 {
   162  			return err
   163  		}
   164  		return fmt.Errorf("docker configured incorrectly: %s", out)
   165  	}
   166  	client.logger.Debug("Verifying if docker-compose is available")
   167  	if out, err := client.Run("docker-compose version"); err != nil {
   168  		if len(out) == 0 {
   169  			return err
   170  		}
   171  		return fmt.Errorf("docker-compose configured incorrectly: %s", out)
   172  	}
   173  	return nil
   174  }
   175  
   176  // Close terminates the connection to an SSH server.
   177  func (client *sshClient) Close() error {
   178  	return client.client.Close()
   179  }
   180  
   181  // Run executes a command on the remote server and returns the combined output
   182  // along with any error status.
   183  func (client *sshClient) Run(cmd string) ([]byte, error) {
   184  	// Establish a single command session
   185  	session, err := client.client.NewSession()
   186  	if err != nil {
   187  		return nil, err
   188  	}
   189  	defer session.Close()
   190  
   191  	// Execute the command and return any output
   192  	client.logger.Trace("Running command on remote server", "cmd", cmd)
   193  	return session.CombinedOutput(cmd)
   194  }
   195  
   196  // Stream executes a command on the remote server and streams all outputs into
   197  // the local stdout and stderr streams.
   198  func (client *sshClient) Stream(cmd string) error {
   199  	// Establish a single command session
   200  	session, err := client.client.NewSession()
   201  	if err != nil {
   202  		return err
   203  	}
   204  	defer session.Close()
   205  
   206  	session.Stdout = os.Stdout
   207  	session.Stderr = os.Stderr
   208  
   209  	// Execute the command and return any output
   210  	client.logger.Trace("Streaming command on remote server", "cmd", cmd)
   211  	return session.Run(cmd)
   212  }
   213  
   214  // Upload copies the set of files to a remote server via SCP, creating any non-
   215  // existing folders in the mean time.
   216  func (client *sshClient) Upload(files map[string][]byte) ([]byte, error) {
   217  	// Establish a single command session
   218  	session, err := client.client.NewSession()
   219  	if err != nil {
   220  		return nil, err
   221  	}
   222  	defer session.Close()
   223  
   224  	// Create a goroutine that streams the SCP content
   225  	go func() {
   226  		out, _ := session.StdinPipe()
   227  		defer out.Close()
   228  
   229  		for file, content := range files {
   230  			client.logger.Trace("Uploading file to server", "file", file, "bytes", len(content))
   231  
   232  			fmt.Fprintln(out, "D0755", 0, filepath.Dir(file))             // Ensure the folder exists
   233  			fmt.Fprintln(out, "C0644", len(content), filepath.Base(file)) // Create the actual file
   234  			out.Write(content)                                            // Stream the data content
   235  			fmt.Fprint(out, "\x00")                                       // Transfer end with \x00
   236  			fmt.Fprintln(out, "E")                                        // Leave directory (simpler)
   237  		}
   238  	}()
   239  	return session.CombinedOutput("/usr/bin/scp -v -tr ./")
   240  }