github.com/rahart/packer@v0.12.2-0.20161229105310-282bb6ad370f/provisioner/ansible/provisioner.go (about)

     1  package ansible
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"crypto/rand"
     7  	"crypto/rsa"
     8  	"crypto/x509"
     9  	"encoding/pem"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"io/ioutil"
    14  	"log"
    15  	"net"
    16  	"os"
    17  	"os/exec"
    18  	"path/filepath"
    19  	"regexp"
    20  	"strconv"
    21  	"strings"
    22  	"sync"
    23  	"unicode"
    24  
    25  	"golang.org/x/crypto/ssh"
    26  
    27  	"github.com/mitchellh/packer/common"
    28  	"github.com/mitchellh/packer/helper/config"
    29  	"github.com/mitchellh/packer/packer"
    30  	"github.com/mitchellh/packer/template/interpolate"
    31  )
    32  
    33  type Config struct {
    34  	common.PackerConfig `mapstructure:",squash"`
    35  	ctx                 interpolate.Context
    36  
    37  	// The command to run ansible
    38  	Command string
    39  
    40  	// Extra options to pass to the ansible command
    41  	ExtraArguments []string `mapstructure:"extra_arguments"`
    42  
    43  	AnsibleEnvVars []string `mapstructure:"ansible_env_vars"`
    44  
    45  	// The main playbook file to execute.
    46  	PlaybookFile         string   `mapstructure:"playbook_file"`
    47  	Groups               []string `mapstructure:"groups"`
    48  	EmptyGroups          []string `mapstructure:"empty_groups"`
    49  	HostAlias            string   `mapstructure:"host_alias"`
    50  	User                 string   `mapstructure:"user"`
    51  	LocalPort            string   `mapstructure:"local_port"`
    52  	SSHHostKeyFile       string   `mapstructure:"ssh_host_key_file"`
    53  	SSHAuthorizedKeyFile string   `mapstructure:"ssh_authorized_key_file"`
    54  	SFTPCmd              string   `mapstructure:"sftp_command"`
    55  	UseSFTP              bool     `mapstructure:"use_sftp"`
    56  	inventoryFile        string
    57  }
    58  
    59  type Provisioner struct {
    60  	config            Config
    61  	adapter           *adapter
    62  	done              chan struct{}
    63  	ansibleVersion    string
    64  	ansibleMajVersion uint
    65  }
    66  
    67  func (p *Provisioner) Prepare(raws ...interface{}) error {
    68  	p.done = make(chan struct{})
    69  
    70  	err := config.Decode(&p.config, &config.DecodeOpts{
    71  		Interpolate:        true,
    72  		InterpolateContext: &p.config.ctx,
    73  		InterpolateFilter: &interpolate.RenderFilter{
    74  			Exclude: []string{},
    75  		},
    76  	}, raws...)
    77  	if err != nil {
    78  		return err
    79  	}
    80  
    81  	// Defaults
    82  	if p.config.Command == "" {
    83  		p.config.Command = "ansible-playbook"
    84  	}
    85  
    86  	if p.config.HostAlias == "" {
    87  		p.config.HostAlias = "default"
    88  	}
    89  
    90  	var errs *packer.MultiError
    91  	err = validateFileConfig(p.config.PlaybookFile, "playbook_file", true)
    92  	if err != nil {
    93  		errs = packer.MultiErrorAppend(errs, err)
    94  	}
    95  
    96  	// Check that the authorized key file exists
    97  	if len(p.config.SSHAuthorizedKeyFile) > 0 {
    98  		err = validateFileConfig(p.config.SSHAuthorizedKeyFile, "ssh_authorized_key_file", true)
    99  		if err != nil {
   100  			log.Println(p.config.SSHAuthorizedKeyFile, "does not exist")
   101  			errs = packer.MultiErrorAppend(errs, err)
   102  		}
   103  	}
   104  	if len(p.config.SSHHostKeyFile) > 0 {
   105  		err = validateFileConfig(p.config.SSHHostKeyFile, "ssh_host_key_file", true)
   106  		if err != nil {
   107  			log.Println(p.config.SSHHostKeyFile, "does not exist")
   108  			errs = packer.MultiErrorAppend(errs, err)
   109  		}
   110  	} else {
   111  		p.config.AnsibleEnvVars = append(p.config.AnsibleEnvVars, "ANSIBLE_HOST_KEY_CHECKING=False")
   112  	}
   113  
   114  	if !p.config.UseSFTP {
   115  		p.config.AnsibleEnvVars = append(p.config.AnsibleEnvVars, "ANSIBLE_SCP_IF_SSH=True")
   116  	}
   117  
   118  	if len(p.config.LocalPort) > 0 {
   119  		if _, err := strconv.ParseUint(p.config.LocalPort, 10, 16); err != nil {
   120  			errs = packer.MultiErrorAppend(errs, fmt.Errorf("local_port: %s must be a valid port", p.config.LocalPort))
   121  		}
   122  	} else {
   123  		p.config.LocalPort = "0"
   124  	}
   125  
   126  	err = p.getVersion()
   127  	if err != nil {
   128  		errs = packer.MultiErrorAppend(errs, err)
   129  	}
   130  
   131  	if p.config.User == "" {
   132  		p.config.User = os.Getenv("USER")
   133  	}
   134  	if p.config.User == "" {
   135  		errs = packer.MultiErrorAppend(errs, fmt.Errorf("user: could not determine current user from environment."))
   136  	}
   137  
   138  	if errs != nil && len(errs.Errors) > 0 {
   139  		return errs
   140  	}
   141  	return nil
   142  }
   143  
   144  func (p *Provisioner) getVersion() error {
   145  	out, err := exec.Command(p.config.Command, "--version").Output()
   146  	if err != nil {
   147  		return err
   148  	}
   149  
   150  	versionRe := regexp.MustCompile(`\w (\d+\.\d+[.\d+]*)`)
   151  	matches := versionRe.FindStringSubmatch(string(out))
   152  	if matches == nil {
   153  		return fmt.Errorf(
   154  			"Could not find %s version in output:\n%s", p.config.Command, string(out))
   155  	}
   156  
   157  	version := matches[1]
   158  	log.Printf("%s version: %s", p.config.Command, version)
   159  	p.ansibleVersion = version
   160  
   161  	majVer, err := strconv.ParseUint(strings.Split(version, ".")[0], 10, 0)
   162  	if err != nil {
   163  		return fmt.Errorf("Could not parse major version from \"%s\".", version)
   164  	}
   165  	p.ansibleMajVersion = uint(majVer)
   166  
   167  	return nil
   168  }
   169  
   170  func (p *Provisioner) Provision(ui packer.Ui, comm packer.Communicator) error {
   171  	ui.Say("Provisioning with Ansible...")
   172  
   173  	k, err := newUserKey(p.config.SSHAuthorizedKeyFile)
   174  	if err != nil {
   175  		return err
   176  	}
   177  
   178  	hostSigner, err := newSigner(p.config.SSHHostKeyFile)
   179  	// Remove the private key file
   180  	if len(k.privKeyFile) > 0 {
   181  		defer os.Remove(k.privKeyFile)
   182  	}
   183  
   184  	keyChecker := ssh.CertChecker{
   185  		UserKeyFallback: func(conn ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
   186  			if user := conn.User(); user != p.config.User {
   187  				return nil, errors.New(fmt.Sprintf("authentication failed: %s is not a valid user", user))
   188  			}
   189  
   190  			if !bytes.Equal(k.Marshal(), pubKey.Marshal()) {
   191  				return nil, errors.New("authentication failed: unauthorized key")
   192  			}
   193  
   194  			return nil, nil
   195  		},
   196  	}
   197  
   198  	config := &ssh.ServerConfig{
   199  		AuthLogCallback: func(conn ssh.ConnMetadata, method string, err error) {
   200  			log.Printf("authentication attempt from %s to %s as %s using %s", conn.RemoteAddr(), conn.LocalAddr(), conn.User(), method)
   201  		},
   202  		PublicKeyCallback: keyChecker.Authenticate,
   203  		//NoClientAuth:      true,
   204  	}
   205  
   206  	config.AddHostKey(hostSigner)
   207  
   208  	localListener, err := func() (net.Listener, error) {
   209  		port, err := strconv.ParseUint(p.config.LocalPort, 10, 16)
   210  		if err != nil {
   211  			return nil, err
   212  		}
   213  
   214  		tries := 1
   215  		if port != 0 {
   216  			tries = 10
   217  		}
   218  		for i := 0; i < tries; i++ {
   219  			l, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port))
   220  			port++
   221  			if err != nil {
   222  				ui.Say(err.Error())
   223  				continue
   224  			}
   225  			_, p.config.LocalPort, err = net.SplitHostPort(l.Addr().String())
   226  			if err != nil {
   227  				ui.Say(err.Error())
   228  				continue
   229  			}
   230  			return l, nil
   231  		}
   232  		return nil, errors.New("Error setting up SSH proxy connection")
   233  	}()
   234  
   235  	if err != nil {
   236  		return err
   237  	}
   238  
   239  	ui = newUi(ui)
   240  	p.adapter = newAdapter(p.done, localListener, config, p.config.SFTPCmd, ui, comm)
   241  
   242  	defer func() {
   243  		log.Print("shutting down the SSH proxy")
   244  		close(p.done)
   245  		p.adapter.Shutdown()
   246  	}()
   247  
   248  	go p.adapter.Serve()
   249  
   250  	if len(p.config.inventoryFile) == 0 {
   251  		tf, err := ioutil.TempFile("", "packer-provisioner-ansible")
   252  		if err != nil {
   253  			return fmt.Errorf("Error preparing inventory file: %s", err)
   254  		}
   255  		defer os.Remove(tf.Name())
   256  
   257  		host := fmt.Sprintf("%s ansible_host=127.0.0.1 ansible_user=%s ansible_port=%s\n",
   258  			p.config.HostAlias, p.config.User, p.config.LocalPort)
   259  		if p.ansibleMajVersion < 2 {
   260  			host = fmt.Sprintf("%s ansible_ssh_host=127.0.0.1 ansible_ssh_user=%s ansible_ssh_port=%s\n",
   261  				p.config.HostAlias, p.config.User, p.config.LocalPort)
   262  		}
   263  
   264  		w := bufio.NewWriter(tf)
   265  		w.WriteString(host)
   266  		for _, group := range p.config.Groups {
   267  			fmt.Fprintf(w, "[%s]\n%s", group, host)
   268  		}
   269  
   270  		for _, group := range p.config.EmptyGroups {
   271  			fmt.Fprintf(w, "[%s]\n", group)
   272  		}
   273  
   274  		if err := w.Flush(); err != nil {
   275  			tf.Close()
   276  			return fmt.Errorf("Error preparing inventory file: %s", err)
   277  		}
   278  		tf.Close()
   279  		p.config.inventoryFile = tf.Name()
   280  		defer func() {
   281  			p.config.inventoryFile = ""
   282  		}()
   283  	}
   284  
   285  	if err := p.executeAnsible(ui, comm, k.privKeyFile); err != nil {
   286  		return fmt.Errorf("Error executing Ansible: %s", err)
   287  	}
   288  
   289  	return nil
   290  }
   291  
   292  func (p *Provisioner) Cancel() {
   293  	if p.done != nil {
   294  		close(p.done)
   295  	}
   296  	if p.adapter != nil {
   297  		p.adapter.Shutdown()
   298  	}
   299  	os.Exit(0)
   300  }
   301  
   302  func (p *Provisioner) executeAnsible(ui packer.Ui, comm packer.Communicator, privKeyFile string) error {
   303  	playbook, _ := filepath.Abs(p.config.PlaybookFile)
   304  	inventory := p.config.inventoryFile
   305  	var envvars []string
   306  
   307  	args := []string{playbook, "-i", inventory}
   308  	if len(privKeyFile) > 0 {
   309  		args = append(args, "--private-key", privKeyFile)
   310  	}
   311  	args = append(args, p.config.ExtraArguments...)
   312  	if len(p.config.AnsibleEnvVars) > 0 {
   313  		envvars = append(envvars, p.config.AnsibleEnvVars...)
   314  	}
   315  
   316  	cmd := exec.Command(p.config.Command, args...)
   317  
   318  	cmd.Env = os.Environ()
   319  	if len(envvars) > 0 {
   320  		cmd.Env = append(cmd.Env, envvars...)
   321  	}
   322  
   323  	stdout, err := cmd.StdoutPipe()
   324  	if err != nil {
   325  		return err
   326  	}
   327  	stderr, err := cmd.StderrPipe()
   328  	if err != nil {
   329  		return err
   330  	}
   331  
   332  	wg := sync.WaitGroup{}
   333  	repeat := func(r io.ReadCloser) {
   334  		reader := bufio.NewReader(r)
   335  		for {
   336  			line, err := reader.ReadString('\n')
   337  			if line != "" {
   338  				line = strings.TrimRightFunc(line, unicode.IsSpace)
   339  				ui.Message(line)
   340  			}
   341  			if err != nil {
   342  				if err == io.EOF {
   343  					break
   344  				} else {
   345  					ui.Error(err.Error())
   346  					break
   347  				}
   348  			}
   349  		}
   350  		wg.Done()
   351  	}
   352  	wg.Add(2)
   353  	go repeat(stdout)
   354  	go repeat(stderr)
   355  
   356  	log.Printf("Executing Ansible: %s", strings.Join(cmd.Args, " "))
   357  	cmd.Start()
   358  	wg.Wait()
   359  	err = cmd.Wait()
   360  	if err != nil {
   361  		return fmt.Errorf("Non-zero exit status: %s", err)
   362  	}
   363  
   364  	return nil
   365  }
   366  
   367  func validateFileConfig(name string, config string, req bool) error {
   368  	if req {
   369  		if name == "" {
   370  			return fmt.Errorf("%s must be specified.", config)
   371  		}
   372  	}
   373  	info, err := os.Stat(name)
   374  	if err != nil {
   375  		return fmt.Errorf("%s: %s is invalid: %s", config, name, err)
   376  	} else if info.IsDir() {
   377  		return fmt.Errorf("%s: %s must point to a file", config, name)
   378  	}
   379  	return nil
   380  }
   381  
   382  type userKey struct {
   383  	ssh.PublicKey
   384  	privKeyFile string
   385  }
   386  
   387  func newUserKey(pubKeyFile string) (*userKey, error) {
   388  	userKey := new(userKey)
   389  	if len(pubKeyFile) > 0 {
   390  		pubKeyBytes, err := ioutil.ReadFile(pubKeyFile)
   391  		if err != nil {
   392  			return nil, errors.New("Failed to read public key")
   393  		}
   394  		userKey.PublicKey, _, _, _, err = ssh.ParseAuthorizedKey(pubKeyBytes)
   395  		if err != nil {
   396  			return nil, errors.New("Failed to parse authorized key")
   397  		}
   398  
   399  		return userKey, nil
   400  	}
   401  
   402  	key, err := rsa.GenerateKey(rand.Reader, 2048)
   403  	if err != nil {
   404  		return nil, errors.New("Failed to generate key pair")
   405  	}
   406  	userKey.PublicKey, err = ssh.NewPublicKey(key.Public())
   407  	if err != nil {
   408  		return nil, errors.New("Failed to extract public key from generated key pair")
   409  	}
   410  
   411  	// To support Ansible calling back to us we need to write
   412  	// this file down
   413  	privateKeyDer := x509.MarshalPKCS1PrivateKey(key)
   414  	privateKeyBlock := pem.Block{
   415  		Type:    "RSA PRIVATE KEY",
   416  		Headers: nil,
   417  		Bytes:   privateKeyDer,
   418  	}
   419  	tf, err := ioutil.TempFile("", "ansible-key")
   420  	if err != nil {
   421  		return nil, errors.New("failed to create temp file for generated key")
   422  	}
   423  	_, err = tf.Write(pem.EncodeToMemory(&privateKeyBlock))
   424  	if err != nil {
   425  		return nil, errors.New("failed to write private key to temp file")
   426  	}
   427  
   428  	err = tf.Close()
   429  	if err != nil {
   430  		return nil, errors.New("failed to close private key temp file")
   431  	}
   432  	userKey.privKeyFile = tf.Name()
   433  
   434  	return userKey, nil
   435  }
   436  
   437  type signer struct {
   438  	ssh.Signer
   439  }
   440  
   441  func newSigner(privKeyFile string) (*signer, error) {
   442  	signer := new(signer)
   443  
   444  	if len(privKeyFile) > 0 {
   445  		privateBytes, err := ioutil.ReadFile(privKeyFile)
   446  		if err != nil {
   447  			return nil, errors.New("Failed to load private host key")
   448  		}
   449  
   450  		signer.Signer, err = ssh.ParsePrivateKey(privateBytes)
   451  		if err != nil {
   452  			return nil, errors.New("Failed to parse private host key")
   453  		}
   454  
   455  		return signer, nil
   456  	}
   457  
   458  	key, err := rsa.GenerateKey(rand.Reader, 2048)
   459  	if err != nil {
   460  		return nil, errors.New("Failed to generate server key pair")
   461  	}
   462  
   463  	signer.Signer, err = ssh.NewSignerFromKey(key)
   464  	if err != nil {
   465  		return nil, errors.New("Failed to extract private key from generated key pair")
   466  	}
   467  
   468  	return signer, nil
   469  }
   470  
   471  // Ui provides concurrency-safe access to packer.Ui.
   472  type Ui struct {
   473  	sem chan int
   474  	ui  packer.Ui
   475  }
   476  
   477  func newUi(ui packer.Ui) packer.Ui {
   478  	return &Ui{sem: make(chan int, 1), ui: ui}
   479  }
   480  
   481  func (ui *Ui) Ask(s string) (string, error) {
   482  	ui.sem <- 1
   483  	ret, err := ui.ui.Ask(s)
   484  	<-ui.sem
   485  
   486  	return ret, err
   487  }
   488  
   489  func (ui *Ui) Say(s string) {
   490  	ui.sem <- 1
   491  	ui.ui.Say(s)
   492  	<-ui.sem
   493  }
   494  
   495  func (ui *Ui) Message(s string) {
   496  	ui.sem <- 1
   497  	ui.ui.Message(s)
   498  	<-ui.sem
   499  }
   500  
   501  func (ui *Ui) Error(s string) {
   502  	ui.sem <- 1
   503  	ui.ui.Error(s)
   504  	<-ui.sem
   505  }
   506  
   507  func (ui *Ui) Machine(t string, args ...string) {
   508  	ui.sem <- 1
   509  	ui.ui.Machine(t, args...)
   510  	<-ui.sem
   511  }