github.com/ddnomad/packer@v1.3.2/provisioner/ansible/scp.go (about)

     1  package ansible
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	"log"
    11  	"os"
    12  	"path/filepath"
    13  	"strings"
    14  	"time"
    15  
    16  	"github.com/hashicorp/packer/packer"
    17  )
    18  
    19  const (
    20  	scpOK         = "\x00"
    21  	scpEmptyError = "\x02\n"
    22  )
    23  
    24  /*
    25  scp is a simple, but poorly documented, protocol. Thankfully, its source is
    26  freely available, and there is at least one page that describes it reasonably
    27  well.
    28  
    29  * https://raw.githubusercontent.com/openssh/openssh-portable/master/scp.c
    30  * https://opensource.apple.com/source/OpenSSH/OpenSSH-7.1/openssh/scp.c
    31  * https://blogs.oracle.com/janp/entry/how_the_scp_protocol_works is a great
    32  	resource, but has some bad information. Its first problem is that it doesn't
    33  	correctly describe why the producer has to read more responses than messages
    34  	it sends (because it has to read the 0 sent by the sink to start the
    35  	transfer). The second problem is that it omits that the producer needs to
    36  	send a 0 byte after file contents.
    37  */
    38  
    39  func scpUploadSession(opts []byte, rest string, in io.Reader, out io.Writer, comm packer.Communicator) error {
    40  	rest = strings.TrimSpace(rest)
    41  	if len(rest) == 0 {
    42  		fmt.Fprintf(out, scpEmptyError)
    43  		return errors.New("no scp target specified")
    44  	}
    45  
    46  	d, err := ioutil.TempDir("", "packer-ansible-upload")
    47  	if err != nil {
    48  		fmt.Fprintf(out, scpEmptyError)
    49  		return err
    50  	}
    51  	defer os.RemoveAll(d)
    52  
    53  	// To properly implement scp, rest should be checked to see if it is a
    54  	// directory on the remote side, but ansible only sends files, so there's no
    55  	// need to set targetIsDir, because it can be safely assumed that rest is
    56  	// intended to be a file, and whatever names are used in 'C' commands are
    57  	// irrelevant.
    58  	state := &scpUploadState{target: rest, srcRoot: d, comm: comm}
    59  
    60  	fmt.Fprintf(out, scpOK) // signal the client to start the transfer.
    61  	return state.Protocol(bufio.NewReader(in), out)
    62  }
    63  
    64  func scpDownloadSession(opts []byte, rest string, in io.Reader, out io.Writer, comm packer.Communicator) error {
    65  	rest = strings.TrimSpace(rest)
    66  	if len(rest) == 0 {
    67  		fmt.Fprintf(out, scpEmptyError)
    68  		return errors.New("no scp source specified")
    69  	}
    70  
    71  	d, err := ioutil.TempDir("", "packer-ansible-download")
    72  	if err != nil {
    73  		fmt.Fprintf(out, scpEmptyError)
    74  		return err
    75  	}
    76  	defer os.RemoveAll(d)
    77  
    78  	if bytes.Contains([]byte{'d'}, opts) {
    79  		// the only ansible module that supports downloading via scp is fetch,
    80  		// fetch only supports file downloads as of Ansible 2.1.
    81  		fmt.Fprintf(out, scpEmptyError)
    82  		return errors.New("directory downloads not supported")
    83  	}
    84  
    85  	f, err := os.Create(filepath.Join(d, filepath.Base(rest)))
    86  	if err != nil {
    87  		fmt.Fprintf(out, scpEmptyError)
    88  		return err
    89  	}
    90  	defer f.Close()
    91  
    92  	err = comm.Download(rest, f)
    93  	if err != nil {
    94  		fmt.Fprintf(out, scpEmptyError)
    95  		return err
    96  	}
    97  
    98  	state := &scpDownloadState{srcRoot: d}
    99  
   100  	return state.Protocol(bufio.NewReader(in), out)
   101  }
   102  
   103  func (state *scpDownloadState) FileProtocol(path string, info os.FileInfo, in *bufio.Reader, out io.Writer) error {
   104  	size := info.Size()
   105  	perms := fmt.Sprintf("C%04o", info.Mode().Perm())
   106  	fmt.Fprintln(out, perms, size, info.Name())
   107  	if err := scpResponse(in); err != nil {
   108  		return err
   109  	}
   110  
   111  	f, err := os.Open(path)
   112  	if err != nil {
   113  		return err
   114  	}
   115  	defer f.Close()
   116  
   117  	io.CopyN(out, f, size)
   118  	fmt.Fprintf(out, scpOK)
   119  
   120  	return scpResponse(in)
   121  }
   122  
   123  type scpUploadState struct {
   124  	comm        packer.Communicator
   125  	target      string // target is the directory on the target
   126  	srcRoot     string // srcRoot is the directory on the host
   127  	mtime       time.Time
   128  	atime       time.Time
   129  	dir         string // dir is a path relative to the roots
   130  	targetIsDir bool
   131  }
   132  
   133  func (scp scpUploadState) DestPath() string {
   134  	return filepath.Join(scp.target, scp.dir)
   135  }
   136  
   137  func (scp scpUploadState) SrcPath() string {
   138  	return filepath.Join(scp.srcRoot, scp.dir)
   139  }
   140  
   141  func (state *scpUploadState) Protocol(in *bufio.Reader, out io.Writer) error {
   142  	for {
   143  		b, err := in.ReadByte()
   144  		if err != nil {
   145  			return err
   146  		}
   147  		switch b {
   148  		case 'T':
   149  			err := state.TimeProtocol(in, out)
   150  			if err != nil {
   151  				return err
   152  			}
   153  		case 'C':
   154  			return state.FileProtocol(in, out)
   155  		case 'E':
   156  			state.dir = filepath.Dir(state.dir)
   157  			fmt.Fprintf(out, scpOK)
   158  			return nil
   159  		case 'D':
   160  			return state.DirProtocol(in, out)
   161  		default:
   162  			fmt.Fprintf(out, scpEmptyError)
   163  			return fmt.Errorf("unexpected message: %c", b)
   164  		}
   165  	}
   166  }
   167  
   168  func (state *scpUploadState) FileProtocol(in *bufio.Reader, out io.Writer) error {
   169  	defer func() {
   170  		state.mtime = time.Time{}
   171  	}()
   172  
   173  	var mode os.FileMode
   174  	var size int64
   175  	var name string
   176  	_, err := fmt.Fscanf(in, "%04o %d %s\n", &mode, &size, &name)
   177  	if err != nil {
   178  		fmt.Fprintf(out, scpEmptyError)
   179  		return fmt.Errorf("invalid file message: %v", err)
   180  	}
   181  	fmt.Fprintf(out, scpOK)
   182  
   183  	var fi os.FileInfo = fileInfo{name: name, size: size, mode: mode, mtime: state.mtime}
   184  
   185  	dest := state.DestPath()
   186  	if state.targetIsDir {
   187  		dest = filepath.Join(dest, fi.Name())
   188  	}
   189  
   190  	err = state.comm.Upload(dest, io.LimitReader(in, fi.Size()), &fi)
   191  	if err != nil {
   192  		fmt.Fprintf(out, scpEmptyError)
   193  		return err
   194  	}
   195  
   196  	if err := scpResponse(in); err != nil {
   197  		return err
   198  	}
   199  
   200  	fmt.Fprintf(out, scpOK)
   201  	return nil
   202  }
   203  
   204  func (state *scpUploadState) TimeProtocol(in *bufio.Reader, out io.Writer) error {
   205  	var m, a int64
   206  	if _, err := fmt.Fscanf(in, "%d 0 %d 0\n", &m, &a); err != nil {
   207  		fmt.Fprintf(out, scpEmptyError)
   208  		return err
   209  	}
   210  	fmt.Fprintf(out, scpOK)
   211  
   212  	state.atime = time.Unix(a, 0)
   213  	state.mtime = time.Unix(m, 0)
   214  	return nil
   215  }
   216  
   217  func (state *scpUploadState) DirProtocol(in *bufio.Reader, out io.Writer) error {
   218  	var mode os.FileMode
   219  	var length uint
   220  	var name string
   221  
   222  	if _, err := fmt.Fscanf(in, "%04o %d %s\n", &mode, &length, &name); err != nil {
   223  		fmt.Fprintf(out, scpEmptyError)
   224  		return fmt.Errorf("invalid directory message: %v", err)
   225  	}
   226  	fmt.Fprintf(out, scpOK)
   227  
   228  	path := filepath.Join(state.dir, name)
   229  	if err := os.Mkdir(path, mode); err != nil {
   230  		return err
   231  	}
   232  	state.dir = path
   233  
   234  	if state.atime.IsZero() {
   235  		state.atime = time.Now()
   236  	}
   237  	if state.mtime.IsZero() {
   238  		state.mtime = time.Now()
   239  	}
   240  
   241  	if err := os.Chtimes(path, state.atime, state.mtime); err != nil {
   242  		return err
   243  	}
   244  
   245  	if err := state.comm.UploadDir(filepath.Dir(state.DestPath()), state.SrcPath(), nil); err != nil {
   246  		return err
   247  	}
   248  
   249  	state.mtime = time.Time{}
   250  	state.atime = time.Time{}
   251  	return state.Protocol(in, out)
   252  }
   253  
   254  type scpDownloadState struct {
   255  	srcRoot string // srcRoot is the directory on the host
   256  }
   257  
   258  func (state *scpDownloadState) Protocol(in *bufio.Reader, out io.Writer) error {
   259  	r := bufio.NewReader(in)
   260  	// read the byte sent by the other side to start the transfer
   261  	if err := scpResponse(r); err != nil {
   262  		return err
   263  	}
   264  
   265  	return filepath.Walk(state.srcRoot, func(path string, info os.FileInfo, err error) error {
   266  		if err != nil {
   267  			return err
   268  		}
   269  
   270  		if path == state.srcRoot {
   271  			return nil
   272  		}
   273  
   274  		if info.IsDir() {
   275  			// no need to get fancy; srcRoot should only contain one file, because
   276  			// Ansible only allows fetching a single file.
   277  			return errors.New("unexpected directory")
   278  		}
   279  
   280  		return state.FileProtocol(path, info, r, out)
   281  	})
   282  }
   283  
   284  func scpOptions(s string) (opts []byte, rest string) {
   285  	end := 0
   286  	opt := false
   287  
   288  Loop:
   289  	for i := 0; i < len(s); i++ {
   290  		b := s[i]
   291  		switch {
   292  		case b == ' ':
   293  			opt = false
   294  			end++
   295  		case b == '-':
   296  			opt = true
   297  			end++
   298  		case opt:
   299  			opts = append(opts, b)
   300  			end++
   301  		default:
   302  			break Loop
   303  		}
   304  	}
   305  
   306  	rest = s[end:]
   307  	return
   308  }
   309  
   310  func scpResponse(r *bufio.Reader) error {
   311  	code, err := r.ReadByte()
   312  	if err != nil {
   313  		return err
   314  	}
   315  
   316  	if code != 0 {
   317  		message, err := r.ReadString('\n')
   318  		if err != nil {
   319  			return fmt.Errorf("Error reading error message: %s", err)
   320  		}
   321  
   322  		// 1 is a warning. Anything higher (really just 2) is an error.
   323  		if code > 1 {
   324  			return errors.New(message)
   325  		}
   326  
   327  		log.Println("WARNING:", err)
   328  	}
   329  	return nil
   330  }
   331  
   332  type fileInfo struct {
   333  	name  string
   334  	size  int64
   335  	mode  os.FileMode
   336  	mtime time.Time
   337  }
   338  
   339  func (fi fileInfo) Name() string      { return fi.name }
   340  func (fi fileInfo) Size() int64       { return fi.size }
   341  func (fi fileInfo) Mode() os.FileMode { return fi.mode }
   342  func (fi fileInfo) ModTime() time.Time {
   343  	if fi.mtime.IsZero() {
   344  		return time.Now()
   345  	}
   346  	return fi.mtime
   347  }
   348  func (fi fileInfo) IsDir() bool      { return fi.mode.IsDir() }
   349  func (fi fileInfo) Sys() interface{} { return nil }