github.com/rahart/packer@v0.12.2-0.20161229105310-282bb6ad370f/provisioner/ansible/scp.go (about)

     1  package ansible
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	"log"
    11  	"os"
    12  	"path/filepath"
    13  	"strings"
    14  	"time"
    15  
    16  	"github.com/mitchellh/packer/packer"
    17  )
    18  
    19  const (
    20  	scpOK         = "\x00"
    21  	scpEmptyError = "\x02\n"
    22  )
    23  
    24  /*
    25  scp is a simple, but poorly documented, protocol. Thankfully, its source is
    26  freely available, and there is at least one page that describes it reasonably
    27  well.
    28  
    29  * https://raw.githubusercontent.com/openssh/openssh-portable/master/scp.c
    30  * https://opensource.apple.com/source/OpenSSH/OpenSSH-7.1/openssh/scp.c
    31  * https://blogs.oracle.com/janp/entry/how_the_scp_protocol_works is a great
    32  	resource, but has some bad information. Its first problem is that it doesn't
    33  	correctly describe why the producer has to read more responses than messages
    34  	it sends (because it has to read the 0 sent by the sink to start the
    35  	transfer). The second problem is that it omits that the producer needs to
    36  	send a 0 byte after file contents.
    37  */
    38  
    39  func scpUploadSession(opts []byte, rest string, in io.Reader, out io.Writer, comm packer.Communicator) error {
    40  	rest = strings.TrimSpace(rest)
    41  	if len(rest) == 0 {
    42  		fmt.Fprintf(out, scpEmptyError)
    43  		return errors.New("no scp target specified")
    44  	}
    45  
    46  	d, err := ioutil.TempDir("", "packer-ansible-upload")
    47  	if err != nil {
    48  		fmt.Fprintf(out, scpEmptyError)
    49  		return err
    50  	}
    51  	defer os.RemoveAll(d)
    52  
    53  	// To properly implement scp, rest should be checked to see if it is a
    54  	// directory on the remote side, but ansible only sends files, so there's no
    55  	// need to set targetIsDir, because it can be safely assumed that rest is
    56  	// intended to be a file, and whatever names are used in 'C' commands are
    57  	// irrelavant.
    58  	state := &scpUploadState{target: rest, srcRoot: d, comm: comm}
    59  
    60  	fmt.Fprintf(out, scpOK) // signal the client to start the transfer.
    61  	return state.Protocol(bufio.NewReader(in), out)
    62  }
    63  
    64  func scpDownloadSession(opts []byte, rest string, in io.Reader, out io.Writer, comm packer.Communicator) error {
    65  	rest = strings.TrimSpace(rest)
    66  	if len(rest) == 0 {
    67  		fmt.Fprintf(out, scpEmptyError)
    68  		return errors.New("no scp source specified")
    69  	}
    70  
    71  	d, err := ioutil.TempDir("", "packer-ansible-download")
    72  	if err != nil {
    73  		fmt.Fprintf(out, scpEmptyError)
    74  		return err
    75  	}
    76  	defer os.RemoveAll(d)
    77  
    78  	if bytes.Contains([]byte{'d'}, opts) {
    79  		// the only ansible module that supports downloading via scp is fetch,
    80  		// fetch only supports file downloads as of Ansible 2.1.
    81  		fmt.Fprintf(out, scpEmptyError)
    82  		return errors.New("directory downloads not supported")
    83  	}
    84  
    85  	f, err := os.Create(filepath.Join(d, filepath.Base(rest)))
    86  	if err != nil {
    87  		fmt.Fprintf(out, scpEmptyError)
    88  		return err
    89  	}
    90  	defer f.Close()
    91  
    92  	err = comm.Download(rest, f)
    93  	if err != nil {
    94  		fmt.Fprintf(out, scpEmptyError)
    95  		return err
    96  	}
    97  
    98  	state := &scpDownloadState{srcRoot: d}
    99  
   100  	return state.Protocol(bufio.NewReader(in), out)
   101  }
   102  
   103  func (state *scpDownloadState) FileProtocol(path string, info os.FileInfo, in *bufio.Reader, out io.Writer) error {
   104  	size := info.Size()
   105  	perms := fmt.Sprintf("C%04o", info.Mode().Perm())
   106  	fmt.Fprintln(out, perms, size, info.Name())
   107  	err := scpResponse(in)
   108  	if err != nil {
   109  		return err
   110  	}
   111  
   112  	f, err := os.Open(path)
   113  	if err != nil {
   114  		return err
   115  	}
   116  	defer f.Close()
   117  
   118  	io.CopyN(out, f, size)
   119  	fmt.Fprintf(out, scpOK)
   120  
   121  	return scpResponse(in)
   122  }
   123  
   124  type scpUploadState struct {
   125  	comm        packer.Communicator
   126  	target      string // target is the directory on the target
   127  	srcRoot     string // srcRoot is the directory on the host
   128  	mtime       time.Time
   129  	atime       time.Time
   130  	dir         string // dir is a path relative to the roots
   131  	targetIsDir bool
   132  }
   133  
   134  func (scp scpUploadState) DestPath() string {
   135  	return filepath.Join(scp.target, scp.dir)
   136  }
   137  
   138  func (scp scpUploadState) SrcPath() string {
   139  	return filepath.Join(scp.srcRoot, scp.dir)
   140  }
   141  
   142  func (state *scpUploadState) Protocol(in *bufio.Reader, out io.Writer) error {
   143  	for {
   144  		b, err := in.ReadByte()
   145  		if err != nil {
   146  			return err
   147  		}
   148  		switch b {
   149  		case 'T':
   150  			err := state.TimeProtocol(in, out)
   151  			if err != nil {
   152  				return err
   153  			}
   154  		case 'C':
   155  			return state.FileProtocol(in, out)
   156  		case 'E':
   157  			state.dir = filepath.Dir(state.dir)
   158  			fmt.Fprintf(out, scpOK)
   159  			return nil
   160  		case 'D':
   161  			return state.DirProtocol(in, out)
   162  		default:
   163  			fmt.Fprintf(out, scpEmptyError)
   164  			return fmt.Errorf("unexpected message: %c", b)
   165  		}
   166  	}
   167  }
   168  
   169  func (state *scpUploadState) FileProtocol(in *bufio.Reader, out io.Writer) error {
   170  	defer func() {
   171  		state.mtime = time.Time{}
   172  	}()
   173  
   174  	var mode os.FileMode
   175  	var size int64
   176  	var name string
   177  	_, err := fmt.Fscanf(in, "%04o %d %s\n", &mode, &size, &name)
   178  	if err != nil {
   179  		fmt.Fprintf(out, scpEmptyError)
   180  		return fmt.Errorf("invalid file message: %v", err)
   181  	}
   182  	fmt.Fprintf(out, scpOK)
   183  
   184  	var fi os.FileInfo = fileInfo{name: name, size: size, mode: mode, mtime: state.mtime}
   185  
   186  	dest := state.DestPath()
   187  	if state.targetIsDir {
   188  		dest = filepath.Join(dest, fi.Name())
   189  	}
   190  
   191  	err = state.comm.Upload(dest, io.LimitReader(in, fi.Size()), &fi)
   192  	if err != nil {
   193  		fmt.Fprintf(out, scpEmptyError)
   194  		return err
   195  	}
   196  
   197  	err = scpResponse(in)
   198  	if err != nil {
   199  		return err
   200  	}
   201  
   202  	fmt.Fprintf(out, scpOK)
   203  	return nil
   204  }
   205  
   206  func (state *scpUploadState) TimeProtocol(in *bufio.Reader, out io.Writer) error {
   207  	var m, a int64
   208  	if _, err := fmt.Fscanf(in, "%d 0 %d 0\n", &m, &a); err != nil {
   209  		fmt.Fprintf(out, scpEmptyError)
   210  		return err
   211  	}
   212  	fmt.Fprintf(out, scpOK)
   213  
   214  	state.atime = time.Unix(a, 0)
   215  	state.mtime = time.Unix(m, 0)
   216  	return nil
   217  }
   218  
   219  func (state *scpUploadState) DirProtocol(in *bufio.Reader, out io.Writer) error {
   220  	var mode os.FileMode
   221  	var length uint
   222  	var name string
   223  
   224  	if _, err := fmt.Fscanf(in, "%04o %d %s\n", &mode, &length, &name); err != nil {
   225  		fmt.Fprintf(out, scpEmptyError)
   226  		return fmt.Errorf("invalid directory message: %v", err)
   227  	}
   228  	fmt.Fprintf(out, scpOK)
   229  
   230  	path := filepath.Join(state.dir, name)
   231  	if err := os.Mkdir(path, mode); err != nil {
   232  		return err
   233  	}
   234  	state.dir = path
   235  
   236  	if state.atime.IsZero() {
   237  		state.atime = time.Now()
   238  	}
   239  	if state.mtime.IsZero() {
   240  		state.mtime = time.Now()
   241  	}
   242  
   243  	if err := os.Chtimes(path, state.atime, state.mtime); err != nil {
   244  		return err
   245  	}
   246  
   247  	if err := state.comm.UploadDir(filepath.Dir(state.DestPath()), state.SrcPath(), nil); err != nil {
   248  		return err
   249  	}
   250  
   251  	state.mtime = time.Time{}
   252  	state.atime = time.Time{}
   253  	return state.Protocol(in, out)
   254  }
   255  
   256  type scpDownloadState struct {
   257  	srcRoot string // srcRoot is the directory on the host
   258  }
   259  
   260  func (state *scpDownloadState) Protocol(in *bufio.Reader, out io.Writer) error {
   261  	r := bufio.NewReader(in)
   262  	// read the byte sent by the other side to start the transfer
   263  	scpResponse(r)
   264  
   265  	return filepath.Walk(state.srcRoot, func(path string, info os.FileInfo, err error) error {
   266  		if err != nil {
   267  			return err
   268  		}
   269  
   270  		if path == state.srcRoot {
   271  			return nil
   272  		}
   273  
   274  		if info.IsDir() {
   275  			// no need to get fancy; srcRoot should only contain one file, because
   276  			// Ansible only allows fetching a single file.
   277  			return errors.New("unexpected directory")
   278  		}
   279  
   280  		return state.FileProtocol(path, info, r, out)
   281  	})
   282  }
   283  
   284  func scpOptions(s string) (opts []byte, rest string) {
   285  	end := 0
   286  	opt := false
   287  
   288  Loop:
   289  	for i := 0; i < len(s); i++ {
   290  		b := s[i]
   291  		switch {
   292  		case b == ' ':
   293  			opt = false
   294  			end++
   295  		case b == '-':
   296  			opt = true
   297  			end++
   298  		case opt:
   299  			opts = append(opts, b)
   300  			end++
   301  		default:
   302  			break Loop
   303  		}
   304  	}
   305  
   306  	rest = s[end:]
   307  	return
   308  }
   309  
   310  func scpResponse(r *bufio.Reader) error {
   311  	code, err := r.ReadByte()
   312  	if err != nil {
   313  		return err
   314  	}
   315  
   316  	if code != 0 {
   317  		message, err := r.ReadString('\n')
   318  		if err != nil {
   319  			return fmt.Errorf("Error reading error message: %s", err)
   320  		}
   321  
   322  		// 1 is a warning. Anything higher (really just 2) is an error.
   323  		if code > 1 {
   324  			return errors.New(string(message))
   325  		}
   326  
   327  		log.Println("WARNING:", err)
   328  	}
   329  	return nil
   330  }
   331  
   332  type fileInfo struct {
   333  	name  string
   334  	size  int64
   335  	mode  os.FileMode
   336  	mtime time.Time
   337  }
   338  
   339  func (fi fileInfo) Name() string      { return fi.name }
   340  func (fi fileInfo) Size() int64       { return fi.size }
   341  func (fi fileInfo) Mode() os.FileMode { return fi.mode }
   342  func (fi fileInfo) ModTime() time.Time {
   343  	if fi.mtime.IsZero() {
   344  		return time.Now()
   345  	}
   346  	return fi.mtime
   347  }
   348  func (fi fileInfo) IsDir() bool      { return fi.mode.IsDir() }
   349  func (fi fileInfo) Sys() interface{} { return nil }