github.com/amanya/packer@v0.12.1-0.20161117214323-902ac5ab2eb6/provisioner/ansible/adapter.go (about)

     1  package ansible
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"log"
    10  	"net"
    11  	"strings"
    12  
    13  	"github.com/mitchellh/packer/packer"
    14  	"golang.org/x/crypto/ssh"
    15  )
    16  
    17  // An adapter satisfies SSH requests (from an Ansible client) by delegating SSH
    18  // exec and subsystem commands to a packer.Communicator.
    19  type adapter struct {
    20  	done    <-chan struct{}
    21  	l       net.Listener
    22  	config  *ssh.ServerConfig
    23  	sftpCmd string
    24  	ui      packer.Ui
    25  	comm    packer.Communicator
    26  }
    27  
    28  func newAdapter(done <-chan struct{}, l net.Listener, config *ssh.ServerConfig, sftpCmd string, ui packer.Ui, comm packer.Communicator) *adapter {
    29  	return &adapter{
    30  		done:    done,
    31  		l:       l,
    32  		config:  config,
    33  		sftpCmd: sftpCmd,
    34  		ui:      ui,
    35  		comm:    comm,
    36  	}
    37  }
    38  
    39  func (c *adapter) Serve() {
    40  	log.Printf("SSH proxy: serving on %s", c.l.Addr())
    41  
    42  	for {
    43  		// Accept will return if either the underlying connection is closed or if a connection is made.
    44  		// after returning, check to see if c.done can be received. If so, then Accept() returned because
    45  		// the connection has been closed.
    46  		conn, err := c.l.Accept()
    47  		select {
    48  		case <-c.done:
    49  			return
    50  		default:
    51  			if err != nil {
    52  				c.ui.Error(fmt.Sprintf("listen.Accept failed: %v", err))
    53  				continue
    54  			}
    55  			go func(conn net.Conn) {
    56  				if err := c.Handle(conn, c.ui); err != nil {
    57  					c.ui.Error(err.Error())
    58  				}
    59  			}(conn)
    60  		}
    61  	}
    62  }
    63  
    64  func (c *adapter) Handle(conn net.Conn, ui packer.Ui) error {
    65  	log.Print("SSH proxy: accepted connection")
    66  	_, chans, reqs, err := ssh.NewServerConn(conn, c.config)
    67  	if err != nil {
    68  		return errors.New("failed to handshake")
    69  	}
    70  
    71  	// discard all global requests
    72  	go ssh.DiscardRequests(reqs)
    73  
    74  	// Service the incoming NewChannels
    75  	for newChannel := range chans {
    76  		if newChannel.ChannelType() != "session" {
    77  			newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
    78  			continue
    79  		}
    80  
    81  		go func(ch ssh.NewChannel) {
    82  			if err := c.handleSession(ch); err != nil {
    83  				c.ui.Error(err.Error())
    84  			}
    85  		}(newChannel)
    86  	}
    87  
    88  	return nil
    89  }
    90  
    91  func (c *adapter) handleSession(newChannel ssh.NewChannel) error {
    92  	channel, requests, err := newChannel.Accept()
    93  	if err != nil {
    94  		return err
    95  	}
    96  	defer channel.Close()
    97  
    98  	done := make(chan struct{})
    99  
   100  	// Sessions have requests such as "pty-req", "shell", "env", and "exec".
   101  	// see RFC 4254, section 6
   102  	go func(in <-chan *ssh.Request) {
   103  		env := make([]envRequestPayload, 4)
   104  		for req := range in {
   105  			switch req.Type {
   106  			case "pty-req":
   107  				log.Println("ansible provisioner pty-req request")
   108  				// accept pty-req requests, but don't actually do anything. Necessary for OpenSSH and sudo.
   109  				req.Reply(true, nil)
   110  
   111  			case "env":
   112  				req, err := newEnvRequest(req)
   113  				if err != nil {
   114  					c.ui.Error(err.Error())
   115  					req.Reply(false, nil)
   116  					continue
   117  				}
   118  				env = append(env, req.Payload)
   119  				log.Printf("new env request: %s", req.Payload)
   120  				req.Reply(true, nil)
   121  			case "exec":
   122  				req, err := newExecRequest(req)
   123  				if err != nil {
   124  					c.ui.Error(err.Error())
   125  					req.Reply(false, nil)
   126  					close(done)
   127  					continue
   128  				}
   129  
   130  				log.Printf("new exec request: %s", req.Payload)
   131  
   132  				if len(req.Payload) == 0 {
   133  					req.Reply(false, nil)
   134  					close(done)
   135  					return
   136  				}
   137  
   138  				go func(channel ssh.Channel) {
   139  					exit := c.exec(string(req.Payload), channel, channel, channel.Stderr())
   140  
   141  					exitStatus := make([]byte, 4)
   142  					binary.BigEndian.PutUint32(exitStatus, uint32(exit))
   143  					channel.SendRequest("exit-status", false, exitStatus)
   144  					close(done)
   145  				}(channel)
   146  				req.Reply(true, nil)
   147  			case "subsystem":
   148  				req, err := newSubsystemRequest(req)
   149  				if err != nil {
   150  					c.ui.Error(err.Error())
   151  					req.Reply(false, nil)
   152  					continue
   153  				}
   154  
   155  				log.Printf("new subsystem request: %s", req.Payload)
   156  				switch req.Payload {
   157  				case "sftp":
   158  					sftpCmd := c.sftpCmd
   159  					if len(sftpCmd) == 0 {
   160  						sftpCmd = "/usr/lib/sftp-server -e"
   161  					}
   162  
   163  					log.Print("starting sftp subsystem")
   164  					go func() {
   165  						_ = c.remoteExec(sftpCmd, channel, channel, channel.Stderr())
   166  						close(done)
   167  					}()
   168  					req.Reply(true, nil)
   169  				default:
   170  					c.ui.Error(fmt.Sprintf("unsupported subsystem requested: %s", req.Payload))
   171  					req.Reply(false, nil)
   172  				}
   173  			default:
   174  				log.Printf("rejecting %s request", req.Type)
   175  				req.Reply(false, nil)
   176  			}
   177  		}
   178  	}(requests)
   179  
   180  	<-done
   181  	return nil
   182  }
   183  
   184  func (c *adapter) Shutdown() {
   185  	c.l.Close()
   186  }
   187  
   188  func (c *adapter) exec(command string, in io.Reader, out io.Writer, err io.Writer) int {
   189  	var exitStatus int
   190  	switch {
   191  	case strings.HasPrefix(command, "scp ") && serveSCP(command[4:]):
   192  		err := c.scpExec(command[4:], in, out, err)
   193  		if err != nil {
   194  			log.Println(err)
   195  			exitStatus = 1
   196  		}
   197  	default:
   198  		exitStatus = c.remoteExec(command, in, out, err)
   199  	}
   200  	return exitStatus
   201  }
   202  
   203  func serveSCP(args string) bool {
   204  	opts, _ := scpOptions(args)
   205  	return bytes.IndexAny(opts, "tf") >= 0
   206  }
   207  
   208  func (c *adapter) scpExec(args string, in io.Reader, out io.Writer, err io.Writer) error {
   209  	opts, rest := scpOptions(args)
   210  
   211  	if i := bytes.IndexByte(opts, 't'); i >= 0 {
   212  		return scpUploadSession(opts, rest, in, out, c.comm)
   213  	}
   214  
   215  	if i := bytes.IndexByte(opts, 'f'); i >= 0 {
   216  		return scpDownloadSession(opts, rest, in, out, c.comm)
   217  	}
   218  	return errors.New("no scp mode specified")
   219  }
   220  
   221  func (c *adapter) remoteExec(command string, in io.Reader, out io.Writer, err io.Writer) int {
   222  	cmd := &packer.RemoteCmd{
   223  		Stdin:   in,
   224  		Stdout:  out,
   225  		Stderr:  err,
   226  		Command: command,
   227  	}
   228  
   229  	if err := c.comm.Start(cmd); err != nil {
   230  		c.ui.Error(err.Error())
   231  		return cmd.ExitStatus
   232  	}
   233  
   234  	cmd.Wait()
   235  
   236  	return cmd.ExitStatus
   237  }
   238  
   239  type envRequest struct {
   240  	*ssh.Request
   241  	Payload envRequestPayload
   242  }
   243  
   244  type envRequestPayload struct {
   245  	Name  string
   246  	Value string
   247  }
   248  
   249  func (p envRequestPayload) String() string {
   250  	return fmt.Sprintf("%s=%s", p.Name, p.Value)
   251  }
   252  
   253  func newEnvRequest(raw *ssh.Request) (*envRequest, error) {
   254  	r := new(envRequest)
   255  	r.Request = raw
   256  
   257  	if err := ssh.Unmarshal(raw.Payload, &r.Payload); err != nil {
   258  		return nil, err
   259  	}
   260  
   261  	return r, nil
   262  }
   263  
   264  func sshString(buf io.Reader) (string, error) {
   265  	var size uint32
   266  	err := binary.Read(buf, binary.BigEndian, &size)
   267  	if err != nil {
   268  		return "", err
   269  	}
   270  
   271  	b := make([]byte, size)
   272  	err = binary.Read(buf, binary.BigEndian, b)
   273  	if err != nil {
   274  		return "", err
   275  	}
   276  	return string(b), nil
   277  }
   278  
   279  type execRequest struct {
   280  	*ssh.Request
   281  	Payload execRequestPayload
   282  }
   283  
   284  type execRequestPayload string
   285  
   286  func (p execRequestPayload) String() string {
   287  	return string(p)
   288  }
   289  
   290  func newExecRequest(raw *ssh.Request) (*execRequest, error) {
   291  	r := new(execRequest)
   292  	r.Request = raw
   293  	buf := bytes.NewReader(r.Request.Payload)
   294  
   295  	var err error
   296  	var payload string
   297  	if payload, err = sshString(buf); err != nil {
   298  		return nil, err
   299  	}
   300  
   301  	r.Payload = execRequestPayload(payload)
   302  	return r, nil
   303  }
   304  
   305  type subsystemRequest struct {
   306  	*ssh.Request
   307  	Payload subsystemRequestPayload
   308  }
   309  
   310  type subsystemRequestPayload string
   311  
   312  func (p subsystemRequestPayload) String() string {
   313  	return string(p)
   314  }
   315  
   316  func newSubsystemRequest(raw *ssh.Request) (*subsystemRequest, error) {
   317  	r := new(subsystemRequest)
   318  	r.Request = raw
   319  	buf := bytes.NewReader(r.Request.Payload)
   320  
   321  	var err error
   322  	var payload string
   323  	if payload, err = sshString(buf); err != nil {
   324  		return nil, err
   325  	}
   326  
   327  	r.Payload = subsystemRequestPayload(payload)
   328  	return r, nil
   329  }