github.com/turbot/go-exec-communicator@v0.0.0-20230412124734-9374347749b6/ssh/provisioner.go (about)

     1  package ssh
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/pem"
     6  	"errors"
     7  	"fmt"
     8  	"io/ioutil"
     9  	"log"
    10  	"net"
    11  	"os"
    12  	"path/filepath"
    13  	"strings"
    14  	"time"
    15  
    16  	"github.com/turbot/go-exec-communicator/shared"
    17  	sshagent "github.com/xanzy/ssh-agent"
    18  	"golang.org/x/crypto/ssh"
    19  	"golang.org/x/crypto/ssh/agent"
    20  	"golang.org/x/crypto/ssh/knownhosts"
    21  )
    22  
    23  const (
    24  	// DefaultUser is used if there is no user given
    25  	DefaultUser = "root"
    26  
    27  	// DefaultPort is used if there is no port given
    28  	DefaultPort = 22
    29  
    30  	// DefaultUnixScriptPath is used as the path to copy the file to
    31  	// for remote execution on unix if not provided otherwise.
    32  	DefaultUnixScriptPath = "/tmp/terraform_%RAND%.sh"
    33  	// DefaultWindowsScriptPath is used as the path to copy the file to
    34  	// for remote execution on windows if not provided otherwise.
    35  	DefaultWindowsScriptPath = "C:/windows/temp/terraform_%RAND%.cmd"
    36  
    37  	// DefaultTimeout is used if there is no timeout given
    38  	DefaultTimeout = 5 * time.Minute
    39  
    40  	// TargetPlatformUnix used for cleaner code, and is used if no target platform has been specified
    41  	TargetPlatformUnix = "unix"
    42  	//TargetPlatformWindows used for cleaner code
    43  	TargetPlatformWindows = "windows"
    44  )
    45  
    46  /*
    47  
    48  // shared.ConnectionInfo is decoded from the ConnInfo of the resource. These are the
    49  // only keys we look at. If a PrivateKey is given, that is used instead
    50  // of a password.
    51  type ConnectionInfo struct {
    52  	User           string
    53  	Password       string
    54  	PrivateKey     string
    55  	Certificate    string
    56  	Host           string
    57  	HostKey        string
    58  	Port           uint16
    59  	Agent          bool
    60  	ScriptPath     string
    61  	TargetPlatform string
    62  	Timeout        string
    63  	TimeoutVal     time.Duration
    64  
    65  	ProxyScheme       string
    66  	ProxyHost         string
    67  	ProxyPort         uint16
    68  	ProxyUserName     string
    69  	ProxyUserPassword string
    70  
    71  	BastionUser        string
    72  	BastionPassword    string
    73  	BastionPrivateKey  string
    74  	BastionCertificate string
    75  	BastionHost        string
    76  	BastionHostKey     string
    77  	BastionPort        uint16
    78  
    79  	AgentIdentity string
    80  }
    81  
    82  */
    83  
    84  /*
    85  
    86  // decodeConnInfo decodes the given cty.Value using the same behavior as the
    87  // lgeacy mapstructure decoder in order to preserve as much of the existing
    88  // logic as possible for compatibility.
    89  func decodeConnInfo(v cty.Value) (*shared.ConnectionInfo, error) {
    90  	connInfo := &shared.ConnectionInfo{}
    91  	if v.IsNull() {
    92  		return connInfo, nil
    93  	}
    94  
    95  	for k, v := range v.AsValueMap() {
    96  		if v.IsNull() {
    97  			continue
    98  		}
    99  
   100  		switch k {
   101  		case "user":
   102  			connInfo.User = v.AsString()
   103  		case "password":
   104  			connInfo.Password = v.AsString()
   105  		case "private_key":
   106  			connInfo.PrivateKey = v.AsString()
   107  		case "certificate":
   108  			connInfo.Certificate = v.AsString()
   109  		case "host":
   110  			connInfo.Host = v.AsString()
   111  		case "host_key":
   112  			connInfo.HostKey = v.AsString()
   113  		case "port":
   114  			if err := gocty.FromCtyValue(v, &connInfo.Port); err != nil {
   115  				return nil, err
   116  			}
   117  		case "agent":
   118  			connInfo.Agent = v.True()
   119  		case "script_path":
   120  			connInfo.ScriptPath = v.AsString()
   121  		case "target_platform":
   122  			connInfo.TargetPlatform = v.AsString()
   123  		case "timeout":
   124  			connInfo.Timeout = v.AsString()
   125  		case "proxy_scheme":
   126  			connInfo.ProxyScheme = v.AsString()
   127  		case "proxy_host":
   128  			connInfo.ProxyHost = v.AsString()
   129  		case "proxy_port":
   130  			if err := gocty.FromCtyValue(v, &connInfo.ProxyPort); err != nil {
   131  				return nil, err
   132  			}
   133  		case "proxy_user_name":
   134  			connInfo.ProxyUserName = v.AsString()
   135  		case "proxy_user_password":
   136  			connInfo.ProxyUserPassword = v.AsString()
   137  		case "bastion_user":
   138  			connInfo.BastionUser = v.AsString()
   139  		case "bastion_password":
   140  			connInfo.BastionPassword = v.AsString()
   141  		case "bastion_private_key":
   142  			connInfo.BastionPrivateKey = v.AsString()
   143  		case "bastion_certificate":
   144  			connInfo.BastionCertificate = v.AsString()
   145  		case "bastion_host":
   146  			connInfo.BastionHost = v.AsString()
   147  		case "bastion_host_key":
   148  			connInfo.BastionHostKey = v.AsString()
   149  		case "bastion_port":
   150  			if err := gocty.FromCtyValue(v, &connInfo.BastionPort); err != nil {
   151  				return nil, err
   152  			}
   153  		case "agent_identity":
   154  			connInfo.AgentIdentity = v.AsString()
   155  		}
   156  	}
   157  	return connInfo, nil
   158  }
   159  
   160  */
   161  
   162  // parseshared.ConnectionInfo is used to convert the raw configuration into the
   163  // *shared.ConnectionInfo struct.
   164  func parseConnectionInfo(ci shared.ConnectionInfo) (*shared.ConnectionInfo, error) {
   165  
   166  	// Start with the provided connection info, we'll modify it as appropriate
   167  	connInfo := ci
   168  
   169  	/*
   170  		v, err := shared.ConnectionBlockSupersetSchema.CoerceValue(v)
   171  		if err != nil {
   172  			return nil, err
   173  		}
   174  
   175  		connInfo, err := decodeConnInfo(v)
   176  		if err != nil {
   177  			return nil, err
   178  		}
   179  	*/
   180  
   181  	// To default Agent to true, we need to check the raw string, since the
   182  	// decoded boolean can't represent "absence of config".
   183  	//
   184  	// And if SSH_AUTH_SOCK is not set, there's no agent to connect to, so we
   185  	// shouldn't try.
   186  	if connInfo.Agent {
   187  		if os.Getenv("SSH_AUTH_SOCK") == "" {
   188  			connInfo.Agent = false
   189  		}
   190  	}
   191  
   192  	if connInfo.User == "" {
   193  		connInfo.User = DefaultUser
   194  	}
   195  
   196  	// Check if host is empty.
   197  	// Otherwise return error.
   198  	if connInfo.Host == "" {
   199  		return nil, fmt.Errorf("host for provisioner cannot be empty")
   200  	}
   201  
   202  	// Format the host if needed.
   203  	// Needed for IPv6 support.
   204  	connInfo.Host = shared.IpFormat(connInfo.Host)
   205  
   206  	if connInfo.Port == 0 {
   207  		connInfo.Port = DefaultPort
   208  	}
   209  	// Set default targetPlatform to unix if it's empty
   210  	if connInfo.TargetPlatform == "" {
   211  		connInfo.TargetPlatform = TargetPlatformUnix
   212  	} else if connInfo.TargetPlatform != TargetPlatformUnix && connInfo.TargetPlatform != TargetPlatformWindows {
   213  		return nil, fmt.Errorf("target_platform for provisioner has to be either %s or %s", TargetPlatformUnix, TargetPlatformWindows)
   214  	}
   215  	// Choose an appropriate default script path based on the target platform. There is no single
   216  	// suitable default script path which works on both UNIX and Windows targets.
   217  	if connInfo.ScriptPath == "" && connInfo.TargetPlatform == TargetPlatformUnix {
   218  		connInfo.ScriptPath = DefaultUnixScriptPath
   219  	}
   220  	if connInfo.ScriptPath == "" && connInfo.TargetPlatform == TargetPlatformWindows {
   221  		connInfo.ScriptPath = DefaultWindowsScriptPath
   222  	}
   223  	if connInfo.Timeout != "" {
   224  		connInfo.TimeoutVal = safeDuration(connInfo.Timeout, DefaultTimeout)
   225  	} else {
   226  		connInfo.TimeoutVal = DefaultTimeout
   227  	}
   228  
   229  	// Default all bastion config attrs to their non-bastion counterparts
   230  	if connInfo.BastionHost != "" {
   231  		// Format the bastion host if needed.
   232  		// Needed for IPv6 support.
   233  		connInfo.BastionHost = shared.IpFormat(connInfo.BastionHost)
   234  
   235  		if connInfo.BastionUser == "" {
   236  			connInfo.BastionUser = connInfo.User
   237  		}
   238  		if connInfo.BastionPassword == "" {
   239  			connInfo.BastionPassword = connInfo.Password
   240  		}
   241  		if connInfo.BastionPrivateKey == "" {
   242  			connInfo.BastionPrivateKey = connInfo.PrivateKey
   243  		}
   244  		if connInfo.BastionCertificate == "" {
   245  			connInfo.BastionCertificate = connInfo.Certificate
   246  		}
   247  		if connInfo.BastionPort == 0 {
   248  			connInfo.BastionPort = connInfo.Port
   249  		}
   250  	}
   251  
   252  	return &connInfo, nil
   253  }
   254  
   255  // safeDuration returns either the parsed duration or a default value
   256  func safeDuration(dur string, defaultDur time.Duration) time.Duration {
   257  	d, err := time.ParseDuration(dur)
   258  	if err != nil {
   259  		log.Printf("Invalid duration '%s', using default of %s", dur, defaultDur)
   260  		return defaultDur
   261  	}
   262  	return d
   263  }
   264  
   265  // prepareSSHConfig is used to turn the *shared.ConnectionInfo provided into a
   266  // usable *SSHConfig for client initialization.
   267  func prepareSSHConfig(connInfo *shared.ConnectionInfo) (*sshConfig, error) {
   268  	sshAgent, err := connectToAgent(connInfo)
   269  	if err != nil {
   270  		return nil, err
   271  	}
   272  
   273  	host := fmt.Sprintf("%s:%d", connInfo.Host, connInfo.Port)
   274  
   275  	sshConf, err := buildSSHClientConfig(sshClientConfigOpts{
   276  		user:        connInfo.User,
   277  		host:        host,
   278  		privateKey:  connInfo.PrivateKey,
   279  		password:    connInfo.Password,
   280  		hostKey:     connInfo.HostKey,
   281  		certificate: connInfo.Certificate,
   282  		sshAgent:    sshAgent,
   283  	})
   284  	if err != nil {
   285  		return nil, err
   286  	}
   287  
   288  	var p *proxyInfo
   289  
   290  	if connInfo.ProxyHost != "" {
   291  		p = newProxyInfo(
   292  			fmt.Sprintf("%s:%d", connInfo.ProxyHost, connInfo.ProxyPort),
   293  			connInfo.ProxyScheme,
   294  			connInfo.ProxyUserName,
   295  			connInfo.ProxyUserPassword,
   296  		)
   297  	}
   298  
   299  	connectFunc := ConnectFunc("tcp", host, p)
   300  
   301  	var bastionConf *ssh.ClientConfig
   302  	if connInfo.BastionHost != "" {
   303  		bastionHost := fmt.Sprintf("%s:%d", connInfo.BastionHost, connInfo.BastionPort)
   304  
   305  		bastionConf, err = buildSSHClientConfig(sshClientConfigOpts{
   306  			user:        connInfo.BastionUser,
   307  			host:        bastionHost,
   308  			privateKey:  connInfo.BastionPrivateKey,
   309  			password:    connInfo.BastionPassword,
   310  			hostKey:     connInfo.HostKey,
   311  			certificate: connInfo.BastionCertificate,
   312  			sshAgent:    sshAgent,
   313  		})
   314  		if err != nil {
   315  			return nil, err
   316  		}
   317  
   318  		connectFunc = BastionConnectFunc("tcp", bastionHost, bastionConf, "tcp", host, p)
   319  	}
   320  
   321  	config := &sshConfig{
   322  		config:     sshConf,
   323  		connection: connectFunc,
   324  		sshAgent:   sshAgent,
   325  	}
   326  	return config, nil
   327  }
   328  
   329  type sshClientConfigOpts struct {
   330  	privateKey  string
   331  	password    string
   332  	sshAgent    *sshAgent
   333  	certificate string
   334  	user        string
   335  	host        string
   336  	hostKey     string
   337  }
   338  
   339  func buildSSHClientConfig(opts sshClientConfigOpts) (*ssh.ClientConfig, error) {
   340  	hkCallback := ssh.InsecureIgnoreHostKey()
   341  
   342  	if opts.hostKey != "" {
   343  		// The knownhosts package only takes paths to files, but terraform
   344  		// generally wants to handle config data in-memory. Rather than making
   345  		// the known_hosts file an exception, write out the data to a temporary
   346  		// file to create the HostKeyCallback.
   347  		tf, err := ioutil.TempFile("", "tf-known_hosts")
   348  		if err != nil {
   349  			return nil, fmt.Errorf("failed to create temp known_hosts file: %s", err)
   350  		}
   351  		defer tf.Close()
   352  		defer os.RemoveAll(tf.Name())
   353  
   354  		// we mark this as a CA as well, but the host key fallback will still
   355  		// use it as a direct match if the remote host doesn't return a
   356  		// certificate.
   357  		if _, err := tf.WriteString(fmt.Sprintf("@cert-authority %s %s\n", opts.host, opts.hostKey)); err != nil {
   358  			return nil, fmt.Errorf("failed to write temp known_hosts file: %s", err)
   359  		}
   360  		tf.Sync()
   361  
   362  		hkCallback, err = knownhosts.New(tf.Name())
   363  		if err != nil {
   364  			return nil, err
   365  		}
   366  	}
   367  
   368  	conf := &ssh.ClientConfig{
   369  		HostKeyCallback: hkCallback,
   370  		User:            opts.user,
   371  	}
   372  
   373  	if opts.privateKey != "" {
   374  		if opts.certificate != "" {
   375  			log.Println("using client certificate for authentication")
   376  
   377  			certSigner, err := signCertWithPrivateKey(opts.privateKey, opts.certificate)
   378  			if err != nil {
   379  				return nil, err
   380  			}
   381  			conf.Auth = append(conf.Auth, certSigner)
   382  		} else {
   383  			log.Println("using private key for authentication")
   384  
   385  			pubKeyAuth, err := readPrivateKey(opts.privateKey)
   386  			if err != nil {
   387  				return nil, err
   388  			}
   389  			conf.Auth = append(conf.Auth, pubKeyAuth)
   390  		}
   391  	}
   392  
   393  	if opts.password != "" {
   394  		conf.Auth = append(conf.Auth, ssh.Password(opts.password))
   395  		conf.Auth = append(conf.Auth, ssh.KeyboardInteractive(
   396  			PasswordKeyboardInteractive(opts.password)))
   397  	}
   398  
   399  	if opts.sshAgent != nil {
   400  		conf.Auth = append(conf.Auth, opts.sshAgent.Auth())
   401  	}
   402  
   403  	return conf, nil
   404  }
   405  
   406  // Create a Cert Signer and return ssh.AuthMethod
   407  func signCertWithPrivateKey(pk string, certificate string) (ssh.AuthMethod, error) {
   408  	rawPk, err := ssh.ParseRawPrivateKey([]byte(pk))
   409  	if err != nil {
   410  		return nil, fmt.Errorf("failed to parse private key %q: %s", pk, err)
   411  	}
   412  
   413  	pcert, _, _, _, err := ssh.ParseAuthorizedKey([]byte(certificate))
   414  	if err != nil {
   415  		return nil, fmt.Errorf("failed to parse certificate %q: %s", certificate, err)
   416  	}
   417  
   418  	usigner, err := ssh.NewSignerFromKey(rawPk)
   419  	if err != nil {
   420  		return nil, fmt.Errorf("failed to create signer from raw private key %q: %s", rawPk, err)
   421  	}
   422  
   423  	ucertSigner, err := ssh.NewCertSigner(pcert.(*ssh.Certificate), usigner)
   424  	if err != nil {
   425  		return nil, fmt.Errorf("failed to create cert signer %q: %s", usigner, err)
   426  	}
   427  
   428  	return ssh.PublicKeys(ucertSigner), nil
   429  }
   430  
   431  func readPrivateKey(pk string) (ssh.AuthMethod, error) {
   432  	// We parse the private key on our own first so that we can
   433  	// show a nicer error if the private key has a password.
   434  	block, _ := pem.Decode([]byte(pk))
   435  	if block == nil {
   436  		return nil, errors.New("Failed to read ssh private key: no key found")
   437  	}
   438  	if block.Headers["Proc-Type"] == "4,ENCRYPTED" {
   439  		return nil, errors.New(
   440  			"Failed to read ssh private key: password protected keys are\n" +
   441  				"not supported. Please decrypt the key prior to use.")
   442  	}
   443  
   444  	signer, err := ssh.ParsePrivateKey([]byte(pk))
   445  	if err != nil {
   446  		return nil, fmt.Errorf("Failed to parse ssh private key: %s", err)
   447  	}
   448  
   449  	return ssh.PublicKeys(signer), nil
   450  }
   451  
   452  func connectToAgent(connInfo *shared.ConnectionInfo) (*sshAgent, error) {
   453  	if !connInfo.Agent {
   454  		// No agent configured
   455  		return nil, nil
   456  	}
   457  
   458  	agent, conn, err := sshagent.New()
   459  	if err != nil {
   460  		return nil, err
   461  	}
   462  
   463  	// connection close is handled over in Communicator
   464  	return &sshAgent{
   465  		agent: agent,
   466  		conn:  conn,
   467  		id:    connInfo.AgentIdentity,
   468  	}, nil
   469  
   470  }
   471  
   472  // A tiny wrapper around an agent.Agent to expose the ability to close its
   473  // associated connection on request.
   474  type sshAgent struct {
   475  	agent agent.Agent
   476  	conn  net.Conn
   477  	id    string
   478  }
   479  
   480  func (a *sshAgent) Close() error {
   481  	if a.conn == nil {
   482  		return nil
   483  	}
   484  
   485  	return a.conn.Close()
   486  }
   487  
   488  // make an attempt to either read the identity file or find a corresponding
   489  // public key file using the typical openssh naming convention.
   490  // This returns the public key in wire format, or nil when a key is not found.
   491  func findIDPublicKey(id string) []byte {
   492  	for _, d := range idKeyData(id) {
   493  		signer, err := ssh.ParsePrivateKey(d)
   494  		if err == nil {
   495  			log.Println("[DEBUG] parsed id private key")
   496  			pk := signer.PublicKey()
   497  			return pk.Marshal()
   498  		}
   499  
   500  		// try it as a publicKey
   501  		pk, err := ssh.ParsePublicKey(d)
   502  		if err == nil {
   503  			log.Println("[DEBUG] parsed id public key")
   504  			return pk.Marshal()
   505  		}
   506  
   507  		// finally try it as an authorized key
   508  		pk, _, _, _, err = ssh.ParseAuthorizedKey(d)
   509  		if err == nil {
   510  			log.Println("[DEBUG] parsed id authorized key")
   511  			return pk.Marshal()
   512  		}
   513  	}
   514  
   515  	return nil
   516  }
   517  
   518  // Try to read an id file using the id as the file path. Also read the .pub
   519  // file if it exists, as the id file may be encrypted. Return only the file
   520  // data read. We don't need to know what data came from which path, as we will
   521  // try parsing each as a private key, a public key and an authorized key
   522  // regardless.
   523  func idKeyData(id string) [][]byte {
   524  	idPath, err := filepath.Abs(id)
   525  	if err != nil {
   526  		return nil
   527  	}
   528  
   529  	var fileData [][]byte
   530  
   531  	paths := []string{idPath}
   532  
   533  	if !strings.HasSuffix(idPath, ".pub") {
   534  		paths = append(paths, idPath+".pub")
   535  	}
   536  
   537  	for _, p := range paths {
   538  		d, err := ioutil.ReadFile(p)
   539  		if err != nil {
   540  			log.Printf("[DEBUG] error reading %q: %s", p, err)
   541  			continue
   542  		}
   543  		log.Printf("[DEBUG] found identity data at %q", p)
   544  		fileData = append(fileData, d)
   545  	}
   546  
   547  	return fileData
   548  }
   549  
   550  // sortSigners moves a signer with an agent comment field matching the
   551  // agent_identity to the head of the list when attempting authentication. This
   552  // helps when there are more keys loaded in an agent than the host will allow
   553  // attempts.
   554  func (s *sshAgent) sortSigners(signers []ssh.Signer) {
   555  	if s.id == "" || len(signers) < 2 {
   556  		return
   557  	}
   558  
   559  	// if we can locate the public key, either by extracting it from the id or
   560  	// locating the .pub file, then we can more easily determine an exact match
   561  	idPk := findIDPublicKey(s.id)
   562  
   563  	// if we have a signer with a connect field that matches the id, send that
   564  	// first, otherwise put close matches at the front of the list.
   565  	head := 0
   566  	for i := range signers {
   567  		pk := signers[i].PublicKey()
   568  		k, ok := pk.(*agent.Key)
   569  		if !ok {
   570  			continue
   571  		}
   572  
   573  		// check for an exact match first
   574  		if bytes.Equal(pk.Marshal(), idPk) || s.id == k.Comment {
   575  			signers[0], signers[i] = signers[i], signers[0]
   576  			break
   577  		}
   578  
   579  		// no exact match yet, move it to the front if it's close. The agent
   580  		// may have loaded as a full filepath, while the config refers to it by
   581  		// filename only.
   582  		if strings.HasSuffix(k.Comment, s.id) {
   583  			signers[head], signers[i] = signers[i], signers[head]
   584  			head++
   585  			continue
   586  		}
   587  	}
   588  }
   589  
   590  func (s *sshAgent) Signers() ([]ssh.Signer, error) {
   591  	signers, err := s.agent.Signers()
   592  	if err != nil {
   593  		return nil, err
   594  	}
   595  
   596  	s.sortSigners(signers)
   597  	return signers, nil
   598  }
   599  
   600  func (a *sshAgent) Auth() ssh.AuthMethod {
   601  	return ssh.PublicKeysCallback(a.Signers)
   602  }
   603  
   604  func (a *sshAgent) ForwardToAgent(client *ssh.Client) error {
   605  	return agent.ForwardToAgent(client, a.agent)
   606  }