github.com/Pankov404/juju@v0.0.0-20150703034450-be266991dceb/environs/sshstorage/storage.go (about)

     1  // Copyright 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package sshstorage
     5  
     6  import (
     7  	"bufio"
     8  	"bytes"
     9  	"encoding/base64"
    10  	"fmt"
    11  	"io"
    12  	"io/ioutil"
    13  	"path"
    14  	"sort"
    15  	"strconv"
    16  	"strings"
    17  
    18  	"github.com/juju/errors"
    19  	"github.com/juju/loggo"
    20  	"github.com/juju/utils"
    21  
    22  	"github.com/juju/juju/utils/ssh"
    23  )
    24  
    25  var logger = loggo.GetLogger("juju.environs.sshstorage")
    26  
    27  // base64LineLength is the default line length for wrapping
    28  // output generated by the base64 command line utility.
    29  const base64LineLength = 76
    30  
    31  // SSHStorage implements storage.Storage.
    32  //
    33  // The storage is created under sudo, and ownership given over to the
    34  // login uid/gid. This is done so that we don't require sudo, and by
    35  // consequence, don't require a pty, so we can interact with a script
    36  // via stdin.
    37  type SSHStorage struct {
    38  	host       string
    39  	remotepath string
    40  	tmpdir     string
    41  
    42  	cmd     *ssh.Cmd
    43  	stdin   io.WriteCloser
    44  	stdout  io.ReadCloser
    45  	scanner *bufio.Scanner
    46  }
    47  
    48  var sshCommand = func(host string, command ...string) *ssh.Cmd {
    49  	return ssh.Command(host, command, nil)
    50  }
    51  
    52  type flockmode string
    53  
    54  const (
    55  	flockShared    flockmode = "-s"
    56  	flockExclusive flockmode = "-x"
    57  )
    58  
    59  type NewSSHStorageParams struct {
    60  	// Host is the host to connect to, in the format [user@]hostname.
    61  	Host string
    62  
    63  	// StorageDir is the root of the remote storage directory.
    64  	StorageDir string
    65  
    66  	// TmpDir is the remote temporary directory for storage.
    67  	// A temporary directory must be specified, and should be located on the
    68  	// same filesystem as the storage directory to ensure atomic writes.
    69  	// The temporary directory will be created when NewSSHStorage is invoked
    70  	// if it doesn't already exist; it will never be removed. NewSSHStorage
    71  	// will attempt to reassign ownership to the login user, and will return
    72  	// an error if it cannot do so.
    73  	TmpDir string
    74  }
    75  
    76  // NewSSHStorage creates a new SSHStorage, connected to the
    77  // specified host, managing state under the specified remote path.
    78  func NewSSHStorage(params NewSSHStorageParams) (*SSHStorage, error) {
    79  	if params.StorageDir == "" {
    80  		return nil, errors.New("storagedir must be specified and non-empty")
    81  	}
    82  	if params.TmpDir == "" {
    83  		return nil, errors.New("tmpdir must be specified and non-empty")
    84  	}
    85  
    86  	script := fmt.Sprintf(
    87  		"install -d -g $SUDO_GID -o $SUDO_UID %s %s",
    88  		utils.ShQuote(params.StorageDir),
    89  		utils.ShQuote(params.TmpDir),
    90  	)
    91  
    92  	cmd := sshCommand(params.Host, "sudo", "-n", "/bin/bash")
    93  	var stderr bytes.Buffer
    94  	cmd.Stderr = &stderr
    95  	cmd.Stdin = strings.NewReader(script)
    96  	if err := cmd.Run(); err != nil {
    97  		err = fmt.Errorf("failed to create storage dir: %v (%v)", err, strings.TrimSpace(stderr.String()))
    98  		return nil, err
    99  	}
   100  
   101  	// We could use sftp, but then we'd be at the mercy of
   102  	// sftp's output messages for checking errors. Instead,
   103  	// we execute an interactive bash shell.
   104  	cmd = sshCommand(params.Host, "bash")
   105  	stdin, err := cmd.StdinPipe()
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  	stdout, err := cmd.StdoutPipe()
   110  	if err != nil {
   111  		stdin.Close()
   112  		return nil, err
   113  	}
   114  	// Combine stdout and stderr, so we can easily
   115  	// get at catastrophic failure messages.
   116  	cmd.Stderr = cmd.Stdout
   117  	stor := &SSHStorage{
   118  		host:       params.Host,
   119  		remotepath: params.StorageDir,
   120  		tmpdir:     params.TmpDir,
   121  		cmd:        cmd,
   122  		stdin:      stdin,
   123  		stdout:     stdout,
   124  		scanner:    bufio.NewScanner(stdout),
   125  	}
   126  	cmd.Start()
   127  
   128  	// Verify we have write permissions.
   129  	_, err = stor.runf(flockExclusive, "touch %s", utils.ShQuote(params.StorageDir))
   130  	if err != nil {
   131  		stdin.Close()
   132  		stdout.Close()
   133  		cmd.Wait()
   134  		return nil, err
   135  	}
   136  	return stor, nil
   137  }
   138  
   139  // Close cleanly terminates the underlying SSH connection.
   140  func (s *SSHStorage) Close() error {
   141  	s.stdin.Close()
   142  	s.stdout.Close()
   143  	return s.cmd.Wait()
   144  }
   145  
   146  func (s *SSHStorage) runf(flockmode flockmode, command string, args ...interface{}) (string, error) {
   147  	command = fmt.Sprintf(command, args...)
   148  	return s.run(flockmode, command, nil, 0)
   149  }
   150  
   151  // terminate closes the stdin, and appends any output to the input error.
   152  func (s *SSHStorage) terminate(err error) error {
   153  	s.stdin.Close()
   154  	var output string
   155  	for s.scanner.Scan() {
   156  		if len(output) > 0 {
   157  			output += "\n"
   158  		}
   159  		output += s.scanner.Text()
   160  	}
   161  	if len(output) > 0 {
   162  		err = fmt.Errorf("%v (output: %q)", err, output)
   163  	}
   164  	return err
   165  }
   166  
   167  func (s *SSHStorage) run(flockmode flockmode, command string, input io.Reader, inputlen int64) (string, error) {
   168  	const rcPrefix = "JUJU-RC: "
   169  	command = fmt.Sprintf(
   170  		"SHELL=/bin/bash flock %s %s -c %s",
   171  		flockmode,
   172  		utils.ShQuote(s.remotepath),
   173  		utils.ShQuote(command),
   174  	)
   175  	stdin := bufio.NewWriter(s.stdin)
   176  	if input != nil {
   177  		command = fmt.Sprintf("base64 -d << '@EOF' | (%s)", command)
   178  	}
   179  	command = fmt.Sprintf("(%s) 2>&1; echo %s$?", command, rcPrefix)
   180  	if _, err := stdin.WriteString(command + "\n"); err != nil {
   181  		return "", fmt.Errorf("failed to write command: %v", err)
   182  	}
   183  	if input != nil {
   184  		if err := copyAsBase64(stdin, input); err != nil {
   185  			return "", s.terminate(fmt.Errorf("failed to write input: %v", err))
   186  		}
   187  	}
   188  	if err := stdin.Flush(); err != nil {
   189  		return "", s.terminate(fmt.Errorf("failed to write input: %v", err))
   190  	}
   191  	var output []string
   192  	for s.scanner.Scan() {
   193  		line := s.scanner.Text()
   194  		if strings.HasPrefix(line, rcPrefix) {
   195  			line := line[len(rcPrefix):]
   196  			rc, err := strconv.Atoi(line)
   197  			if err != nil {
   198  				return "", fmt.Errorf("failed to parse exit code %q: %v", line, err)
   199  			}
   200  			outputJoined := strings.Join(output, "\n")
   201  			if rc == 0 {
   202  				return outputJoined, nil
   203  			}
   204  			return "", SSHStorageError{outputJoined, rc}
   205  		} else {
   206  			output = append(output, line)
   207  		}
   208  	}
   209  
   210  	err := fmt.Errorf("failed to locate %q", rcPrefix)
   211  	if len(output) > 0 {
   212  		err = fmt.Errorf("%v (output: %q)", err, strings.Join(output, "\n"))
   213  	}
   214  	if scannerErr := s.scanner.Err(); scannerErr != nil {
   215  		err = fmt.Errorf("%v (scanner error: %v)", err, scannerErr)
   216  	}
   217  	return "", err
   218  }
   219  
   220  func copyAsBase64(w *bufio.Writer, r io.Reader) error {
   221  	wrapper := newLineWrapWriter(w, base64LineLength)
   222  	encoder := base64.NewEncoder(base64.StdEncoding, wrapper)
   223  	if _, err := io.Copy(encoder, r); err != nil {
   224  		return err
   225  	}
   226  	if err := encoder.Close(); err != nil {
   227  		return err
   228  	}
   229  	if _, err := w.WriteString("\n@EOF\n"); err != nil {
   230  		return err
   231  	}
   232  	return nil
   233  }
   234  
   235  // path returns a remote absolute path for a storage object name.
   236  func (s *SSHStorage) path(name string) (string, error) {
   237  	remotepath := path.Clean(path.Join(s.remotepath, name))
   238  	if !strings.HasPrefix(remotepath, s.remotepath) {
   239  		return "", fmt.Errorf("%q escapes storage directory", name)
   240  	}
   241  	return remotepath, nil
   242  }
   243  
   244  // Get implements storage.StorageReader.Get.
   245  func (s *SSHStorage) Get(name string) (io.ReadCloser, error) {
   246  	logger.Debugf("getting %q from storage", name)
   247  	path, err := s.path(name)
   248  	if err != nil {
   249  		return nil, err
   250  	}
   251  	filename := utils.ShQuote(path)
   252  	out, err := s.runf(flockShared, "(test -e %s || (echo No such file && exit 1)) && base64 < %s", filename, filename)
   253  	if err != nil {
   254  		err := err.(SSHStorageError)
   255  		if strings.Contains(err.Output, "No such file") {
   256  			return nil, errors.NewNotFound(err, path+" not found")
   257  		}
   258  		return nil, err
   259  	}
   260  	decoded, err := base64.StdEncoding.DecodeString(out)
   261  	if err != nil {
   262  		return nil, err
   263  	}
   264  	return ioutil.NopCloser(bytes.NewBuffer(decoded)), nil
   265  }
   266  
   267  // List implements storage.StorageReader.List.
   268  func (s *SSHStorage) List(prefix string) ([]string, error) {
   269  	remotepath, err := s.path(prefix)
   270  	if err != nil {
   271  		return nil, err
   272  	}
   273  	dir, prefix := path.Split(remotepath)
   274  	quotedDir := utils.ShQuote(dir)
   275  	out, err := s.runf(flockShared, "(test -d %s && find %s -type f) || true", quotedDir, quotedDir)
   276  	if err != nil {
   277  		return nil, err
   278  	}
   279  	if out == "" {
   280  		return nil, nil
   281  	}
   282  	var names []string
   283  	for _, name := range strings.Split(out, "\n") {
   284  		if strings.HasPrefix(name[len(dir):], prefix) {
   285  			names = append(names, name[len(s.remotepath)+1:])
   286  		}
   287  	}
   288  	sort.Strings(names)
   289  	return names, nil
   290  }
   291  
   292  // URL implements storage.StorageReader.URL.
   293  func (s *SSHStorage) URL(name string) (string, error) {
   294  	path, err := s.path(name)
   295  	if err != nil {
   296  		return "", err
   297  	}
   298  	return fmt.Sprintf("sftp://%s/%s", s.host, path), nil
   299  }
   300  
   301  // DefaultConsistencyStrategy implements storage.StorageReader.ConsistencyStrategy.
   302  func (s *SSHStorage) DefaultConsistencyStrategy() utils.AttemptStrategy {
   303  	return utils.AttemptStrategy{}
   304  }
   305  
   306  // ShouldRetry is specified in the StorageReader interface.
   307  func (s *SSHStorage) ShouldRetry(err error) bool {
   308  	return false
   309  }
   310  
   311  // Put implements storage.StorageWriter.Put
   312  func (s *SSHStorage) Put(name string, r io.Reader, length int64) error {
   313  	logger.Debugf("putting %q (len %d) to storage", name, length)
   314  	path, err := s.path(name)
   315  	if err != nil {
   316  		return err
   317  	}
   318  	path = utils.ShQuote(path)
   319  	tmpdir := utils.ShQuote(s.tmpdir)
   320  
   321  	// Write to a temporary file ($TMPFILE), then mv atomically.
   322  	command := fmt.Sprintf("mkdir -p `dirname %s` && cat >| $TMPFILE", path)
   323  	command = fmt.Sprintf(
   324  		"TMPFILE=`mktemp --tmpdir=%s` && ((%s && mv $TMPFILE %s) || rm -f $TMPFILE)",
   325  		tmpdir, command, path,
   326  	)
   327  
   328  	_, err = s.run(flockExclusive, command+"\n", r, length)
   329  	return err
   330  }
   331  
   332  // Remove implements storage.StorageWriter.Remove
   333  func (s *SSHStorage) Remove(name string) error {
   334  	path, err := s.path(name)
   335  	if err != nil {
   336  		return err
   337  	}
   338  	path = utils.ShQuote(path)
   339  	_, err = s.runf(flockExclusive, "rm -f %s", path)
   340  	return err
   341  }
   342  
   343  // RemoveAll implements storage.StorageWriter.RemoveAll
   344  func (s *SSHStorage) RemoveAll() error {
   345  	_, err := s.runf(flockExclusive, "rm -fr %s/*", utils.ShQuote(s.remotepath))
   346  	return err
   347  }