github.com/StackPointCloud/packer@v0.10.2-0.20180716202532-b28098e0f79b/provisioner/ansible/provisioner.go (about)

     1  package ansible
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"crypto/rand"
     7  	"crypto/rsa"
     8  	"crypto/x509"
     9  	"encoding/pem"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"io/ioutil"
    14  	"log"
    15  	"net"
    16  	"os"
    17  	"os/exec"
    18  	"os/user"
    19  	"path/filepath"
    20  	"regexp"
    21  	"strconv"
    22  	"strings"
    23  	"sync"
    24  	"unicode"
    25  
    26  	"golang.org/x/crypto/ssh"
    27  
    28  	"github.com/hashicorp/packer/common"
    29  	"github.com/hashicorp/packer/helper/config"
    30  	"github.com/hashicorp/packer/packer"
    31  	"github.com/hashicorp/packer/template/interpolate"
    32  )
    33  
    34  type Config struct {
    35  	common.PackerConfig `mapstructure:",squash"`
    36  	ctx                 interpolate.Context
    37  
    38  	// The command to run ansible
    39  	Command string
    40  
    41  	// Extra options to pass to the ansible command
    42  	ExtraArguments []string `mapstructure:"extra_arguments"`
    43  
    44  	AnsibleEnvVars []string `mapstructure:"ansible_env_vars"`
    45  
    46  	// The main playbook file to execute.
    47  	PlaybookFile         string   `mapstructure:"playbook_file"`
    48  	Groups               []string `mapstructure:"groups"`
    49  	EmptyGroups          []string `mapstructure:"empty_groups"`
    50  	HostAlias            string   `mapstructure:"host_alias"`
    51  	User                 string   `mapstructure:"user"`
    52  	LocalPort            string   `mapstructure:"local_port"`
    53  	SSHHostKeyFile       string   `mapstructure:"ssh_host_key_file"`
    54  	SSHAuthorizedKeyFile string   `mapstructure:"ssh_authorized_key_file"`
    55  	SFTPCmd              string   `mapstructure:"sftp_command"`
    56  	SkipVersionCheck     bool     `mapstructure:"skip_version_check"`
    57  	UseSFTP              bool     `mapstructure:"use_sftp"`
    58  	InventoryDirectory   string   `mapstructure:"inventory_directory"`
    59  	InventoryFile        string   `mapstructure:"inventory_file"`
    60  }
    61  
    62  type Provisioner struct {
    63  	config            Config
    64  	adapter           *adapter
    65  	done              chan struct{}
    66  	ansibleVersion    string
    67  	ansibleMajVersion uint
    68  }
    69  
    70  func (p *Provisioner) Prepare(raws ...interface{}) error {
    71  	p.done = make(chan struct{})
    72  
    73  	err := config.Decode(&p.config, &config.DecodeOpts{
    74  		Interpolate:        true,
    75  		InterpolateContext: &p.config.ctx,
    76  		InterpolateFilter: &interpolate.RenderFilter{
    77  			Exclude: []string{},
    78  		},
    79  	}, raws...)
    80  	if err != nil {
    81  		return err
    82  	}
    83  
    84  	// Defaults
    85  	if p.config.Command == "" {
    86  		p.config.Command = "ansible-playbook"
    87  	}
    88  
    89  	if p.config.HostAlias == "" {
    90  		p.config.HostAlias = "default"
    91  	}
    92  
    93  	var errs *packer.MultiError
    94  	err = validateFileConfig(p.config.PlaybookFile, "playbook_file", true)
    95  	if err != nil {
    96  		errs = packer.MultiErrorAppend(errs, err)
    97  	}
    98  
    99  	// Check that the authorized key file exists
   100  	if len(p.config.SSHAuthorizedKeyFile) > 0 {
   101  		err = validateFileConfig(p.config.SSHAuthorizedKeyFile, "ssh_authorized_key_file", true)
   102  		if err != nil {
   103  			log.Println(p.config.SSHAuthorizedKeyFile, "does not exist")
   104  			errs = packer.MultiErrorAppend(errs, err)
   105  		}
   106  	}
   107  	if len(p.config.SSHHostKeyFile) > 0 {
   108  		err = validateFileConfig(p.config.SSHHostKeyFile, "ssh_host_key_file", true)
   109  		if err != nil {
   110  			log.Println(p.config.SSHHostKeyFile, "does not exist")
   111  			errs = packer.MultiErrorAppend(errs, err)
   112  		}
   113  	} else {
   114  		p.config.AnsibleEnvVars = append(p.config.AnsibleEnvVars, "ANSIBLE_HOST_KEY_CHECKING=False")
   115  	}
   116  
   117  	if !p.config.UseSFTP {
   118  		p.config.AnsibleEnvVars = append(p.config.AnsibleEnvVars, "ANSIBLE_SCP_IF_SSH=True")
   119  	}
   120  
   121  	if len(p.config.LocalPort) > 0 {
   122  		if _, err := strconv.ParseUint(p.config.LocalPort, 10, 16); err != nil {
   123  			errs = packer.MultiErrorAppend(errs, fmt.Errorf("local_port: %s must be a valid port", p.config.LocalPort))
   124  		}
   125  	} else {
   126  		p.config.LocalPort = "0"
   127  	}
   128  
   129  	if len(p.config.InventoryDirectory) > 0 {
   130  		err = validateInventoryDirectoryConfig(p.config.InventoryDirectory)
   131  		if err != nil {
   132  			log.Println(p.config.InventoryDirectory, "does not exist")
   133  			errs = packer.MultiErrorAppend(errs, err)
   134  		}
   135  	}
   136  
   137  	if !p.config.SkipVersionCheck {
   138  		err = p.getVersion()
   139  		if err != nil {
   140  			errs = packer.MultiErrorAppend(errs, err)
   141  		}
   142  	}
   143  
   144  	if p.config.User == "" {
   145  		usr, err := user.Current()
   146  		if err != nil {
   147  			errs = packer.MultiErrorAppend(errs, err)
   148  		} else {
   149  			p.config.User = usr.Username
   150  		}
   151  	}
   152  	if p.config.User == "" {
   153  		errs = packer.MultiErrorAppend(errs, fmt.Errorf("user: could not determine current user from environment."))
   154  	}
   155  
   156  	if errs != nil && len(errs.Errors) > 0 {
   157  		return errs
   158  	}
   159  	return nil
   160  }
   161  
   162  func (p *Provisioner) getVersion() error {
   163  	out, err := exec.Command(p.config.Command, "--version").Output()
   164  	if err != nil {
   165  		return fmt.Errorf(
   166  			"Error running \"%s --version\": %s", p.config.Command, err.Error())
   167  	}
   168  
   169  	versionRe := regexp.MustCompile(`\w (\d+\.\d+[.\d+]*)`)
   170  	matches := versionRe.FindStringSubmatch(string(out))
   171  	if matches == nil {
   172  		return fmt.Errorf(
   173  			"Could not find %s version in output:\n%s", p.config.Command, string(out))
   174  	}
   175  
   176  	version := matches[1]
   177  	log.Printf("%s version: %s", p.config.Command, version)
   178  	p.ansibleVersion = version
   179  
   180  	majVer, err := strconv.ParseUint(strings.Split(version, ".")[0], 10, 0)
   181  	if err != nil {
   182  		return fmt.Errorf("Could not parse major version from \"%s\".", version)
   183  	}
   184  	p.ansibleMajVersion = uint(majVer)
   185  
   186  	return nil
   187  }
   188  
   189  func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error {
   190  	ui.Say("Provisioning with Ansible...")
   191  
   192  	k, err := newUserKey(p.config.SSHAuthorizedKeyFile)
   193  	if err != nil {
   194  		return err
   195  	}
   196  
   197  	hostSigner, err := newSigner(p.config.SSHHostKeyFile)
   198  	// Remove the private key file
   199  	if len(k.privKeyFile) > 0 {
   200  		defer os.Remove(k.privKeyFile)
   201  	}
   202  
   203  	keyChecker := ssh.CertChecker{
   204  		UserKeyFallback: func(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
   205  			if user := conn.User(); user != p.config.User {
   206  				return nil, errors.New(fmt.Sprintf("authentication failed: %s is not a valid user", user))
   207  			}
   208  
   209  			if !bytes.Equal(k.Marshal(), pubKey.Marshal()) {
   210  				return nil, errors.New("authentication failed: unauthorized key")
   211  			}
   212  
   213  			return nil, nil
   214  		},
   215  	}
   216  
   217  	config := &ssh.ServerConfig{
   218  		AuthLogCallback: func(conn ssh.ConnMetadata, method string, err error) {
   219  			log.Printf("authentication attempt from %s to %s as %s using %s", conn.RemoteAddr(), conn.LocalAddr(), conn.User(), method)
   220  		},
   221  		PublicKeyCallback: keyChecker.Authenticate,
   222  		//NoClientAuth:      true,
   223  	}
   224  
   225  	config.AddHostKey(hostSigner)
   226  
   227  	localListener, err := func() (net.Listener, error) {
   228  		port, err := strconv.ParseUint(p.config.LocalPort, 10, 16)
   229  		if err != nil {
   230  			return nil, err
   231  		}
   232  
   233  		tries := 1
   234  		if port != 0 {
   235  			tries = 10
   236  		}
   237  		for i := 0; i < tries; i++ {
   238  			l, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
   239  			port++
   240  			if err != nil {
   241  				ui.Say(err.Error())
   242  				continue
   243  			}
   244  			_, p.config.LocalPort, err = net.SplitHostPort(l.Addr().String())
   245  			if err != nil {
   246  				ui.Say(err.Error())
   247  				continue
   248  			}
   249  			return l, nil
   250  		}
   251  		return nil, errors.New("Error setting up SSH proxy connection")
   252  	}()
   253  
   254  	if err != nil {
   255  		return err
   256  	}
   257  
   258  	ui = newUi(ui)
   259  	p.adapter = newAdapter(p.done, localListener, config, p.config.SFTPCmd, ui, comm)
   260  
   261  	defer func() {
   262  		log.Print("shutting down the SSH proxy")
   263  		close(p.done)
   264  		p.adapter.Shutdown()
   265  	}()
   266  
   267  	go p.adapter.Serve()
   268  
   269  	if len(p.config.InventoryFile) == 0 {
   270  		tf, err := ioutil.TempFile(p.config.InventoryDirectory, "packer-provisioner-ansible")
   271  		if err != nil {
   272  			return fmt.Errorf("Error preparing inventory file: %s", err)
   273  		}
   274  		defer os.Remove(tf.Name())
   275  
   276  		host := fmt.Sprintf("%s ansible_host=127.0.0.1 ansible_user=%s ansible_port=%s\n",
   277  			p.config.HostAlias, p.config.User, p.config.LocalPort)
   278  		if p.ansibleMajVersion < 2 {
   279  			host = fmt.Sprintf("%s ansible_ssh_host=127.0.0.1 ansible_ssh_user=%s ansible_ssh_port=%s\n",
   280  				p.config.HostAlias, p.config.User, p.config.LocalPort)
   281  		}
   282  
   283  		w := bufio.NewWriter(tf)
   284  		w.WriteString(host)
   285  		for _, group := range p.config.Groups {
   286  			fmt.Fprintf(w, "[%s]\n%s", group, host)
   287  		}
   288  
   289  		for _, group := range p.config.EmptyGroups {
   290  			fmt.Fprintf(w, "[%s]\n", group)
   291  		}
   292  
   293  		if err := w.Flush(); err != nil {
   294  			tf.Close()
   295  			return fmt.Errorf("Error preparing inventory file: %s", err)
   296  		}
   297  		tf.Close()
   298  		p.config.InventoryFile = tf.Name()
   299  		defer func() {
   300  			p.config.InventoryFile = ""
   301  		}()
   302  	}
   303  
   304  	if err := p.executeAnsible(ui, comm, k.privKeyFile); err != nil {
   305  		return fmt.Errorf("Error executing Ansible: %s", err)
   306  	}
   307  
   308  	return nil
   309  }
   310  
   311  func (p *Provisioner) Cancel() {
   312  	if p.done != nil {
   313  		close(p.done)
   314  	}
   315  	if p.adapter != nil {
   316  		p.adapter.Shutdown()
   317  	}
   318  	os.Exit(0)
   319  }
   320  
   321  func (p *Provisioner) executeAnsible(ui packer.Ui, comm packer.Communicator, privKeyFile string) error {
   322  	playbook, _ := filepath.Abs(p.config.PlaybookFile)
   323  	inventory := p.config.InventoryFile
   324  	if len(p.config.InventoryDirectory) > 0 {
   325  		inventory = p.config.InventoryDirectory
   326  	}
   327  	var envvars []string
   328  
   329  	args := []string{"--extra-vars", fmt.Sprintf("packer_build_name=%s packer_builder_type=%s",
   330  		p.config.PackerBuildName, p.config.PackerBuilderType),
   331  		"-i", inventory, playbook}
   332  	if len(privKeyFile) > 0 {
   333  		// Changed this from using --private-key to supplying -e ansible_ssh_private_key_file as the latter
   334  		// is treated as a highest priority variable, and thus prevents overriding by dynamic variables
   335  		// as seen in #5852
   336  		// args = append(args, "--private-key", privKeyFile)
   337  		args = append(args, "-e", fmt.Sprintf("ansible_ssh_private_key_file=%s", privKeyFile))
   338  	}
   339  	args = append(args, p.config.ExtraArguments...)
   340  	if len(p.config.AnsibleEnvVars) > 0 {
   341  		envvars = append(envvars, p.config.AnsibleEnvVars...)
   342  	}
   343  
   344  	cmd := exec.Command(p.config.Command, args...)
   345  
   346  	cmd.Env = os.Environ()
   347  	if len(envvars) > 0 {
   348  		cmd.Env = append(cmd.Env, envvars...)
   349  	}
   350  
   351  	stdout, err := cmd.StdoutPipe()
   352  	if err != nil {
   353  		return err
   354  	}
   355  	stderr, err := cmd.StderrPipe()
   356  	if err != nil {
   357  		return err
   358  	}
   359  
   360  	wg := sync.WaitGroup{}
   361  	repeat := func(r io.ReadCloser) {
   362  		reader := bufio.NewReader(r)
   363  		for {
   364  			line, err := reader.ReadString('\n')
   365  			if line != "" {
   366  				line = strings.TrimRightFunc(line, unicode.IsSpace)
   367  				ui.Message(line)
   368  			}
   369  			if err != nil {
   370  				if err == io.EOF {
   371  					break
   372  				} else {
   373  					ui.Error(err.Error())
   374  					break
   375  				}
   376  			}
   377  		}
   378  		wg.Done()
   379  	}
   380  	wg.Add(2)
   381  	go repeat(stdout)
   382  	go repeat(stderr)
   383  
   384  	ui.Say(fmt.Sprintf("Executing Ansible: %s", strings.Join(cmd.Args, " ")))
   385  	if err := cmd.Start(); err != nil {
   386  		return err
   387  	}
   388  	wg.Wait()
   389  	err = cmd.Wait()
   390  	if err != nil {
   391  		return fmt.Errorf("Non-zero exit status: %s", err)
   392  	}
   393  
   394  	return nil
   395  }
   396  
   397  func validateFileConfig(name string, config string, req bool) error {
   398  	if req {
   399  		if name == "" {
   400  			return fmt.Errorf("%s must be specified.", config)
   401  		}
   402  	}
   403  	info, err := os.Stat(name)
   404  	if err != nil {
   405  		return fmt.Errorf("%s: %s is invalid: %s", config, name, err)
   406  	} else if info.IsDir() {
   407  		return fmt.Errorf("%s: %s must point to a file", config, name)
   408  	}
   409  	return nil
   410  }
   411  
   412  func validateInventoryDirectoryConfig(name string) error {
   413  	info, err := os.Stat(name)
   414  	if err != nil {
   415  		return fmt.Errorf("inventory_directory: %s is invalid: %s", name, err)
   416  	} else if !info.IsDir() {
   417  		return fmt.Errorf("inventory_directory: %s must point to a directory", name)
   418  	}
   419  	return nil
   420  }
   421  
   422  type userKey struct {
   423  	ssh.PublicKey
   424  	privKeyFile string
   425  }
   426  
   427  func newUserKey(pubKeyFile string) (*userKey, error) {
   428  	userKey := new(userKey)
   429  	if len(pubKeyFile) > 0 {
   430  		pubKeyBytes, err := ioutil.ReadFile(pubKeyFile)
   431  		if err != nil {
   432  			return nil, errors.New("Failed to read public key")
   433  		}
   434  		userKey.PublicKey, _, _, _, err = ssh.ParseAuthorizedKey(pubKeyBytes)
   435  		if err != nil {
   436  			return nil, errors.New("Failed to parse authorized key")
   437  		}
   438  
   439  		return userKey, nil
   440  	}
   441  
   442  	key, err := rsa.GenerateKey(rand.Reader, 2048)
   443  	if err != nil {
   444  		return nil, errors.New("Failed to generate key pair")
   445  	}
   446  	userKey.PublicKey, err = ssh.NewPublicKey(key.Public())
   447  	if err != nil {
   448  		return nil, errors.New("Failed to extract public key from generated key pair")
   449  	}
   450  
   451  	// To support Ansible calling back to us we need to write
   452  	// this file down
   453  	privateKeyDer := x509.MarshalPKCS1PrivateKey(key)
   454  	privateKeyBlock := pem.Block{
   455  		Type:    "RSA PRIVATE KEY",
   456  		Headers: nil,
   457  		Bytes:   privateKeyDer,
   458  	}
   459  	tf, err := ioutil.TempFile("", "ansible-key")
   460  	if err != nil {
   461  		return nil, errors.New("failed to create temp file for generated key")
   462  	}
   463  	_, err = tf.Write(pem.EncodeToMemory(&privateKeyBlock))
   464  	if err != nil {
   465  		return nil, errors.New("failed to write private key to temp file")
   466  	}
   467  
   468  	err = tf.Close()
   469  	if err != nil {
   470  		return nil, errors.New("failed to close private key temp file")
   471  	}
   472  	userKey.privKeyFile = tf.Name()
   473  
   474  	return userKey, nil
   475  }
   476  
   477  type signer struct {
   478  	ssh.Signer
   479  }
   480  
   481  func newSigner(privKeyFile string) (*signer, error) {
   482  	signer := new(signer)
   483  
   484  	if len(privKeyFile) > 0 {
   485  		privateBytes, err := ioutil.ReadFile(privKeyFile)
   486  		if err != nil {
   487  			return nil, errors.New("Failed to load private host key")
   488  		}
   489  
   490  		signer.Signer, err = ssh.ParsePrivateKey(privateBytes)
   491  		if err != nil {
   492  			return nil, errors.New("Failed to parse private host key")
   493  		}
   494  
   495  		return signer, nil
   496  	}
   497  
   498  	key, err := rsa.GenerateKey(rand.Reader, 2048)
   499  	if err != nil {
   500  		return nil, errors.New("Failed to generate server key pair")
   501  	}
   502  
   503  	signer.Signer, err = ssh.NewSignerFromKey(key)
   504  	if err != nil {
   505  		return nil, errors.New("Failed to extract private key from generated key pair")
   506  	}
   507  
   508  	return signer, nil
   509  }
   510  
   511  // Ui provides concurrency-safe access to packer.Ui.
   512  type Ui struct {
   513  	sem chan int
   514  	ui  packer.Ui
   515  }
   516  
   517  func newUi(ui packer.Ui) packer.Ui {
   518  	return &Ui{sem: make(chan int, 1), ui: ui}
   519  }
   520  
   521  func (ui *Ui) Ask(s string) (string, error) {
   522  	ui.sem <- 1
   523  	ret, err := ui.ui.Ask(s)
   524  	<-ui.sem
   525  
   526  	return ret, err
   527  }
   528  
   529  func (ui *Ui) Say(s string) {
   530  	ui.sem <- 1
   531  	ui.ui.Say(s)
   532  	<-ui.sem
   533  }
   534  
   535  func (ui *Ui) Message(s string) {
   536  	ui.sem <- 1
   537  	ui.ui.Message(s)
   538  	<-ui.sem
   539  }
   540  
   541  func (ui *Ui) Error(s string) {
   542  	ui.sem <- 1
   543  	ui.ui.Error(s)
   544  	<-ui.sem
   545  }
   546  
   547  func (ui *Ui) Machine(t string, args ...string) {
   548  	ui.sem <- 1
   549  	ui.ui.Machine(t, args...)
   550  	<-ui.sem
   551  }