github.com/victortrac/packer@v0.7.6-0.20160602180447-63c7fdb6e41f/provisioner/ansible/provisioner.go (about)

     1  package ansible
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"crypto/rand"
     7  	"crypto/rsa"
     8  	"crypto/x509"
     9  	"encoding/pem"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"io/ioutil"
    14  	"log"
    15  	"net"
    16  	"os"
    17  	"os/exec"
    18  	"path/filepath"
    19  	"regexp"
    20  	"strconv"
    21  	"strings"
    22  	"sync"
    23  
    24  	"golang.org/x/crypto/ssh"
    25  
    26  	"github.com/mitchellh/packer/common"
    27  	"github.com/mitchellh/packer/helper/config"
    28  	"github.com/mitchellh/packer/packer"
    29  	"github.com/mitchellh/packer/template/interpolate"
    30  )
    31  
    32  type Config struct {
    33  	common.PackerConfig `mapstructure:",squash"`
    34  	ctx                 interpolate.Context
    35  
    36  	// The command to run ansible
    37  	Command string
    38  
    39  	// Extra options to pass to the ansible command
    40  	ExtraArguments []string `mapstructure:"extra_arguments"`
    41  
    42  	AnsibleEnvVars []string `mapstructure:"ansible_env_vars"`
    43  
    44  	// The main playbook file to execute.
    45  	PlaybookFile         string   `mapstructure:"playbook_file"`
    46  	Groups               []string `mapstructure:"groups"`
    47  	EmptyGroups          []string `mapstructure:"empty_groups"`
    48  	HostAlias            string   `mapstructure:"host_alias"`
    49  	User                 string   `mapstructure:"user"`
    50  	LocalPort            string   `mapstructure:"local_port"`
    51  	SSHHostKeyFile       string   `mapstructure:"ssh_host_key_file"`
    52  	SSHAuthorizedKeyFile string   `mapstructure:"ssh_authorized_key_file"`
    53  	SFTPCmd              string   `mapstructure:"sftp_command"`
    54  	inventoryFile        string
    55  }
    56  
    57  type Provisioner struct {
    58  	config            Config
    59  	adapter           *adapter
    60  	done              chan struct{}
    61  	ansibleVersion    string
    62  	ansibleMajVersion uint
    63  }
    64  
    65  func (p *Provisioner) Prepare(raws ...interface{}) error {
    66  	p.done = make(chan struct{})
    67  
    68  	err := config.Decode(&p.config, &config.DecodeOpts{
    69  		Interpolate:        true,
    70  		InterpolateContext: &p.config.ctx,
    71  		InterpolateFilter: &interpolate.RenderFilter{
    72  			Exclude: []string{},
    73  		},
    74  	}, raws...)
    75  	if err != nil {
    76  		return err
    77  	}
    78  
    79  	// Defaults
    80  	if p.config.Command == "" {
    81  		p.config.Command = "ansible-playbook"
    82  	}
    83  
    84  	if p.config.HostAlias == "" {
    85  		p.config.HostAlias = "default"
    86  	}
    87  
    88  	var errs *packer.MultiError
    89  	err = validateFileConfig(p.config.PlaybookFile, "playbook_file", true)
    90  	if err != nil {
    91  		errs = packer.MultiErrorAppend(errs, err)
    92  	}
    93  
    94  	// Check that the authorized key file exists
    95  	if len(p.config.SSHAuthorizedKeyFile) > 0 {
    96  		err = validateFileConfig(p.config.SSHAuthorizedKeyFile, "ssh_authorized_key_file", true)
    97  		if err != nil {
    98  			log.Println(p.config.SSHAuthorizedKeyFile, "does not exist")
    99  			errs = packer.MultiErrorAppend(errs, err)
   100  		}
   101  	}
   102  	if len(p.config.SSHHostKeyFile) > 0 {
   103  		err = validateFileConfig(p.config.SSHHostKeyFile, "ssh_host_key_file", true)
   104  		if err != nil {
   105  			log.Println(p.config.SSHHostKeyFile, "does not exist")
   106  			errs = packer.MultiErrorAppend(errs, err)
   107  		}
   108  	}
   109  
   110  	if len(p.config.LocalPort) > 0 {
   111  		if _, err := strconv.ParseUint(p.config.LocalPort, 10, 16); err != nil {
   112  			errs = packer.MultiErrorAppend(errs, fmt.Errorf("local_port: %s must be a valid port", p.config.LocalPort))
   113  		}
   114  	} else {
   115  		p.config.LocalPort = "0"
   116  	}
   117  
   118  	err = p.getVersion()
   119  	if err != nil {
   120  		errs = packer.MultiErrorAppend(errs, err)
   121  	}
   122  
   123  	if p.config.User == "" {
   124  		p.config.User = os.Getenv("USER")
   125  	}
   126  	if p.config.User == "" {
   127  		errs = packer.MultiErrorAppend(errs, fmt.Errorf("user: could not determine current user from environment."))
   128  	}
   129  
   130  	if errs != nil && len(errs.Errors) > 0 {
   131  		return errs
   132  	}
   133  	return nil
   134  }
   135  
   136  func (p *Provisioner) getVersion() error {
   137  	out, err := exec.Command(p.config.Command, "--version").Output()
   138  	if err != nil {
   139  		return err
   140  	}
   141  
   142  	versionRe := regexp.MustCompile(`\w (\d+\.\d+[.\d+]*)`)
   143  	matches := versionRe.FindStringSubmatch(string(out))
   144  	if matches == nil {
   145  		return fmt.Errorf(
   146  			"Could not find %s version in output:\n%s", p.config.Command, string(out))
   147  	}
   148  
   149  	version := matches[1]
   150  	log.Printf("%s version: %s", p.config.Command, version)
   151  	p.ansibleVersion = version
   152  
   153  	majVer, err := strconv.ParseUint(strings.Split(version, ".")[0], 10, 0)
   154  	if err != nil {
   155  		return fmt.Errorf("Could not parse major version from \"%s\".", version)
   156  	}
   157  	p.ansibleMajVersion = uint(majVer)
   158  
   159  	return nil
   160  }
   161  
   162  func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error {
   163  	ui.Say("Provisioning with Ansible...")
   164  
   165  	k, err := newUserKey(p.config.SSHAuthorizedKeyFile)
   166  	if err != nil {
   167  		return err
   168  	}
   169  
   170  	hostSigner, err := newSigner(p.config.SSHHostKeyFile)
   171  	// Remove the private key file
   172  	if len(k.privKeyFile) > 0 {
   173  		defer os.Remove(k.privKeyFile)
   174  	}
   175  
   176  	keyChecker := ssh.CertChecker{
   177  		UserKeyFallback: func(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
   178  			if user := conn.User(); user != p.config.User {
   179  				ui.Say(fmt.Sprintf("%s is not a valid user", user))
   180  				return nil, errors.New("authentication failed")
   181  			}
   182  
   183  			if !bytes.Equal(k.Marshal(), pubKey.Marshal()) {
   184  				ui.Say("unauthorized key")
   185  				return nil, errors.New("authentication failed")
   186  			}
   187  
   188  			return nil, nil
   189  		},
   190  	}
   191  
   192  	config := &ssh.ServerConfig{
   193  		AuthLogCallback: func(conn ssh.ConnMetadata, method string, err error) {
   194  			ui.Say(fmt.Sprintf("authentication attempt from %s to %s as %s using %s", conn.RemoteAddr(), conn.LocalAddr(), conn.User(), method))
   195  		},
   196  		PublicKeyCallback: keyChecker.Authenticate,
   197  		//NoClientAuth:      true,
   198  	}
   199  
   200  	config.AddHostKey(hostSigner)
   201  
   202  	localListener, err := func() (net.Listener, error) {
   203  		port, err := strconv.ParseUint(p.config.LocalPort, 10, 16)
   204  		if err != nil {
   205  			return nil, err
   206  		}
   207  
   208  		tries := 1
   209  		if port != 0 {
   210  			tries = 10
   211  		}
   212  		for i := 0; i < tries; i++ {
   213  			l, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
   214  			port++
   215  			if err != nil {
   216  				ui.Say(err.Error())
   217  				continue
   218  			}
   219  			_, p.config.LocalPort, err = net.SplitHostPort(l.Addr().String())
   220  			if err != nil {
   221  				ui.Say(err.Error())
   222  				continue
   223  			}
   224  			return l, nil
   225  		}
   226  		return nil, errors.New("Error setting up SSH proxy connection")
   227  	}()
   228  
   229  	if err != nil {
   230  		return err
   231  	}
   232  
   233  	ui = newUi(ui)
   234  	p.adapter = newAdapter(p.done, localListener, config, p.config.SFTPCmd, ui, comm)
   235  
   236  	defer func() {
   237  		ui.Say("shutting down the SSH proxy")
   238  		close(p.done)
   239  		p.adapter.Shutdown()
   240  	}()
   241  
   242  	go p.adapter.Serve()
   243  
   244  	if len(p.config.inventoryFile) == 0 {
   245  		tf, err := ioutil.TempFile("", "packer-provisioner-ansible")
   246  		if err != nil {
   247  			return fmt.Errorf("Error preparing inventory file: %s", err)
   248  		}
   249  		defer os.Remove(tf.Name())
   250  
   251  		host := fmt.Sprintf("%s ansible_host=127.0.0.1 ansible_user=%s ansible_port=%s\n",
   252  			p.config.HostAlias, p.config.User, p.config.LocalPort)
   253  		if p.ansibleMajVersion < 2 {
   254  			host = fmt.Sprintf("%s ansible_ssh_host=127.0.0.1 ansible_ssh_user=%s ansible_ssh_port=%s\n",
   255  				p.config.HostAlias, p.config.User, p.config.LocalPort)
   256  		}
   257  
   258  		w := bufio.NewWriter(tf)
   259  		w.WriteString(host)
   260  		for _, group := range p.config.Groups {
   261  			fmt.Fprintf(w, "[%s]\n%s", group, host)
   262  		}
   263  
   264  		for _, group := range p.config.EmptyGroups {
   265  			fmt.Fprintf(w, "[%s]\n", group)
   266  		}
   267  
   268  		if err := w.Flush(); err != nil {
   269  			tf.Close()
   270  			return fmt.Errorf("Error preparing inventory file: %s", err)
   271  		}
   272  		tf.Close()
   273  		p.config.inventoryFile = tf.Name()
   274  		defer func() {
   275  			p.config.inventoryFile = ""
   276  		}()
   277  	}
   278  
   279  	if err := p.executeAnsible(ui, comm, k.privKeyFile, !hostSigner.generated); err != nil {
   280  		return fmt.Errorf("Error executing Ansible: %s", err)
   281  	}
   282  
   283  	return nil
   284  }
   285  
   286  func (p *Provisioner) Cancel() {
   287  	if p.done != nil {
   288  		close(p.done)
   289  	}
   290  	if p.adapter != nil {
   291  		p.adapter.Shutdown()
   292  	}
   293  	os.Exit(0)
   294  }
   295  
   296  func (p *Provisioner) executeAnsible(ui packer.Ui, comm packer.Communicator, privKeyFile string, checkHostKey bool) error {
   297  	playbook, _ := filepath.Abs(p.config.PlaybookFile)
   298  	inventory := p.config.inventoryFile
   299  	var envvars []string
   300  
   301  	args := []string{playbook, "-i", inventory}
   302  	if len(privKeyFile) > 0 {
   303  		args = append(args, "--private-key", privKeyFile)
   304  	}
   305  	args = append(args, p.config.ExtraArguments...)
   306  	if len(p.config.AnsibleEnvVars) > 0 {
   307  		envvars = append(envvars, p.config.AnsibleEnvVars...)
   308  	}
   309  
   310  	cmd := exec.Command(p.config.Command, args...)
   311  
   312  	cmd.Env = os.Environ()
   313  	if len(envvars) > 0 {
   314  		cmd.Env = append(cmd.Env, envvars...)
   315  	}
   316  
   317  	if !checkHostKey {
   318  		cmd.Env = append(cmd.Env, "ANSIBLE_HOST_KEY_CHECKING=False")
   319  	}
   320  
   321  	stdout, err := cmd.StdoutPipe()
   322  	if err != nil {
   323  		return err
   324  	}
   325  	stderr, err := cmd.StderrPipe()
   326  	if err != nil {
   327  		return err
   328  	}
   329  
   330  	wg := sync.WaitGroup{}
   331  	repeat := func(r io.ReadCloser) {
   332  		scanner := bufio.NewScanner(r)
   333  		for scanner.Scan() {
   334  			ui.Message(scanner.Text())
   335  		}
   336  		if err := scanner.Err(); err != nil {
   337  			ui.Error(err.Error())
   338  		}
   339  		wg.Done()
   340  	}
   341  	wg.Add(2)
   342  	go repeat(stdout)
   343  	go repeat(stderr)
   344  
   345  	ui.Say(fmt.Sprintf("Executing Ansible: %s", strings.Join(cmd.Args, " ")))
   346  	cmd.Start()
   347  	wg.Wait()
   348  	err = cmd.Wait()
   349  	if err != nil {
   350  		return fmt.Errorf("Non-zero exit status: %s", err)
   351  	}
   352  
   353  	return nil
   354  }
   355  
   356  func validateFileConfig(name string, config string, req bool) error {
   357  	if req {
   358  		if name == "" {
   359  			return fmt.Errorf("%s must be specified.", config)
   360  		}
   361  	}
   362  	info, err := os.Stat(name)
   363  	if err != nil {
   364  		return fmt.Errorf("%s: %s is invalid: %s", config, name, err)
   365  	} else if info.IsDir() {
   366  		return fmt.Errorf("%s: %s must point to a file", config, name)
   367  	}
   368  	return nil
   369  }
   370  
   371  type userKey struct {
   372  	ssh.PublicKey
   373  	privKeyFile string
   374  }
   375  
   376  func newUserKey(pubKeyFile string) (*userKey, error) {
   377  	userKey := new(userKey)
   378  	if len(pubKeyFile) > 0 {
   379  		pubKeyBytes, err := ioutil.ReadFile(pubKeyFile)
   380  		if err != nil {
   381  			return nil, errors.New("Failed to read public key")
   382  		}
   383  		userKey.PublicKey, _, _, _, err = ssh.ParseAuthorizedKey(pubKeyBytes)
   384  		if err != nil {
   385  			return nil, errors.New("Failed to parse authorized key")
   386  		}
   387  
   388  		return userKey, nil
   389  	}
   390  
   391  	key, err := rsa.GenerateKey(rand.Reader, 2048)
   392  	if err != nil {
   393  		return nil, errors.New("Failed to generate key pair")
   394  	}
   395  	userKey.PublicKey, err = ssh.NewPublicKey(key.Public())
   396  	if err != nil {
   397  		return nil, errors.New("Failed to extract public key from generated key pair")
   398  	}
   399  
   400  	// To support Ansible calling back to us we need to write
   401  	// this file down
   402  	privateKeyDer := x509.MarshalPKCS1PrivateKey(key)
   403  	privateKeyBlock := pem.Block{
   404  		Type:    "RSA PRIVATE KEY",
   405  		Headers: nil,
   406  		Bytes:   privateKeyDer,
   407  	}
   408  	tf, err := ioutil.TempFile("", "ansible-key")
   409  	if err != nil {
   410  		return nil, errors.New("failed to create temp file for generated key")
   411  	}
   412  	_, err = tf.Write(pem.EncodeToMemory(&privateKeyBlock))
   413  	if err != nil {
   414  		return nil, errors.New("failed to write private key to temp file")
   415  	}
   416  
   417  	err = tf.Close()
   418  	if err != nil {
   419  		return nil, errors.New("failed to close private key temp file")
   420  	}
   421  	userKey.privKeyFile = tf.Name()
   422  
   423  	return userKey, nil
   424  }
   425  
   426  type signer struct {
   427  	ssh.Signer
   428  	generated bool
   429  }
   430  
   431  func newSigner(privKeyFile string) (*signer, error) {
   432  	signer := new(signer)
   433  
   434  	if len(privKeyFile) > 0 {
   435  		privateBytes, err := ioutil.ReadFile(privKeyFile)
   436  		if err != nil {
   437  			return nil, errors.New("Failed to load private host key")
   438  		}
   439  
   440  		signer.Signer, err = ssh.ParsePrivateKey(privateBytes)
   441  		if err != nil {
   442  			return nil, errors.New("Failed to parse private host key")
   443  		}
   444  
   445  		return signer, nil
   446  	}
   447  
   448  	key, err := rsa.GenerateKey(rand.Reader, 2048)
   449  	if err != nil {
   450  		return nil, errors.New("Failed to generate server key pair")
   451  	}
   452  
   453  	signer.Signer, err = ssh.NewSignerFromKey(key)
   454  	if err != nil {
   455  		return nil, errors.New("Failed to extract private key from generated key pair")
   456  	}
   457  	signer.generated = true
   458  
   459  	return signer, nil
   460  }
   461  
   462  // Ui provides concurrency-safe access to packer.Ui.
   463  type Ui struct {
   464  	sem chan int
   465  	ui  packer.Ui
   466  }
   467  
   468  func newUi(ui packer.Ui) packer.Ui {
   469  	return &Ui{sem: make(chan int, 1), ui: ui}
   470  }
   471  
   472  func (ui *Ui) Ask(s string) (string, error) {
   473  	ui.sem <- 1
   474  	ret, err := ui.ui.Ask(s)
   475  	<-ui.sem
   476  
   477  	return ret, err
   478  }
   479  
   480  func (ui *Ui) Say(s string) {
   481  	ui.sem <- 1
   482  	ui.ui.Say(s)
   483  	<-ui.sem
   484  }
   485  
   486  func (ui *Ui) Message(s string) {
   487  	ui.sem <- 1
   488  	ui.ui.Message(s)
   489  	<-ui.sem
   490  }
   491  
   492  func (ui *Ui) Error(s string) {
   493  	ui.sem <- 1
   494  	ui.ui.Error(s)
   495  	<-ui.sem
   496  }
   497  
   498  func (ui *Ui) Machine(t string, args ...string) {
   499  	ui.sem <- 1
   500  	ui.ui.Machine(t, args...)
   501  	<-ui.sem
   502  }