github.com/hugh712/snapd@v0.0.0-20200910133618-1a99902bd583/overlord/snapshotstate/backend/reader.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2018 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package backend
    21  
    22  import (
    23  	"bytes"
    24  	"context"
    25  	"crypto"
    26  	"errors"
    27  	"fmt"
    28  	"hash"
    29  	"io"
    30  	"io/ioutil"
    31  	"os"
    32  	"path/filepath"
    33  	"sort"
    34  	"syscall"
    35  
    36  	"github.com/snapcore/snapd/client"
    37  	"github.com/snapcore/snapd/jsonutil"
    38  	"github.com/snapcore/snapd/logger"
    39  	"github.com/snapcore/snapd/osutil"
    40  	"github.com/snapcore/snapd/osutil/sys"
    41  	"github.com/snapcore/snapd/snap"
    42  	"github.com/snapcore/snapd/strutil"
    43  )
    44  
    45  // A Reader is a snapshot that's been opened for reading.
    46  type Reader struct {
    47  	*os.File
    48  	client.Snapshot
    49  }
    50  
    51  // Open a Snapshot given its full filename.
    52  //
    53  // If the returned error is nil, the caller must close the reader (or
    54  // its file) when done with it.
    55  //
    56  // If the returned error is non-nil, the returned Reader will be nil,
    57  // *or* have a non-empty Broken; in the latter case its file will be
    58  // closed.
    59  func Open(fn string) (reader *Reader, e error) {
    60  	f, err := os.Open(fn)
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  	defer func() {
    65  		if e != nil && f != nil {
    66  			f.Close()
    67  		}
    68  	}()
    69  
    70  	reader = &Reader{
    71  		File: f,
    72  	}
    73  
    74  	// first try to load the metadata itself
    75  	var sz osutil.Sizer
    76  	hasher := crypto.SHA3_384.New()
    77  	metaReader, metaSize, err := zipMember(f, metadataName)
    78  	if err != nil {
    79  		// no metadata file -> nothing to do :-(
    80  		return nil, err
    81  	}
    82  
    83  	if err := jsonutil.DecodeWithNumber(io.TeeReader(metaReader, io.MultiWriter(hasher, &sz)), &reader.Snapshot); err != nil {
    84  		return nil, err
    85  	}
    86  
    87  	// OK, from here on we have a Snapshot
    88  
    89  	if !reader.IsValid() {
    90  		reader.Broken = "invalid snapshot"
    91  		return reader, errors.New(reader.Broken)
    92  	}
    93  
    94  	if sz.Size() != metaSize {
    95  		reader.Broken = fmt.Sprintf("declared metadata size (%d) does not match actual (%d)", metaSize, sz.Size())
    96  		return reader, errors.New(reader.Broken)
    97  	}
    98  
    99  	actualMetaHash := fmt.Sprintf("%x", hasher.Sum(nil))
   100  
   101  	// grab the metadata hash
   102  	sz.Reset()
   103  	metaHashReader, metaHashSize, err := zipMember(f, metaHashName)
   104  	if err != nil {
   105  		reader.Broken = err.Error()
   106  		return reader, err
   107  	}
   108  	metaHashBuf, err := ioutil.ReadAll(io.TeeReader(metaHashReader, &sz))
   109  	if err != nil {
   110  		reader.Broken = err.Error()
   111  		return reader, err
   112  	}
   113  	if sz.Size() != metaHashSize {
   114  		reader.Broken = fmt.Sprintf("declared hash size (%d) does not match actual (%d)", metaHashSize, sz.Size())
   115  		return reader, errors.New(reader.Broken)
   116  	}
   117  	if expectedMetaHash := string(bytes.TrimSpace(metaHashBuf)); actualMetaHash != expectedMetaHash {
   118  		reader.Broken = fmt.Sprintf("declared hash (%.7s…) does not match actual (%.7s…)", expectedMetaHash, actualMetaHash)
   119  		return reader, errors.New(reader.Broken)
   120  	}
   121  
   122  	return reader, nil
   123  }
   124  
   125  func (r *Reader) checkOne(ctx context.Context, entry string, hasher hash.Hash) error {
   126  	body, reportedSize, err := zipMember(r.File, entry)
   127  	if err != nil {
   128  		return err
   129  	}
   130  	defer body.Close()
   131  
   132  	expectedHash := r.SHA3_384[entry]
   133  	readSize, err := io.Copy(io.MultiWriter(osutil.ContextWriter(ctx), hasher), body)
   134  	if err != nil {
   135  		return err
   136  	}
   137  
   138  	if readSize != reportedSize {
   139  		return fmt.Errorf("snapshot entry %q size (%d) different from actual (%d)", entry, reportedSize, readSize)
   140  	}
   141  
   142  	if actualHash := fmt.Sprintf("%x", hasher.Sum(nil)); actualHash != expectedHash {
   143  		return fmt.Errorf("snapshot entry %q expected hash (%.7s…) does not match actual (%.7s…)", entry, expectedHash, actualHash)
   144  	}
   145  	return nil
   146  }
   147  
   148  // Check that the data contained in the snapshot matches its hashsums.
   149  func (r *Reader) Check(ctx context.Context, usernames []string) error {
   150  	sort.Strings(usernames)
   151  
   152  	hasher := crypto.SHA3_384.New()
   153  	for entry := range r.SHA3_384 {
   154  		if len(usernames) > 0 && isUserArchive(entry) {
   155  			username := entryUsername(entry)
   156  			if !strutil.SortedListContains(usernames, username) {
   157  				logger.Debugf("In checking snapshot %q, skipping entry %q by user request.", r.Name(), username)
   158  				continue
   159  			}
   160  		}
   161  
   162  		if err := r.checkOne(ctx, entry, hasher); err != nil {
   163  			return err
   164  		}
   165  		hasher.Reset()
   166  	}
   167  
   168  	return nil
   169  }
   170  
   171  // Logf is the type implemented by logging functions.
   172  type Logf func(format string, args ...interface{})
   173  
   174  // Restore the data from the snapshot.
   175  //
   176  // If successful this will replace the existing data (for the given revision,
   177  // or the one in the snapshot) with that contained in the snapshot. It keeps
   178  // track of the old data in the task so it can be undone (or cleaned up).
   179  func (r *Reader) Restore(ctx context.Context, current snap.Revision, usernames []string, logf Logf) (rs *RestoreState, e error) {
   180  	rs = &RestoreState{}
   181  	defer func() {
   182  		if e != nil {
   183  			logger.Noticef("Restore of snapshot %q failed (%v); undoing.", r.Name(), e)
   184  			rs.Revert()
   185  			rs = nil
   186  		}
   187  	}()
   188  
   189  	sort.Strings(usernames)
   190  	isRoot := sys.Geteuid() == 0
   191  	si := snap.MinimalPlaceInfo(r.Snap, r.Revision)
   192  	hasher := crypto.SHA3_384.New()
   193  	var sz osutil.Sizer
   194  
   195  	var curdir string
   196  	if !current.Unset() {
   197  		curdir = current.String()
   198  	}
   199  
   200  	for entry := range r.SHA3_384 {
   201  		if err := ctx.Err(); err != nil {
   202  			return rs, err
   203  		}
   204  
   205  		var dest string
   206  		isUser := isUserArchive(entry)
   207  		username := "root"
   208  		uid := sys.UserID(osutil.NoChown)
   209  		gid := sys.GroupID(osutil.NoChown)
   210  
   211  		if !isUser {
   212  			if entry != archiveName {
   213  				// hmmm
   214  				logf("Skipping restore of unknown entry %q.", entry)
   215  				continue
   216  			}
   217  			dest = si.DataDir()
   218  		} else {
   219  			username = entryUsername(entry)
   220  			if len(usernames) > 0 && !strutil.SortedListContains(usernames, username) {
   221  				logger.Debugf("In restoring snapshot %q, skipping entry %q by user request.", r.Name(), username)
   222  				continue
   223  			}
   224  			usr, err := userLookup(username)
   225  			if err != nil {
   226  				logf("Skipping restore of user %q: %v.", username, err)
   227  				continue
   228  			}
   229  
   230  			dest = si.UserDataDir(usr.HomeDir)
   231  			fi, err := os.Stat(usr.HomeDir)
   232  			if err != nil {
   233  				if osutil.IsDirNotExist(err) {
   234  					logf("Skipping restore of %q as %q doesn't exist.", dest, usr.HomeDir)
   235  				} else {
   236  					logf("Skipping restore of %q: %v.", dest, err)
   237  				}
   238  				continue
   239  			}
   240  
   241  			if !fi.IsDir() {
   242  				logf("Skipping restore of %q as %q is not a directory.", dest, usr.HomeDir)
   243  				continue
   244  			}
   245  
   246  			if st, ok := fi.Sys().(*syscall.Stat_t); ok && isRoot {
   247  				// the mkdir below will use the uid/gid of usr.HomeDir
   248  				if st.Uid > 0 {
   249  					uid = sys.UserID(st.Uid)
   250  				}
   251  				if st.Gid > 0 {
   252  					gid = sys.GroupID(st.Gid)
   253  				}
   254  			}
   255  		}
   256  		parent, revdir := filepath.Split(dest)
   257  
   258  		exists, isDir, err := osutil.DirExists(parent)
   259  		if err != nil {
   260  			return rs, err
   261  		}
   262  		if !exists {
   263  			// NOTE that the chown won't happen (it'll be NoChown)
   264  			// for the system path, and we won't be creating the
   265  			// user's home (as we skip restore in that case).
   266  			// Also no chown happens for root/root.
   267  			if err := osutil.MkdirAllChown(parent, 0755, uid, gid); err != nil {
   268  				return rs, err
   269  			}
   270  			rs.Created = append(rs.Created, parent)
   271  		} else if !isDir {
   272  			return rs, fmt.Errorf("Cannot restore snapshot into %q: not a directory.", parent)
   273  		}
   274  
   275  		// TODO: have something more atomic in osutil
   276  		tempdir, err := ioutil.TempDir(parent, ".snapshot")
   277  		if err != nil {
   278  			return rs, err
   279  		}
   280  		if err := sys.ChownPath(tempdir, uid, gid); err != nil {
   281  			return rs, err
   282  		}
   283  
   284  		// one way or another we want tempdir gone
   285  		defer func() {
   286  			if err := os.RemoveAll(tempdir); err != nil {
   287  				logf("Cannot clean up temporary directory %q: %v.", tempdir, err)
   288  			}
   289  		}()
   290  
   291  		logger.Debugf("Restoring %q from %q into %q.", entry, r.Name(), tempdir)
   292  
   293  		body, expectedSize, err := zipMember(r.File, entry)
   294  		if err != nil {
   295  			return rs, err
   296  		}
   297  
   298  		expectedHash := r.SHA3_384[entry]
   299  
   300  		tr := io.TeeReader(body, io.MultiWriter(hasher, &sz))
   301  
   302  		// resist the temptation of using archive/tar unless it's proven
   303  		// that calling out to tar has issues -- there are a lot of
   304  		// special cases we'd need to consider otherwise
   305  		cmd := tarAsUser(username,
   306  			"--extract",
   307  			"--preserve-permissions", "--preserve-order", "--gunzip",
   308  			"--directory", tempdir)
   309  		cmd.Env = []string{}
   310  		cmd.Stdin = tr
   311  		matchCounter := &strutil.MatchCounter{N: 1}
   312  		cmd.Stderr = matchCounter
   313  		cmd.Stdout = os.Stderr
   314  		if isTesting {
   315  			matchCounter.N = -1
   316  			cmd.Stderr = io.MultiWriter(os.Stderr, matchCounter)
   317  		}
   318  
   319  		if err = osutil.RunWithContext(ctx, cmd); err != nil {
   320  			matches, count := matchCounter.Matches()
   321  			if count > 0 {
   322  				return rs, fmt.Errorf("cannot unpack archive: %s (and %d more)", matches[0], count-1)
   323  			}
   324  			return rs, fmt.Errorf("tar failed: %v", err)
   325  		}
   326  
   327  		if sz.Size() != expectedSize {
   328  			return rs, fmt.Errorf("snapshot %q entry %q expected size (%d) does not match actual (%d)",
   329  				r.Name(), entry, expectedSize, sz.Size())
   330  		}
   331  
   332  		if actualHash := fmt.Sprintf("%x", hasher.Sum(nil)); actualHash != expectedHash {
   333  			return rs, fmt.Errorf("snapshot %q entry %q expected hash (%.7s…) does not match actual (%.7s…)",
   334  				r.Name(), entry, expectedHash, actualHash)
   335  		}
   336  
   337  		if curdir != "" && curdir != revdir {
   338  			// rename it in tempdir
   339  			// this is where we assume the current revision can read the snapshot revision's data
   340  			if err := os.Rename(filepath.Join(tempdir, revdir), filepath.Join(tempdir, curdir)); err != nil {
   341  				return rs, err
   342  			}
   343  			revdir = curdir
   344  		}
   345  
   346  		for _, dir := range []string{"common", revdir} {
   347  			source := filepath.Join(tempdir, dir)
   348  			if exists, _, err := osutil.DirExists(source); err != nil {
   349  				return rs, err
   350  			} else if !exists {
   351  				continue
   352  			}
   353  			target := filepath.Join(parent, dir)
   354  			exists, _, err := osutil.DirExists(target)
   355  			if err != nil {
   356  				return rs, err
   357  			}
   358  			if exists {
   359  				rsfn := restoreStateFilename(target)
   360  				if err := os.Rename(target, rsfn); err != nil {
   361  					return rs, err
   362  				}
   363  				rs.Moved = append(rs.Moved, rsfn)
   364  			}
   365  
   366  			if err := os.Rename(source, target); err != nil {
   367  				return rs, err
   368  			}
   369  			rs.Created = append(rs.Created, target)
   370  		}
   371  
   372  		sz.Reset()
   373  		hasher.Reset()
   374  	}
   375  
   376  	return rs, nil
   377  }