github.com/aiyaya188/klaytn@v0.0.0-20220629133911-2c66fd5546f4/cmd/puppeth/ssh.go (about)

     1  // Copyright 2017 The go-ethereum Authors
     2  // This file is part of go-ethereum.
     3  //
     4  // go-ethereum is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // go-ethereum is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU General Public License
    15  // along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package main
    18  
    19  import (
    20  	"bufio"
    21  	"bytes"
    22  	"errors"
    23  	"fmt"
    24  	"net"
    25  	"os"
    26  	"os/user"
    27  	"path/filepath"
    28  	"strings"
    29  
    30  	"github.com/aiyaya188/klaytn/log"
    31  	"golang.org/x/crypto/ssh"
    32  	"golang.org/x/crypto/ssh/agent"
    33  	"golang.org/x/term"
    34  )
    35  
    36  // sshClient is a small wrapper around Go's SSH client with a few utility methods
    37  // implemented on top.
    38  type sshClient struct {
    39  	server  string // Server name or IP without port number
    40  	address string // IP address of the remote server
    41  	pubkey  []byte // RSA public key to authenticate the server
    42  	client  *ssh.Client
    43  	logger  log.Logger
    44  }
    45  
    46  const EnvSSHAuthSock = "SSH_AUTH_SOCK"
    47  
    48  // dial establishes an SSH connection to a remote node using the current user and
    49  // the user's configured private RSA key. If that fails, password authentication
    50  // is fallen back to. server can be a string like user:identity@server:port.
    51  func dial(server string, pubkey []byte) (*sshClient, error) {
    52  	// Figure out username, identity, hostname and port
    53  	hostname := ""
    54  	hostport := server
    55  	username := ""
    56  	identity := "id_rsa" // default
    57  
    58  	if strings.Contains(server, "@") {
    59  		prefix := server[:strings.Index(server, "@")]
    60  		if strings.Contains(prefix, ":") {
    61  			username = prefix[:strings.Index(prefix, ":")]
    62  			identity = prefix[strings.Index(prefix, ":")+1:]
    63  		} else {
    64  			username = prefix
    65  		}
    66  		hostport = server[strings.Index(server, "@")+1:]
    67  	}
    68  	if strings.Contains(hostport, ":") {
    69  		hostname = hostport[:strings.Index(hostport, ":")]
    70  	} else {
    71  		hostname = hostport
    72  		hostport += ":22"
    73  	}
    74  	logger := log.New("server", server)
    75  	logger.Debug("Attempting to establish SSH connection")
    76  
    77  	user, err := user.Current()
    78  	if err != nil {
    79  		return nil, err
    80  	}
    81  	if username == "" {
    82  		username = user.Username
    83  	}
    84  
    85  	// Configure the supported authentication methods (ssh agent, private key and password)
    86  	var (
    87  		auths []ssh.AuthMethod
    88  		conn  net.Conn
    89  	)
    90  	if conn, err = net.Dial("unix", os.Getenv(EnvSSHAuthSock)); err != nil {
    91  		log.Warn("Unable to dial SSH agent, falling back to private keys", "err", err)
    92  	} else {
    93  		client := agent.NewClient(conn)
    94  		auths = append(auths, ssh.PublicKeysCallback(client.Signers))
    95  	}
    96  	if err != nil {
    97  		path := filepath.Join(user.HomeDir, ".ssh", identity)
    98  		if buf, err := os.ReadFile(path); err != nil {
    99  			log.Warn("No SSH key, falling back to passwords", "path", path, "err", err)
   100  		} else {
   101  			key, err := ssh.ParsePrivateKey(buf)
   102  			if err != nil {
   103  				fmt.Printf("What's the decryption password for %s? (won't be echoed)\n>", path)
   104  				blob, err := term.ReadPassword(int(os.Stdin.Fd()))
   105  				fmt.Println()
   106  				if err != nil {
   107  					log.Warn("Couldn't read password", "err", err)
   108  				}
   109  				key, err := ssh.ParsePrivateKeyWithPassphrase(buf, blob)
   110  				if err != nil {
   111  					log.Warn("Failed to decrypt SSH key, falling back to passwords", "path", path, "err", err)
   112  				} else {
   113  					auths = append(auths, ssh.PublicKeys(key))
   114  				}
   115  			} else {
   116  				auths = append(auths, ssh.PublicKeys(key))
   117  			}
   118  		}
   119  		auths = append(auths, ssh.PasswordCallback(func() (string, error) {
   120  			fmt.Printf("What's the login password for %s at %s? (won't be echoed)\n> ", username, server)
   121  			blob, err := term.ReadPassword(int(os.Stdin.Fd()))
   122  
   123  			fmt.Println()
   124  			return string(blob), err
   125  		}))
   126  	}
   127  	// Resolve the IP address of the remote server
   128  	addr, err := net.LookupHost(hostname)
   129  	if err != nil {
   130  		return nil, err
   131  	}
   132  	if len(addr) == 0 {
   133  		return nil, errors.New("no IPs associated with domain")
   134  	}
   135  	// Try to dial in to the remote server
   136  	logger.Trace("Dialing remote SSH server", "user", username)
   137  	keycheck := func(hostname string, remote net.Addr, key ssh.PublicKey) error {
   138  		// If no public key is known for SSH, ask the user to confirm
   139  		if pubkey == nil {
   140  			fmt.Println()
   141  			fmt.Printf("The authenticity of host '%s (%s)' can't be established.\n", hostname, remote)
   142  			fmt.Printf("SSH key fingerprint is %s [MD5]\n", ssh.FingerprintLegacyMD5(key))
   143  			fmt.Printf("Are you sure you want to continue connecting (yes/no)? ")
   144  
   145  			for {
   146  				text, err := bufio.NewReader(os.Stdin).ReadString('\n')
   147  				switch {
   148  				case err != nil:
   149  					return err
   150  				case strings.TrimSpace(text) == "yes":
   151  					pubkey = key.Marshal()
   152  					return nil
   153  				case strings.TrimSpace(text) == "no":
   154  					return errors.New("users says no")
   155  				default:
   156  					fmt.Println("Please answer 'yes' or 'no'")
   157  					continue
   158  				}
   159  			}
   160  		}
   161  		// If a public key exists for this SSH server, check that it matches
   162  		if bytes.Equal(pubkey, key.Marshal()) {
   163  			return nil
   164  		}
   165  		// We have a mismatch, forbid connecting
   166  		return errors.New("ssh key mismatch, readd the machine to update")
   167  	}
   168  	client, err := ssh.Dial("tcp", hostport, &ssh.ClientConfig{User: username, Auth: auths, HostKeyCallback: keycheck})
   169  	if err != nil {
   170  		return nil, err
   171  	}
   172  	// Connection established, return our utility wrapper
   173  	c := &sshClient{
   174  		server:  hostname,
   175  		address: addr[0],
   176  		pubkey:  pubkey,
   177  		client:  client,
   178  		logger:  logger,
   179  	}
   180  	if err := c.init(); err != nil {
   181  		client.Close()
   182  		return nil, err
   183  	}
   184  	return c, nil
   185  }
   186  
   187  // init runs some initialization commands on the remote server to ensure it's
   188  // capable of acting as puppeth target.
   189  func (client *sshClient) init() error {
   190  	client.logger.Debug("Verifying if docker is available")
   191  	if out, err := client.Run("docker version"); err != nil {
   192  		if len(out) == 0 {
   193  			return err
   194  		}
   195  		return fmt.Errorf("docker configured incorrectly: %s", out)
   196  	}
   197  	client.logger.Debug("Verifying if docker-compose is available")
   198  	if out, err := client.Run("docker-compose version"); err != nil {
   199  		if len(out) == 0 {
   200  			return err
   201  		}
   202  		return fmt.Errorf("docker-compose configured incorrectly: %s", out)
   203  	}
   204  	return nil
   205  }
   206  
   207  // Close terminates the connection to an SSH server.
   208  func (client *sshClient) Close() error {
   209  	return client.client.Close()
   210  }
   211  
   212  // Run executes a command on the remote server and returns the combined output
   213  // along with any error status.
   214  func (client *sshClient) Run(cmd string) ([]byte, error) {
   215  	// Establish a single command session
   216  	session, err := client.client.NewSession()
   217  	if err != nil {
   218  		return nil, err
   219  	}
   220  	defer session.Close()
   221  
   222  	// Execute the command and return any output
   223  	client.logger.Trace("Running command on remote server", "cmd", cmd)
   224  	return session.CombinedOutput(cmd)
   225  }
   226  
   227  // Stream executes a command on the remote server and streams all outputs into
   228  // the local stdout and stderr streams.
   229  func (client *sshClient) Stream(cmd string) error {
   230  	// Establish a single command session
   231  	session, err := client.client.NewSession()
   232  	if err != nil {
   233  		return err
   234  	}
   235  	defer session.Close()
   236  
   237  	session.Stdout = os.Stdout
   238  	session.Stderr = os.Stderr
   239  
   240  	// Execute the command and return any output
   241  	client.logger.Trace("Streaming command on remote server", "cmd", cmd)
   242  	return session.Run(cmd)
   243  }
   244  
   245  // Upload copies the set of files to a remote server via SCP, creating any non-
   246  // existing folders in the mean time.
   247  func (client *sshClient) Upload(files map[string][]byte) ([]byte, error) {
   248  	// Establish a single command session
   249  	session, err := client.client.NewSession()
   250  	if err != nil {
   251  		return nil, err
   252  	}
   253  	defer session.Close()
   254  
   255  	// Create a goroutine that streams the SCP content
   256  	go func() {
   257  		out, _ := session.StdinPipe()
   258  		defer out.Close()
   259  
   260  		for file, content := range files {
   261  			client.logger.Trace("Uploading file to server", "file", file, "bytes", len(content))
   262  
   263  			fmt.Fprintln(out, "D0755", 0, filepath.Dir(file))             // Ensure the folder exists
   264  			fmt.Fprintln(out, "C0644", len(content), filepath.Base(file)) // Create the actual file
   265  			out.Write(content)                                            // Stream the data content
   266  			fmt.Fprint(out, "\x00")                                       // Transfer end with \x00
   267  			fmt.Fprintln(out, "E")                                        // Leave directory (simpler)
   268  		}
   269  	}()
   270  	return session.CombinedOutput("/usr/bin/scp -v -tr ./")
   271  }