github.com/amanya/packer@v0.12.1-0.20161117214323-902ac5ab2eb6/provisioner/ansible/scp.go (about)

     1  package ansible
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	"log"
    11  	"os"
    12  	"path/filepath"
    13  	"strings"
    14  	"time"
    15  
    16  	"github.com/mitchellh/packer/packer"
    17  )
    18  
    19  const (
    20  	scpOK         = "\x00"
    21  	scpEmptyError = "\x02\n"
    22  )
    23  
    24  /*
    25  scp is a simple, but poorly documented, protocol. Thankfully, its source is
    26  freely available, and there is at least one page that describes it reasonably
    27  well.
    28  
    29  * https://raw.githubusercontent.com/openssh/openssh-portable/master/scp.c
    30  * https://opensource.apple.com/source/OpenSSH/OpenSSH-7.1/openssh/scp.c
    31  * https://blogs.oracle.com/janp/entry/how_the_scp_protocol_works is a great
    32  	resource, but has some bad information. Its first problem is that it doesn't
    33  	correctly describe why the producer has to read more responses than messages
    34  	it sends (because it has to read the 0 sent by the sink to start the
    35  	transfer). The second problem is that it omits that the producer needs to
    36  	send a 0 byte after file contents.
    37  */
    38  
    39  func scpUploadSession(opts []byte, rest string, in io.Reader, out io.Writer, comm packer.Communicator) error {
    40  	rest = strings.TrimSpace(rest)
    41  	if len(rest) == 0 {
    42  		fmt.Fprintf(out, scpEmptyError)
    43  		return errors.New("no scp target specified")
    44  	}
    45  
    46  	d, err := ioutil.TempDir("", "packer-ansible-upload")
    47  	if err != nil {
    48  		fmt.Fprintf(out, scpEmptyError)
    49  		return err
    50  	}
    51  	defer os.RemoveAll(d)
    52  
    53  	state := &scpUploadState{destRoot: rest, srcRoot: d, comm: comm}
    54  
    55  	fmt.Fprintf(out, scpOK) // signal the client to start the transfer.
    56  	return state.Protocol(bufio.NewReader(in), out)
    57  }
    58  
    59  func scpDownloadSession(opts []byte, rest string, in io.Reader, out io.Writer, comm packer.Communicator) error {
    60  	rest = strings.TrimSpace(rest)
    61  	if len(rest) == 0 {
    62  		fmt.Fprintf(out, scpEmptyError)
    63  		return errors.New("no scp source specified")
    64  	}
    65  
    66  	d, err := ioutil.TempDir("", "packer-ansible-download")
    67  	if err != nil {
    68  		fmt.Fprintf(out, scpEmptyError)
    69  		return err
    70  	}
    71  	defer os.RemoveAll(d)
    72  
    73  	if bytes.Contains([]byte{'d'}, opts) {
    74  		// the only ansible module that supports downloading via scp is fetch,
    75  		// fetch only supports file downloads as of Ansible 2.1.
    76  		fmt.Fprintf(out, scpEmptyError)
    77  		return errors.New("directory downloads not supported")
    78  	}
    79  
    80  	f, err := os.Create(filepath.Join(d, filepath.Base(rest)))
    81  	if err != nil {
    82  		fmt.Fprintf(out, scpEmptyError)
    83  		return err
    84  	}
    85  	defer f.Close()
    86  
    87  	err = comm.Download(rest, f)
    88  	if err != nil {
    89  		fmt.Fprintf(out, scpEmptyError)
    90  		return err
    91  	}
    92  
    93  	state := &scpDownloadState{srcRoot: d}
    94  
    95  	return state.Protocol(bufio.NewReader(in), out)
    96  }
    97  
    98  func (state *scpDownloadState) FileProtocol(path string, info os.FileInfo, in *bufio.Reader, out io.Writer) error {
    99  	size := info.Size()
   100  	perms := fmt.Sprintf("C%04o", info.Mode().Perm())
   101  	fmt.Fprintln(out, perms, size, info.Name())
   102  	err := scpResponse(in)
   103  	if err != nil {
   104  		return err
   105  	}
   106  
   107  	f, err := os.Open(path)
   108  	if err != nil {
   109  		return err
   110  	}
   111  	defer f.Close()
   112  
   113  	io.CopyN(out, f, size)
   114  	fmt.Fprintf(out, scpOK)
   115  
   116  	return scpResponse(in)
   117  }
   118  
   119  type scpUploadState struct {
   120  	comm     packer.Communicator
   121  	destRoot string // destRoot is the directory on the target
   122  	srcRoot  string // srcRoot is the directory on the host
   123  	mtime    time.Time
   124  	atime    time.Time
   125  	dir      string // dir is a path relative to the roots
   126  }
   127  
   128  func (scp scpUploadState) DestPath() string {
   129  	return filepath.Join(scp.destRoot, scp.dir)
   130  }
   131  
   132  func (scp scpUploadState) SrcPath() string {
   133  	return filepath.Join(scp.srcRoot, scp.dir)
   134  }
   135  
   136  func (state *scpUploadState) Protocol(in *bufio.Reader, out io.Writer) error {
   137  	for {
   138  		b, err := in.ReadByte()
   139  		if err != nil {
   140  			return err
   141  		}
   142  		switch b {
   143  		case 'T':
   144  			err := state.TimeProtocol(in, out)
   145  			if err != nil {
   146  				return err
   147  			}
   148  		case 'C':
   149  			return state.FileProtocol(in, out)
   150  		case 'E':
   151  			state.dir = filepath.Dir(state.dir)
   152  			fmt.Fprintf(out, scpOK)
   153  			return nil
   154  		case 'D':
   155  			return state.DirProtocol(in, out)
   156  		default:
   157  			fmt.Fprintf(out, scpEmptyError)
   158  			return fmt.Errorf("unexpected message: %c", b)
   159  		}
   160  	}
   161  }
   162  
   163  func (state *scpUploadState) FileProtocol(in *bufio.Reader, out io.Writer) error {
   164  	defer func() {
   165  		state.mtime = time.Time{}
   166  	}()
   167  
   168  	var mode os.FileMode
   169  	var size int64
   170  	var name string
   171  	_, err := fmt.Fscanf(in, "%04o %d %s\n", &mode, &size, &name)
   172  	if err != nil {
   173  		fmt.Fprintf(out, scpEmptyError)
   174  		return fmt.Errorf("invalid file message: %v", err)
   175  	}
   176  	fmt.Fprintf(out, scpOK)
   177  
   178  	var fi os.FileInfo = fileInfo{name: name, size: size, mode: mode, mtime: state.mtime}
   179  
   180  	err = state.comm.Upload(filepath.Join(state.DestPath(), fi.Name()), io.LimitReader(in, fi.Size()), &fi)
   181  	if err != nil {
   182  		fmt.Fprintf(out, scpEmptyError)
   183  		return err
   184  	}
   185  
   186  	err = scpResponse(in)
   187  	if err != nil {
   188  		return err
   189  	}
   190  
   191  	fmt.Fprintf(out, scpOK)
   192  	return nil
   193  }
   194  
   195  func (state *scpUploadState) TimeProtocol(in *bufio.Reader, out io.Writer) error {
   196  	var m, a int64
   197  	if _, err := fmt.Fscanf(in, "%d 0 %d 0\n", &m, &a); err != nil {
   198  		fmt.Fprintf(out, scpEmptyError)
   199  		return err
   200  	}
   201  	fmt.Fprintf(out, scpOK)
   202  
   203  	state.atime = time.Unix(a, 0)
   204  	state.mtime = time.Unix(m, 0)
   205  	return nil
   206  }
   207  
   208  func (state *scpUploadState) DirProtocol(in *bufio.Reader, out io.Writer) error {
   209  	var mode os.FileMode
   210  	var length uint
   211  	var name string
   212  
   213  	if _, err := fmt.Fscanf(in, "%04o %d %s\n", &mode, &length, &name); err != nil {
   214  		fmt.Fprintf(out, scpEmptyError)
   215  		return fmt.Errorf("invalid directory message: %v", err)
   216  	}
   217  	fmt.Fprintf(out, scpOK)
   218  
   219  	path := filepath.Join(state.dir, name)
   220  	if err := os.Mkdir(path, mode); err != nil {
   221  		return err
   222  	}
   223  	state.dir = path
   224  
   225  	if state.atime.IsZero() {
   226  		state.atime = time.Now()
   227  	}
   228  	if state.mtime.IsZero() {
   229  		state.mtime = time.Now()
   230  	}
   231  
   232  	if err := os.Chtimes(path, state.atime, state.mtime); err != nil {
   233  		return err
   234  	}
   235  
   236  	if err := state.comm.UploadDir(filepath.Dir(state.DestPath()), state.SrcPath(), nil); err != nil {
   237  		return err
   238  	}
   239  
   240  	state.mtime = time.Time{}
   241  	state.atime = time.Time{}
   242  	return state.Protocol(in, out)
   243  }
   244  
   245  type scpDownloadState struct {
   246  	srcRoot string // srcRoot is the directory on the host
   247  }
   248  
   249  func (state *scpDownloadState) Protocol(in *bufio.Reader, out io.Writer) error {
   250  	r := bufio.NewReader(in)
   251  	// read the byte sent by the other side to start the transfer
   252  	scpResponse(r)
   253  
   254  	return filepath.Walk(state.srcRoot, func(path string, info os.FileInfo, err error) error {
   255  		if err != nil {
   256  			return err
   257  		}
   258  
   259  		if path == state.srcRoot {
   260  			return nil
   261  		}
   262  
   263  		if info.IsDir() {
   264  			// no need to get fancy; srcRoot should only contain one file, because
   265  			// Ansible only allows fetching a single file.
   266  			return errors.New("unexpected directory")
   267  		}
   268  
   269  		return state.FileProtocol(path, info, r, out)
   270  	})
   271  }
   272  
   273  func scpOptions(s string) (opts []byte, rest string) {
   274  	end := 0
   275  	opt := false
   276  
   277  Loop:
   278  	for i := 0; i < len(s); i++ {
   279  		b := s[i]
   280  		switch {
   281  		case b == ' ':
   282  			opt = false
   283  			end++
   284  		case b == '-':
   285  			opt = true
   286  			end++
   287  		case opt:
   288  			opts = append(opts, b)
   289  			end++
   290  		default:
   291  			break Loop
   292  		}
   293  	}
   294  
   295  	rest = s[end:]
   296  	return
   297  }
   298  
   299  func scpResponse(r *bufio.Reader) error {
   300  	code, err := r.ReadByte()
   301  	if err != nil {
   302  		return err
   303  	}
   304  
   305  	if code != 0 {
   306  		message, err := r.ReadString('\n')
   307  		if err != nil {
   308  			return fmt.Errorf("Error reading error message: %s", err)
   309  		}
   310  
   311  		// 1 is a warning. Anything higher (really just 2) is an error.
   312  		if code > 1 {
   313  			return errors.New(string(message))
   314  		}
   315  
   316  		log.Println("WARNING:", err)
   317  	}
   318  	return nil
   319  }
   320  
   321  type fileInfo struct {
   322  	name  string
   323  	size  int64
   324  	mode  os.FileMode
   325  	mtime time.Time
   326  }
   327  
   328  func (fi fileInfo) Name() string      { return fi.name }
   329  func (fi fileInfo) Size() int64       { return fi.size }
   330  func (fi fileInfo) Mode() os.FileMode { return fi.mode }
   331  func (fi fileInfo) ModTime() time.Time {
   332  	if fi.mtime.IsZero() {
   333  		return time.Now()
   334  	}
   335  	return fi.mtime
   336  }
   337  func (fi fileInfo) IsDir() bool      { return fi.mode.IsDir() }
   338  func (fi fileInfo) Sys() interface{} { return nil }