github.com/coreos/mantle@v0.13.0/network/ssh.go (about)

     1  // Copyright 2015 CoreOS, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package network
    16  
    17  import (
    18  	"crypto/rand"
    19  	"crypto/rsa"
    20  	"fmt"
    21  	"io/ioutil"
    22  	"net"
    23  	"os"
    24  	"strings"
    25  
    26  	"golang.org/x/crypto/ssh"
    27  	"golang.org/x/crypto/ssh/agent"
    28  )
    29  
    30  const (
    31  	defaultPort = 22
    32  	defaultUser = "core"
    33  	rsaKeySize  = 2048
    34  )
    35  
    36  // Dialer is an interface for anything compatible with net.Dialer
    37  type Dialer interface {
    38  	Dial(network, address string) (net.Conn, error)
    39  }
    40  
    41  // SSHAgent can manage keys, updates cloud config, and loves ponies.
    42  // The embedded dialer is used for establishing new SSH connections.
    43  type SSHAgent struct {
    44  	agent.Agent
    45  	Dialer
    46  	User     string
    47  	Socket   string
    48  	sockDir  string
    49  	listener *net.UnixListener
    50  }
    51  
    52  // NewSSHAgent constructs a new SSHAgent using dialer to create ssh
    53  // connections.
    54  func NewSSHAgent(dialer Dialer) (*SSHAgent, error) {
    55  	key, err := rsa.GenerateKey(rand.Reader, rsaKeySize)
    56  	if err != nil {
    57  		return nil, err
    58  	}
    59  
    60  	addedkey := agent.AddedKey{
    61  		PrivateKey: key,
    62  		Comment:    "core@default",
    63  	}
    64  
    65  	keyring := agent.NewKeyring()
    66  	err = keyring.Add(addedkey)
    67  	if err != nil {
    68  		return nil, err
    69  	}
    70  
    71  	sockDir, err := ioutil.TempDir("", "mantle-ssh-")
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  
    76  	// Use a similar naming scheme to ssh-agent
    77  	sockPath := fmt.Sprintf("%s/agent.%d", sockDir, os.Getpid())
    78  	sockAddr := &net.UnixAddr{Name: sockPath, Net: "unix"}
    79  	listener, err := net.ListenUnix("unix", sockAddr)
    80  	if err != nil {
    81  		os.RemoveAll(sockDir)
    82  		return nil, err
    83  	}
    84  
    85  	a := &SSHAgent{
    86  		Agent:    keyring,
    87  		Dialer:   dialer,
    88  		User:     defaultUser,
    89  		Socket:   sockPath,
    90  		sockDir:  sockDir,
    91  		listener: listener,
    92  	}
    93  
    94  	go func() {
    95  		for {
    96  			conn, err := listener.Accept()
    97  			if err != nil {
    98  				return
    99  			}
   100  			go agent.ServeAgent(a, conn)
   101  		}
   102  	}()
   103  
   104  	return a, nil
   105  }
   106  
   107  // Close closes the unix socket of the agent.
   108  func (a *SSHAgent) Close() error {
   109  	a.listener.Close()
   110  	return os.RemoveAll(a.sockDir)
   111  }
   112  
   113  // Add port to host if not already set.
   114  func ensurePortSuffix(host string, port int) string {
   115  	switch {
   116  	case !strings.Contains(host, ":"):
   117  		return fmt.Sprintf("%s:%d", host, port)
   118  	case strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]"):
   119  		return fmt.Sprintf("%s:%d", host, port)
   120  	case strings.HasPrefix(host, "[") && strings.Contains(host, "]:"):
   121  		return host
   122  	case strings.Count(host, ":") > 1:
   123  		return fmt.Sprintf("[%s]:%d", host, port)
   124  	default:
   125  		return host
   126  	}
   127  }
   128  
   129  func (a *SSHAgent) newClient(host string, user string, auth []ssh.AuthMethod) (*ssh.Client, error) {
   130  	sshcfg := ssh.ClientConfig{
   131  		User: user,
   132  		Auth: auth,
   133  	}
   134  	addr := ensurePortSuffix(host, defaultPort)
   135  	tcpconn, err := a.Dial("tcp", addr)
   136  	if err != nil {
   137  		return nil, err
   138  	}
   139  
   140  	sshconn, chans, reqs, err := ssh.NewClientConn(tcpconn, addr, &sshcfg)
   141  	if err != nil {
   142  		return nil, err
   143  	}
   144  
   145  	client := ssh.NewClient(sshconn, chans, reqs)
   146  	err = agent.ForwardToAgent(client, a)
   147  	if err != nil {
   148  		client.Close()
   149  		return nil, err
   150  	}
   151  
   152  	return client, nil
   153  }
   154  
   155  // NewClient connects to the given host via SSH, the client will support
   156  // agent forwarding but it must also be enabled per-session.
   157  func (a *SSHAgent) NewClient(host string) (*ssh.Client, error) {
   158  	return a.NewUserClient(host, a.User)
   159  }
   160  
   161  // NewUserClient connects to the given host via SSH using the provided username.
   162  // The client will support agent forwarding but it must also be enabled per-session.
   163  func (a *SSHAgent) NewUserClient(host string, user string) (*ssh.Client, error) {
   164  	client, err := a.newClient(host, user, []ssh.AuthMethod{ssh.PublicKeysCallback(a.Signers)})
   165  	if err != nil {
   166  		return nil, err
   167  	}
   168  	return client, nil
   169  }
   170  
   171  // NewPasswordClient connects to the given host via SSH using the
   172  // provided username and password
   173  func (a *SSHAgent) NewPasswordClient(host string, user string, password string) (*ssh.Client, error) {
   174  	client, err := a.newClient(host, user, []ssh.AuthMethod{ssh.Password(password)})
   175  	if err != nil {
   176  		return nil, err
   177  	}
   178  	return client, nil
   179  }