github.com/ddnomad/packer@v1.3.2/provisioner/ansible/adapter.go (about)

     1  package ansible
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"log"
    10  	"net"
    11  	"strings"
    12  
    13  	"github.com/google/shlex"
    14  	"github.com/hashicorp/packer/packer"
    15  	"golang.org/x/crypto/ssh"
    16  )
    17  
    18  // An adapter satisfies SSH requests (from an Ansible client) by delegating SSH
    19  // exec and subsystem commands to a packer.Communicator.
    20  type adapter struct {
    21  	done    <-chan struct{}
    22  	l       net.Listener
    23  	config  *ssh.ServerConfig
    24  	sftpCmd string
    25  	ui      packer.Ui
    26  	comm    packer.Communicator
    27  }
    28  
    29  func newAdapter(done <-chan struct{}, l net.Listener, config *ssh.ServerConfig, sftpCmd string, ui packer.Ui, comm packer.Communicator) *adapter {
    30  	return &adapter{
    31  		done:    done,
    32  		l:       l,
    33  		config:  config,
    34  		sftpCmd: sftpCmd,
    35  		ui:      ui,
    36  		comm:    comm,
    37  	}
    38  }
    39  
    40  func (c *adapter) Serve() {
    41  	log.Printf("SSH proxy: serving on %s", c.l.Addr())
    42  
    43  	for {
    44  		// Accept will return if either the underlying connection is closed or if a connection is made.
    45  		// after returning, check to see if c.done can be received. If so, then Accept() returned because
    46  		// the connection has been closed.
    47  		conn, err := c.l.Accept()
    48  		select {
    49  		case <-c.done:
    50  			return
    51  		default:
    52  			if err != nil {
    53  				c.ui.Error(fmt.Sprintf("listen.Accept failed: %v", err))
    54  				continue
    55  			}
    56  			go func(conn net.Conn) {
    57  				if err := c.Handle(conn, c.ui); err != nil {
    58  					c.ui.Error(err.Error())
    59  				}
    60  			}(conn)
    61  		}
    62  	}
    63  }
    64  
    65  func (c *adapter) Handle(conn net.Conn, ui packer.Ui) error {
    66  	log.Print("SSH proxy: accepted connection")
    67  	_, chans, reqs, err := ssh.NewServerConn(conn, c.config)
    68  	if err != nil {
    69  		return errors.New("failed to handshake")
    70  	}
    71  
    72  	// discard all global requests
    73  	go ssh.DiscardRequests(reqs)
    74  
    75  	// Service the incoming NewChannels
    76  	for newChannel := range chans {
    77  		if newChannel.ChannelType() != "session" {
    78  			newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
    79  			continue
    80  		}
    81  
    82  		go func(ch ssh.NewChannel) {
    83  			if err := c.handleSession(ch); err != nil {
    84  				c.ui.Error(err.Error())
    85  			}
    86  		}(newChannel)
    87  	}
    88  
    89  	return nil
    90  }
    91  
    92  func (c *adapter) handleSession(newChannel ssh.NewChannel) error {
    93  	channel, requests, err := newChannel.Accept()
    94  	if err != nil {
    95  		return err
    96  	}
    97  	defer channel.Close()
    98  
    99  	done := make(chan struct{})
   100  
   101  	// Sessions have requests such as "pty-req", "shell", "env", and "exec".
   102  	// see RFC 4254, section 6
   103  	go func(in <-chan *ssh.Request) {
   104  		env := make([]envRequestPayload, 4)
   105  		for req := range in {
   106  			switch req.Type {
   107  			case "pty-req":
   108  				log.Println("ansible provisioner pty-req request")
   109  				// accept pty-req requests, but don't actually do anything. Necessary for OpenSSH and sudo.
   110  				req.Reply(true, nil)
   111  
   112  			case "env":
   113  				req, err := newEnvRequest(req)
   114  				if err != nil {
   115  					c.ui.Error(err.Error())
   116  					req.Reply(false, nil)
   117  					continue
   118  				}
   119  				env = append(env, req.Payload)
   120  				log.Printf("new env request: %s", req.Payload)
   121  				req.Reply(true, nil)
   122  			case "exec":
   123  				req, err := newExecRequest(req)
   124  				if err != nil {
   125  					c.ui.Error(err.Error())
   126  					req.Reply(false, nil)
   127  					close(done)
   128  					continue
   129  				}
   130  
   131  				log.Printf("new exec request: %s", req.Payload)
   132  
   133  				if len(req.Payload) == 0 {
   134  					req.Reply(false, nil)
   135  					close(done)
   136  					return
   137  				}
   138  
   139  				go func(channel ssh.Channel) {
   140  					exit := c.exec(string(req.Payload), channel, channel, channel.Stderr())
   141  
   142  					exitStatus := make([]byte, 4)
   143  					binary.BigEndian.PutUint32(exitStatus, uint32(exit))
   144  					channel.SendRequest("exit-status", false, exitStatus)
   145  					close(done)
   146  				}(channel)
   147  				req.Reply(true, nil)
   148  			case "subsystem":
   149  				req, err := newSubsystemRequest(req)
   150  				if err != nil {
   151  					c.ui.Error(err.Error())
   152  					req.Reply(false, nil)
   153  					continue
   154  				}
   155  
   156  				log.Printf("new subsystem request: %s", req.Payload)
   157  				switch req.Payload {
   158  				case "sftp":
   159  					sftpCmd := c.sftpCmd
   160  					if len(sftpCmd) == 0 {
   161  						sftpCmd = "/usr/lib/sftp-server -e"
   162  					}
   163  
   164  					log.Print("starting sftp subsystem")
   165  					go func() {
   166  						_ = c.remoteExec(sftpCmd, channel, channel, channel.Stderr())
   167  						close(done)
   168  					}()
   169  					req.Reply(true, nil)
   170  				default:
   171  					c.ui.Error(fmt.Sprintf("unsupported subsystem requested: %s", req.Payload))
   172  					req.Reply(false, nil)
   173  				}
   174  			default:
   175  				log.Printf("rejecting %s request", req.Type)
   176  				req.Reply(false, nil)
   177  			}
   178  		}
   179  	}(requests)
   180  
   181  	<-done
   182  	return nil
   183  }
   184  
   185  func (c *adapter) Shutdown() {
   186  	c.l.Close()
   187  }
   188  
   189  func (c *adapter) exec(command string, in io.Reader, out io.Writer, err io.Writer) int {
   190  	var exitStatus int
   191  	switch {
   192  	case strings.HasPrefix(command, "scp ") && serveSCP(command[4:]):
   193  		err := c.scpExec(command[4:], in, out)
   194  		if err != nil {
   195  			log.Println(err)
   196  			exitStatus = 1
   197  		}
   198  	default:
   199  		exitStatus = c.remoteExec(command, in, out, err)
   200  	}
   201  	return exitStatus
   202  }
   203  
   204  func serveSCP(args string) bool {
   205  	opts, _ := scpOptions(args)
   206  	return bytes.IndexAny(opts, "tf") >= 0
   207  }
   208  
   209  func (c *adapter) scpExec(args string, in io.Reader, out io.Writer) error {
   210  	opts, rest := scpOptions(args)
   211  
   212  	// remove the quoting that ansible added to rest for shell safety.
   213  	shargs, err := shlex.Split(rest)
   214  	if err != nil {
   215  		return err
   216  	}
   217  	rest = strings.Join(shargs, "")
   218  
   219  	if i := bytes.IndexByte(opts, 't'); i >= 0 {
   220  		return scpUploadSession(opts, rest, in, out, c.comm)
   221  	}
   222  
   223  	if i := bytes.IndexByte(opts, 'f'); i >= 0 {
   224  		return scpDownloadSession(opts, rest, in, out, c.comm)
   225  	}
   226  	return errors.New("no scp mode specified")
   227  }
   228  
   229  func (c *adapter) remoteExec(command string, in io.Reader, out io.Writer, err io.Writer) int {
   230  	cmd := &packer.RemoteCmd{
   231  		Stdin:   in,
   232  		Stdout:  out,
   233  		Stderr:  err,
   234  		Command: command,
   235  	}
   236  
   237  	if err := c.comm.Start(cmd); err != nil {
   238  		c.ui.Error(err.Error())
   239  		return cmd.ExitStatus
   240  	}
   241  
   242  	cmd.Wait()
   243  
   244  	return cmd.ExitStatus
   245  }
   246  
   247  type envRequest struct {
   248  	*ssh.Request
   249  	Payload envRequestPayload
   250  }
   251  
   252  type envRequestPayload struct {
   253  	Name  string
   254  	Value string
   255  }
   256  
   257  func (p envRequestPayload) String() string {
   258  	return fmt.Sprintf("%s=%s", p.Name, p.Value)
   259  }
   260  
   261  func newEnvRequest(raw *ssh.Request) (*envRequest, error) {
   262  	r := new(envRequest)
   263  	r.Request = raw
   264  
   265  	if err := ssh.Unmarshal(raw.Payload, &r.Payload); err != nil {
   266  		return nil, err
   267  	}
   268  
   269  	return r, nil
   270  }
   271  
   272  func sshString(buf io.Reader) (string, error) {
   273  	var size uint32
   274  	err := binary.Read(buf, binary.BigEndian, &size)
   275  	if err != nil {
   276  		return "", err
   277  	}
   278  
   279  	b := make([]byte, size)
   280  	err = binary.Read(buf, binary.BigEndian, b)
   281  	if err != nil {
   282  		return "", err
   283  	}
   284  	return string(b), nil
   285  }
   286  
   287  type execRequest struct {
   288  	*ssh.Request
   289  	Payload execRequestPayload
   290  }
   291  
   292  type execRequestPayload string
   293  
   294  func (p execRequestPayload) String() string {
   295  	return string(p)
   296  }
   297  
   298  func newExecRequest(raw *ssh.Request) (*execRequest, error) {
   299  	r := new(execRequest)
   300  	r.Request = raw
   301  	buf := bytes.NewReader(r.Request.Payload)
   302  
   303  	var err error
   304  	var payload string
   305  	if payload, err = sshString(buf); err != nil {
   306  		return nil, err
   307  	}
   308  
   309  	r.Payload = execRequestPayload(payload)
   310  	return r, nil
   311  }
   312  
   313  type subsystemRequest struct {
   314  	*ssh.Request
   315  	Payload subsystemRequestPayload
   316  }
   317  
   318  type subsystemRequestPayload string
   319  
   320  func (p subsystemRequestPayload) String() string {
   321  	return string(p)
   322  }
   323  
   324  func newSubsystemRequest(raw *ssh.Request) (*subsystemRequest, error) {
   325  	r := new(subsystemRequest)
   326  	r.Request = raw
   327  	buf := bytes.NewReader(r.Request.Payload)
   328  
   329  	var err error
   330  	var payload string
   331  	if payload, err = sshString(buf); err != nil {
   332  		return nil, err
   333  	}
   334  
   335  	r.Payload = subsystemRequestPayload(payload)
   336  	return r, nil
   337  }