github.com/sberex/go-sberex@v1.8.2-0.20181113200658-ed96ac38f7d7/cmd/puppeth/ssh.go (about) 1 // This file is part of the go-sberex library. The go-sberex library is 2 // free software: you can redistribute it and/or modify it under the terms 3 // of the GNU Lesser General Public License as published by the Free 4 // Software Foundation, either version 3 of the License, or (at your option) 5 // any later version. 6 // 7 // The go-sberex library is distributed in the hope that it will be useful, 8 // but WITHOUT ANY WARRANTY; without even the implied warranty of 9 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser 10 // General Public License <http://www.gnu.org/licenses/> for more details. 11 12 package main 13 14 import ( 15 "bufio" 16 "bytes" 17 "errors" 18 "fmt" 19 "io/ioutil" 20 "net" 21 "os" 22 "os/user" 23 "path/filepath" 24 "strings" 25 26 "github.com/Sberex/go-sberex/log" 27 "golang.org/x/crypto/ssh" 28 "golang.org/x/crypto/ssh/terminal" 29 ) 30 31 // sshClient is a small wrapper around Go's SSH client with a few utility methods 32 // implemented on top. 33 type sshClient struct { 34 server string // Server name or IP without port number 35 address string // IP address of the remote server 36 pubkey []byte // RSA public key to authenticate the server 37 client *ssh.Client 38 logger log.Logger 39 } 40 41 // dial establishes an SSH connection to a remote node using the current user and 42 // the user's configured private RSA key. If that fails, password authentication 43 // is fallen back to. The caller may override the login user via user@server:port. 44 func dial(server string, pubkey []byte) (*sshClient, error) { 45 // Figure out a label for the server and a logger 46 label := server 47 if strings.Contains(label, ":") { 48 label = label[:strings.Index(label, ":")] 49 } 50 login := "" 51 if strings.Contains(server, "@") { 52 login = label[:strings.Index(label, "@")] 53 label = label[strings.Index(label, "@")+1:] 54 server = server[strings.Index(server, "@")+1:] 55 } 56 logger := log.New("server", label) 57 logger.Debug("Attempting to establish SSH connection") 58 59 user, err := user.Current() 60 if err != nil { 61 return nil, err 62 } 63 if login == "" { 64 login = user.Username 65 } 66 // Configure the supported authentication methods (private key and password) 67 var auths []ssh.AuthMethod 68 69 path := filepath.Join(user.HomeDir, ".ssh", "id_rsa") 70 if buf, err := ioutil.ReadFile(path); err != nil { 71 log.Warn("No SSH key, falling back to passwords", "path", path, "err", err) 72 } else { 73 key, err := ssh.ParsePrivateKey(buf) 74 if err != nil { 75 fmt.Printf("What's the decryption password for %s? (won't be echoed)\n>", path) 76 blob, err := terminal.ReadPassword(int(os.Stdin.Fd())) 77 fmt.Println() 78 if err != nil { 79 log.Warn("Couldn't read password", "err", err) 80 } 81 key, err := ssh.ParsePrivateKeyWithPassphrase(buf, blob) 82 if err != nil { 83 log.Warn("Failed to decrypt SSH key, falling back to passwords", "path", path, "err", err) 84 } else { 85 auths = append(auths, ssh.PublicKeys(key)) 86 } 87 } else { 88 auths = append(auths, ssh.PublicKeys(key)) 89 } 90 } 91 auths = append(auths, ssh.PasswordCallback(func() (string, error) { 92 fmt.Printf("What's the login password for %s at %s? (won't be echoed)\n> ", login, server) 93 blob, err := terminal.ReadPassword(int(os.Stdin.Fd())) 94 95 fmt.Println() 96 return string(blob), err 97 })) 98 // Resolve the IP address of the remote server 99 addr, err := net.LookupHost(label) 100 if err != nil { 101 return nil, err 102 } 103 if len(addr) == 0 { 104 return nil, errors.New("no IPs associated with domain") 105 } 106 // Try to dial in to the remote server 107 logger.Trace("Dialing remote SSH server", "user", login) 108 if !strings.Contains(server, ":") { 109 server += ":22" 110 } 111 keycheck := func(hostname string, remote net.Addr, key ssh.PublicKey) error { 112 // If no public key is known for SSH, ask the user to confirm 113 if pubkey == nil { 114 fmt.Println() 115 fmt.Printf("The authenticity of host '%s (%s)' can't be established.\n", hostname, remote) 116 fmt.Printf("SSH key fingerprint is %s [MD5]\n", ssh.FingerprintLegacyMD5(key)) 117 fmt.Printf("Are you sure you want to continue connecting (yes/no)? ") 118 119 text, err := bufio.NewReader(os.Stdin).ReadString('\n') 120 switch { 121 case err != nil: 122 return err 123 case strings.TrimSpace(text) == "yes": 124 pubkey = key.Marshal() 125 return nil 126 default: 127 return fmt.Errorf("unknown auth choice: %v", text) 128 } 129 } 130 // If a public key exists for this SSH server, check that it matches 131 if bytes.Equal(pubkey, key.Marshal()) { 132 return nil 133 } 134 // We have a mismatch, forbid connecting 135 return errors.New("ssh key mismatch, readd the machine to update") 136 } 137 client, err := ssh.Dial("tcp", server, &ssh.ClientConfig{User: login, Auth: auths, HostKeyCallback: keycheck}) 138 if err != nil { 139 return nil, err 140 } 141 // Connection established, return our utility wrapper 142 c := &sshClient{ 143 server: label, 144 address: addr[0], 145 pubkey: pubkey, 146 client: client, 147 logger: logger, 148 } 149 if err := c.init(); err != nil { 150 client.Close() 151 return nil, err 152 } 153 return c, nil 154 } 155 156 // init runs some initialization commands on the remote server to ensure it's 157 // capable of acting as puppeth target. 158 func (client *sshClient) init() error { 159 client.logger.Debug("Verifying if docker is available") 160 if out, err := client.Run("docker version"); err != nil { 161 if len(out) == 0 { 162 return err 163 } 164 return fmt.Errorf("docker configured incorrectly: %s", out) 165 } 166 client.logger.Debug("Verifying if docker-compose is available") 167 if out, err := client.Run("docker-compose version"); err != nil { 168 if len(out) == 0 { 169 return err 170 } 171 return fmt.Errorf("docker-compose configured incorrectly: %s", out) 172 } 173 return nil 174 } 175 176 // Close terminates the connection to an SSH server. 177 func (client *sshClient) Close() error { 178 return client.client.Close() 179 } 180 181 // Run executes a command on the remote server and returns the combined output 182 // along with any error status. 183 func (client *sshClient) Run(cmd string) ([]byte, error) { 184 // Establish a single command session 185 session, err := client.client.NewSession() 186 if err != nil { 187 return nil, err 188 } 189 defer session.Close() 190 191 // Execute the command and return any output 192 client.logger.Trace("Running command on remote server", "cmd", cmd) 193 return session.CombinedOutput(cmd) 194 } 195 196 // Stream executes a command on the remote server and streams all outputs into 197 // the local stdout and stderr streams. 198 func (client *sshClient) Stream(cmd string) error { 199 // Establish a single command session 200 session, err := client.client.NewSession() 201 if err != nil { 202 return err 203 } 204 defer session.Close() 205 206 session.Stdout = os.Stdout 207 session.Stderr = os.Stderr 208 209 // Execute the command and return any output 210 client.logger.Trace("Streaming command on remote server", "cmd", cmd) 211 return session.Run(cmd) 212 } 213 214 // Upload copies the set of files to a remote server via SCP, creating any non- 215 // existing folders in the mean time. 216 func (client *sshClient) Upload(files map[string][]byte) ([]byte, error) { 217 // Establish a single command session 218 session, err := client.client.NewSession() 219 if err != nil { 220 return nil, err 221 } 222 defer session.Close() 223 224 // Create a goroutine that streams the SCP content 225 go func() { 226 out, _ := session.StdinPipe() 227 defer out.Close() 228 229 for file, content := range files { 230 client.logger.Trace("Uploading file to server", "file", file, "bytes", len(content)) 231 232 fmt.Fprintln(out, "D0755", 0, filepath.Dir(file)) // Ensure the folder exists 233 fmt.Fprintln(out, "C0644", len(content), filepath.Base(file)) // Create the actual file 234 out.Write(content) // Stream the data content 235 fmt.Fprint(out, "\x00") // Transfer end with \x00 236 fmt.Fprintln(out, "E") // Leave directory (simpler) 237 } 238 }() 239 return session.CombinedOutput("/usr/bin/scp -v -tr ./") 240 }