github.com/dacamp/packer@v0.10.2/provisioner/ansible/provisioner.go (about)

     1  package ansible
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"crypto/rand"
     7  	"crypto/rsa"
     8  	"crypto/x509"
     9  	"encoding/pem"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"io/ioutil"
    14  	"log"
    15  	"net"
    16  	"os"
    17  	"os/exec"
    18  	"path/filepath"
    19  	"regexp"
    20  	"strconv"
    21  	"strings"
    22  	"sync"
    23  
    24  	"golang.org/x/crypto/ssh"
    25  
    26  	"github.com/mitchellh/packer/common"
    27  	"github.com/mitchellh/packer/helper/config"
    28  	"github.com/mitchellh/packer/packer"
    29  	"github.com/mitchellh/packer/template/interpolate"
    30  )
    31  
    32  type Config struct {
    33  	common.PackerConfig `mapstructure:",squash"`
    34  	ctx                 interpolate.Context
    35  
    36  	// The command to run ansible
    37  	Command string
    38  
    39  	// Extra options to pass to the ansible command
    40  	ExtraArguments []string `mapstructure:"extra_arguments"`
    41  
    42  	AnsibleEnvVars []string `mapstructure:"ansible_env_vars"`
    43  
    44  	// The main playbook file to execute.
    45  	PlaybookFile         string   `mapstructure:"playbook_file"`
    46  	Groups               []string `mapstructure:"groups"`
    47  	EmptyGroups          []string `mapstructure:"empty_groups"`
    48  	HostAlias            string   `mapstructure:"host_alias"`
    49  	User                 string   `mapstructure:"user"`
    50  	LocalPort            string   `mapstructure:"local_port"`
    51  	SSHHostKeyFile       string   `mapstructure:"ssh_host_key_file"`
    52  	SSHAuthorizedKeyFile string   `mapstructure:"ssh_authorized_key_file"`
    53  	SFTPCmd              string   `mapstructure:"sftp_command"`
    54  	inventoryFile        string
    55  }
    56  
    57  type Provisioner struct {
    58  	config            Config
    59  	adapter           *adapter
    60  	done              chan struct{}
    61  	ansibleVersion    string
    62  	ansibleMajVersion uint
    63  }
    64  
    65  func (p *Provisioner) Prepare(raws ...interface{}) error {
    66  	p.done = make(chan struct{})
    67  
    68  	err := config.Decode(&p.config, &config.DecodeOpts{
    69  		Interpolate:        true,
    70  		InterpolateContext: &p.config.ctx,
    71  		InterpolateFilter: &interpolate.RenderFilter{
    72  			Exclude: []string{},
    73  		},
    74  	}, raws...)
    75  	if err != nil {
    76  		return err
    77  	}
    78  
    79  	// Defaults
    80  	if p.config.Command == "" {
    81  		p.config.Command = "ansible-playbook"
    82  	}
    83  
    84  	if p.config.HostAlias == "" {
    85  		p.config.HostAlias = "default"
    86  	}
    87  
    88  	var errs *packer.MultiError
    89  	err = validateFileConfig(p.config.PlaybookFile, "playbook_file", true)
    90  	if err != nil {
    91  		errs = packer.MultiErrorAppend(errs, err)
    92  	}
    93  
    94  	// Check that the authorized key file exists
    95  	if len(p.config.SSHAuthorizedKeyFile) > 0 {
    96  		err = validateFileConfig(p.config.SSHAuthorizedKeyFile, "ssh_authorized_key_file", true)
    97  		if err != nil {
    98  			log.Println(p.config.SSHAuthorizedKeyFile, "does not exist")
    99  			errs = packer.MultiErrorAppend(errs, err)
   100  		}
   101  	}
   102  	if len(p.config.SSHHostKeyFile) > 0 {
   103  		err = validateFileConfig(p.config.SSHHostKeyFile, "ssh_host_key_file", true)
   104  		if err != nil {
   105  			log.Println(p.config.SSHHostKeyFile, "does not exist")
   106  			errs = packer.MultiErrorAppend(errs, err)
   107  		}
   108  	}
   109  
   110  	if len(p.config.LocalPort) > 0 {
   111  		if _, err := strconv.ParseUint(p.config.LocalPort, 10, 16); err != nil {
   112  			errs = packer.MultiErrorAppend(errs, fmt.Errorf("local_port: %s must be a valid port", p.config.LocalPort))
   113  		}
   114  	} else {
   115  		p.config.LocalPort = "0"
   116  	}
   117  
   118  	err = p.getVersion()
   119  	if err != nil {
   120  		errs = packer.MultiErrorAppend(errs, err)
   121  	}
   122  
   123  	if p.config.User == "" {
   124  		p.config.User = os.Getenv("USER")
   125  	}
   126  	if p.config.User == "" {
   127  		errs = packer.MultiErrorAppend(errs, fmt.Errorf("user: could not determine current user from environment."))
   128  	}
   129  
   130  	if errs != nil && len(errs.Errors) > 0 {
   131  		return errs
   132  	}
   133  	return nil
   134  }
   135  
   136  func (p *Provisioner) getVersion() error {
   137  	out, err := exec.Command(p.config.Command, "--version").Output()
   138  	if err != nil {
   139  		return err
   140  	}
   141  
   142  	versionRe := regexp.MustCompile(`\w (\d+\.\d+[.\d+]*)`)
   143  	matches := versionRe.FindStringSubmatch(string(out))
   144  	if matches == nil {
   145  		return fmt.Errorf(
   146  			"Could not find %s version in output:\n%s", p.config.Command, string(out))
   147  	}
   148  
   149  	version := matches[1]
   150  	log.Printf("%s version: %s", p.config.Command, version)
   151  	p.ansibleVersion = version
   152  
   153  	majVer, err := strconv.ParseUint(strings.Split(version, ".")[0], 10, 0)
   154  	if err != nil {
   155  		return fmt.Errorf("Could not parse major version from \"%s\".", version)
   156  	}
   157  	p.ansibleMajVersion = uint(majVer)
   158  
   159  	return nil
   160  }
   161  
   162  func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error {
   163  	ui.Say("Provisioning with Ansible...")
   164  
   165  	k, err := newUserKey(p.config.SSHAuthorizedKeyFile)
   166  	if err != nil {
   167  		return err
   168  	}
   169  
   170  	hostSigner, err := newSigner(p.config.SSHHostKeyFile)
   171  	// Remove the private key file
   172  	if len(k.privKeyFile) > 0 {
   173  		defer os.Remove(k.privKeyFile)
   174  	}
   175  
   176  	keyChecker := ssh.CertChecker{
   177  		UserKeyFallback: func(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
   178  			if user := conn.User(); user != p.config.User {
   179  				ui.Say(fmt.Sprintf("%s is not a valid user", user))
   180  				return nil, errors.New("authentication failed")
   181  			}
   182  
   183  			if !bytes.Equal(k.Marshal(), pubKey.Marshal()) {
   184  				ui.Say("unauthorized key")
   185  				return nil, errors.New("authentication failed")
   186  			}
   187  
   188  			return nil, nil
   189  		},
   190  	}
   191  
   192  	config := &ssh.ServerConfig{
   193  		AuthLogCallback: func(conn ssh.ConnMetadata, method string, err error) {
   194  			ui.Say(fmt.Sprintf("authentication attempt from %s to %s as %s using %s", conn.RemoteAddr(), conn.LocalAddr(), conn.User(), method))
   195  		},
   196  		PublicKeyCallback: keyChecker.Authenticate,
   197  		//NoClientAuth:      true,
   198  	}
   199  
   200  	config.AddHostKey(hostSigner)
   201  
   202  	localListener, err := func() (net.Listener, error) {
   203  		port, err := strconv.ParseUint(p.config.LocalPort, 10, 16)
   204  		if err != nil {
   205  			return nil, err
   206  		}
   207  
   208  		tries := 1
   209  		if port != 0 {
   210  			tries = 10
   211  		}
   212  		for i := 0; i < tries; i++ {
   213  			l, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
   214  			port++
   215  			if err != nil {
   216  				ui.Say(err.Error())
   217  				continue
   218  			}
   219  			_, p.config.LocalPort, err = net.SplitHostPort(l.Addr().String())
   220  			if err != nil {
   221  				ui.Say(err.Error())
   222  				continue
   223  			}
   224  			return l, nil
   225  		}
   226  		return nil, errors.New("Error setting up SSH proxy connection")
   227  	}()
   228  
   229  	if err != nil {
   230  		return err
   231  	}
   232  
   233  	ui = newUi(ui)
   234  	p.adapter = newAdapter(p.done, localListener, config, p.config.SFTPCmd, ui, comm)
   235  
   236  	defer func() {
   237  		ui.Say("shutting down the SSH proxy")
   238  		close(p.done)
   239  		p.adapter.Shutdown()
   240  	}()
   241  
   242  	go p.adapter.Serve()
   243  
   244  	if len(p.config.inventoryFile) == 0 {
   245  		tf, err := ioutil.TempFile("", "packer-provisioner-ansible")
   246  		if err != nil {
   247  			return fmt.Errorf("Error preparing inventory file: %s", err)
   248  		}
   249  		defer os.Remove(tf.Name())
   250  
   251  		host := fmt.Sprintf("%s ansible_host=127.0.0.1 ansible_user=%s ansible_port=%s\n",
   252  			p.config.HostAlias, p.config.User, p.config.LocalPort)
   253  		if p.ansibleMajVersion < 2 {
   254  			host = fmt.Sprintf("%s ansible_ssh_host=127.0.0.1 ansible_ssh_user=%s ansible_ssh_port=%s\n",
   255  				p.config.HostAlias, p.config.User, p.config.LocalPort)
   256  		}
   257  
   258  		w := bufio.NewWriter(tf)
   259  		w.WriteString(host)
   260  		for _, group := range p.config.Groups {
   261  			fmt.Fprintf(w, "[%s]\n%s", group, host)
   262  		}
   263  
   264  		for _, group := range p.config.EmptyGroups {
   265  			fmt.Fprintf(w, "[%s]\n", group)
   266  		}
   267  
   268  		if err := w.Flush(); err != nil {
   269  			tf.Close()
   270  			return fmt.Errorf("Error preparing inventory file: %s", err)
   271  		}
   272  		tf.Close()
   273  		p.config.inventoryFile = tf.Name()
   274  		defer func() {
   275  			p.config.inventoryFile = ""
   276  		}()
   277  	}
   278  
   279  	if err := p.executeAnsible(ui, comm, k.privKeyFile, !hostSigner.generated); err != nil {
   280  		return fmt.Errorf("Error executing Ansible: %s", err)
   281  	}
   282  
   283  	return nil
   284  }
   285  
   286  func (p *Provisioner) Cancel() {
   287  	if p.done != nil {
   288  		close(p.done)
   289  	}
   290  	if p.adapter != nil {
   291  		p.adapter.Shutdown()
   292  	}
   293  	os.Exit(0)
   294  }
   295  
   296  func (p *Provisioner) executeAnsible(ui packer.Ui, comm packer.Communicator, privKeyFile string, checkHostKey bool) error {
   297  	playbook, _ := filepath.Abs(p.config.PlaybookFile)
   298  	inventory := p.config.inventoryFile
   299  	var envvars []string
   300  
   301  	args := []string{playbook, "-i", inventory}
   302  	if len(privKeyFile) > 0 {
   303  		args = append(args, "--private-key", privKeyFile)
   304  	}
   305  	args = append(args, p.config.ExtraArguments...)
   306  	if len(p.config.AnsibleEnvVars) > 0 {
   307  		envvars = append(envvars, p.config.AnsibleEnvVars...)
   308  	}
   309  
   310  	cmd := exec.Command(p.config.Command, args...)
   311  
   312  	cmd.Env = os.Environ()
   313  	if len(envvars) > 0 {
   314  		cmd.Env = append(cmd.Env, envvars...)
   315  	} else if !checkHostKey {
   316  		cmd.Env = append(cmd.Env, "ANSIBLE_HOST_KEY_CHECKING=False")
   317  	}
   318  
   319  	stdout, err := cmd.StdoutPipe()
   320  	if err != nil {
   321  		return err
   322  	}
   323  	stderr, err := cmd.StderrPipe()
   324  	if err != nil {
   325  		return err
   326  	}
   327  
   328  	wg := sync.WaitGroup{}
   329  	repeat := func(r io.ReadCloser) {
   330  		scanner := bufio.NewScanner(r)
   331  		for scanner.Scan() {
   332  			ui.Message(scanner.Text())
   333  		}
   334  		if err := scanner.Err(); err != nil {
   335  			ui.Error(err.Error())
   336  		}
   337  		wg.Done()
   338  	}
   339  	wg.Add(2)
   340  	go repeat(stdout)
   341  	go repeat(stderr)
   342  
   343  	ui.Say(fmt.Sprintf("Executing Ansible: %s", strings.Join(cmd.Args, " ")))
   344  	cmd.Start()
   345  	wg.Wait()
   346  	err = cmd.Wait()
   347  	if err != nil {
   348  		return fmt.Errorf("Non-zero exit status: %s", err)
   349  	}
   350  
   351  	return nil
   352  }
   353  
   354  func validateFileConfig(name string, config string, req bool) error {
   355  	if req {
   356  		if name == "" {
   357  			return fmt.Errorf("%s must be specified.", config)
   358  		}
   359  	}
   360  	info, err := os.Stat(name)
   361  	if err != nil {
   362  		return fmt.Errorf("%s: %s is invalid: %s", config, name, err)
   363  	} else if info.IsDir() {
   364  		return fmt.Errorf("%s: %s must point to a file", config, name)
   365  	}
   366  	return nil
   367  }
   368  
   369  type userKey struct {
   370  	ssh.PublicKey
   371  	privKeyFile string
   372  }
   373  
   374  func newUserKey(pubKeyFile string) (*userKey, error) {
   375  	userKey := new(userKey)
   376  	if len(pubKeyFile) > 0 {
   377  		pubKeyBytes, err := ioutil.ReadFile(pubKeyFile)
   378  		if err != nil {
   379  			return nil, errors.New("Failed to read public key")
   380  		}
   381  		userKey.PublicKey, _, _, _, err = ssh.ParseAuthorizedKey(pubKeyBytes)
   382  		if err != nil {
   383  			return nil, errors.New("Failed to parse authorized key")
   384  		}
   385  
   386  		return userKey, nil
   387  	}
   388  
   389  	key, err := rsa.GenerateKey(rand.Reader, 2048)
   390  	if err != nil {
   391  		return nil, errors.New("Failed to generate key pair")
   392  	}
   393  	userKey.PublicKey, err = ssh.NewPublicKey(key.Public())
   394  	if err != nil {
   395  		return nil, errors.New("Failed to extract public key from generated key pair")
   396  	}
   397  
   398  	// To support Ansible calling back to us we need to write
   399  	// this file down
   400  	privateKeyDer := x509.MarshalPKCS1PrivateKey(key)
   401  	privateKeyBlock := pem.Block{
   402  		Type:    "RSA PRIVATE KEY",
   403  		Headers: nil,
   404  		Bytes:   privateKeyDer,
   405  	}
   406  	tf, err := ioutil.TempFile("", "ansible-key")
   407  	if err != nil {
   408  		return nil, errors.New("failed to create temp file for generated key")
   409  	}
   410  	_, err = tf.Write(pem.EncodeToMemory(&privateKeyBlock))
   411  	if err != nil {
   412  		return nil, errors.New("failed to write private key to temp file")
   413  	}
   414  
   415  	err = tf.Close()
   416  	if err != nil {
   417  		return nil, errors.New("failed to close private key temp file")
   418  	}
   419  	userKey.privKeyFile = tf.Name()
   420  
   421  	return userKey, nil
   422  }
   423  
   424  type signer struct {
   425  	ssh.Signer
   426  	generated bool
   427  }
   428  
   429  func newSigner(privKeyFile string) (*signer, error) {
   430  	signer := new(signer)
   431  
   432  	if len(privKeyFile) > 0 {
   433  		privateBytes, err := ioutil.ReadFile(privKeyFile)
   434  		if err != nil {
   435  			return nil, errors.New("Failed to load private host key")
   436  		}
   437  
   438  		signer.Signer, err = ssh.ParsePrivateKey(privateBytes)
   439  		if err != nil {
   440  			return nil, errors.New("Failed to parse private host key")
   441  		}
   442  
   443  		return signer, nil
   444  	}
   445  
   446  	key, err := rsa.GenerateKey(rand.Reader, 2048)
   447  	if err != nil {
   448  		return nil, errors.New("Failed to generate server key pair")
   449  	}
   450  
   451  	signer.Signer, err = ssh.NewSignerFromKey(key)
   452  	if err != nil {
   453  		return nil, errors.New("Failed to extract private key from generated key pair")
   454  	}
   455  	signer.generated = true
   456  
   457  	return signer, nil
   458  }
   459  
   460  // Ui provides concurrency-safe access to packer.Ui.
   461  type Ui struct {
   462  	sem chan int
   463  	ui  packer.Ui
   464  }
   465  
   466  func newUi(ui packer.Ui) packer.Ui {
   467  	return &Ui{sem: make(chan int, 1), ui: ui}
   468  }
   469  
   470  func (ui *Ui) Ask(s string) (string, error) {
   471  	ui.sem <- 1
   472  	ret, err := ui.ui.Ask(s)
   473  	<-ui.sem
   474  
   475  	return ret, err
   476  }
   477  
   478  func (ui *Ui) Say(s string) {
   479  	ui.sem <- 1
   480  	ui.ui.Say(s)
   481  	<-ui.sem
   482  }
   483  
   484  func (ui *Ui) Message(s string) {
   485  	ui.sem <- 1
   486  	ui.ui.Message(s)
   487  	<-ui.sem
   488  }
   489  
   490  func (ui *Ui) Error(s string) {
   491  	ui.sem <- 1
   492  	ui.ui.Error(s)
   493  	<-ui.sem
   494  }
   495  
   496  func (ui *Ui) Machine(t string, args ...string) {
   497  	ui.sem <- 1
   498  	ui.ui.Machine(t, args...)
   499  	<-ui.sem
   500  }