github.com/pulumi/terraform@v1.4.0/pkg/communicator/ssh/provisioner.go (about)

     1  package ssh
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/pem"
     6  	"errors"
     7  	"fmt"
     8  	"io/ioutil"
     9  	"log"
    10  	"net"
    11  	"os"
    12  	"path/filepath"
    13  	"strings"
    14  	"time"
    15  
    16  	"github.com/pulumi/terraform/pkg/communicator/shared"
    17  	sshagent "github.com/xanzy/ssh-agent"
    18  	"github.com/zclconf/go-cty/cty"
    19  	"github.com/zclconf/go-cty/cty/gocty"
    20  	"golang.org/x/crypto/ssh"
    21  	"golang.org/x/crypto/ssh/agent"
    22  	"golang.org/x/crypto/ssh/knownhosts"
    23  )
    24  
    25  const (
    26  	// DefaultUser is used if there is no user given
    27  	DefaultUser = "root"
    28  
    29  	// DefaultPort is used if there is no port given
    30  	DefaultPort = 22
    31  
    32  	// DefaultUnixScriptPath is used as the path to copy the file to
    33  	// for remote execution on unix if not provided otherwise.
    34  	DefaultUnixScriptPath = "/tmp/terraform_%RAND%.sh"
    35  	// DefaultWindowsScriptPath is used as the path to copy the file to
    36  	// for remote execution on windows if not provided otherwise.
    37  	DefaultWindowsScriptPath = "C:/windows/temp/terraform_%RAND%.cmd"
    38  
    39  	// DefaultTimeout is used if there is no timeout given
    40  	DefaultTimeout = 5 * time.Minute
    41  
    42  	// TargetPlatformUnix used for cleaner code, and is used if no target platform has been specified
    43  	TargetPlatformUnix = "unix"
    44  	//TargetPlatformWindows used for cleaner code
    45  	TargetPlatformWindows = "windows"
    46  )
    47  
    48  // connectionInfo is decoded from the ConnInfo of the resource. These are the
    49  // only keys we look at. If a PrivateKey is given, that is used instead
    50  // of a password.
    51  type connectionInfo struct {
    52  	User           string
    53  	Password       string
    54  	PrivateKey     string
    55  	Certificate    string
    56  	Host           string
    57  	HostKey        string
    58  	Port           uint16
    59  	Agent          bool
    60  	ScriptPath     string
    61  	TargetPlatform string
    62  	Timeout        string
    63  	TimeoutVal     time.Duration
    64  
    65  	ProxyScheme       string
    66  	ProxyHost         string
    67  	ProxyPort         uint16
    68  	ProxyUserName     string
    69  	ProxyUserPassword string
    70  
    71  	BastionUser        string
    72  	BastionPassword    string
    73  	BastionPrivateKey  string
    74  	BastionCertificate string
    75  	BastionHost        string
    76  	BastionHostKey     string
    77  	BastionPort        uint16
    78  
    79  	AgentIdentity string
    80  }
    81  
    82  // decodeConnInfo decodes the given cty.Value using the same behavior as the
    83  // lgeacy mapstructure decoder in order to preserve as much of the existing
    84  // logic as possible for compatibility.
    85  func decodeConnInfo(v cty.Value) (*connectionInfo, error) {
    86  	connInfo := &connectionInfo{}
    87  	if v.IsNull() {
    88  		return connInfo, nil
    89  	}
    90  
    91  	for k, v := range v.AsValueMap() {
    92  		if v.IsNull() {
    93  			continue
    94  		}
    95  
    96  		switch k {
    97  		case "user":
    98  			connInfo.User = v.AsString()
    99  		case "password":
   100  			connInfo.Password = v.AsString()
   101  		case "private_key":
   102  			connInfo.PrivateKey = v.AsString()
   103  		case "certificate":
   104  			connInfo.Certificate = v.AsString()
   105  		case "host":
   106  			connInfo.Host = v.AsString()
   107  		case "host_key":
   108  			connInfo.HostKey = v.AsString()
   109  		case "port":
   110  			if err := gocty.FromCtyValue(v, &connInfo.Port); err != nil {
   111  				return nil, err
   112  			}
   113  		case "agent":
   114  			connInfo.Agent = v.True()
   115  		case "script_path":
   116  			connInfo.ScriptPath = v.AsString()
   117  		case "target_platform":
   118  			connInfo.TargetPlatform = v.AsString()
   119  		case "timeout":
   120  			connInfo.Timeout = v.AsString()
   121  		case "proxy_scheme":
   122  			connInfo.ProxyScheme = v.AsString()
   123  		case "proxy_host":
   124  			connInfo.ProxyHost = v.AsString()
   125  		case "proxy_port":
   126  			if err := gocty.FromCtyValue(v, &connInfo.ProxyPort); err != nil {
   127  				return nil, err
   128  			}
   129  		case "proxy_user_name":
   130  			connInfo.ProxyUserName = v.AsString()
   131  		case "proxy_user_password":
   132  			connInfo.ProxyUserPassword = v.AsString()
   133  		case "bastion_user":
   134  			connInfo.BastionUser = v.AsString()
   135  		case "bastion_password":
   136  			connInfo.BastionPassword = v.AsString()
   137  		case "bastion_private_key":
   138  			connInfo.BastionPrivateKey = v.AsString()
   139  		case "bastion_certificate":
   140  			connInfo.BastionCertificate = v.AsString()
   141  		case "bastion_host":
   142  			connInfo.BastionHost = v.AsString()
   143  		case "bastion_host_key":
   144  			connInfo.BastionHostKey = v.AsString()
   145  		case "bastion_port":
   146  			if err := gocty.FromCtyValue(v, &connInfo.BastionPort); err != nil {
   147  				return nil, err
   148  			}
   149  		case "agent_identity":
   150  			connInfo.AgentIdentity = v.AsString()
   151  		}
   152  	}
   153  	return connInfo, nil
   154  }
   155  
   156  // parseConnectionInfo is used to convert the raw configuration into the
   157  // *connectionInfo struct.
   158  func parseConnectionInfo(v cty.Value) (*connectionInfo, error) {
   159  	v, err := shared.ConnectionBlockSupersetSchema.CoerceValue(v)
   160  	if err != nil {
   161  		return nil, err
   162  	}
   163  
   164  	connInfo, err := decodeConnInfo(v)
   165  	if err != nil {
   166  		return nil, err
   167  	}
   168  
   169  	// To default Agent to true, we need to check the raw string, since the
   170  	// decoded boolean can't represent "absence of config".
   171  	//
   172  	// And if SSH_AUTH_SOCK is not set, there's no agent to connect to, so we
   173  	// shouldn't try.
   174  	agent := v.GetAttr("agent")
   175  	if agent.IsNull() && os.Getenv("SSH_AUTH_SOCK") != "" {
   176  		connInfo.Agent = true
   177  	}
   178  
   179  	if connInfo.User == "" {
   180  		connInfo.User = DefaultUser
   181  	}
   182  
   183  	// Check if host is empty.
   184  	// Otherwise return error.
   185  	if connInfo.Host == "" {
   186  		return nil, fmt.Errorf("host for provisioner cannot be empty")
   187  	}
   188  
   189  	// Format the host if needed.
   190  	// Needed for IPv6 support.
   191  	connInfo.Host = shared.IpFormat(connInfo.Host)
   192  
   193  	if connInfo.Port == 0 {
   194  		connInfo.Port = DefaultPort
   195  	}
   196  	// Set default targetPlatform to unix if it's empty
   197  	if connInfo.TargetPlatform == "" {
   198  		connInfo.TargetPlatform = TargetPlatformUnix
   199  	} else if connInfo.TargetPlatform != TargetPlatformUnix && connInfo.TargetPlatform != TargetPlatformWindows {
   200  		return nil, fmt.Errorf("target_platform for provisioner has to be either %s or %s", TargetPlatformUnix, TargetPlatformWindows)
   201  	}
   202  	// Choose an appropriate default script path based on the target platform. There is no single
   203  	// suitable default script path which works on both UNIX and Windows targets.
   204  	if connInfo.ScriptPath == "" && connInfo.TargetPlatform == TargetPlatformUnix {
   205  		connInfo.ScriptPath = DefaultUnixScriptPath
   206  	}
   207  	if connInfo.ScriptPath == "" && connInfo.TargetPlatform == TargetPlatformWindows {
   208  		connInfo.ScriptPath = DefaultWindowsScriptPath
   209  	}
   210  	if connInfo.Timeout != "" {
   211  		connInfo.TimeoutVal = safeDuration(connInfo.Timeout, DefaultTimeout)
   212  	} else {
   213  		connInfo.TimeoutVal = DefaultTimeout
   214  	}
   215  
   216  	// Default all bastion config attrs to their non-bastion counterparts
   217  	if connInfo.BastionHost != "" {
   218  		// Format the bastion host if needed.
   219  		// Needed for IPv6 support.
   220  		connInfo.BastionHost = shared.IpFormat(connInfo.BastionHost)
   221  
   222  		if connInfo.BastionUser == "" {
   223  			connInfo.BastionUser = connInfo.User
   224  		}
   225  		if connInfo.BastionPassword == "" {
   226  			connInfo.BastionPassword = connInfo.Password
   227  		}
   228  		if connInfo.BastionPrivateKey == "" {
   229  			connInfo.BastionPrivateKey = connInfo.PrivateKey
   230  		}
   231  		if connInfo.BastionCertificate == "" {
   232  			connInfo.BastionCertificate = connInfo.Certificate
   233  		}
   234  		if connInfo.BastionPort == 0 {
   235  			connInfo.BastionPort = connInfo.Port
   236  		}
   237  	}
   238  
   239  	return connInfo, nil
   240  }
   241  
   242  // safeDuration returns either the parsed duration or a default value
   243  func safeDuration(dur string, defaultDur time.Duration) time.Duration {
   244  	d, err := time.ParseDuration(dur)
   245  	if err != nil {
   246  		log.Printf("Invalid duration '%s', using default of %s", dur, defaultDur)
   247  		return defaultDur
   248  	}
   249  	return d
   250  }
   251  
   252  // prepareSSHConfig is used to turn the *ConnectionInfo provided into a
   253  // usable *SSHConfig for client initialization.
   254  func prepareSSHConfig(connInfo *connectionInfo) (*sshConfig, error) {
   255  	sshAgent, err := connectToAgent(connInfo)
   256  	if err != nil {
   257  		return nil, err
   258  	}
   259  
   260  	host := fmt.Sprintf("%s:%d", connInfo.Host, connInfo.Port)
   261  
   262  	sshConf, err := buildSSHClientConfig(sshClientConfigOpts{
   263  		user:        connInfo.User,
   264  		host:        host,
   265  		privateKey:  connInfo.PrivateKey,
   266  		password:    connInfo.Password,
   267  		hostKey:     connInfo.HostKey,
   268  		certificate: connInfo.Certificate,
   269  		sshAgent:    sshAgent,
   270  	})
   271  	if err != nil {
   272  		return nil, err
   273  	}
   274  
   275  	var p *proxyInfo
   276  
   277  	if connInfo.ProxyHost != "" {
   278  		p = newProxyInfo(
   279  			fmt.Sprintf("%s:%d", connInfo.ProxyHost, connInfo.ProxyPort),
   280  			connInfo.ProxyScheme,
   281  			connInfo.ProxyUserName,
   282  			connInfo.ProxyUserPassword,
   283  		)
   284  	}
   285  
   286  	connectFunc := ConnectFunc("tcp", host, p)
   287  
   288  	var bastionConf *ssh.ClientConfig
   289  	if connInfo.BastionHost != "" {
   290  		bastionHost := fmt.Sprintf("%s:%d", connInfo.BastionHost, connInfo.BastionPort)
   291  
   292  		bastionConf, err = buildSSHClientConfig(sshClientConfigOpts{
   293  			user:        connInfo.BastionUser,
   294  			host:        bastionHost,
   295  			privateKey:  connInfo.BastionPrivateKey,
   296  			password:    connInfo.BastionPassword,
   297  			hostKey:     connInfo.HostKey,
   298  			certificate: connInfo.BastionCertificate,
   299  			sshAgent:    sshAgent,
   300  		})
   301  		if err != nil {
   302  			return nil, err
   303  		}
   304  
   305  		connectFunc = BastionConnectFunc("tcp", bastionHost, bastionConf, "tcp", host, p)
   306  	}
   307  
   308  	config := &sshConfig{
   309  		config:     sshConf,
   310  		connection: connectFunc,
   311  		sshAgent:   sshAgent,
   312  	}
   313  	return config, nil
   314  }
   315  
   316  type sshClientConfigOpts struct {
   317  	privateKey  string
   318  	password    string
   319  	sshAgent    *sshAgent
   320  	certificate string
   321  	user        string
   322  	host        string
   323  	hostKey     string
   324  }
   325  
   326  func buildSSHClientConfig(opts sshClientConfigOpts) (*ssh.ClientConfig, error) {
   327  	hkCallback := ssh.InsecureIgnoreHostKey()
   328  
   329  	if opts.hostKey != "" {
   330  		// The knownhosts package only takes paths to files, but terraform
   331  		// generally wants to handle config data in-memory. Rather than making
   332  		// the known_hosts file an exception, write out the data to a temporary
   333  		// file to create the HostKeyCallback.
   334  		tf, err := ioutil.TempFile("", "tf-known_hosts")
   335  		if err != nil {
   336  			return nil, fmt.Errorf("failed to create temp known_hosts file: %s", err)
   337  		}
   338  		defer tf.Close()
   339  		defer os.RemoveAll(tf.Name())
   340  
   341  		// we mark this as a CA as well, but the host key fallback will still
   342  		// use it as a direct match if the remote host doesn't return a
   343  		// certificate.
   344  		if _, err := tf.WriteString(fmt.Sprintf("@cert-authority %s %s\n", opts.host, opts.hostKey)); err != nil {
   345  			return nil, fmt.Errorf("failed to write temp known_hosts file: %s", err)
   346  		}
   347  		tf.Sync()
   348  
   349  		hkCallback, err = knownhosts.New(tf.Name())
   350  		if err != nil {
   351  			return nil, err
   352  		}
   353  	}
   354  
   355  	conf := &ssh.ClientConfig{
   356  		HostKeyCallback: hkCallback,
   357  		User:            opts.user,
   358  	}
   359  
   360  	if opts.privateKey != "" {
   361  		if opts.certificate != "" {
   362  			log.Println("using client certificate for authentication")
   363  
   364  			certSigner, err := signCertWithPrivateKey(opts.privateKey, opts.certificate)
   365  			if err != nil {
   366  				return nil, err
   367  			}
   368  			conf.Auth = append(conf.Auth, certSigner)
   369  		} else {
   370  			log.Println("using private key for authentication")
   371  
   372  			pubKeyAuth, err := readPrivateKey(opts.privateKey)
   373  			if err != nil {
   374  				return nil, err
   375  			}
   376  			conf.Auth = append(conf.Auth, pubKeyAuth)
   377  		}
   378  	}
   379  
   380  	if opts.password != "" {
   381  		conf.Auth = append(conf.Auth, ssh.Password(opts.password))
   382  		conf.Auth = append(conf.Auth, ssh.KeyboardInteractive(
   383  			PasswordKeyboardInteractive(opts.password)))
   384  	}
   385  
   386  	if opts.sshAgent != nil {
   387  		conf.Auth = append(conf.Auth, opts.sshAgent.Auth())
   388  	}
   389  
   390  	return conf, nil
   391  }
   392  
   393  // Create a Cert Signer and return ssh.AuthMethod
   394  func signCertWithPrivateKey(pk string, certificate string) (ssh.AuthMethod, error) {
   395  	rawPk, err := ssh.ParseRawPrivateKey([]byte(pk))
   396  	if err != nil {
   397  		return nil, fmt.Errorf("failed to parse private key %q: %s", pk, err)
   398  	}
   399  
   400  	pcert, _, _, _, err := ssh.ParseAuthorizedKey([]byte(certificate))
   401  	if err != nil {
   402  		return nil, fmt.Errorf("failed to parse certificate %q: %s", certificate, err)
   403  	}
   404  
   405  	usigner, err := ssh.NewSignerFromKey(rawPk)
   406  	if err != nil {
   407  		return nil, fmt.Errorf("failed to create signer from raw private key %q: %s", rawPk, err)
   408  	}
   409  
   410  	ucertSigner, err := ssh.NewCertSigner(pcert.(*ssh.Certificate), usigner)
   411  	if err != nil {
   412  		return nil, fmt.Errorf("failed to create cert signer %q: %s", usigner, err)
   413  	}
   414  
   415  	return ssh.PublicKeys(ucertSigner), nil
   416  }
   417  
   418  func readPrivateKey(pk string) (ssh.AuthMethod, error) {
   419  	// We parse the private key on our own first so that we can
   420  	// show a nicer error if the private key has a password.
   421  	block, _ := pem.Decode([]byte(pk))
   422  	if block == nil {
   423  		return nil, errors.New("Failed to read ssh private key: no key found")
   424  	}
   425  	if block.Headers["Proc-Type"] == "4,ENCRYPTED" {
   426  		return nil, errors.New(
   427  			"Failed to read ssh private key: password protected keys are\n" +
   428  				"not supported. Please decrypt the key prior to use.")
   429  	}
   430  
   431  	signer, err := ssh.ParsePrivateKey([]byte(pk))
   432  	if err != nil {
   433  		return nil, fmt.Errorf("Failed to parse ssh private key: %s", err)
   434  	}
   435  
   436  	return ssh.PublicKeys(signer), nil
   437  }
   438  
   439  func connectToAgent(connInfo *connectionInfo) (*sshAgent, error) {
   440  	if !connInfo.Agent {
   441  		// No agent configured
   442  		return nil, nil
   443  	}
   444  
   445  	agent, conn, err := sshagent.New()
   446  	if err != nil {
   447  		return nil, err
   448  	}
   449  
   450  	// connection close is handled over in Communicator
   451  	return &sshAgent{
   452  		agent: agent,
   453  		conn:  conn,
   454  		id:    connInfo.AgentIdentity,
   455  	}, nil
   456  
   457  }
   458  
   459  // A tiny wrapper around an agent.Agent to expose the ability to close its
   460  // associated connection on request.
   461  type sshAgent struct {
   462  	agent agent.Agent
   463  	conn  net.Conn
   464  	id    string
   465  }
   466  
   467  func (a *sshAgent) Close() error {
   468  	if a.conn == nil {
   469  		return nil
   470  	}
   471  
   472  	return a.conn.Close()
   473  }
   474  
   475  // make an attempt to either read the identity file or find a corresponding
   476  // public key file using the typical openssh naming convention.
   477  // This returns the public key in wire format, or nil when a key is not found.
   478  func findIDPublicKey(id string) []byte {
   479  	for _, d := range idKeyData(id) {
   480  		signer, err := ssh.ParsePrivateKey(d)
   481  		if err == nil {
   482  			log.Println("[DEBUG] parsed id private key")
   483  			pk := signer.PublicKey()
   484  			return pk.Marshal()
   485  		}
   486  
   487  		// try it as a publicKey
   488  		pk, err := ssh.ParsePublicKey(d)
   489  		if err == nil {
   490  			log.Println("[DEBUG] parsed id public key")
   491  			return pk.Marshal()
   492  		}
   493  
   494  		// finally try it as an authorized key
   495  		pk, _, _, _, err = ssh.ParseAuthorizedKey(d)
   496  		if err == nil {
   497  			log.Println("[DEBUG] parsed id authorized key")
   498  			return pk.Marshal()
   499  		}
   500  	}
   501  
   502  	return nil
   503  }
   504  
   505  // Try to read an id file using the id as the file path. Also read the .pub
   506  // file if it exists, as the id file may be encrypted. Return only the file
   507  // data read. We don't need to know what data came from which path, as we will
   508  // try parsing each as a private key, a public key and an authorized key
   509  // regardless.
   510  func idKeyData(id string) [][]byte {
   511  	idPath, err := filepath.Abs(id)
   512  	if err != nil {
   513  		return nil
   514  	}
   515  
   516  	var fileData [][]byte
   517  
   518  	paths := []string{idPath}
   519  
   520  	if !strings.HasSuffix(idPath, ".pub") {
   521  		paths = append(paths, idPath+".pub")
   522  	}
   523  
   524  	for _, p := range paths {
   525  		d, err := ioutil.ReadFile(p)
   526  		if err != nil {
   527  			log.Printf("[DEBUG] error reading %q: %s", p, err)
   528  			continue
   529  		}
   530  		log.Printf("[DEBUG] found identity data at %q", p)
   531  		fileData = append(fileData, d)
   532  	}
   533  
   534  	return fileData
   535  }
   536  
   537  // sortSigners moves a signer with an agent comment field matching the
   538  // agent_identity to the head of the list when attempting authentication. This
   539  // helps when there are more keys loaded in an agent than the host will allow
   540  // attempts.
   541  func (s *sshAgent) sortSigners(signers []ssh.Signer) {
   542  	if s.id == "" || len(signers) < 2 {
   543  		return
   544  	}
   545  
   546  	// if we can locate the public key, either by extracting it from the id or
   547  	// locating the .pub file, then we can more easily determine an exact match
   548  	idPk := findIDPublicKey(s.id)
   549  
   550  	// if we have a signer with a connect field that matches the id, send that
   551  	// first, otherwise put close matches at the front of the list.
   552  	head := 0
   553  	for i := range signers {
   554  		pk := signers[i].PublicKey()
   555  		k, ok := pk.(*agent.Key)
   556  		if !ok {
   557  			continue
   558  		}
   559  
   560  		// check for an exact match first
   561  		if bytes.Equal(pk.Marshal(), idPk) || s.id == k.Comment {
   562  			signers[0], signers[i] = signers[i], signers[0]
   563  			break
   564  		}
   565  
   566  		// no exact match yet, move it to the front if it's close. The agent
   567  		// may have loaded as a full filepath, while the config refers to it by
   568  		// filename only.
   569  		if strings.HasSuffix(k.Comment, s.id) {
   570  			signers[head], signers[i] = signers[i], signers[head]
   571  			head++
   572  			continue
   573  		}
   574  	}
   575  }
   576  
   577  func (s *sshAgent) Signers() ([]ssh.Signer, error) {
   578  	signers, err := s.agent.Signers()
   579  	if err != nil {
   580  		return nil, err
   581  	}
   582  
   583  	s.sortSigners(signers)
   584  	return signers, nil
   585  }
   586  
   587  func (a *sshAgent) Auth() ssh.AuthMethod {
   588  	return ssh.PublicKeysCallback(a.Signers)
   589  }
   590  
   591  func (a *sshAgent) ForwardToAgent(client *ssh.Client) error {
   592  	return agent.ForwardToAgent(client, a.agent)
   593  }