github.com/mmcquillan/packer@v1.1.1-0.20171009221028-c85cf0483a5d/provisioner/ansible/provisioner.go (about)

     1  package ansible
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"crypto/rand"
     7  	"crypto/rsa"
     8  	"crypto/x509"
     9  	"encoding/pem"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"io/ioutil"
    14  	"log"
    15  	"net"
    16  	"os"
    17  	"os/exec"
    18  	"path/filepath"
    19  	"regexp"
    20  	"strconv"
    21  	"strings"
    22  	"sync"
    23  	"unicode"
    24  
    25  	"golang.org/x/crypto/ssh"
    26  
    27  	"github.com/hashicorp/packer/common"
    28  	"github.com/hashicorp/packer/helper/config"
    29  	"github.com/hashicorp/packer/packer"
    30  	"github.com/hashicorp/packer/template/interpolate"
    31  )
    32  
    33  type Config struct {
    34  	common.PackerConfig `mapstructure:",squash"`
    35  	ctx                 interpolate.Context
    36  
    37  	// The command to run ansible
    38  	Command string
    39  
    40  	// Extra options to pass to the ansible command
    41  	ExtraArguments []string `mapstructure:"extra_arguments"`
    42  
    43  	AnsibleEnvVars []string `mapstructure:"ansible_env_vars"`
    44  
    45  	// The main playbook file to execute.
    46  	PlaybookFile         string   `mapstructure:"playbook_file"`
    47  	Groups               []string `mapstructure:"groups"`
    48  	EmptyGroups          []string `mapstructure:"empty_groups"`
    49  	HostAlias            string   `mapstructure:"host_alias"`
    50  	User                 string   `mapstructure:"user"`
    51  	LocalPort            string   `mapstructure:"local_port"`
    52  	SSHHostKeyFile       string   `mapstructure:"ssh_host_key_file"`
    53  	SSHAuthorizedKeyFile string   `mapstructure:"ssh_authorized_key_file"`
    54  	SFTPCmd              string   `mapstructure:"sftp_command"`
    55  	SkipVersionCheck     bool     `mapstructure:"skip_version_check"`
    56  	UseSFTP              bool     `mapstructure:"use_sftp"`
    57  	InventoryDirectory   string   `mapstructure:"inventory_directory"`
    58  	inventoryFile        string
    59  }
    60  
    61  type Provisioner struct {
    62  	config            Config
    63  	adapter           *adapter
    64  	done              chan struct{}
    65  	ansibleVersion    string
    66  	ansibleMajVersion uint
    67  }
    68  
    69  func (p *Provisioner) Prepare(raws ...interface{}) error {
    70  	p.done = make(chan struct{})
    71  
    72  	err := config.Decode(&p.config, &config.DecodeOpts{
    73  		Interpolate:        true,
    74  		InterpolateContext: &p.config.ctx,
    75  		InterpolateFilter: &interpolate.RenderFilter{
    76  			Exclude: []string{},
    77  		},
    78  	}, raws...)
    79  	if err != nil {
    80  		return err
    81  	}
    82  
    83  	// Defaults
    84  	if p.config.Command == "" {
    85  		p.config.Command = "ansible-playbook"
    86  	}
    87  
    88  	if p.config.HostAlias == "" {
    89  		p.config.HostAlias = "default"
    90  	}
    91  
    92  	var errs *packer.MultiError
    93  	err = validateFileConfig(p.config.PlaybookFile, "playbook_file", true)
    94  	if err != nil {
    95  		errs = packer.MultiErrorAppend(errs, err)
    96  	}
    97  
    98  	// Check that the authorized key file exists
    99  	if len(p.config.SSHAuthorizedKeyFile) > 0 {
   100  		err = validateFileConfig(p.config.SSHAuthorizedKeyFile, "ssh_authorized_key_file", true)
   101  		if err != nil {
   102  			log.Println(p.config.SSHAuthorizedKeyFile, "does not exist")
   103  			errs = packer.MultiErrorAppend(errs, err)
   104  		}
   105  	}
   106  	if len(p.config.SSHHostKeyFile) > 0 {
   107  		err = validateFileConfig(p.config.SSHHostKeyFile, "ssh_host_key_file", true)
   108  		if err != nil {
   109  			log.Println(p.config.SSHHostKeyFile, "does not exist")
   110  			errs = packer.MultiErrorAppend(errs, err)
   111  		}
   112  	} else {
   113  		p.config.AnsibleEnvVars = append(p.config.AnsibleEnvVars, "ANSIBLE_HOST_KEY_CHECKING=False")
   114  	}
   115  
   116  	if !p.config.UseSFTP {
   117  		p.config.AnsibleEnvVars = append(p.config.AnsibleEnvVars, "ANSIBLE_SCP_IF_SSH=True")
   118  	}
   119  
   120  	if len(p.config.LocalPort) > 0 {
   121  		if _, err := strconv.ParseUint(p.config.LocalPort, 10, 16); err != nil {
   122  			errs = packer.MultiErrorAppend(errs, fmt.Errorf("local_port: %s must be a valid port", p.config.LocalPort))
   123  		}
   124  	} else {
   125  		p.config.LocalPort = "0"
   126  	}
   127  
   128  	if len(p.config.InventoryDirectory) > 0 {
   129  		err = validateInventoryDirectoryConfig(p.config.InventoryDirectory)
   130  		if err != nil {
   131  			log.Println(p.config.InventoryDirectory, "does not exist")
   132  			errs = packer.MultiErrorAppend(errs, err)
   133  		}
   134  	}
   135  
   136  	if !p.config.SkipVersionCheck {
   137  		err = p.getVersion()
   138  		if err != nil {
   139  			errs = packer.MultiErrorAppend(errs, err)
   140  		}
   141  	}
   142  
   143  	if p.config.User == "" {
   144  		p.config.User = os.Getenv("USER")
   145  	}
   146  	if p.config.User == "" {
   147  		errs = packer.MultiErrorAppend(errs, fmt.Errorf("user: could not determine current user from environment."))
   148  	}
   149  
   150  	if errs != nil && len(errs.Errors) > 0 {
   151  		return errs
   152  	}
   153  	return nil
   154  }
   155  
   156  func (p *Provisioner) getVersion() error {
   157  	out, err := exec.Command(p.config.Command, "--version").Output()
   158  	if err != nil {
   159  		return fmt.Errorf(
   160  			"Error running \"%s --version\": %s", p.config.Command, err.Error())
   161  	}
   162  
   163  	versionRe := regexp.MustCompile(`\w (\d+\.\d+[.\d+]*)`)
   164  	matches := versionRe.FindStringSubmatch(string(out))
   165  	if matches == nil {
   166  		return fmt.Errorf(
   167  			"Could not find %s version in output:\n%s", p.config.Command, string(out))
   168  	}
   169  
   170  	version := matches[1]
   171  	log.Printf("%s version: %s", p.config.Command, version)
   172  	p.ansibleVersion = version
   173  
   174  	majVer, err := strconv.ParseUint(strings.Split(version, ".")[0], 10, 0)
   175  	if err != nil {
   176  		return fmt.Errorf("Could not parse major version from \"%s\".", version)
   177  	}
   178  	p.ansibleMajVersion = uint(majVer)
   179  
   180  	return nil
   181  }
   182  
   183  func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error {
   184  	ui.Say("Provisioning with Ansible...")
   185  
   186  	k, err := newUserKey(p.config.SSHAuthorizedKeyFile)
   187  	if err != nil {
   188  		return err
   189  	}
   190  
   191  	hostSigner, err := newSigner(p.config.SSHHostKeyFile)
   192  	// Remove the private key file
   193  	if len(k.privKeyFile) > 0 {
   194  		defer os.Remove(k.privKeyFile)
   195  	}
   196  
   197  	keyChecker := ssh.CertChecker{
   198  		UserKeyFallback: func(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
   199  			if user := conn.User(); user != p.config.User {
   200  				return nil, errors.New(fmt.Sprintf("authentication failed: %s is not a valid user", user))
   201  			}
   202  
   203  			if !bytes.Equal(k.Marshal(), pubKey.Marshal()) {
   204  				return nil, errors.New("authentication failed: unauthorized key")
   205  			}
   206  
   207  			return nil, nil
   208  		},
   209  	}
   210  
   211  	config := &ssh.ServerConfig{
   212  		AuthLogCallback: func(conn ssh.ConnMetadata, method string, err error) {
   213  			log.Printf("authentication attempt from %s to %s as %s using %s", conn.RemoteAddr(), conn.LocalAddr(), conn.User(), method)
   214  		},
   215  		PublicKeyCallback: keyChecker.Authenticate,
   216  		//NoClientAuth:      true,
   217  	}
   218  
   219  	config.AddHostKey(hostSigner)
   220  
   221  	localListener, err := func() (net.Listener, error) {
   222  		port, err := strconv.ParseUint(p.config.LocalPort, 10, 16)
   223  		if err != nil {
   224  			return nil, err
   225  		}
   226  
   227  		tries := 1
   228  		if port != 0 {
   229  			tries = 10
   230  		}
   231  		for i := 0; i < tries; i++ {
   232  			l, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
   233  			port++
   234  			if err != nil {
   235  				ui.Say(err.Error())
   236  				continue
   237  			}
   238  			_, p.config.LocalPort, err = net.SplitHostPort(l.Addr().String())
   239  			if err != nil {
   240  				ui.Say(err.Error())
   241  				continue
   242  			}
   243  			return l, nil
   244  		}
   245  		return nil, errors.New("Error setting up SSH proxy connection")
   246  	}()
   247  
   248  	if err != nil {
   249  		return err
   250  	}
   251  
   252  	ui = newUi(ui)
   253  	p.adapter = newAdapter(p.done, localListener, config, p.config.SFTPCmd, ui, comm)
   254  
   255  	defer func() {
   256  		log.Print("shutting down the SSH proxy")
   257  		close(p.done)
   258  		p.adapter.Shutdown()
   259  	}()
   260  
   261  	go p.adapter.Serve()
   262  
   263  	if len(p.config.inventoryFile) == 0 {
   264  		tf, err := ioutil.TempFile(p.config.InventoryDirectory, "packer-provisioner-ansible")
   265  		if err != nil {
   266  			return fmt.Errorf("Error preparing inventory file: %s", err)
   267  		}
   268  		defer os.Remove(tf.Name())
   269  
   270  		host := fmt.Sprintf("%s ansible_host=127.0.0.1 ansible_user=%s ansible_port=%s\n",
   271  			p.config.HostAlias, p.config.User, p.config.LocalPort)
   272  		if p.ansibleMajVersion < 2 {
   273  			host = fmt.Sprintf("%s ansible_ssh_host=127.0.0.1 ansible_ssh_user=%s ansible_ssh_port=%s\n",
   274  				p.config.HostAlias, p.config.User, p.config.LocalPort)
   275  		}
   276  
   277  		w := bufio.NewWriter(tf)
   278  		w.WriteString(host)
   279  		for _, group := range p.config.Groups {
   280  			fmt.Fprintf(w, "[%s]\n%s", group, host)
   281  		}
   282  
   283  		for _, group := range p.config.EmptyGroups {
   284  			fmt.Fprintf(w, "[%s]\n", group)
   285  		}
   286  
   287  		if err := w.Flush(); err != nil {
   288  			tf.Close()
   289  			return fmt.Errorf("Error preparing inventory file: %s", err)
   290  		}
   291  		tf.Close()
   292  		p.config.inventoryFile = tf.Name()
   293  		defer func() {
   294  			p.config.inventoryFile = ""
   295  		}()
   296  	}
   297  
   298  	if err := p.executeAnsible(ui, comm, k.privKeyFile); err != nil {
   299  		return fmt.Errorf("Error executing Ansible: %s", err)
   300  	}
   301  
   302  	return nil
   303  }
   304  
   305  func (p *Provisioner) Cancel() {
   306  	if p.done != nil {
   307  		close(p.done)
   308  	}
   309  	if p.adapter != nil {
   310  		p.adapter.Shutdown()
   311  	}
   312  	os.Exit(0)
   313  }
   314  
   315  func (p *Provisioner) executeAnsible(ui packer.Ui, comm packer.Communicator, privKeyFile string) error {
   316  	playbook, _ := filepath.Abs(p.config.PlaybookFile)
   317  	inventory := p.config.inventoryFile
   318  	var envvars []string
   319  
   320  	args := []string{"--extra-vars", fmt.Sprintf("packer_build_name=%s packer_builder_type=%s",
   321  		p.config.PackerBuildName, p.config.PackerBuilderType),
   322  		"-i", inventory, playbook}
   323  	if len(privKeyFile) > 0 {
   324  		args = append(args, "--private-key", privKeyFile)
   325  	}
   326  	args = append(args, p.config.ExtraArguments...)
   327  	if len(p.config.AnsibleEnvVars) > 0 {
   328  		envvars = append(envvars, p.config.AnsibleEnvVars...)
   329  	}
   330  
   331  	cmd := exec.Command(p.config.Command, args...)
   332  
   333  	cmd.Env = os.Environ()
   334  	if len(envvars) > 0 {
   335  		cmd.Env = append(cmd.Env, envvars...)
   336  	}
   337  
   338  	stdout, err := cmd.StdoutPipe()
   339  	if err != nil {
   340  		return err
   341  	}
   342  	stderr, err := cmd.StderrPipe()
   343  	if err != nil {
   344  		return err
   345  	}
   346  
   347  	wg := sync.WaitGroup{}
   348  	repeat := func(r io.ReadCloser) {
   349  		reader := bufio.NewReader(r)
   350  		for {
   351  			line, err := reader.ReadString('\n')
   352  			if line != "" {
   353  				line = strings.TrimRightFunc(line, unicode.IsSpace)
   354  				ui.Message(line)
   355  			}
   356  			if err != nil {
   357  				if err == io.EOF {
   358  					break
   359  				} else {
   360  					ui.Error(err.Error())
   361  					break
   362  				}
   363  			}
   364  		}
   365  		wg.Done()
   366  	}
   367  	wg.Add(2)
   368  	go repeat(stdout)
   369  	go repeat(stderr)
   370  
   371  	ui.Say(fmt.Sprintf("Executing Ansible: %s", strings.Join(cmd.Args, " ")))
   372  	if err := cmd.Start(); err != nil {
   373  		return err
   374  	}
   375  	wg.Wait()
   376  	err = cmd.Wait()
   377  	if err != nil {
   378  		return fmt.Errorf("Non-zero exit status: %s", err)
   379  	}
   380  
   381  	return nil
   382  }
   383  
   384  func validateFileConfig(name string, config string, req bool) error {
   385  	if req {
   386  		if name == "" {
   387  			return fmt.Errorf("%s must be specified.", config)
   388  		}
   389  	}
   390  	info, err := os.Stat(name)
   391  	if err != nil {
   392  		return fmt.Errorf("%s: %s is invalid: %s", config, name, err)
   393  	} else if info.IsDir() {
   394  		return fmt.Errorf("%s: %s must point to a file", config, name)
   395  	}
   396  	return nil
   397  }
   398  
   399  func validateInventoryDirectoryConfig(name string) error {
   400  	info, err := os.Stat(name)
   401  	if err != nil {
   402  		return fmt.Errorf("inventory_directory: %s is invalid: %s", name, err)
   403  	} else if !info.IsDir() {
   404  		return fmt.Errorf("inventory_directory: %s must point to a directory", name)
   405  	}
   406  	return nil
   407  }
   408  
   409  type userKey struct {
   410  	ssh.PublicKey
   411  	privKeyFile string
   412  }
   413  
   414  func newUserKey(pubKeyFile string) (*userKey, error) {
   415  	userKey := new(userKey)
   416  	if len(pubKeyFile) > 0 {
   417  		pubKeyBytes, err := ioutil.ReadFile(pubKeyFile)
   418  		if err != nil {
   419  			return nil, errors.New("Failed to read public key")
   420  		}
   421  		userKey.PublicKey, _, _, _, err = ssh.ParseAuthorizedKey(pubKeyBytes)
   422  		if err != nil {
   423  			return nil, errors.New("Failed to parse authorized key")
   424  		}
   425  
   426  		return userKey, nil
   427  	}
   428  
   429  	key, err := rsa.GenerateKey(rand.Reader, 2048)
   430  	if err != nil {
   431  		return nil, errors.New("Failed to generate key pair")
   432  	}
   433  	userKey.PublicKey, err = ssh.NewPublicKey(key.Public())
   434  	if err != nil {
   435  		return nil, errors.New("Failed to extract public key from generated key pair")
   436  	}
   437  
   438  	// To support Ansible calling back to us we need to write
   439  	// this file down
   440  	privateKeyDer := x509.MarshalPKCS1PrivateKey(key)
   441  	privateKeyBlock := pem.Block{
   442  		Type:    "RSA PRIVATE KEY",
   443  		Headers: nil,
   444  		Bytes:   privateKeyDer,
   445  	}
   446  	tf, err := ioutil.TempFile("", "ansible-key")
   447  	if err != nil {
   448  		return nil, errors.New("failed to create temp file for generated key")
   449  	}
   450  	_, err = tf.Write(pem.EncodeToMemory(&privateKeyBlock))
   451  	if err != nil {
   452  		return nil, errors.New("failed to write private key to temp file")
   453  	}
   454  
   455  	err = tf.Close()
   456  	if err != nil {
   457  		return nil, errors.New("failed to close private key temp file")
   458  	}
   459  	userKey.privKeyFile = tf.Name()
   460  
   461  	return userKey, nil
   462  }
   463  
   464  type signer struct {
   465  	ssh.Signer
   466  }
   467  
   468  func newSigner(privKeyFile string) (*signer, error) {
   469  	signer := new(signer)
   470  
   471  	if len(privKeyFile) > 0 {
   472  		privateBytes, err := ioutil.ReadFile(privKeyFile)
   473  		if err != nil {
   474  			return nil, errors.New("Failed to load private host key")
   475  		}
   476  
   477  		signer.Signer, err = ssh.ParsePrivateKey(privateBytes)
   478  		if err != nil {
   479  			return nil, errors.New("Failed to parse private host key")
   480  		}
   481  
   482  		return signer, nil
   483  	}
   484  
   485  	key, err := rsa.GenerateKey(rand.Reader, 2048)
   486  	if err != nil {
   487  		return nil, errors.New("Failed to generate server key pair")
   488  	}
   489  
   490  	signer.Signer, err = ssh.NewSignerFromKey(key)
   491  	if err != nil {
   492  		return nil, errors.New("Failed to extract private key from generated key pair")
   493  	}
   494  
   495  	return signer, nil
   496  }
   497  
   498  // Ui provides concurrency-safe access to packer.Ui.
   499  type Ui struct {
   500  	sem chan int
   501  	ui  packer.Ui
   502  }
   503  
   504  func newUi(ui packer.Ui) packer.Ui {
   505  	return &Ui{sem: make(chan int, 1), ui: ui}
   506  }
   507  
   508  func (ui *Ui) Ask(s string) (string, error) {
   509  	ui.sem <- 1
   510  	ret, err := ui.ui.Ask(s)
   511  	<-ui.sem
   512  
   513  	return ret, err
   514  }
   515  
   516  func (ui *Ui) Say(s string) {
   517  	ui.sem <- 1
   518  	ui.ui.Say(s)
   519  	<-ui.sem
   520  }
   521  
   522  func (ui *Ui) Message(s string) {
   523  	ui.sem <- 1
   524  	ui.ui.Message(s)
   525  	<-ui.sem
   526  }
   527  
   528  func (ui *Ui) Error(s string) {
   529  	ui.sem <- 1
   530  	ui.ui.Error(s)
   531  	<-ui.sem
   532  }
   533  
   534  func (ui *Ui) Machine(t string, args ...string) {
   535  	ui.sem <- 1
   536  	ui.ui.Machine(t, args...)
   537  	<-ui.sem
   538  }