github.com/muratcelep/terraform@v1.1.0-beta2-not-internal-4/not-internal/communicator/ssh/provisioner.go (about)

     1  package ssh
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/pem"
     6  	"errors"
     7  	"fmt"
     8  	"io/ioutil"
     9  	"log"
    10  	"net"
    11  	"os"
    12  	"path/filepath"
    13  	"strings"
    14  	"time"
    15  
    16  	"github.com/muratcelep/terraform/not-internal/communicator/shared"
    17  	sshagent "github.com/xanzy/ssh-agent"
    18  	"github.com/zclconf/go-cty/cty"
    19  	"github.com/zclconf/go-cty/cty/gocty"
    20  	"golang.org/x/crypto/ssh"
    21  	"golang.org/x/crypto/ssh/agent"
    22  	"golang.org/x/crypto/ssh/knownhosts"
    23  )
    24  
    25  const (
    26  	// DefaultUser is used if there is no user given
    27  	DefaultUser = "root"
    28  
    29  	// DefaultPort is used if there is no port given
    30  	DefaultPort = 22
    31  
    32  	// DefaultUnixScriptPath is used as the path to copy the file to
    33  	// for remote execution on unix if not provided otherwise.
    34  	DefaultUnixScriptPath = "/tmp/terraform_%RAND%.sh"
    35  	// DefaultWindowsScriptPath is used as the path to copy the file to
    36  	// for remote execution on windows if not provided otherwise.
    37  	DefaultWindowsScriptPath = "C:/windows/temp/terraform_%RAND%.cmd"
    38  
    39  	// DefaultTimeout is used if there is no timeout given
    40  	DefaultTimeout = 5 * time.Minute
    41  
    42  	// TargetPlatformUnix used for cleaner code, and is used if no target platform has been specified
    43  	TargetPlatformUnix = "unix"
    44  	//TargetPlatformWindows used for cleaner code
    45  	TargetPlatformWindows = "windows"
    46  )
    47  
    48  // connectionInfo is decoded from the ConnInfo of the resource. These are the
    49  // only keys we look at. If a PrivateKey is given, that is used instead
    50  // of a password.
    51  type connectionInfo struct {
    52  	User           string
    53  	Password       string
    54  	PrivateKey     string
    55  	Certificate    string
    56  	Host           string
    57  	HostKey        string
    58  	Port           uint16
    59  	Agent          bool
    60  	ScriptPath     string
    61  	TargetPlatform string
    62  	Timeout        string
    63  	TimeoutVal     time.Duration
    64  
    65  	BastionUser        string
    66  	BastionPassword    string
    67  	BastionPrivateKey  string
    68  	BastionCertificate string
    69  	BastionHost        string
    70  	BastionHostKey     string
    71  	BastionPort        uint16
    72  
    73  	AgentIdentity string
    74  }
    75  
    76  // decodeConnInfo decodes the given cty.Value using the same behavior as the
    77  // lgeacy mapstructure decoder in order to preserve as much of the existing
    78  // logic as possible for compatibility.
    79  func decodeConnInfo(v cty.Value) (*connectionInfo, error) {
    80  	connInfo := &connectionInfo{}
    81  	if v.IsNull() {
    82  		return connInfo, nil
    83  	}
    84  
    85  	for k, v := range v.AsValueMap() {
    86  		if v.IsNull() {
    87  			continue
    88  		}
    89  
    90  		switch k {
    91  		case "user":
    92  			connInfo.User = v.AsString()
    93  		case "password":
    94  			connInfo.Password = v.AsString()
    95  		case "private_key":
    96  			connInfo.PrivateKey = v.AsString()
    97  		case "certificate":
    98  			connInfo.Certificate = v.AsString()
    99  		case "host":
   100  			connInfo.Host = v.AsString()
   101  		case "host_key":
   102  			connInfo.HostKey = v.AsString()
   103  		case "port":
   104  			if err := gocty.FromCtyValue(v, &connInfo.Port); err != nil {
   105  				return nil, err
   106  			}
   107  		case "agent":
   108  			connInfo.Agent = v.True()
   109  		case "script_path":
   110  			connInfo.ScriptPath = v.AsString()
   111  		case "target_platform":
   112  			connInfo.TargetPlatform = v.AsString()
   113  		case "timeout":
   114  			connInfo.Timeout = v.AsString()
   115  		case "bastion_user":
   116  			connInfo.BastionUser = v.AsString()
   117  		case "bastion_password":
   118  			connInfo.BastionPassword = v.AsString()
   119  		case "bastion_private_key":
   120  			connInfo.BastionPrivateKey = v.AsString()
   121  		case "bastion_certificate":
   122  			connInfo.BastionCertificate = v.AsString()
   123  		case "bastion_host":
   124  			connInfo.BastionHost = v.AsString()
   125  		case "bastion_host_key":
   126  			connInfo.BastionHostKey = v.AsString()
   127  		case "bastion_port":
   128  			if err := gocty.FromCtyValue(v, &connInfo.BastionPort); err != nil {
   129  				return nil, err
   130  			}
   131  		case "agent_identity":
   132  			connInfo.AgentIdentity = v.AsString()
   133  		}
   134  	}
   135  	return connInfo, nil
   136  }
   137  
   138  // parseConnectionInfo is used to convert the raw configuration into the
   139  // *connectionInfo struct.
   140  func parseConnectionInfo(v cty.Value) (*connectionInfo, error) {
   141  	v, err := shared.ConnectionBlockSupersetSchema.CoerceValue(v)
   142  	if err != nil {
   143  		return nil, err
   144  	}
   145  
   146  	connInfo, err := decodeConnInfo(v)
   147  	if err != nil {
   148  		return nil, err
   149  	}
   150  
   151  	// To default Agent to true, we need to check the raw string, since the
   152  	// decoded boolean can't represent "absence of config".
   153  	//
   154  	// And if SSH_AUTH_SOCK is not set, there's no agent to connect to, so we
   155  	// shouldn't try.
   156  	agent := v.GetAttr("agent")
   157  	if agent.IsNull() && os.Getenv("SSH_AUTH_SOCK") != "" {
   158  		connInfo.Agent = true
   159  	}
   160  
   161  	if connInfo.User == "" {
   162  		connInfo.User = DefaultUser
   163  	}
   164  
   165  	// Check if host is empty.
   166  	// Otherwise return error.
   167  	if connInfo.Host == "" {
   168  		return nil, fmt.Errorf("host for provisioner cannot be empty")
   169  	}
   170  
   171  	// Format the host if needed.
   172  	// Needed for IPv6 support.
   173  	connInfo.Host = shared.IpFormat(connInfo.Host)
   174  
   175  	if connInfo.Port == 0 {
   176  		connInfo.Port = DefaultPort
   177  	}
   178  	// Set default targetPlatform to unix if it's empty
   179  	if connInfo.TargetPlatform == "" {
   180  		connInfo.TargetPlatform = TargetPlatformUnix
   181  	} else if connInfo.TargetPlatform != TargetPlatformUnix && connInfo.TargetPlatform != TargetPlatformWindows {
   182  		return nil, fmt.Errorf("target_platform for provisioner has to be either %s or %s", TargetPlatformUnix, TargetPlatformWindows)
   183  	}
   184  	// Choose an appropriate default script path based on the target platform. There is no single
   185  	// suitable default script path which works on both UNIX and Windows targets.
   186  	if connInfo.ScriptPath == "" && connInfo.TargetPlatform == TargetPlatformUnix {
   187  		connInfo.ScriptPath = DefaultUnixScriptPath
   188  	}
   189  	if connInfo.ScriptPath == "" && connInfo.TargetPlatform == TargetPlatformWindows {
   190  		connInfo.ScriptPath = DefaultWindowsScriptPath
   191  	}
   192  	if connInfo.Timeout != "" {
   193  		connInfo.TimeoutVal = safeDuration(connInfo.Timeout, DefaultTimeout)
   194  	} else {
   195  		connInfo.TimeoutVal = DefaultTimeout
   196  	}
   197  
   198  	// Default all bastion config attrs to their non-bastion counterparts
   199  	if connInfo.BastionHost != "" {
   200  		// Format the bastion host if needed.
   201  		// Needed for IPv6 support.
   202  		connInfo.BastionHost = shared.IpFormat(connInfo.BastionHost)
   203  
   204  		if connInfo.BastionUser == "" {
   205  			connInfo.BastionUser = connInfo.User
   206  		}
   207  		if connInfo.BastionPassword == "" {
   208  			connInfo.BastionPassword = connInfo.Password
   209  		}
   210  		if connInfo.BastionPrivateKey == "" {
   211  			connInfo.BastionPrivateKey = connInfo.PrivateKey
   212  		}
   213  		if connInfo.BastionCertificate == "" {
   214  			connInfo.BastionCertificate = connInfo.Certificate
   215  		}
   216  		if connInfo.BastionPort == 0 {
   217  			connInfo.BastionPort = connInfo.Port
   218  		}
   219  	}
   220  
   221  	return connInfo, nil
   222  }
   223  
   224  // safeDuration returns either the parsed duration or a default value
   225  func safeDuration(dur string, defaultDur time.Duration) time.Duration {
   226  	d, err := time.ParseDuration(dur)
   227  	if err != nil {
   228  		log.Printf("Invalid duration '%s', using default of %s", dur, defaultDur)
   229  		return defaultDur
   230  	}
   231  	return d
   232  }
   233  
   234  // prepareSSHConfig is used to turn the *ConnectionInfo provided into a
   235  // usable *SSHConfig for client initialization.
   236  func prepareSSHConfig(connInfo *connectionInfo) (*sshConfig, error) {
   237  	sshAgent, err := connectToAgent(connInfo)
   238  	if err != nil {
   239  		return nil, err
   240  	}
   241  
   242  	host := fmt.Sprintf("%s:%d", connInfo.Host, connInfo.Port)
   243  
   244  	sshConf, err := buildSSHClientConfig(sshClientConfigOpts{
   245  		user:        connInfo.User,
   246  		host:        host,
   247  		privateKey:  connInfo.PrivateKey,
   248  		password:    connInfo.Password,
   249  		hostKey:     connInfo.HostKey,
   250  		certificate: connInfo.Certificate,
   251  		sshAgent:    sshAgent,
   252  	})
   253  	if err != nil {
   254  		return nil, err
   255  	}
   256  
   257  	connectFunc := ConnectFunc("tcp", host)
   258  
   259  	var bastionConf *ssh.ClientConfig
   260  	if connInfo.BastionHost != "" {
   261  		bastionHost := fmt.Sprintf("%s:%d", connInfo.BastionHost, connInfo.BastionPort)
   262  
   263  		bastionConf, err = buildSSHClientConfig(sshClientConfigOpts{
   264  			user:        connInfo.BastionUser,
   265  			host:        bastionHost,
   266  			privateKey:  connInfo.BastionPrivateKey,
   267  			password:    connInfo.BastionPassword,
   268  			hostKey:     connInfo.HostKey,
   269  			certificate: connInfo.BastionCertificate,
   270  			sshAgent:    sshAgent,
   271  		})
   272  		if err != nil {
   273  			return nil, err
   274  		}
   275  
   276  		connectFunc = BastionConnectFunc("tcp", bastionHost, bastionConf, "tcp", host)
   277  	}
   278  
   279  	config := &sshConfig{
   280  		config:     sshConf,
   281  		connection: connectFunc,
   282  		sshAgent:   sshAgent,
   283  	}
   284  	return config, nil
   285  }
   286  
   287  type sshClientConfigOpts struct {
   288  	privateKey  string
   289  	password    string
   290  	sshAgent    *sshAgent
   291  	certificate string
   292  	user        string
   293  	host        string
   294  	hostKey     string
   295  }
   296  
   297  func buildSSHClientConfig(opts sshClientConfigOpts) (*ssh.ClientConfig, error) {
   298  	hkCallback := ssh.InsecureIgnoreHostKey()
   299  
   300  	if opts.hostKey != "" {
   301  		// The knownhosts package only takes paths to files, but terraform
   302  		// generally wants to handle config data in-memory. Rather than making
   303  		// the known_hosts file an exception, write out the data to a temporary
   304  		// file to create the HostKeyCallback.
   305  		tf, err := ioutil.TempFile("", "tf-known_hosts")
   306  		if err != nil {
   307  			return nil, fmt.Errorf("failed to create temp known_hosts file: %s", err)
   308  		}
   309  		defer tf.Close()
   310  		defer os.RemoveAll(tf.Name())
   311  
   312  		// we mark this as a CA as well, but the host key fallback will still
   313  		// use it as a direct match if the remote host doesn't return a
   314  		// certificate.
   315  		if _, err := tf.WriteString(fmt.Sprintf("@cert-authority %s %s\n", opts.host, opts.hostKey)); err != nil {
   316  			return nil, fmt.Errorf("failed to write temp known_hosts file: %s", err)
   317  		}
   318  		tf.Sync()
   319  
   320  		hkCallback, err = knownhosts.New(tf.Name())
   321  		if err != nil {
   322  			return nil, err
   323  		}
   324  	}
   325  
   326  	conf := &ssh.ClientConfig{
   327  		HostKeyCallback: hkCallback,
   328  		User:            opts.user,
   329  	}
   330  
   331  	if opts.privateKey != "" {
   332  		if opts.certificate != "" {
   333  			log.Println("using client certificate for authentication")
   334  
   335  			certSigner, err := signCertWithPrivateKey(opts.privateKey, opts.certificate)
   336  			if err != nil {
   337  				return nil, err
   338  			}
   339  			conf.Auth = append(conf.Auth, certSigner)
   340  		} else {
   341  			log.Println("using private key for authentication")
   342  
   343  			pubKeyAuth, err := readPrivateKey(opts.privateKey)
   344  			if err != nil {
   345  				return nil, err
   346  			}
   347  			conf.Auth = append(conf.Auth, pubKeyAuth)
   348  		}
   349  	}
   350  
   351  	if opts.password != "" {
   352  		conf.Auth = append(conf.Auth, ssh.Password(opts.password))
   353  		conf.Auth = append(conf.Auth, ssh.KeyboardInteractive(
   354  			PasswordKeyboardInteractive(opts.password)))
   355  	}
   356  
   357  	if opts.sshAgent != nil {
   358  		conf.Auth = append(conf.Auth, opts.sshAgent.Auth())
   359  	}
   360  
   361  	return conf, nil
   362  }
   363  
   364  // Create a Cert Signer and return ssh.AuthMethod
   365  func signCertWithPrivateKey(pk string, certificate string) (ssh.AuthMethod, error) {
   366  	rawPk, err := ssh.ParseRawPrivateKey([]byte(pk))
   367  	if err != nil {
   368  		return nil, fmt.Errorf("failed to parse private key %q: %s", pk, err)
   369  	}
   370  
   371  	pcert, _, _, _, err := ssh.ParseAuthorizedKey([]byte(certificate))
   372  	if err != nil {
   373  		return nil, fmt.Errorf("failed to parse certificate %q: %s", certificate, err)
   374  	}
   375  
   376  	usigner, err := ssh.NewSignerFromKey(rawPk)
   377  	if err != nil {
   378  		return nil, fmt.Errorf("failed to create signer from raw private key %q: %s", rawPk, err)
   379  	}
   380  
   381  	ucertSigner, err := ssh.NewCertSigner(pcert.(*ssh.Certificate), usigner)
   382  	if err != nil {
   383  		return nil, fmt.Errorf("failed to create cert signer %q: %s", usigner, err)
   384  	}
   385  
   386  	return ssh.PublicKeys(ucertSigner), nil
   387  }
   388  
   389  func readPrivateKey(pk string) (ssh.AuthMethod, error) {
   390  	// We parse the private key on our own first so that we can
   391  	// show a nicer error if the private key has a password.
   392  	block, _ := pem.Decode([]byte(pk))
   393  	if block == nil {
   394  		return nil, errors.New("Failed to read ssh private key: no key found")
   395  	}
   396  	if block.Headers["Proc-Type"] == "4,ENCRYPTED" {
   397  		return nil, errors.New(
   398  			"Failed to read ssh private key: password protected keys are\n" +
   399  				"not supported. Please decrypt the key prior to use.")
   400  	}
   401  
   402  	signer, err := ssh.ParsePrivateKey([]byte(pk))
   403  	if err != nil {
   404  		return nil, fmt.Errorf("Failed to parse ssh private key: %s", err)
   405  	}
   406  
   407  	return ssh.PublicKeys(signer), nil
   408  }
   409  
   410  func connectToAgent(connInfo *connectionInfo) (*sshAgent, error) {
   411  	if !connInfo.Agent {
   412  		// No agent configured
   413  		return nil, nil
   414  	}
   415  
   416  	agent, conn, err := sshagent.New()
   417  	if err != nil {
   418  		return nil, err
   419  	}
   420  
   421  	// connection close is handled over in Communicator
   422  	return &sshAgent{
   423  		agent: agent,
   424  		conn:  conn,
   425  		id:    connInfo.AgentIdentity,
   426  	}, nil
   427  
   428  }
   429  
   430  // A tiny wrapper around an agent.Agent to expose the ability to close its
   431  // associated connection on request.
   432  type sshAgent struct {
   433  	agent agent.Agent
   434  	conn  net.Conn
   435  	id    string
   436  }
   437  
   438  func (a *sshAgent) Close() error {
   439  	if a.conn == nil {
   440  		return nil
   441  	}
   442  
   443  	return a.conn.Close()
   444  }
   445  
   446  // make an attempt to either read the identity file or find a corresponding
   447  // public key file using the typical openssh naming convention.
   448  // This returns the public key in wire format, or nil when a key is not found.
   449  func findIDPublicKey(id string) []byte {
   450  	for _, d := range idKeyData(id) {
   451  		signer, err := ssh.ParsePrivateKey(d)
   452  		if err == nil {
   453  			log.Println("[DEBUG] parsed id private key")
   454  			pk := signer.PublicKey()
   455  			return pk.Marshal()
   456  		}
   457  
   458  		// try it as a publicKey
   459  		pk, err := ssh.ParsePublicKey(d)
   460  		if err == nil {
   461  			log.Println("[DEBUG] parsed id public key")
   462  			return pk.Marshal()
   463  		}
   464  
   465  		// finally try it as an authorized key
   466  		pk, _, _, _, err = ssh.ParseAuthorizedKey(d)
   467  		if err == nil {
   468  			log.Println("[DEBUG] parsed id authorized key")
   469  			return pk.Marshal()
   470  		}
   471  	}
   472  
   473  	return nil
   474  }
   475  
   476  // Try to read an id file using the id as the file path. Also read the .pub
   477  // file if it exists, as the id file may be encrypted. Return only the file
   478  // data read. We don't need to know what data came from which path, as we will
   479  // try parsing each as a private key, a public key and an authorized key
   480  // regardless.
   481  func idKeyData(id string) [][]byte {
   482  	idPath, err := filepath.Abs(id)
   483  	if err != nil {
   484  		return nil
   485  	}
   486  
   487  	var fileData [][]byte
   488  
   489  	paths := []string{idPath}
   490  
   491  	if !strings.HasSuffix(idPath, ".pub") {
   492  		paths = append(paths, idPath+".pub")
   493  	}
   494  
   495  	for _, p := range paths {
   496  		d, err := ioutil.ReadFile(p)
   497  		if err != nil {
   498  			log.Printf("[DEBUG] error reading %q: %s", p, err)
   499  			continue
   500  		}
   501  		log.Printf("[DEBUG] found identity data at %q", p)
   502  		fileData = append(fileData, d)
   503  	}
   504  
   505  	return fileData
   506  }
   507  
   508  // sortSigners moves a signer with an agent comment field matching the
   509  // agent_identity to the head of the list when attempting authentication. This
   510  // helps when there are more keys loaded in an agent than the host will allow
   511  // attempts.
   512  func (s *sshAgent) sortSigners(signers []ssh.Signer) {
   513  	if s.id == "" || len(signers) < 2 {
   514  		return
   515  	}
   516  
   517  	// if we can locate the public key, either by extracting it from the id or
   518  	// locating the .pub file, then we can more easily determine an exact match
   519  	idPk := findIDPublicKey(s.id)
   520  
   521  	// if we have a signer with a connect field that matches the id, send that
   522  	// first, otherwise put close matches at the front of the list.
   523  	head := 0
   524  	for i := range signers {
   525  		pk := signers[i].PublicKey()
   526  		k, ok := pk.(*agent.Key)
   527  		if !ok {
   528  			continue
   529  		}
   530  
   531  		// check for an exact match first
   532  		if bytes.Equal(pk.Marshal(), idPk) || s.id == k.Comment {
   533  			signers[0], signers[i] = signers[i], signers[0]
   534  			break
   535  		}
   536  
   537  		// no exact match yet, move it to the front if it's close. The agent
   538  		// may have loaded as a full filepath, while the config refers to it by
   539  		// filename only.
   540  		if strings.HasSuffix(k.Comment, s.id) {
   541  			signers[head], signers[i] = signers[i], signers[head]
   542  			head++
   543  			continue
   544  		}
   545  	}
   546  }
   547  
   548  func (s *sshAgent) Signers() ([]ssh.Signer, error) {
   549  	signers, err := s.agent.Signers()
   550  	if err != nil {
   551  		return nil, err
   552  	}
   553  
   554  	s.sortSigners(signers)
   555  	return signers, nil
   556  }
   557  
   558  func (a *sshAgent) Auth() ssh.AuthMethod {
   559  	return ssh.PublicKeysCallback(a.Signers)
   560  }
   561  
   562  func (a *sshAgent) ForwardToAgent(client *ssh.Client) error {
   563  	return agent.ForwardToAgent(client, a.agent)
   564  }