github.com/timstclair/heapster@v0.20.0-alpha1/Godeps/_workspace/src/k8s.io/kubernetes/pkg/util/ssh.go (about)

     1  /*
     2  Copyright 2015 The Kubernetes Authors All rights reserved.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package util
    18  
    19  import (
    20  	"bytes"
    21  	"crypto/rand"
    22  	"crypto/rsa"
    23  	"crypto/x509"
    24  	"encoding/pem"
    25  	"errors"
    26  	"fmt"
    27  	"io"
    28  	"io/ioutil"
    29  	mathrand "math/rand"
    30  	"net"
    31  	"os"
    32  	"time"
    33  
    34  	"github.com/golang/glog"
    35  	"github.com/prometheus/client_golang/prometheus"
    36  	"golang.org/x/crypto/ssh"
    37  )
    38  
    39  var (
    40  	tunnelOpenCounter = prometheus.NewCounter(
    41  		prometheus.CounterOpts{
    42  			Name: "ssh_tunnel_open_count",
    43  			Help: "Counter of ssh tunnel total open attempts",
    44  		},
    45  	)
    46  	tunnelOpenFailCounter = prometheus.NewCounter(
    47  		prometheus.CounterOpts{
    48  			Name: "ssh_tunnel_open_fail_count",
    49  			Help: "Counter of ssh tunnel failed open attempts",
    50  		},
    51  	)
    52  )
    53  
    54  func init() {
    55  	prometheus.MustRegister(tunnelOpenCounter)
    56  	prometheus.MustRegister(tunnelOpenFailCounter)
    57  }
    58  
    59  // TODO: Unit tests for this code, we can spin up a test SSH server with instructions here:
    60  // https://godoc.org/golang.org/x/crypto/ssh#ServerConn
    61  type SSHTunnel struct {
    62  	Config  *ssh.ClientConfig
    63  	Host    string
    64  	SSHPort string
    65  	running bool
    66  	sock    net.Listener
    67  	client  *ssh.Client
    68  }
    69  
    70  func (s *SSHTunnel) copyBytes(out io.Writer, in io.Reader) {
    71  	if _, err := io.Copy(out, in); err != nil {
    72  		glog.Errorf("Error in SSH tunnel: %v", err)
    73  	}
    74  }
    75  
    76  func NewSSHTunnel(user, keyfile, host string) (*SSHTunnel, error) {
    77  	signer, err := MakePrivateKeySignerFromFile(keyfile)
    78  	if err != nil {
    79  		return nil, err
    80  	}
    81  	return makeSSHTunnel(user, signer, host)
    82  }
    83  
    84  func NewSSHTunnelFromBytes(user string, privateKey []byte, host string) (*SSHTunnel, error) {
    85  	signer, err := MakePrivateKeySignerFromBytes(privateKey)
    86  	if err != nil {
    87  		return nil, err
    88  	}
    89  	return makeSSHTunnel(user, signer, host)
    90  }
    91  
    92  func makeSSHTunnel(user string, signer ssh.Signer, host string) (*SSHTunnel, error) {
    93  	config := ssh.ClientConfig{
    94  		User: user,
    95  		Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
    96  	}
    97  	return &SSHTunnel{
    98  		Config:  &config,
    99  		Host:    host,
   100  		SSHPort: "22",
   101  	}, nil
   102  }
   103  
   104  func (s *SSHTunnel) Open() error {
   105  	var err error
   106  	s.client, err = ssh.Dial("tcp", net.JoinHostPort(s.Host, s.SSHPort), s.Config)
   107  	tunnelOpenCounter.Inc()
   108  	if err != nil {
   109  		tunnelOpenFailCounter.Inc()
   110  		return err
   111  	}
   112  	return nil
   113  }
   114  
   115  func (s *SSHTunnel) Dial(network, address string) (net.Conn, error) {
   116  	if s.client == nil {
   117  		return nil, errors.New("tunnel is not opened.")
   118  	}
   119  	return s.client.Dial(network, address)
   120  }
   121  
   122  func (s *SSHTunnel) tunnel(conn net.Conn, remoteHost, remotePort string) error {
   123  	if s.client == nil {
   124  		return errors.New("tunnel is not opened.")
   125  	}
   126  	tunnel, err := s.client.Dial("tcp", net.JoinHostPort(remoteHost, remotePort))
   127  	if err != nil {
   128  		return err
   129  	}
   130  	go s.copyBytes(tunnel, conn)
   131  	go s.copyBytes(conn, tunnel)
   132  	return nil
   133  }
   134  
   135  func (s *SSHTunnel) Close() error {
   136  	if s.client == nil {
   137  		return errors.New("Cannot close tunnel. Tunnel was not opened.")
   138  	}
   139  	if err := s.client.Close(); err != nil {
   140  		return err
   141  	}
   142  	return nil
   143  }
   144  
   145  // Interface to allow mocking of ssh.Dial, for testing SSH
   146  type sshDialer interface {
   147  	Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error)
   148  }
   149  
   150  // Real implementation of sshDialer
   151  type realSSHDialer struct{}
   152  
   153  var _ sshDialer = &realSSHDialer{}
   154  
   155  func (d *realSSHDialer) Dial(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
   156  	return ssh.Dial(network, addr, config)
   157  }
   158  
   159  // RunSSHCommand returns the stdout, stderr, and exit code from running cmd on
   160  // host as specific user, along with any SSH-level error.
   161  // If user=="", it will default (like SSH) to os.Getenv("USER")
   162  func RunSSHCommand(cmd, user, host string, signer ssh.Signer) (string, string, int, error) {
   163  	return runSSHCommand(&realSSHDialer{}, cmd, user, host, signer)
   164  }
   165  
   166  // Internal implementation of runSSHCommand, for testing
   167  func runSSHCommand(dialer sshDialer, cmd, user, host string, signer ssh.Signer) (string, string, int, error) {
   168  	if user == "" {
   169  		user = os.Getenv("USER")
   170  	}
   171  	// Setup the config, dial the server, and open a session.
   172  	config := &ssh.ClientConfig{
   173  		User: user,
   174  		Auth: []ssh.AuthMethod{ssh.PublicKeys(signer)},
   175  	}
   176  	client, err := dialer.Dial("tcp", host, config)
   177  	if err != nil {
   178  		return "", "", 0, fmt.Errorf("error getting SSH client to %s@%s: '%v'", user, host, err)
   179  	}
   180  	session, err := client.NewSession()
   181  	if err != nil {
   182  		return "", "", 0, fmt.Errorf("error creating session to %s@%s: '%v'", user, host, err)
   183  	}
   184  	defer session.Close()
   185  
   186  	// Run the command.
   187  	code := 0
   188  	var bout, berr bytes.Buffer
   189  	session.Stdout, session.Stderr = &bout, &berr
   190  	if err = session.Run(cmd); err != nil {
   191  		// Check whether the command failed to run or didn't complete.
   192  		if exiterr, ok := err.(*ssh.ExitError); ok {
   193  			// If we got an ExitError and the exit code is nonzero, we'll
   194  			// consider the SSH itself successful (just that the command run
   195  			// errored on the host).
   196  			if code = exiterr.ExitStatus(); code != 0 {
   197  				err = nil
   198  			}
   199  		} else {
   200  			// Some other kind of error happened (e.g. an IOError); consider the
   201  			// SSH unsuccessful.
   202  			err = fmt.Errorf("failed running `%s` on %s@%s: '%v'", cmd, user, host, err)
   203  		}
   204  	}
   205  	return bout.String(), berr.String(), code, err
   206  }
   207  
   208  func MakePrivateKeySignerFromFile(key string) (ssh.Signer, error) {
   209  	// Create an actual signer.
   210  	buffer, err := ioutil.ReadFile(key)
   211  	if err != nil {
   212  		return nil, fmt.Errorf("error reading SSH key %s: '%v'", key, err)
   213  	}
   214  	return MakePrivateKeySignerFromBytes(buffer)
   215  }
   216  
   217  func MakePrivateKeySignerFromBytes(buffer []byte) (ssh.Signer, error) {
   218  	signer, err := ssh.ParsePrivateKey(buffer)
   219  	if err != nil {
   220  		return nil, fmt.Errorf("error parsing SSH key %s: '%v'", buffer, err)
   221  	}
   222  	return signer, nil
   223  }
   224  
   225  func ParsePublicKeyFromFile(keyFile string) (*rsa.PublicKey, error) {
   226  	buffer, err := ioutil.ReadFile(keyFile)
   227  	if err != nil {
   228  		return nil, fmt.Errorf("error reading SSH key %s: '%v'", keyFile, err)
   229  	}
   230  	keyBlock, _ := pem.Decode(buffer)
   231  	key, err := x509.ParsePKIXPublicKey(keyBlock.Bytes)
   232  	if err != nil {
   233  		return nil, fmt.Errorf("error parsing SSH key %s: '%v'", keyFile, err)
   234  	}
   235  	rsaKey, ok := key.(*rsa.PublicKey)
   236  	if !ok {
   237  		return nil, fmt.Errorf("SSH key could not be parsed as rsa public key")
   238  	}
   239  	return rsaKey, nil
   240  }
   241  
   242  // Should be thread safe.
   243  type SSHTunnelEntry struct {
   244  	Address string
   245  	Tunnel  *SSHTunnel
   246  }
   247  
   248  // Not thread safe!
   249  type SSHTunnelList struct {
   250  	entries []SSHTunnelEntry
   251  }
   252  
   253  func MakeSSHTunnels(user, keyfile string, addresses []string) *SSHTunnelList {
   254  	tunnels := []SSHTunnelEntry{}
   255  	for ix := range addresses {
   256  		addr := addresses[ix]
   257  		tunnel, err := NewSSHTunnel(user, keyfile, addr)
   258  		if err != nil {
   259  			glog.Errorf("Failed to create tunnel for %q: %v", addr, err)
   260  			continue
   261  		}
   262  		tunnels = append(tunnels, SSHTunnelEntry{addr, tunnel})
   263  	}
   264  	return &SSHTunnelList{tunnels}
   265  }
   266  
   267  // Open attempts to open all tunnels in the list, and removes any tunnels that
   268  // failed to open.
   269  func (l *SSHTunnelList) Open() error {
   270  	var openTunnels []SSHTunnelEntry
   271  	for ix := range l.entries {
   272  		if err := l.entries[ix].Tunnel.Open(); err != nil {
   273  			glog.Errorf("Failed to open tunnel %v: %v", l.entries[ix], err)
   274  		} else {
   275  			openTunnels = append(openTunnels, l.entries[ix])
   276  		}
   277  	}
   278  	l.entries = openTunnels
   279  	if len(l.entries) == 0 {
   280  		return errors.New("Failed to open any tunnels.")
   281  	}
   282  	return nil
   283  }
   284  
   285  // Close asynchronously closes all tunnels in the list after waiting for 1
   286  // minute. Tunnels will still be open upon this function's return, but should
   287  // no longer be used.
   288  func (l *SSHTunnelList) Close() {
   289  	for ix := range l.entries {
   290  		entry := l.entries[ix]
   291  		go func() {
   292  			defer HandleCrash()
   293  			time.Sleep(1 * time.Minute)
   294  			if err := entry.Tunnel.Close(); err != nil {
   295  				glog.Errorf("Failed to close tunnel %v: %v", entry, err)
   296  			}
   297  		}()
   298  	}
   299  }
   300  
   301  /* this will make sense if we move the lock into SSHTunnelList.
   302  func (l *SSHTunnelList) Dial(network, addr string) (net.Conn, error) {
   303  	if len(l.entries) == 0 {
   304  		return nil, fmt.Errorf("empty tunnel list.")
   305  	}
   306  	n := mathrand.Intn(len(l.entries))
   307  	return l.entries[n].Tunnel.Dial(network, addr)
   308  }
   309  */
   310  
   311  // Returns a random tunnel, xor an error if there are none.
   312  func (l *SSHTunnelList) PickRandomTunnel() (SSHTunnelEntry, error) {
   313  	if len(l.entries) == 0 {
   314  		return SSHTunnelEntry{}, fmt.Errorf("empty tunnel list.")
   315  	}
   316  	n := mathrand.Intn(len(l.entries))
   317  	return l.entries[n], nil
   318  }
   319  
   320  func (l *SSHTunnelList) Has(addr string) bool {
   321  	for ix := range l.entries {
   322  		if l.entries[ix].Address == addr {
   323  			return true
   324  		}
   325  	}
   326  	return false
   327  }
   328  
   329  func (l *SSHTunnelList) Len() int {
   330  	return len(l.entries)
   331  }
   332  
   333  func EncodePrivateKey(private *rsa.PrivateKey) []byte {
   334  	return pem.EncodeToMemory(&pem.Block{
   335  		Bytes: x509.MarshalPKCS1PrivateKey(private),
   336  		Type:  "RSA PRIVATE KEY",
   337  	})
   338  }
   339  
   340  func EncodePublicKey(public *rsa.PublicKey) ([]byte, error) {
   341  	publicBytes, err := x509.MarshalPKIXPublicKey(public)
   342  	if err != nil {
   343  		return nil, err
   344  	}
   345  	return pem.EncodeToMemory(&pem.Block{
   346  		Bytes: publicBytes,
   347  		Type:  "PUBLIC KEY",
   348  	}), nil
   349  }
   350  
   351  func EncodeSSHKey(public *rsa.PublicKey) ([]byte, error) {
   352  	publicKey, err := ssh.NewPublicKey(public)
   353  	if err != nil {
   354  		return nil, err
   355  	}
   356  	return ssh.MarshalAuthorizedKey(publicKey), nil
   357  }
   358  
   359  func GenerateKey(bits int) (*rsa.PrivateKey, *rsa.PublicKey, error) {
   360  	private, err := rsa.GenerateKey(rand.Reader, bits)
   361  	if err != nil {
   362  		return nil, nil, err
   363  	}
   364  	return private, &private.PublicKey, nil
   365  }