github.com/meulengracht/snapd@v0.0.0-20210719210640-8bde69bcc84e/overlord/snapshotstate/backend/backend.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2018-2020 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package backend
    21  
    22  import (
    23  	"archive/tar"
    24  	"archive/zip"
    25  	"bytes"
    26  	"context"
    27  	"crypto"
    28  	"encoding/json"
    29  	"errors"
    30  	"fmt"
    31  	"io"
    32  	"io/ioutil"
    33  	"os"
    34  	"path"
    35  	"path/filepath"
    36  	"regexp"
    37  	"runtime"
    38  	"sort"
    39  	"strconv"
    40  	"strings"
    41  	"syscall"
    42  	"time"
    43  
    44  	"github.com/snapcore/snapd/client"
    45  	"github.com/snapcore/snapd/dirs"
    46  	"github.com/snapcore/snapd/logger"
    47  	"github.com/snapcore/snapd/osutil"
    48  	"github.com/snapcore/snapd/snap"
    49  	"github.com/snapcore/snapd/snapdenv"
    50  	"github.com/snapcore/snapd/strutil"
    51  )
    52  
    53  const (
    54  	archiveName  = "archive.tgz"
    55  	metadataName = "meta.json"
    56  	metaHashName = "meta.sha3_384"
    57  
    58  	userArchivePrefix = "user/"
    59  	userArchiveSuffix = ".tgz"
    60  )
    61  
    62  var (
    63  	// Stop is used to ask Iter to stop iteration, without it being an error.
    64  	Stop = errors.New("stop iteration")
    65  
    66  	osOpen      = os.Open
    67  	dirNames    = (*os.File).Readdirnames
    68  	backendOpen = Open
    69  	timeNow     = time.Now
    70  
    71  	usersForUsernames = usersForUsernamesImpl
    72  )
    73  
    74  // LastSnapshotSetID returns the highest set id number for the snapshots stored
    75  // in snapshots directory; set ids are inferred from the filenames.
    76  func LastSnapshotSetID() (uint64, error) {
    77  	dir, err := osOpen(dirs.SnapshotsDir)
    78  	if err != nil {
    79  		if osutil.IsDirNotExist(err) {
    80  			// no snapshots
    81  			return 0, nil
    82  		}
    83  		return 0, fmt.Errorf("cannot open snapshots directory: %v", err)
    84  	}
    85  	defer dir.Close()
    86  
    87  	var maxSetID uint64
    88  
    89  	var readErr error
    90  	for readErr == nil {
    91  		var names []string
    92  		// note os.Readdirnames can return a non-empty names and a non-nil err
    93  		names, readErr = dirNames(dir, 100)
    94  		for _, name := range names {
    95  			if ok, setID := isSnapshotFilename(name); ok {
    96  				if setID > maxSetID {
    97  					maxSetID = setID
    98  				}
    99  			}
   100  		}
   101  	}
   102  	if readErr != nil && readErr != io.EOF {
   103  		return 0, readErr
   104  	}
   105  	return maxSetID, nil
   106  }
   107  
   108  // Iter loops over all snapshots in the snapshots directory, applying the given
   109  // function to each. The snapshot will be closed after the function returns. If
   110  // the function returns an error, iteration is stopped (and if the error isn't
   111  // Stop, it's returned as the error of the iterator).
   112  func Iter(ctx context.Context, f func(*Reader) error) error {
   113  	if err := ctx.Err(); err != nil {
   114  		return err
   115  	}
   116  
   117  	dir, err := osOpen(dirs.SnapshotsDir)
   118  	if err != nil {
   119  		if osutil.IsDirNotExist(err) {
   120  			// no dir -> no snapshots
   121  			return nil
   122  		}
   123  		return fmt.Errorf("cannot open snapshots directory: %v", err)
   124  	}
   125  	defer dir.Close()
   126  
   127  	importsInProgress := map[uint64]bool{}
   128  	var names []string
   129  	var readErr error
   130  	for readErr == nil && err == nil {
   131  		names, readErr = dirNames(dir, 100)
   132  		// note os.Readdirnames can return a non-empty names and a non-nil err
   133  		for _, name := range names {
   134  			if err = ctx.Err(); err != nil {
   135  				break
   136  			}
   137  
   138  			// filter out non-snapshot directory entries
   139  			ok, setID := isSnapshotFilename(name)
   140  			if !ok {
   141  				continue
   142  			}
   143  			// keep track of in-progress in a map as well
   144  			// to avoid races. E.g.:
   145  			// 1. The dirNnames() are read
   146  			// 2. 99_some-snap_1.0_x1.zip is returned
   147  			// 3. the code checks if 99_importing is there,
   148  			//    it is so 99_some-snap is skipped
   149  			// 4. other snapshots are examined
   150  			// 5. in-parallel 99_importing finishes
   151  			// 7. 99_other-snap_1.0_x1.zip is now examined
   152  			// 8. code checks if 99_importing is there, but it
   153  			//    is no longer there because import
   154  			//    finished in the meantime. We still
   155  			//    want to not call the callback with
   156  			//    99_other-snap or the callback would get
   157  			//    an incomplete view about 99_snapshot.
   158  			if importsInProgress[setID] {
   159  				continue
   160  			}
   161  			if importInProgressFor(setID) {
   162  				importsInProgress[setID] = true
   163  				continue
   164  			}
   165  
   166  			filename := filepath.Join(dirs.SnapshotsDir, name)
   167  			reader, openError := backendOpen(filename, setID)
   168  			// reader can be non-nil even when openError is not nil (in
   169  			// which case reader.Broken will have a reason). f can
   170  			// check and either ignore or return an error when
   171  			// finding a broken snapshot.
   172  			if reader != nil {
   173  				err = f(reader)
   174  			} else {
   175  				// TODO: use warnings instead
   176  				logger.Noticef("Cannot open snapshot %q: %v.", name, openError)
   177  			}
   178  			if openError == nil {
   179  				// if openError was nil the snapshot was opened and needs closing
   180  				if closeError := reader.Close(); err == nil {
   181  					err = closeError
   182  				}
   183  			}
   184  			if err != nil {
   185  				break
   186  			}
   187  		}
   188  	}
   189  
   190  	if readErr != nil && readErr != io.EOF {
   191  		return readErr
   192  	}
   193  
   194  	if err == Stop {
   195  		err = nil
   196  	}
   197  
   198  	return err
   199  }
   200  
   201  // List valid snapshots sets.
   202  func List(ctx context.Context, setID uint64, snapNames []string) ([]client.SnapshotSet, error) {
   203  	setshots := map[uint64][]*client.Snapshot{}
   204  	err := Iter(ctx, func(reader *Reader) error {
   205  		if setID == 0 || reader.SetID == setID {
   206  			if len(snapNames) == 0 || strutil.ListContains(snapNames, reader.Snap) {
   207  				setshots[reader.SetID] = append(setshots[reader.SetID], &reader.Snapshot)
   208  			}
   209  		}
   210  		return nil
   211  	})
   212  
   213  	sets := make([]client.SnapshotSet, 0, len(setshots))
   214  	for id, shots := range setshots {
   215  		sort.Sort(bySnap(shots))
   216  		sets = append(sets, client.SnapshotSet{ID: id, Snapshots: shots})
   217  	}
   218  
   219  	sort.Sort(byID(sets))
   220  
   221  	return sets, err
   222  }
   223  
   224  // Filename of the given client.Snapshot in this backend.
   225  func Filename(snapshot *client.Snapshot) string {
   226  	// this _needs_ the snap name and version to be valid
   227  	return filepath.Join(dirs.SnapshotsDir, fmt.Sprintf("%d_%s_%s_%s.zip", snapshot.SetID, snapshot.Snap, snapshot.Version, snapshot.Revision))
   228  }
   229  
   230  // isSnapshotFilename checks if the given filePath is a snapshot file name, i.e.
   231  // if it starts with a numeric set id and ends with .zip extension;
   232  // filePath can be just a file name, or a full path.
   233  func isSnapshotFilename(filePath string) (ok bool, setID uint64) {
   234  	fname := filepath.Base(filePath)
   235  	// XXX: we could use a regexp here to match very precisely all the elements
   236  	// of the filename following Filename() above, but perhaps it's better no to
   237  	// go overboard with it in case the format evolves in the future. Only check
   238  	// if the name starts with a set-id and ends with .zip.
   239  	//
   240  	// Filename is "<sid>_<snapName>_version_revision.zip", e.g. "16_snapcraft_4.2_5407.zip"
   241  	ext := filepath.Ext(fname)
   242  	if ext != ".zip" {
   243  		return false, 0
   244  	}
   245  	parts := strings.SplitN(fname, "_", 2)
   246  	if len(parts) != 2 {
   247  		return false, 0
   248  	}
   249  	// invalid: no parts following <sid>_
   250  	if parts[1] == ext {
   251  		return false, 0
   252  	}
   253  	id, err := strconv.Atoi(parts[0])
   254  	if err != nil {
   255  		return false, 0
   256  	}
   257  	return true, uint64(id)
   258  }
   259  
   260  // EstimateSnapshotSize calculates estimated size of the snapshot.
   261  func EstimateSnapshotSize(si *snap.Info, usernames []string) (uint64, error) {
   262  	var total uint64
   263  	calculateSize := func(path string, finfo os.FileInfo, err error) error {
   264  		if finfo.Mode().IsRegular() {
   265  			total += uint64(finfo.Size())
   266  		}
   267  		return err
   268  	}
   269  
   270  	visitDir := func(dir string) error {
   271  		exists, isDir, err := osutil.DirExists(dir)
   272  		if err != nil {
   273  			return err
   274  		}
   275  		if !(exists && isDir) {
   276  			return nil
   277  		}
   278  		return filepath.Walk(dir, calculateSize)
   279  	}
   280  
   281  	for _, dir := range []string{si.DataDir(), si.CommonDataDir()} {
   282  		if err := visitDir(dir); err != nil {
   283  			return 0, err
   284  		}
   285  	}
   286  
   287  	users, err := usersForUsernames(usernames)
   288  	if err != nil {
   289  		return 0, err
   290  	}
   291  	for _, usr := range users {
   292  		if err := visitDir(si.UserDataDir(usr.HomeDir)); err != nil {
   293  			return 0, err
   294  		}
   295  		if err := visitDir(si.UserCommonDataDir(usr.HomeDir)); err != nil {
   296  			return 0, err
   297  		}
   298  	}
   299  
   300  	// XXX: we could use a typical compression factor here
   301  	return total, nil
   302  }
   303  
   304  // Save a snapshot
   305  func Save(ctx context.Context, id uint64, si *snap.Info, cfg map[string]interface{}, usernames []string) (*client.Snapshot, error) {
   306  	if err := os.MkdirAll(dirs.SnapshotsDir, 0700); err != nil {
   307  		return nil, err
   308  	}
   309  
   310  	snapshot := &client.Snapshot{
   311  		SetID:    id,
   312  		Snap:     si.InstanceName(),
   313  		SnapID:   si.SnapID,
   314  		Revision: si.Revision,
   315  		Version:  si.Version,
   316  		Epoch:    si.Epoch,
   317  		Time:     timeNow(),
   318  		SHA3_384: make(map[string]string),
   319  		Size:     0,
   320  		Conf:     cfg,
   321  		// Note: Auto is no longer set in the Snapshot.
   322  	}
   323  
   324  	aw, err := osutil.NewAtomicFile(Filename(snapshot), 0600, 0, osutil.NoChown, osutil.NoChown)
   325  	if err != nil {
   326  		return nil, err
   327  	}
   328  	// if things worked, we'll commit (and Cancel becomes a NOP)
   329  	defer aw.Cancel()
   330  
   331  	w := zip.NewWriter(aw)
   332  	defer w.Close() // note this does not close the file descriptor (that's done by hand on the atomic writer, above)
   333  	if err := addDirToZip(ctx, snapshot, w, "root", archiveName, si.DataDir()); err != nil {
   334  		return nil, err
   335  	}
   336  
   337  	users, err := usersForUsernames(usernames)
   338  	if err != nil {
   339  		return nil, err
   340  	}
   341  
   342  	for _, usr := range users {
   343  		if err := addDirToZip(ctx, snapshot, w, usr.Username, userArchiveName(usr), si.UserDataDir(usr.HomeDir)); err != nil {
   344  			return nil, err
   345  		}
   346  	}
   347  
   348  	metaWriter, err := w.Create(metadataName)
   349  	if err != nil {
   350  		return nil, err
   351  	}
   352  
   353  	hasher := crypto.SHA3_384.New()
   354  	enc := json.NewEncoder(io.MultiWriter(metaWriter, hasher))
   355  	if err := enc.Encode(snapshot); err != nil {
   356  		return nil, err
   357  	}
   358  
   359  	hashWriter, err := w.Create(metaHashName)
   360  	if err != nil {
   361  		return nil, err
   362  	}
   363  	fmt.Fprintf(hashWriter, "%x\n", hasher.Sum(nil))
   364  	if err := w.Close(); err != nil {
   365  		return nil, err
   366  	}
   367  
   368  	if err := ctx.Err(); err != nil {
   369  		return nil, err
   370  	}
   371  
   372  	if err := aw.Commit(); err != nil {
   373  		return nil, err
   374  	}
   375  
   376  	return snapshot, nil
   377  }
   378  
   379  var isTesting = snapdenv.Testing()
   380  
   381  func addDirToZip(ctx context.Context, snapshot *client.Snapshot, w *zip.Writer, username string, entry, dir string) error {
   382  	parent, revdir := filepath.Split(dir)
   383  	exists, isDir, err := osutil.DirExists(parent)
   384  	if err != nil {
   385  		return err
   386  	}
   387  	if exists && !isDir {
   388  		logger.Noticef("Not saving directories under %q in snapshot #%d of %q as it is not a directory.", parent, snapshot.SetID, snapshot.Snap)
   389  		return nil
   390  	}
   391  	if !exists {
   392  		logger.Debugf("Not saving directories under %q in snapshot #%d of %q as it is does not exist.", parent, snapshot.SetID, snapshot.Snap)
   393  		return nil
   394  	}
   395  	tarArgs := []string{
   396  		"--create",
   397  		"--sparse", "--gzip",
   398  		"--format", "gnu",
   399  		"--directory", parent,
   400  	}
   401  
   402  	noRev, noCommon := true, true
   403  
   404  	exists, isDir, err = osutil.DirExists(dir)
   405  	if err != nil {
   406  		return err
   407  	}
   408  	switch {
   409  	case exists && isDir:
   410  		tarArgs = append(tarArgs, revdir)
   411  		noRev = false
   412  	case exists && !isDir:
   413  		logger.Noticef("Not saving %q in snapshot #%d of %q as it is not a directory.", dir, snapshot.SetID, snapshot.Snap)
   414  	case !exists:
   415  		logger.Debugf("Not saving %q in snapshot #%d of %q as it is does not exist.", dir, snapshot.SetID, snapshot.Snap)
   416  	}
   417  
   418  	common := filepath.Join(parent, "common")
   419  	exists, isDir, err = osutil.DirExists(common)
   420  	if err != nil {
   421  		return err
   422  	}
   423  	switch {
   424  	case exists && isDir:
   425  		tarArgs = append(tarArgs, "common")
   426  		noCommon = false
   427  	case exists && !isDir:
   428  		logger.Noticef("Not saving %q in snapshot #%d of %q as it is not a directory.", common, snapshot.SetID, snapshot.Snap)
   429  	case !exists:
   430  		logger.Debugf("Not saving %q in snapshot #%d of %q as it is does not exist.", common, snapshot.SetID, snapshot.Snap)
   431  	}
   432  
   433  	if noCommon && noRev {
   434  		return nil
   435  	}
   436  
   437  	archiveWriter, err := w.CreateHeader(&zip.FileHeader{Name: entry})
   438  	if err != nil {
   439  		return err
   440  	}
   441  
   442  	var sz osutil.Sizer
   443  	hasher := crypto.SHA3_384.New()
   444  
   445  	cmd := tarAsUser(username, tarArgs...)
   446  	cmd.Stdout = io.MultiWriter(archiveWriter, hasher, &sz)
   447  	matchCounter := &strutil.MatchCounter{
   448  		// keep at most 5 matches
   449  		N: 5,
   450  		// keep the last lines only, those likely contain the reason for
   451  		// fatal errors
   452  		LastN: true,
   453  	}
   454  	cmd.Stderr = matchCounter
   455  	if isTesting {
   456  		matchCounter.N = -1
   457  		cmd.Stderr = io.MultiWriter(os.Stderr, matchCounter)
   458  	}
   459  	if err := osutil.RunWithContext(ctx, cmd); err != nil {
   460  		matches, count := matchCounter.Matches()
   461  		if count > 0 {
   462  			note := ""
   463  			if count > 5 {
   464  				note = fmt.Sprintf(" (showing last 5 lines out of %d)", count)
   465  			}
   466  			// we have at most 5 matches here
   467  			errStr := strings.Join(matches, "\n")
   468  			return fmt.Errorf("cannot create archive%s:\n%s", note, errStr)
   469  		}
   470  		return fmt.Errorf("tar failed: %v", err)
   471  	}
   472  
   473  	snapshot.SHA3_384[entry] = fmt.Sprintf("%x", hasher.Sum(nil))
   474  	snapshot.Size += sz.Size()
   475  
   476  	return nil
   477  }
   478  
   479  var ErrCannotCancel = errors.New("cannot cancel: import already finished")
   480  
   481  // multiError collects multiple errors that affected an operation.
   482  type multiError struct {
   483  	header string
   484  	errs   []error
   485  }
   486  
   487  // newMultiError returns a new multiError struct initialized with
   488  // the given format string that explains what operation potentially
   489  // went wrong. multiError can be nested and will render correctly
   490  // in these cases.
   491  func newMultiError(header string, errs []error) error {
   492  	return &multiError{header: header, errs: errs}
   493  }
   494  
   495  // Error formats the error string.
   496  func (me *multiError) Error() string {
   497  	return me.nestedError(0)
   498  }
   499  
   500  // helper to ensure formating of nested multiErrors works.
   501  func (me *multiError) nestedError(level int) string {
   502  	indent := strings.Repeat(" ", level)
   503  	buf := bytes.NewBufferString(fmt.Sprintf("%s:\n", me.header))
   504  	if level > 8 {
   505  		return "circular or too deep error nesting (max 8)?!"
   506  	}
   507  	for i, err := range me.errs {
   508  		switch v := err.(type) {
   509  		case *multiError:
   510  			fmt.Fprintf(buf, "%s- %v", indent, v.nestedError(level+1))
   511  		default:
   512  			fmt.Fprintf(buf, "%s- %v", indent, err)
   513  		}
   514  		if i < len(me.errs)-1 {
   515  			fmt.Fprintf(buf, "\n")
   516  		}
   517  	}
   518  	return buf.String()
   519  }
   520  
   521  var (
   522  	importingFnRegexp = regexp.MustCompile("^([0-9]+)_importing$")
   523  	importingFnGlob   = "[0-9]*_importing"
   524  	importingFnFmt    = "%d_importing"
   525  	importingForIDFmt = "%d_*.zip"
   526  )
   527  
   528  // importInProgressFor return true if the given snapshot id has an import
   529  // that is in progress.
   530  func importInProgressFor(setID uint64) bool {
   531  	return newImportTransaction(setID).InProgress()
   532  }
   533  
   534  // importTransaction keeps track of the given snapshot ID import and
   535  // ensures it can be committed/cancelled in an atomic way.
   536  //
   537  // Start() must be called before the first data is imported. When the
   538  // import is successful Commit() should be called.
   539  //
   540  // Cancel() will cancel the given import and cleanup. It's always safe
   541  // to defer a Cancel() it will just return a "ErrCannotCancel" after
   542  // a commit.
   543  type importTransaction struct {
   544  	id        uint64
   545  	lockPath  string
   546  	committed bool
   547  }
   548  
   549  // newImportTransaction creates a new importTransaction for the given
   550  // snapshot id.
   551  func newImportTransaction(setID uint64) *importTransaction {
   552  	return &importTransaction{
   553  		id:       setID,
   554  		lockPath: filepath.Join(dirs.SnapshotsDir, fmt.Sprintf(importingFnFmt, setID)),
   555  	}
   556  }
   557  
   558  // newImportTransactionFromImportFile creates a new importTransaction
   559  // for the given import file path. It may return an error if an
   560  // invalid file was specified.
   561  func newImportTransactionFromImportFile(p string) (*importTransaction, error) {
   562  	parts := importingFnRegexp.FindStringSubmatch(path.Base(p))
   563  	if len(parts) != 2 {
   564  		return nil, fmt.Errorf("cannot determine snapshot id from %q", p)
   565  	}
   566  	setID, err := strconv.ParseUint(parts[1], 10, 64)
   567  	if err != nil {
   568  		return nil, err
   569  	}
   570  	return newImportTransaction(setID), nil
   571  }
   572  
   573  // Start marks the start of a snapshot import
   574  func (t *importTransaction) Start() error {
   575  	return t.lock()
   576  }
   577  
   578  // InProgress returns true if there is an import for this transactions
   579  // snapshot ID already.
   580  func (t *importTransaction) InProgress() bool {
   581  	return osutil.FileExists(t.lockPath)
   582  }
   583  
   584  // Cancel cancels a snapshot import and cleanups any files on disk belonging
   585  // to this snapshot ID.
   586  func (t *importTransaction) Cancel() error {
   587  	if t.committed {
   588  		return ErrCannotCancel
   589  	}
   590  	inProgressImports, err := filepath.Glob(filepath.Join(dirs.SnapshotsDir, fmt.Sprintf(importingForIDFmt, t.id)))
   591  	if err != nil {
   592  		return err
   593  	}
   594  	var errs []error
   595  	for _, p := range inProgressImports {
   596  		if err := os.Remove(p); err != nil {
   597  			errs = append(errs, err)
   598  		}
   599  	}
   600  	if err := t.unlock(); err != nil {
   601  		errs = append(errs, err)
   602  	}
   603  	if len(errs) > 0 {
   604  		return newMultiError(fmt.Sprintf("cannot cancel import for set id %d", t.id), errs)
   605  	}
   606  	return nil
   607  }
   608  
   609  // Commit will commit a given transaction
   610  func (t *importTransaction) Commit() error {
   611  	if err := t.unlock(); err != nil {
   612  		return err
   613  	}
   614  	t.committed = true
   615  	return nil
   616  }
   617  
   618  func (t *importTransaction) lock() error {
   619  	return ioutil.WriteFile(t.lockPath, nil, 0644)
   620  }
   621  
   622  func (t *importTransaction) unlock() error {
   623  	return os.Remove(t.lockPath)
   624  }
   625  
   626  var filepathGlob = filepath.Glob
   627  
   628  // CleanupAbandondedImports will clean any import that is in progress.
   629  // This is meant to be called at startup of snapd before any real imports
   630  // happen. It is not safe to run this concurrently with any other snapshot
   631  // operation.
   632  //
   633  // The amount of snapshots cleaned is returned and an error if one or
   634  // more cleanups did not succeed.
   635  func CleanupAbandondedImports() (cleaned int, err error) {
   636  	inProgressSnapshots, err := filepathGlob(filepath.Join(dirs.SnapshotsDir, importingFnGlob))
   637  	if err != nil {
   638  		return 0, err
   639  	}
   640  
   641  	var errs []error
   642  	for _, p := range inProgressSnapshots {
   643  		tr, err := newImportTransactionFromImportFile(p)
   644  		if err != nil {
   645  			errs = append(errs, err)
   646  			continue
   647  		}
   648  		if err := tr.Cancel(); err != nil {
   649  			errs = append(errs, err)
   650  		} else {
   651  			cleaned++
   652  		}
   653  	}
   654  	if len(errs) > 0 {
   655  		return cleaned, newMultiError("cannot cleanup imports", errs)
   656  	}
   657  	return cleaned, nil
   658  }
   659  
   660  // ImportFlags carries extra flags to drive import behavior.
   661  type ImportFlags struct {
   662  	// noDuplicatedImportCheck tells import not to check for existing snapshot
   663  	// with same content hash (and not report DuplicatedSnapshotImportError).
   664  	NoDuplicatedImportCheck bool
   665  }
   666  
   667  // Import a snapshot from the export file format
   668  func Import(ctx context.Context, id uint64, r io.Reader, flags *ImportFlags) (snapNames []string, err error) {
   669  	if err := os.MkdirAll(dirs.SnapshotsDir, 0700); err != nil {
   670  		return nil, err
   671  	}
   672  
   673  	errPrefix := fmt.Sprintf("cannot import snapshot %d", id)
   674  
   675  	tr := newImportTransaction(id)
   676  	if tr.InProgress() {
   677  		return nil, fmt.Errorf("%s: already in progress for this set id", errPrefix)
   678  	}
   679  	if err := tr.Start(); err != nil {
   680  		return nil, err
   681  	}
   682  	// Cancel once Committed is a NOP
   683  	defer tr.Cancel()
   684  
   685  	// Unpack and validate the streamed data
   686  	//
   687  	// XXX: this will leak snapshot IDs, i.e. we allocate a new
   688  	// snapshot ID before but then we error here because of e.g.
   689  	// duplicated import attempts
   690  	snapNames, err = unpackVerifySnapshotImport(ctx, r, id, flags)
   691  	if err != nil {
   692  		if _, ok := err.(DuplicatedSnapshotImportError); ok {
   693  			return nil, err
   694  		}
   695  		return nil, fmt.Errorf("%s: %v", errPrefix, err)
   696  	}
   697  	if err := tr.Commit(); err != nil {
   698  		return nil, err
   699  	}
   700  
   701  	return snapNames, nil
   702  }
   703  
   704  func writeOneSnapshotFile(targetPath string, tr io.Reader) error {
   705  	t, err := os.OpenFile(targetPath, os.O_CREATE|os.O_RDWR, 0600)
   706  	if err != nil {
   707  		return fmt.Errorf("cannot create snapshot file %q: %v", targetPath, err)
   708  	}
   709  	defer t.Close()
   710  
   711  	if _, err := io.Copy(t, tr); err != nil {
   712  		return fmt.Errorf("cannot write snapshot file %q: %v", targetPath, err)
   713  	}
   714  	return nil
   715  }
   716  
   717  type DuplicatedSnapshotImportError struct {
   718  	SetID     uint64
   719  	SnapNames []string
   720  }
   721  
   722  func (e DuplicatedSnapshotImportError) Error() string {
   723  	return fmt.Sprintf("cannot import snapshot, already available as snapshot id %v", e.SetID)
   724  }
   725  
   726  func checkDuplicatedSnapshotSetWithContentHash(ctx context.Context, contentHash []byte) error {
   727  	snapshotSetMap := map[uint64]client.SnapshotSet{}
   728  
   729  	// XXX: deal with import in progress here
   730  
   731  	// get all current snapshotSets
   732  	err := Iter(ctx, func(reader *Reader) error {
   733  		ss := snapshotSetMap[reader.SetID]
   734  		ss.Snapshots = append(ss.Snapshots, &reader.Snapshot)
   735  		snapshotSetMap[reader.SetID] = ss
   736  		return nil
   737  	})
   738  	if err != nil {
   739  		return fmt.Errorf("cannot calculate snapshot set hashes: %v", err)
   740  	}
   741  
   742  	for setID, ss := range snapshotSetMap {
   743  		h, err := ss.ContentHash()
   744  		if err != nil {
   745  			return fmt.Errorf("cannot calculate content hash for %v: %v", setID, err)
   746  		}
   747  		if bytes.Equal(h, contentHash) {
   748  			var snapNames []string
   749  			for _, snapshot := range ss.Snapshots {
   750  				snapNames = append(snapNames, snapshot.Snap)
   751  			}
   752  			return DuplicatedSnapshotImportError{SetID: setID, SnapNames: snapNames}
   753  		}
   754  	}
   755  	return nil
   756  }
   757  
   758  func unpackVerifySnapshotImport(ctx context.Context, r io.Reader, realSetID uint64, flags *ImportFlags) (snapNames []string, err error) {
   759  	var exportFound bool
   760  
   761  	tr := tar.NewReader(r)
   762  	var tarErr error
   763  	var header *tar.Header
   764  
   765  	if flags == nil {
   766  		flags = &ImportFlags{}
   767  	}
   768  
   769  	for tarErr == nil {
   770  		header, tarErr = tr.Next()
   771  		if tarErr == io.EOF {
   772  			break
   773  		}
   774  		switch {
   775  		case tarErr != nil:
   776  			return nil, fmt.Errorf("cannot read snapshot import: %v", tarErr)
   777  		case header == nil:
   778  			// should not happen
   779  			return nil, fmt.Errorf("tar header not found")
   780  		case header.Typeflag == tar.TypeDir:
   781  			return nil, errors.New("unexpected directory in import file")
   782  		}
   783  
   784  		if header.Name == "content.json" {
   785  			var ej contentJSON
   786  			dec := json.NewDecoder(tr)
   787  			if err := dec.Decode(&ej); err != nil {
   788  				return nil, err
   789  			}
   790  			if !flags.NoDuplicatedImportCheck {
   791  				// XXX: this is potentially slow as it needs
   792  				//      to open all snapshots files and read a
   793  				//      small amount of data from them
   794  				if err := checkDuplicatedSnapshotSetWithContentHash(ctx, ej.ContentHash); err != nil {
   795  					return nil, err
   796  				}
   797  			}
   798  			continue
   799  		}
   800  
   801  		if header.Name == "export.json" {
   802  			// XXX: read into memory and validate once we
   803  			// hashes in export.json
   804  			exportFound = true
   805  			continue
   806  		}
   807  
   808  		// Format of the snapshot import is:
   809  		//     $setID_.....
   810  		// But because the setID is local this will not be correct
   811  		// for our system and we need to discard this setID.
   812  		//
   813  		// So chop off the incorrect (old) setID and just use
   814  		// the rest that is still valid.
   815  		l := strings.SplitN(header.Name, "_", 2)
   816  		if len(l) != 2 {
   817  			return nil, fmt.Errorf("unexpected filename in import stream: %v", header.Name)
   818  		}
   819  		targetPath := path.Join(dirs.SnapshotsDir, fmt.Sprintf("%d_%s", realSetID, l[1]))
   820  		if err := writeOneSnapshotFile(targetPath, tr); err != nil {
   821  			return snapNames, err
   822  		}
   823  
   824  		r, err := backendOpen(targetPath, realSetID)
   825  		if err != nil {
   826  			return snapNames, fmt.Errorf("cannot open snapshot: %v", err)
   827  		}
   828  		err = r.Check(context.TODO(), nil)
   829  		r.Close()
   830  		snapNames = append(snapNames, r.Snap)
   831  		if err != nil {
   832  			return snapNames, fmt.Errorf("validation failed for %q: %v", targetPath, err)
   833  		}
   834  	}
   835  
   836  	if !exportFound {
   837  		return nil, fmt.Errorf("no export.json file in uploaded data")
   838  	}
   839  	// XXX: validate using the unmarshalled export.json hashes here
   840  
   841  	return snapNames, nil
   842  }
   843  
   844  type exportMetadata struct {
   845  	Format int       `json:"format"`
   846  	Date   time.Time `json:"date"`
   847  	Files  []string  `json:"files"`
   848  }
   849  
   850  type SnapshotExport struct {
   851  	// open snapshot files
   852  	snapshotFiles []*os.File
   853  
   854  	// contentHash of the full snapshot
   855  	contentHash []byte
   856  
   857  	// remember setID mostly for nicer errors
   858  	setID uint64
   859  
   860  	// cached size, needs to be calculated with CalculateSize
   861  	size int64
   862  }
   863  
   864  // NewSnapshotExport will return a SnapshotExport structure. It must be
   865  // Close()ed after use to avoid leaking file descriptors.
   866  func NewSnapshotExport(ctx context.Context, setID uint64) (se *SnapshotExport, err error) {
   867  	var snapshotFiles []*os.File
   868  	var snapshotSet client.SnapshotSet
   869  
   870  	defer func() {
   871  		// cleanup any open FDs if anything goes wrong
   872  		if err != nil {
   873  			for _, f := range snapshotFiles {
   874  				f.Close()
   875  			}
   876  		}
   877  	}()
   878  
   879  	// Open all files first and keep the file descriptors
   880  	// open. The caller should have locked the state so that no
   881  	// delete/change snapshot operations can happen while the
   882  	// files are getting opened.
   883  	err = Iter(ctx, func(reader *Reader) error {
   884  		if reader.SetID == setID {
   885  			snapshotSet.Snapshots = append(snapshotSet.Snapshots, &reader.Snapshot)
   886  
   887  			// Duplicate the file descriptor of the reader
   888  			// we were handed as Iter() closes those as
   889  			// soon as this unnamed returns. We re-package
   890  			// the file descriptor into snapshotFiles
   891  			// below.
   892  			fd, err := syscall.Dup(int(reader.Fd()))
   893  			if err != nil {
   894  				return fmt.Errorf("cannot duplicate descriptor: %v", err)
   895  			}
   896  			f := os.NewFile(uintptr(fd), reader.Name())
   897  			if f == nil {
   898  				return fmt.Errorf("cannot open file from descriptor %d", fd)
   899  			}
   900  			snapshotFiles = append(snapshotFiles, f)
   901  		}
   902  		return nil
   903  	})
   904  	if err != nil {
   905  		return nil, fmt.Errorf("cannot export snapshot %v: %v", setID, err)
   906  	}
   907  	if len(snapshotFiles) == 0 {
   908  		return nil, fmt.Errorf("no snapshot data found for %v", setID)
   909  	}
   910  
   911  	h, err := snapshotSet.ContentHash()
   912  	if err != nil {
   913  		return nil, fmt.Errorf("cannot calculate content hash for snapshot export %v: %v", setID, err)
   914  	}
   915  	se = &SnapshotExport{snapshotFiles: snapshotFiles, setID: setID, contentHash: h}
   916  
   917  	// ensure we never leak FDs even if the user does not call close
   918  	runtime.SetFinalizer(se, (*SnapshotExport).Close)
   919  
   920  	return se, nil
   921  }
   922  
   923  // Init will calculate the snapshot size. This can take some time
   924  // so it should be called without any locks. The SnapshotExport
   925  // keeps the FDs open so even files moved/deleted will be found.
   926  func (se *SnapshotExport) Init() error {
   927  	// Export once into a dummy writer so that we can set the size
   928  	// of the export. This is then used to set the Content-Length
   929  	// in the response correctly.
   930  	//
   931  	// Note that the size of the generated tar could change if the
   932  	// time switches between this export and the export we stream
   933  	// to the client to a time after the year 2242. This is unlikely
   934  	// but a known issue with this approach here.
   935  	var sz osutil.Sizer
   936  	if err := se.StreamTo(&sz); err != nil {
   937  		return fmt.Errorf("cannot calculcate the size for %v: %s", se.setID, err)
   938  	}
   939  	se.size = sz.Size()
   940  	return nil
   941  }
   942  
   943  func (se *SnapshotExport) Size() int64 {
   944  	return se.size
   945  }
   946  
   947  func (se *SnapshotExport) Close() {
   948  	for _, f := range se.snapshotFiles {
   949  		f.Close()
   950  	}
   951  	se.snapshotFiles = nil
   952  }
   953  
   954  type contentJSON struct {
   955  	ContentHash []byte `json:"content-hash"`
   956  }
   957  
   958  func (se *SnapshotExport) StreamTo(w io.Writer) error {
   959  	// write out a tar
   960  	var files []string
   961  	tw := tar.NewWriter(w)
   962  	defer tw.Close()
   963  
   964  	// export contentHash as content.json
   965  	h, err := json.Marshal(contentJSON{se.contentHash})
   966  	if err != nil {
   967  		return err
   968  	}
   969  	hdr := &tar.Header{
   970  		Typeflag: tar.TypeReg,
   971  		Name:     "content.json",
   972  		Size:     int64(len(h)),
   973  		Mode:     0640,
   974  		ModTime:  timeNow(),
   975  	}
   976  	if err := tw.WriteHeader(hdr); err != nil {
   977  		return err
   978  	}
   979  	if _, err := tw.Write(h); err != nil {
   980  		return err
   981  	}
   982  
   983  	// write out the individual snapshots
   984  	for _, snapshotFile := range se.snapshotFiles {
   985  		stat, err := snapshotFile.Stat()
   986  		if err != nil {
   987  			return err
   988  		}
   989  		if !stat.Mode().IsRegular() {
   990  			// should never happen
   991  			return fmt.Errorf("unexported special file %q in snapshot: %s", stat.Name(), stat.Mode())
   992  		}
   993  		if _, err := snapshotFile.Seek(0, 0); err != nil {
   994  			return fmt.Errorf("cannot seek on %v: %v", stat.Name(), err)
   995  		}
   996  		hdr, err := tar.FileInfoHeader(stat, "")
   997  		if err != nil {
   998  			return fmt.Errorf("symlink: %v", stat.Name())
   999  		}
  1000  		if err = tw.WriteHeader(hdr); err != nil {
  1001  			return fmt.Errorf("cannot write header for %v: %v", stat.Name(), err)
  1002  		}
  1003  		if _, err := io.Copy(tw, snapshotFile); err != nil {
  1004  			return fmt.Errorf("cannot write data for %v: %v", stat.Name(), err)
  1005  		}
  1006  
  1007  		files = append(files, path.Base(snapshotFile.Name()))
  1008  	}
  1009  
  1010  	// write the metadata last, then the client can use that to
  1011  	// validate the archive is complete
  1012  	meta := exportMetadata{
  1013  		Format: 1,
  1014  		Date:   timeNow(),
  1015  		Files:  files,
  1016  	}
  1017  	metaDataBuf, err := json.Marshal(&meta)
  1018  	if err != nil {
  1019  		return fmt.Errorf("cannot marshal meta-data: %v", err)
  1020  	}
  1021  	hdr = &tar.Header{
  1022  		Typeflag: tar.TypeReg,
  1023  		Name:     "export.json",
  1024  		Size:     int64(len(metaDataBuf)),
  1025  		Mode:     0640,
  1026  		ModTime:  timeNow(),
  1027  	}
  1028  	if err := tw.WriteHeader(hdr); err != nil {
  1029  		return err
  1030  	}
  1031  	if _, err := tw.Write(metaDataBuf); err != nil {
  1032  		return err
  1033  	}
  1034  
  1035  	return nil
  1036  }