github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/cmd/roachprod/ssh/ssh.go (about)

     1  // Copyright 2018 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package ssh
    12  
    13  import (
    14  	"fmt"
    15  	"io"
    16  	"io/ioutil"
    17  	"log"
    18  	"net"
    19  	"os"
    20  	"path/filepath"
    21  	"strings"
    22  	"sync"
    23  	"time"
    24  
    25  	"github.com/cockroachdb/cockroach/pkg/cmd/roachprod/config"
    26  	"github.com/cockroachdb/cockroach/pkg/util/syncutil"
    27  	"github.com/cockroachdb/errors"
    28  	"golang.org/x/crypto/ssh"
    29  	"golang.org/x/crypto/ssh/agent"
    30  	"golang.org/x/crypto/ssh/knownhosts"
    31  )
    32  
    33  var knownHosts ssh.HostKeyCallback
    34  var knownHostsOnce sync.Once
    35  
    36  // InsecureIgnoreHostKey TODO(peter): document
    37  var InsecureIgnoreHostKey bool
    38  
    39  func getKnownHosts() ssh.HostKeyCallback {
    40  	knownHostsOnce.Do(func() {
    41  		var err error
    42  		if InsecureIgnoreHostKey {
    43  			knownHosts = ssh.InsecureIgnoreHostKey()
    44  		} else {
    45  			knownHosts, err = knownhosts.New(filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts"))
    46  			if err != nil {
    47  				log.Fatal(err)
    48  			}
    49  		}
    50  	})
    51  	return knownHosts
    52  }
    53  
    54  func getSSHAgentSigners() []ssh.Signer {
    55  	const authSockEnv = "SSH_AUTH_SOCK"
    56  	agentSocket := os.Getenv(authSockEnv)
    57  	if agentSocket == "" {
    58  		return nil
    59  	}
    60  	sock, err := net.Dial("unix", agentSocket)
    61  	if err != nil {
    62  		log.Printf("SSH_AUTH_SOCK set but unable to connect to agent: %s", err)
    63  		return nil
    64  	}
    65  	agent := agent.NewClient(sock)
    66  	signers, err := agent.Signers()
    67  	if err != nil {
    68  		log.Printf("unable to retrieve keys from agent: %s", err)
    69  		return nil
    70  	}
    71  	return signers
    72  }
    73  
    74  func getSSHKeySigner(path string, haveAgent bool) ssh.Signer {
    75  	key, err := ioutil.ReadFile(path)
    76  	if err != nil {
    77  		if !os.IsNotExist(err) {
    78  			log.Printf("unable to read SSH key %q: %s", path, err)
    79  		}
    80  		return nil
    81  	}
    82  
    83  	signer, err := ssh.ParsePrivateKey(key)
    84  	if err != nil {
    85  		if strings.Contains(err.Error(), "cannot decode encrypted private key") {
    86  			if !haveAgent {
    87  				log.Printf(
    88  					"skipping encrypted SSH key %q; if necessary, add the key to your SSH agent", path)
    89  			}
    90  		} else {
    91  			log.Printf("unable to parse SSH key %q: %s", path, err)
    92  		}
    93  		return nil
    94  	}
    95  	return signer
    96  }
    97  
    98  func getDefaultSSHKeySigners(haveAgent bool) []ssh.Signer {
    99  	var signers []ssh.Signer
   100  	for _, name := range []string{"id_rsa", "google_compute_engine"} {
   101  		s := getSSHKeySigner(filepath.Join(config.OSUser.HomeDir, ".ssh", name), haveAgent)
   102  		if s != nil {
   103  			signers = append(signers, s)
   104  		}
   105  	}
   106  	return signers
   107  }
   108  
   109  func newSSHClient(user, host string) (*ssh.Client, net.Conn, error) {
   110  	config := &ssh.ClientConfig{
   111  		User:            user,
   112  		Auth:            []ssh.AuthMethod{ssh.PublicKeys(sshState.signers...)},
   113  		HostKeyCallback: getKnownHosts(),
   114  	}
   115  	config.SetDefaults()
   116  
   117  	addr := fmt.Sprintf("%s:22", host)
   118  	conn, err := net.DialTimeout("tcp", addr, 30*time.Second)
   119  	if err != nil {
   120  		return nil, nil, err
   121  	}
   122  	c, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
   123  	if err != nil {
   124  		return nil, nil, err
   125  	}
   126  	return ssh.NewClient(c, chans, reqs), conn, nil
   127  }
   128  
   129  type sshClient struct {
   130  	syncutil.Mutex
   131  	*ssh.Client
   132  }
   133  
   134  var sshState = struct {
   135  	signers     []ssh.Signer
   136  	signersInit sync.Once
   137  
   138  	clients  map[string]*sshClient
   139  	clientMu syncutil.Mutex
   140  }{
   141  	clients: map[string]*sshClient{},
   142  }
   143  
   144  // NewSSHSession TODO(peter): document
   145  func NewSSHSession(user, host string) (*ssh.Session, error) {
   146  	if host == "127.0.0.1" || host == "localhost" {
   147  		return nil, errors.New("unable to ssh to localhost; file a bug")
   148  	}
   149  
   150  	sshState.clientMu.Lock()
   151  	target := fmt.Sprintf("%s@%s", user, host)
   152  	client := sshState.clients[target]
   153  	if client == nil {
   154  		client = &sshClient{}
   155  		sshState.clients[target] = client
   156  	}
   157  	sshState.clientMu.Unlock()
   158  
   159  	sshState.signersInit.Do(func() {
   160  		sshState.signers = append(sshState.signers, getSSHAgentSigners()...)
   161  		haveAgentSigner := len(sshState.signers) > 0
   162  		sshState.signers = append(sshState.signers, getDefaultSSHKeySigners(haveAgentSigner)...)
   163  	})
   164  
   165  	client.Lock()
   166  	defer client.Unlock()
   167  	if client.Client == nil {
   168  		var err error
   169  		client.Client, _, err = newSSHClient(user, host)
   170  		if err != nil {
   171  			return nil, err
   172  		}
   173  	}
   174  	return client.NewSession()
   175  }
   176  
   177  // ProgressWriter TODO(peter): document
   178  type ProgressWriter struct {
   179  	Writer   io.Writer
   180  	Done     int64
   181  	Total    int64
   182  	Progress func(float64)
   183  }
   184  
   185  func (p *ProgressWriter) Write(b []byte) (int, error) {
   186  	n, err := p.Writer.Write(b)
   187  	if err == nil {
   188  		p.Done += int64(n)
   189  		p.Progress(float64(p.Done) / float64(p.Total))
   190  	}
   191  	return n, err
   192  }