github.com/terramate-io/tf@v0.0.0-20230830114523-fce866b4dfcd/communicator/ssh/provisioner.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package ssh
     5  
     6  import (
     7  	"bytes"
     8  	"encoding/pem"
     9  	"errors"
    10  	"fmt"
    11  	"io/ioutil"
    12  	"log"
    13  	"net"
    14  	"os"
    15  	"path/filepath"
    16  	"strings"
    17  	"time"
    18  
    19  	"github.com/terramate-io/tf/communicator/shared"
    20  	sshagent "github.com/xanzy/ssh-agent"
    21  	"github.com/zclconf/go-cty/cty"
    22  	"github.com/zclconf/go-cty/cty/gocty"
    23  	"golang.org/x/crypto/ssh"
    24  	"golang.org/x/crypto/ssh/agent"
    25  	"golang.org/x/crypto/ssh/knownhosts"
    26  )
    27  
    28  const (
    29  	// DefaultUser is used if there is no user given
    30  	DefaultUser = "root"
    31  
    32  	// DefaultPort is used if there is no port given
    33  	DefaultPort = 22
    34  
    35  	// DefaultUnixScriptPath is used as the path to copy the file to
    36  	// for remote execution on unix if not provided otherwise.
    37  	DefaultUnixScriptPath = "/tmp/terraform_%RAND%.sh"
    38  	// DefaultWindowsScriptPath is used as the path to copy the file to
    39  	// for remote execution on windows if not provided otherwise.
    40  	DefaultWindowsScriptPath = "C:/windows/temp/terraform_%RAND%.cmd"
    41  
    42  	// DefaultTimeout is used if there is no timeout given
    43  	DefaultTimeout = 5 * time.Minute
    44  
    45  	// TargetPlatformUnix used for cleaner code, and is used if no target platform has been specified
    46  	TargetPlatformUnix = "unix"
    47  	//TargetPlatformWindows used for cleaner code
    48  	TargetPlatformWindows = "windows"
    49  )
    50  
    51  // connectionInfo is decoded from the ConnInfo of the resource. These are the
    52  // only keys we look at. If a PrivateKey is given, that is used instead
    53  // of a password.
    54  type connectionInfo struct {
    55  	User           string
    56  	Password       string
    57  	PrivateKey     string
    58  	Certificate    string
    59  	Host           string
    60  	HostKey        string
    61  	Port           uint16
    62  	Agent          bool
    63  	ScriptPath     string
    64  	TargetPlatform string
    65  	Timeout        string
    66  	TimeoutVal     time.Duration
    67  
    68  	ProxyScheme       string
    69  	ProxyHost         string
    70  	ProxyPort         uint16
    71  	ProxyUserName     string
    72  	ProxyUserPassword string
    73  
    74  	BastionUser        string
    75  	BastionPassword    string
    76  	BastionPrivateKey  string
    77  	BastionCertificate string
    78  	BastionHost        string
    79  	BastionHostKey     string
    80  	BastionPort        uint16
    81  
    82  	AgentIdentity string
    83  }
    84  
    85  // decodeConnInfo decodes the given cty.Value using the same behavior as the
    86  // lgeacy mapstructure decoder in order to preserve as much of the existing
    87  // logic as possible for compatibility.
    88  func decodeConnInfo(v cty.Value) (*connectionInfo, error) {
    89  	connInfo := &connectionInfo{}
    90  	if v.IsNull() {
    91  		return connInfo, nil
    92  	}
    93  
    94  	for k, v := range v.AsValueMap() {
    95  		if v.IsNull() {
    96  			continue
    97  		}
    98  
    99  		switch k {
   100  		case "user":
   101  			connInfo.User = v.AsString()
   102  		case "password":
   103  			connInfo.Password = v.AsString()
   104  		case "private_key":
   105  			connInfo.PrivateKey = v.AsString()
   106  		case "certificate":
   107  			connInfo.Certificate = v.AsString()
   108  		case "host":
   109  			connInfo.Host = v.AsString()
   110  		case "host_key":
   111  			connInfo.HostKey = v.AsString()
   112  		case "port":
   113  			if err := gocty.FromCtyValue(v, &connInfo.Port); err != nil {
   114  				return nil, err
   115  			}
   116  		case "agent":
   117  			connInfo.Agent = v.True()
   118  		case "script_path":
   119  			connInfo.ScriptPath = v.AsString()
   120  		case "target_platform":
   121  			connInfo.TargetPlatform = v.AsString()
   122  		case "timeout":
   123  			connInfo.Timeout = v.AsString()
   124  		case "proxy_scheme":
   125  			connInfo.ProxyScheme = v.AsString()
   126  		case "proxy_host":
   127  			connInfo.ProxyHost = v.AsString()
   128  		case "proxy_port":
   129  			if err := gocty.FromCtyValue(v, &connInfo.ProxyPort); err != nil {
   130  				return nil, err
   131  			}
   132  		case "proxy_user_name":
   133  			connInfo.ProxyUserName = v.AsString()
   134  		case "proxy_user_password":
   135  			connInfo.ProxyUserPassword = v.AsString()
   136  		case "bastion_user":
   137  			connInfo.BastionUser = v.AsString()
   138  		case "bastion_password":
   139  			connInfo.BastionPassword = v.AsString()
   140  		case "bastion_private_key":
   141  			connInfo.BastionPrivateKey = v.AsString()
   142  		case "bastion_certificate":
   143  			connInfo.BastionCertificate = v.AsString()
   144  		case "bastion_host":
   145  			connInfo.BastionHost = v.AsString()
   146  		case "bastion_host_key":
   147  			connInfo.BastionHostKey = v.AsString()
   148  		case "bastion_port":
   149  			if err := gocty.FromCtyValue(v, &connInfo.BastionPort); err != nil {
   150  				return nil, err
   151  			}
   152  		case "agent_identity":
   153  			connInfo.AgentIdentity = v.AsString()
   154  		}
   155  	}
   156  	return connInfo, nil
   157  }
   158  
   159  // parseConnectionInfo is used to convert the raw configuration into the
   160  // *connectionInfo struct.
   161  func parseConnectionInfo(v cty.Value) (*connectionInfo, error) {
   162  	v, err := shared.ConnectionBlockSupersetSchema.CoerceValue(v)
   163  	if err != nil {
   164  		return nil, err
   165  	}
   166  
   167  	connInfo, err := decodeConnInfo(v)
   168  	if err != nil {
   169  		return nil, err
   170  	}
   171  
   172  	// To default Agent to true, we need to check the raw string, since the
   173  	// decoded boolean can't represent "absence of config".
   174  	//
   175  	// And if SSH_AUTH_SOCK is not set, there's no agent to connect to, so we
   176  	// shouldn't try.
   177  	agent := v.GetAttr("agent")
   178  	if agent.IsNull() && os.Getenv("SSH_AUTH_SOCK") != "" {
   179  		connInfo.Agent = true
   180  	}
   181  
   182  	if connInfo.User == "" {
   183  		connInfo.User = DefaultUser
   184  	}
   185  
   186  	// Check if host is empty.
   187  	// Otherwise return error.
   188  	if connInfo.Host == "" {
   189  		return nil, fmt.Errorf("host for provisioner cannot be empty")
   190  	}
   191  
   192  	// Format the host if needed.
   193  	// Needed for IPv6 support.
   194  	connInfo.Host = shared.IpFormat(connInfo.Host)
   195  
   196  	if connInfo.Port == 0 {
   197  		connInfo.Port = DefaultPort
   198  	}
   199  	// Set default targetPlatform to unix if it's empty
   200  	if connInfo.TargetPlatform == "" {
   201  		connInfo.TargetPlatform = TargetPlatformUnix
   202  	} else if connInfo.TargetPlatform != TargetPlatformUnix && connInfo.TargetPlatform != TargetPlatformWindows {
   203  		return nil, fmt.Errorf("target_platform for provisioner has to be either %s or %s", TargetPlatformUnix, TargetPlatformWindows)
   204  	}
   205  	// Choose an appropriate default script path based on the target platform. There is no single
   206  	// suitable default script path which works on both UNIX and Windows targets.
   207  	if connInfo.ScriptPath == "" && connInfo.TargetPlatform == TargetPlatformUnix {
   208  		connInfo.ScriptPath = DefaultUnixScriptPath
   209  	}
   210  	if connInfo.ScriptPath == "" && connInfo.TargetPlatform == TargetPlatformWindows {
   211  		connInfo.ScriptPath = DefaultWindowsScriptPath
   212  	}
   213  	if connInfo.Timeout != "" {
   214  		connInfo.TimeoutVal = safeDuration(connInfo.Timeout, DefaultTimeout)
   215  	} else {
   216  		connInfo.TimeoutVal = DefaultTimeout
   217  	}
   218  
   219  	// Default all bastion config attrs to their non-bastion counterparts
   220  	if connInfo.BastionHost != "" {
   221  		// Format the bastion host if needed.
   222  		// Needed for IPv6 support.
   223  		connInfo.BastionHost = shared.IpFormat(connInfo.BastionHost)
   224  
   225  		if connInfo.BastionUser == "" {
   226  			connInfo.BastionUser = connInfo.User
   227  		}
   228  		if connInfo.BastionPassword == "" {
   229  			connInfo.BastionPassword = connInfo.Password
   230  		}
   231  		if connInfo.BastionPrivateKey == "" {
   232  			connInfo.BastionPrivateKey = connInfo.PrivateKey
   233  		}
   234  		if connInfo.BastionCertificate == "" {
   235  			connInfo.BastionCertificate = connInfo.Certificate
   236  		}
   237  		if connInfo.BastionPort == 0 {
   238  			connInfo.BastionPort = connInfo.Port
   239  		}
   240  	}
   241  
   242  	return connInfo, nil
   243  }
   244  
   245  // safeDuration returns either the parsed duration or a default value
   246  func safeDuration(dur string, defaultDur time.Duration) time.Duration {
   247  	d, err := time.ParseDuration(dur)
   248  	if err != nil {
   249  		log.Printf("Invalid duration '%s', using default of %s", dur, defaultDur)
   250  		return defaultDur
   251  	}
   252  	return d
   253  }
   254  
   255  // prepareSSHConfig is used to turn the *ConnectionInfo provided into a
   256  // usable *SSHConfig for client initialization.
   257  func prepareSSHConfig(connInfo *connectionInfo) (*sshConfig, error) {
   258  	sshAgent, err := connectToAgent(connInfo)
   259  	if err != nil {
   260  		return nil, err
   261  	}
   262  
   263  	host := fmt.Sprintf("%s:%d", connInfo.Host, connInfo.Port)
   264  
   265  	sshConf, err := buildSSHClientConfig(sshClientConfigOpts{
   266  		user:        connInfo.User,
   267  		host:        host,
   268  		privateKey:  connInfo.PrivateKey,
   269  		password:    connInfo.Password,
   270  		hostKey:     connInfo.HostKey,
   271  		certificate: connInfo.Certificate,
   272  		sshAgent:    sshAgent,
   273  	})
   274  	if err != nil {
   275  		return nil, err
   276  	}
   277  
   278  	var p *proxyInfo
   279  
   280  	if connInfo.ProxyHost != "" {
   281  		p = newProxyInfo(
   282  			fmt.Sprintf("%s:%d", connInfo.ProxyHost, connInfo.ProxyPort),
   283  			connInfo.ProxyScheme,
   284  			connInfo.ProxyUserName,
   285  			connInfo.ProxyUserPassword,
   286  		)
   287  	}
   288  
   289  	connectFunc := ConnectFunc("tcp", host, p)
   290  
   291  	var bastionConf *ssh.ClientConfig
   292  	if connInfo.BastionHost != "" {
   293  		bastionHost := fmt.Sprintf("%s:%d", connInfo.BastionHost, connInfo.BastionPort)
   294  
   295  		bastionConf, err = buildSSHClientConfig(sshClientConfigOpts{
   296  			user:        connInfo.BastionUser,
   297  			host:        bastionHost,
   298  			privateKey:  connInfo.BastionPrivateKey,
   299  			password:    connInfo.BastionPassword,
   300  			hostKey:     connInfo.HostKey,
   301  			certificate: connInfo.BastionCertificate,
   302  			sshAgent:    sshAgent,
   303  		})
   304  		if err != nil {
   305  			return nil, err
   306  		}
   307  
   308  		connectFunc = BastionConnectFunc("tcp", bastionHost, bastionConf, "tcp", host, p)
   309  	}
   310  
   311  	config := &sshConfig{
   312  		config:     sshConf,
   313  		connection: connectFunc,
   314  		sshAgent:   sshAgent,
   315  	}
   316  	return config, nil
   317  }
   318  
   319  type sshClientConfigOpts struct {
   320  	privateKey  string
   321  	password    string
   322  	sshAgent    *sshAgent
   323  	certificate string
   324  	user        string
   325  	host        string
   326  	hostKey     string
   327  }
   328  
   329  func buildSSHClientConfig(opts sshClientConfigOpts) (*ssh.ClientConfig, error) {
   330  	hkCallback := ssh.InsecureIgnoreHostKey()
   331  
   332  	if opts.hostKey != "" {
   333  		// The knownhosts package only takes paths to files, but terraform
   334  		// generally wants to handle config data in-memory. Rather than making
   335  		// the known_hosts file an exception, write out the data to a temporary
   336  		// file to create the HostKeyCallback.
   337  		tf, err := ioutil.TempFile("", "tf-known_hosts")
   338  		if err != nil {
   339  			return nil, fmt.Errorf("failed to create temp known_hosts file: %s", err)
   340  		}
   341  		defer tf.Close()
   342  		defer os.RemoveAll(tf.Name())
   343  
   344  		// we mark this as a CA as well, but the host key fallback will still
   345  		// use it as a direct match if the remote host doesn't return a
   346  		// certificate.
   347  		if _, err := tf.WriteString(fmt.Sprintf("@cert-authority %s %s\n", opts.host, opts.hostKey)); err != nil {
   348  			return nil, fmt.Errorf("failed to write temp known_hosts file: %s", err)
   349  		}
   350  		tf.Sync()
   351  
   352  		hkCallback, err = knownhosts.New(tf.Name())
   353  		if err != nil {
   354  			return nil, err
   355  		}
   356  	}
   357  
   358  	conf := &ssh.ClientConfig{
   359  		HostKeyCallback: hkCallback,
   360  		User:            opts.user,
   361  	}
   362  
   363  	if opts.privateKey != "" {
   364  		if opts.certificate != "" {
   365  			log.Println("using client certificate for authentication")
   366  
   367  			certSigner, err := signCertWithPrivateKey(opts.privateKey, opts.certificate)
   368  			if err != nil {
   369  				return nil, err
   370  			}
   371  			conf.Auth = append(conf.Auth, certSigner)
   372  		} else {
   373  			log.Println("using private key for authentication")
   374  
   375  			pubKeyAuth, err := readPrivateKey(opts.privateKey)
   376  			if err != nil {
   377  				return nil, err
   378  			}
   379  			conf.Auth = append(conf.Auth, pubKeyAuth)
   380  		}
   381  	}
   382  
   383  	if opts.password != "" {
   384  		conf.Auth = append(conf.Auth, ssh.Password(opts.password))
   385  		conf.Auth = append(conf.Auth, ssh.KeyboardInteractive(
   386  			PasswordKeyboardInteractive(opts.password)))
   387  	}
   388  
   389  	if opts.sshAgent != nil {
   390  		conf.Auth = append(conf.Auth, opts.sshAgent.Auth())
   391  	}
   392  
   393  	return conf, nil
   394  }
   395  
   396  // Create a Cert Signer and return ssh.AuthMethod
   397  func signCertWithPrivateKey(pk string, certificate string) (ssh.AuthMethod, error) {
   398  	rawPk, err := ssh.ParseRawPrivateKey([]byte(pk))
   399  	if err != nil {
   400  		return nil, fmt.Errorf("failed to parse private key %q: %s", pk, err)
   401  	}
   402  
   403  	pcert, _, _, _, err := ssh.ParseAuthorizedKey([]byte(certificate))
   404  	if err != nil {
   405  		return nil, fmt.Errorf("failed to parse certificate %q: %s", certificate, err)
   406  	}
   407  
   408  	usigner, err := ssh.NewSignerFromKey(rawPk)
   409  	if err != nil {
   410  		return nil, fmt.Errorf("failed to create signer from raw private key %q: %s", rawPk, err)
   411  	}
   412  
   413  	ucertSigner, err := ssh.NewCertSigner(pcert.(*ssh.Certificate), usigner)
   414  	if err != nil {
   415  		return nil, fmt.Errorf("failed to create cert signer %q: %s", usigner, err)
   416  	}
   417  
   418  	return ssh.PublicKeys(ucertSigner), nil
   419  }
   420  
   421  func readPrivateKey(pk string) (ssh.AuthMethod, error) {
   422  	// We parse the private key on our own first so that we can
   423  	// show a nicer error if the private key has a password.
   424  	block, _ := pem.Decode([]byte(pk))
   425  	if block == nil {
   426  		return nil, errors.New("Failed to read ssh private key: no key found")
   427  	}
   428  	if block.Headers["Proc-Type"] == "4,ENCRYPTED" {
   429  		return nil, errors.New(
   430  			"Failed to read ssh private key: password protected keys are\n" +
   431  				"not supported. Please decrypt the key prior to use.")
   432  	}
   433  
   434  	signer, err := ssh.ParsePrivateKey([]byte(pk))
   435  	if err != nil {
   436  		return nil, fmt.Errorf("Failed to parse ssh private key: %s", err)
   437  	}
   438  
   439  	return ssh.PublicKeys(signer), nil
   440  }
   441  
   442  func connectToAgent(connInfo *connectionInfo) (*sshAgent, error) {
   443  	if !connInfo.Agent {
   444  		// No agent configured
   445  		return nil, nil
   446  	}
   447  
   448  	agent, conn, err := sshagent.New()
   449  	if err != nil {
   450  		return nil, err
   451  	}
   452  
   453  	// connection close is handled over in Communicator
   454  	return &sshAgent{
   455  		agent: agent,
   456  		conn:  conn,
   457  		id:    connInfo.AgentIdentity,
   458  	}, nil
   459  
   460  }
   461  
   462  // A tiny wrapper around an agent.Agent to expose the ability to close its
   463  // associated connection on request.
   464  type sshAgent struct {
   465  	agent agent.Agent
   466  	conn  net.Conn
   467  	id    string
   468  }
   469  
   470  func (a *sshAgent) Close() error {
   471  	if a.conn == nil {
   472  		return nil
   473  	}
   474  
   475  	return a.conn.Close()
   476  }
   477  
   478  // make an attempt to either read the identity file or find a corresponding
   479  // public key file using the typical openssh naming convention.
   480  // This returns the public key in wire format, or nil when a key is not found.
   481  func findIDPublicKey(id string) []byte {
   482  	for _, d := range idKeyData(id) {
   483  		signer, err := ssh.ParsePrivateKey(d)
   484  		if err == nil {
   485  			log.Println("[DEBUG] parsed id private key")
   486  			pk := signer.PublicKey()
   487  			return pk.Marshal()
   488  		}
   489  
   490  		// try it as a publicKey
   491  		pk, err := ssh.ParsePublicKey(d)
   492  		if err == nil {
   493  			log.Println("[DEBUG] parsed id public key")
   494  			return pk.Marshal()
   495  		}
   496  
   497  		// finally try it as an authorized key
   498  		pk, _, _, _, err = ssh.ParseAuthorizedKey(d)
   499  		if err == nil {
   500  			log.Println("[DEBUG] parsed id authorized key")
   501  			return pk.Marshal()
   502  		}
   503  	}
   504  
   505  	return nil
   506  }
   507  
   508  // Try to read an id file using the id as the file path. Also read the .pub
   509  // file if it exists, as the id file may be encrypted. Return only the file
   510  // data read. We don't need to know what data came from which path, as we will
   511  // try parsing each as a private key, a public key and an authorized key
   512  // regardless.
   513  func idKeyData(id string) [][]byte {
   514  	idPath, err := filepath.Abs(id)
   515  	if err != nil {
   516  		return nil
   517  	}
   518  
   519  	var fileData [][]byte
   520  
   521  	paths := []string{idPath}
   522  
   523  	if !strings.HasSuffix(idPath, ".pub") {
   524  		paths = append(paths, idPath+".pub")
   525  	}
   526  
   527  	for _, p := range paths {
   528  		d, err := ioutil.ReadFile(p)
   529  		if err != nil {
   530  			log.Printf("[DEBUG] error reading %q: %s", p, err)
   531  			continue
   532  		}
   533  		log.Printf("[DEBUG] found identity data at %q", p)
   534  		fileData = append(fileData, d)
   535  	}
   536  
   537  	return fileData
   538  }
   539  
   540  // sortSigners moves a signer with an agent comment field matching the
   541  // agent_identity to the head of the list when attempting authentication. This
   542  // helps when there are more keys loaded in an agent than the host will allow
   543  // attempts.
   544  func (s *sshAgent) sortSigners(signers []ssh.Signer) {
   545  	if s.id == "" || len(signers) < 2 {
   546  		return
   547  	}
   548  
   549  	// if we can locate the public key, either by extracting it from the id or
   550  	// locating the .pub file, then we can more easily determine an exact match
   551  	idPk := findIDPublicKey(s.id)
   552  
   553  	// if we have a signer with a connect field that matches the id, send that
   554  	// first, otherwise put close matches at the front of the list.
   555  	head := 0
   556  	for i := range signers {
   557  		pk := signers[i].PublicKey()
   558  		k, ok := pk.(*agent.Key)
   559  		if !ok {
   560  			continue
   561  		}
   562  
   563  		// check for an exact match first
   564  		if bytes.Equal(pk.Marshal(), idPk) || s.id == k.Comment {
   565  			signers[0], signers[i] = signers[i], signers[0]
   566  			break
   567  		}
   568  
   569  		// no exact match yet, move it to the front if it's close. The agent
   570  		// may have loaded as a full filepath, while the config refers to it by
   571  		// filename only.
   572  		if strings.HasSuffix(k.Comment, s.id) {
   573  			signers[head], signers[i] = signers[i], signers[head]
   574  			head++
   575  			continue
   576  		}
   577  	}
   578  }
   579  
   580  func (s *sshAgent) Signers() ([]ssh.Signer, error) {
   581  	signers, err := s.agent.Signers()
   582  	if err != nil {
   583  		return nil, err
   584  	}
   585  
   586  	s.sortSigners(signers)
   587  	return signers, nil
   588  }
   589  
   590  func (a *sshAgent) Auth() ssh.AuthMethod {
   591  	return ssh.PublicKeysCallback(a.Signers)
   592  }
   593  
   594  func (a *sshAgent) ForwardToAgent(client *ssh.Client) error {
   595  	return agent.ForwardToAgent(client, a.agent)
   596  }