github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/cmd/roachprod/ssh/ssh.go (about) 1 // Copyright 2018 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package ssh 12 13 import ( 14 "fmt" 15 "io" 16 "io/ioutil" 17 "log" 18 "net" 19 "os" 20 "path/filepath" 21 "strings" 22 "sync" 23 "time" 24 25 "github.com/cockroachdb/cockroach/pkg/cmd/roachprod/config" 26 "github.com/cockroachdb/cockroach/pkg/util/syncutil" 27 "github.com/cockroachdb/errors" 28 "golang.org/x/crypto/ssh" 29 "golang.org/x/crypto/ssh/agent" 30 "golang.org/x/crypto/ssh/knownhosts" 31 ) 32 33 var knownHosts ssh.HostKeyCallback 34 var knownHostsOnce sync.Once 35 36 // InsecureIgnoreHostKey TODO(peter): document 37 var InsecureIgnoreHostKey bool 38 39 func getKnownHosts() ssh.HostKeyCallback { 40 knownHostsOnce.Do(func() { 41 var err error 42 if InsecureIgnoreHostKey { 43 knownHosts = ssh.InsecureIgnoreHostKey() 44 } else { 45 knownHosts, err = knownhosts.New(filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts")) 46 if err != nil { 47 log.Fatal(err) 48 } 49 } 50 }) 51 return knownHosts 52 } 53 54 func getSSHAgentSigners() []ssh.Signer { 55 const authSockEnv = "SSH_AUTH_SOCK" 56 agentSocket := os.Getenv(authSockEnv) 57 if agentSocket == "" { 58 return nil 59 } 60 sock, err := net.Dial("unix", agentSocket) 61 if err != nil { 62 log.Printf("SSH_AUTH_SOCK set but unable to connect to agent: %s", err) 63 return nil 64 } 65 agent := agent.NewClient(sock) 66 signers, err := agent.Signers() 67 if err != nil { 68 log.Printf("unable to retrieve keys from agent: %s", err) 69 return nil 70 } 71 return signers 72 } 73 74 func getSSHKeySigner(path string, haveAgent bool) ssh.Signer { 75 key, err := ioutil.ReadFile(path) 76 if err != nil { 77 if !os.IsNotExist(err) { 78 log.Printf("unable to read SSH key %q: %s", path, err) 79 } 80 return nil 81 } 82 83 signer, err := ssh.ParsePrivateKey(key) 84 if err != nil { 85 if strings.Contains(err.Error(), "cannot decode encrypted private key") { 86 if !haveAgent { 87 log.Printf( 88 "skipping encrypted SSH key %q; if necessary, add the key to your SSH agent", path) 89 } 90 } else { 91 log.Printf("unable to parse SSH key %q: %s", path, err) 92 } 93 return nil 94 } 95 return signer 96 } 97 98 func getDefaultSSHKeySigners(haveAgent bool) []ssh.Signer { 99 var signers []ssh.Signer 100 for _, name := range []string{"id_rsa", "google_compute_engine"} { 101 s := getSSHKeySigner(filepath.Join(config.OSUser.HomeDir, ".ssh", name), haveAgent) 102 if s != nil { 103 signers = append(signers, s) 104 } 105 } 106 return signers 107 } 108 109 func newSSHClient(user, host string) (*ssh.Client, net.Conn, error) { 110 config := &ssh.ClientConfig{ 111 User: user, 112 Auth: []ssh.AuthMethod{ssh.PublicKeys(sshState.signers...)}, 113 HostKeyCallback: getKnownHosts(), 114 } 115 config.SetDefaults() 116 117 addr := fmt.Sprintf("%s:22", host) 118 conn, err := net.DialTimeout("tcp", addr, 30*time.Second) 119 if err != nil { 120 return nil, nil, err 121 } 122 c, chans, reqs, err := ssh.NewClientConn(conn, addr, config) 123 if err != nil { 124 return nil, nil, err 125 } 126 return ssh.NewClient(c, chans, reqs), conn, nil 127 } 128 129 type sshClient struct { 130 syncutil.Mutex 131 *ssh.Client 132 } 133 134 var sshState = struct { 135 signers []ssh.Signer 136 signersInit sync.Once 137 138 clients map[string]*sshClient 139 clientMu syncutil.Mutex 140 }{ 141 clients: map[string]*sshClient{}, 142 } 143 144 // NewSSHSession TODO(peter): document 145 func NewSSHSession(user, host string) (*ssh.Session, error) { 146 if host == "127.0.0.1" || host == "localhost" { 147 return nil, errors.New("unable to ssh to localhost; file a bug") 148 } 149 150 sshState.clientMu.Lock() 151 target := fmt.Sprintf("%s@%s", user, host) 152 client := sshState.clients[target] 153 if client == nil { 154 client = &sshClient{} 155 sshState.clients[target] = client 156 } 157 sshState.clientMu.Unlock() 158 159 sshState.signersInit.Do(func() { 160 sshState.signers = append(sshState.signers, getSSHAgentSigners()...) 161 haveAgentSigner := len(sshState.signers) > 0 162 sshState.signers = append(sshState.signers, getDefaultSSHKeySigners(haveAgentSigner)...) 163 }) 164 165 client.Lock() 166 defer client.Unlock() 167 if client.Client == nil { 168 var err error 169 client.Client, _, err = newSSHClient(user, host) 170 if err != nil { 171 return nil, err 172 } 173 } 174 return client.NewSession() 175 } 176 177 // ProgressWriter TODO(peter): document 178 type ProgressWriter struct { 179 Writer io.Writer 180 Done int64 181 Total int64 182 Progress func(float64) 183 } 184 185 func (p *ProgressWriter) Write(b []byte) (int, error) { 186 n, err := p.Writer.Write(b) 187 if err == nil { 188 p.Done += int64(n) 189 p.Progress(float64(p.Done) / float64(p.Total)) 190 } 191 return n, err 192 }