github.com/rogpeppe/juju@v0.0.0-20140613142852-6337964b789e/environs/sshstorage/storage.go (about)

     1  // Copyright 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package sshstorage
     5  
     6  import (
     7  	"bufio"
     8  	"bytes"
     9  	"encoding/base64"
    10  	"fmt"
    11  	"io"
    12  	"io/ioutil"
    13  	"path"
    14  	"sort"
    15  	"strconv"
    16  	"strings"
    17  
    18  	"github.com/juju/errors"
    19  	"github.com/juju/loggo"
    20  	"github.com/juju/utils"
    21  
    22  	"github.com/juju/juju/utils/ssh"
    23  )
    24  
    25  var logger = loggo.GetLogger("juju.environs.sshstorage")
    26  
    27  // base64LineLength is the default line length for wrapping
    28  // output generated by the base64 command line utility.
    29  const base64LineLength = 76
    30  
    31  // SSHStorage implements storage.Storage.
    32  //
    33  // The storage is created under sudo, and ownership given over to the
    34  // login uid/gid. This is done so that we don't require sudo, and by
    35  // consequence, don't require a pty, so we can interact with a script
    36  // via stdin.
    37  type SSHStorage struct {
    38  	host       string
    39  	remotepath string
    40  	tmpdir     string
    41  
    42  	cmd     *ssh.Cmd
    43  	stdin   io.WriteCloser
    44  	stdout  io.ReadCloser
    45  	scanner *bufio.Scanner
    46  }
    47  
    48  var sshCommand = func(host string, command ...string) *ssh.Cmd {
    49  	return ssh.Command(host, command, nil)
    50  }
    51  
    52  type flockmode string
    53  
    54  const (
    55  	flockShared    flockmode = "-s"
    56  	flockExclusive flockmode = "-x"
    57  )
    58  
    59  type NewSSHStorageParams struct {
    60  	// Host is the host to connect to, in the format [user@]hostname.
    61  	Host string
    62  
    63  	// StorageDir is the root of the remote storage directory.
    64  	StorageDir string
    65  
    66  	// TmpDir is the remote temporary directory for storage.
    67  	// A temporary directory must be specified, and should be located on the
    68  	// same filesystem as the storage directory to ensure atomic writes.
    69  	// The temporary directory will be created when NewSSHStorage is invoked
    70  	// if it doesn't already exist; it will never be removed. NewSSHStorage
    71  	// will attempt to reassign ownership to the login user, and will return
    72  	// an error if it cannot do so.
    73  	TmpDir string
    74  }
    75  
    76  // NewSSHStorage creates a new SSHStorage, connected to the
    77  // specified host, managing state under the specified remote path.
    78  func NewSSHStorage(params NewSSHStorageParams) (*SSHStorage, error) {
    79  	if params.StorageDir == "" {
    80  		return nil, errors.New("storagedir must be specified and non-empty")
    81  	}
    82  	if params.TmpDir == "" {
    83  		return nil, errors.New("tmpdir must be specified and non-empty")
    84  	}
    85  
    86  	script := fmt.Sprintf(
    87  		"install -d -g $SUDO_GID -o $SUDO_UID %s %s",
    88  		utils.ShQuote(params.StorageDir),
    89  		utils.ShQuote(params.TmpDir),
    90  	)
    91  
    92  	cmd := sshCommand(params.Host, "sudo", "-n", "/bin/bash")
    93  	var stderr bytes.Buffer
    94  	cmd.Stderr = &stderr
    95  	cmd.Stdin = strings.NewReader(script)
    96  	if err := cmd.Run(); err != nil {
    97  		err = fmt.Errorf("failed to create storage dir: %v (%v)", err, strings.TrimSpace(stderr.String()))
    98  		return nil, err
    99  	}
   100  
   101  	// We could use sftp, but then we'd be at the mercy of
   102  	// sftp's output messages for checking errors. Instead,
   103  	// we execute an interactive bash shell.
   104  	cmd = sshCommand(params.Host, "bash")
   105  	stdin, err := cmd.StdinPipe()
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  	stdout, err := cmd.StdoutPipe()
   110  	if err != nil {
   111  		stdin.Close()
   112  		return nil, err
   113  	}
   114  	// Combine stdout and stderr, so we can easily
   115  	// get at catastrophic failure messages.
   116  	cmd.Stderr = cmd.Stdout
   117  	stor := &SSHStorage{
   118  		host:       params.Host,
   119  		remotepath: params.StorageDir,
   120  		tmpdir:     params.TmpDir,
   121  		cmd:        cmd,
   122  		stdin:      stdin,
   123  		stdout:     stdout,
   124  		scanner:    bufio.NewScanner(stdout),
   125  	}
   126  	cmd.Start()
   127  
   128  	// Verify we have write permissions.
   129  	_, err = stor.runf(flockExclusive, "touch %s", utils.ShQuote(params.StorageDir))
   130  	if err != nil {
   131  		stdin.Close()
   132  		stdout.Close()
   133  		cmd.Wait()
   134  		return nil, err
   135  	}
   136  	return stor, nil
   137  }
   138  
   139  // Close cleanly terminates the underlying SSH connection.
   140  func (s *SSHStorage) Close() error {
   141  	s.stdin.Close()
   142  	s.stdout.Close()
   143  	return s.cmd.Wait()
   144  }
   145  
   146  func (s *SSHStorage) runf(flockmode flockmode, command string, args ...interface{}) (string, error) {
   147  	command = fmt.Sprintf(command, args...)
   148  	return s.run(flockmode, command, nil, 0)
   149  }
   150  
   151  // terminate closes the stdin, and appends any output to the input error.
   152  func (s *SSHStorage) terminate(err error) error {
   153  	s.stdin.Close()
   154  	var output string
   155  	for s.scanner.Scan() {
   156  		if len(output) > 0 {
   157  			output += "\n"
   158  		}
   159  		output += s.scanner.Text()
   160  	}
   161  	if len(output) > 0 {
   162  		err = fmt.Errorf("%v (output: %q)", err, output)
   163  	}
   164  	return err
   165  }
   166  
   167  func (s *SSHStorage) run(flockmode flockmode, command string, input io.Reader, inputlen int64) (string, error) {
   168  	const rcPrefix = "JUJU-RC: "
   169  	command = fmt.Sprintf(
   170  		"SHELL=/bin/bash flock %s %s -c %s",
   171  		flockmode,
   172  		utils.ShQuote(s.remotepath),
   173  		utils.ShQuote(command),
   174  	)
   175  	stdin := bufio.NewWriter(s.stdin)
   176  	if input != nil {
   177  		command = fmt.Sprintf("base64 -d << '@EOF' | (%s)", command)
   178  	}
   179  	command = fmt.Sprintf("(%s) 2>&1; echo %s$?", command, rcPrefix)
   180  	if _, err := stdin.WriteString(command + "\n"); err != nil {
   181  		return "", fmt.Errorf("failed to write command: %v", err)
   182  	}
   183  	if input != nil {
   184  		if err := copyAsBase64(stdin, input); err != nil {
   185  			return "", s.terminate(fmt.Errorf("failed to write input: %v", err))
   186  		}
   187  	}
   188  	if err := stdin.Flush(); err != nil {
   189  		return "", s.terminate(fmt.Errorf("failed to write input: %v", err))
   190  	}
   191  	var output []string
   192  	for s.scanner.Scan() {
   193  		line := s.scanner.Text()
   194  		if strings.HasPrefix(line, rcPrefix) {
   195  			line := line[len(rcPrefix):]
   196  			rc, err := strconv.Atoi(line)
   197  			if err != nil {
   198  				return "", fmt.Errorf("failed to parse exit code %q: %v", line, err)
   199  			}
   200  			outputJoined := strings.Join(output, "\n")
   201  			if rc == 0 {
   202  				return outputJoined, nil
   203  			}
   204  			return "", SSHStorageError{outputJoined, rc}
   205  		} else {
   206  			output = append(output, line)
   207  		}
   208  	}
   209  
   210  	err := fmt.Errorf("failed to locate %q", rcPrefix)
   211  	if len(output) > 0 {
   212  		err = fmt.Errorf("%v (output: %q)", err, strings.Join(output, "\n"))
   213  	}
   214  	if scannerErr := s.scanner.Err(); scannerErr != nil {
   215  		err = fmt.Errorf("%v (scanner error: %v)", err, scannerErr)
   216  	}
   217  	return "", err
   218  }
   219  
   220  func copyAsBase64(w *bufio.Writer, r io.Reader) error {
   221  	wrapper := newLineWrapWriter(w, base64LineLength)
   222  	encoder := base64.NewEncoder(base64.StdEncoding, wrapper)
   223  	if _, err := io.Copy(encoder, r); err != nil {
   224  		return err
   225  	}
   226  	if err := encoder.Close(); err != nil {
   227  		return err
   228  	}
   229  	if _, err := w.WriteString("\n@EOF\n"); err != nil {
   230  		return err
   231  	}
   232  	return nil
   233  }
   234  
   235  // path returns a remote absolute path for a storage object name.
   236  func (s *SSHStorage) path(name string) (string, error) {
   237  	remotepath := path.Clean(path.Join(s.remotepath, name))
   238  	if !strings.HasPrefix(remotepath, s.remotepath) {
   239  		return "", fmt.Errorf("%q escapes storage directory", name)
   240  	}
   241  	return remotepath, nil
   242  }
   243  
   244  // Get implements storage.StorageReader.Get.
   245  func (s *SSHStorage) Get(name string) (io.ReadCloser, error) {
   246  	logger.Debugf("getting %q from storage", name)
   247  	path, err := s.path(name)
   248  	if err != nil {
   249  		return nil, err
   250  	}
   251  	out, err := s.runf(flockShared, "base64 < %s", utils.ShQuote(path))
   252  	if err != nil {
   253  		err := err.(SSHStorageError)
   254  		if strings.Contains(err.Output, "No such file") {
   255  			return nil, errors.NewNotFound(err, "")
   256  		}
   257  		return nil, err
   258  	}
   259  	decoded, err := base64.StdEncoding.DecodeString(out)
   260  	if err != nil {
   261  		return nil, err
   262  	}
   263  	return ioutil.NopCloser(bytes.NewBuffer(decoded)), nil
   264  }
   265  
   266  // List implements storage.StorageReader.List.
   267  func (s *SSHStorage) List(prefix string) ([]string, error) {
   268  	remotepath, err := s.path(prefix)
   269  	if err != nil {
   270  		return nil, err
   271  	}
   272  	dir, prefix := path.Split(remotepath)
   273  	quotedDir := utils.ShQuote(dir)
   274  	out, err := s.runf(flockShared, "(test -d %s && find %s -type f) || true", quotedDir, quotedDir)
   275  	if err != nil {
   276  		return nil, err
   277  	}
   278  	if out == "" {
   279  		return nil, nil
   280  	}
   281  	var names []string
   282  	for _, name := range strings.Split(out, "\n") {
   283  		if strings.HasPrefix(name[len(dir):], prefix) {
   284  			names = append(names, name[len(s.remotepath)+1:])
   285  		}
   286  	}
   287  	sort.Strings(names)
   288  	return names, nil
   289  }
   290  
   291  // URL implements storage.StorageReader.URL.
   292  func (s *SSHStorage) URL(name string) (string, error) {
   293  	path, err := s.path(name)
   294  	if err != nil {
   295  		return "", err
   296  	}
   297  	return fmt.Sprintf("sftp://%s/%s", s.host, path), nil
   298  }
   299  
   300  // DefaultConsistencyStrategy implements storage.StorageReader.ConsistencyStrategy.
   301  func (s *SSHStorage) DefaultConsistencyStrategy() utils.AttemptStrategy {
   302  	return utils.AttemptStrategy{}
   303  }
   304  
   305  // ShouldRetry is specified in the StorageReader interface.
   306  func (s *SSHStorage) ShouldRetry(err error) bool {
   307  	return false
   308  }
   309  
   310  // Put implements storage.StorageWriter.Put
   311  func (s *SSHStorage) Put(name string, r io.Reader, length int64) error {
   312  	logger.Debugf("putting %q (len %d) to storage", name, length)
   313  	path, err := s.path(name)
   314  	if err != nil {
   315  		return err
   316  	}
   317  	path = utils.ShQuote(path)
   318  	tmpdir := utils.ShQuote(s.tmpdir)
   319  
   320  	// Write to a temporary file ($TMPFILE), then mv atomically.
   321  	command := fmt.Sprintf("mkdir -p `dirname %s` && cat > $TMPFILE", path)
   322  	command = fmt.Sprintf(
   323  		"TMPFILE=`mktemp --tmpdir=%s` && ((%s && mv $TMPFILE %s) || rm -f $TMPFILE)",
   324  		tmpdir, command, path,
   325  	)
   326  
   327  	_, err = s.run(flockExclusive, command+"\n", r, length)
   328  	return err
   329  }
   330  
   331  // Remove implements storage.StorageWriter.Remove
   332  func (s *SSHStorage) Remove(name string) error {
   333  	path, err := s.path(name)
   334  	if err != nil {
   335  		return err
   336  	}
   337  	path = utils.ShQuote(path)
   338  	_, err = s.runf(flockExclusive, "rm -f %s", path)
   339  	return err
   340  }
   341  
   342  // RemoveAll implements storage.StorageWriter.RemoveAll
   343  func (s *SSHStorage) RemoveAll() error {
   344  	_, err := s.runf(flockExclusive, "rm -fr %s/*", utils.ShQuote(s.remotepath))
   345  	return err
   346  }