github.com/Lephar/snapd@v0.0.0-20210825215435-c7fba9cef4d2/overlord/snapshotstate/backend/reader.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2018 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package backend
    21  
    22  import (
    23  	"bytes"
    24  	"context"
    25  	"crypto"
    26  	"errors"
    27  	"fmt"
    28  	"hash"
    29  	"io"
    30  	"io/ioutil"
    31  	"os"
    32  	"path/filepath"
    33  	"sort"
    34  	"syscall"
    35  
    36  	"github.com/snapcore/snapd/client"
    37  	"github.com/snapcore/snapd/jsonutil"
    38  	"github.com/snapcore/snapd/logger"
    39  	"github.com/snapcore/snapd/osutil"
    40  	"github.com/snapcore/snapd/osutil/sys"
    41  	"github.com/snapcore/snapd/snap"
    42  	"github.com/snapcore/snapd/strutil"
    43  )
    44  
    45  // ExtractFnameSetID can be passed to Open() to have set ID inferred from
    46  // snapshot filename.
    47  const ExtractFnameSetID = 0
    48  
    49  // A Reader is a snapshot that's been opened for reading.
    50  type Reader struct {
    51  	*os.File
    52  	client.Snapshot
    53  }
    54  
    55  // Open a Snapshot given its full filename.
    56  //
    57  // The returned reader will have its setID set to the value of the argument,
    58  // or inferred from the snapshot filename if ExtractFnameSetID constant is
    59  // passed.
    60  //
    61  // If the returned error is nil, the caller must close the reader (or
    62  // its file) when done with it.
    63  //
    64  // If the returned error is non-nil, the returned Reader will be nil,
    65  // *or* have a non-empty Broken; in the latter case its file will be
    66  // closed.
    67  func Open(fn string, setID uint64) (reader *Reader, e error) {
    68  	f, err := os.Open(fn)
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  	defer func() {
    73  		if e != nil && f != nil {
    74  			f.Close()
    75  		}
    76  	}()
    77  
    78  	reader = &Reader{
    79  		File: f,
    80  	}
    81  
    82  	// first try to load the metadata itself
    83  	var sz osutil.Sizer
    84  	hasher := crypto.SHA3_384.New()
    85  	metaReader, metaSize, err := zipMember(f, metadataName)
    86  	if err != nil {
    87  		// no metadata file -> nothing to do :-(
    88  		return nil, err
    89  	}
    90  
    91  	if err := jsonutil.DecodeWithNumber(io.TeeReader(metaReader, io.MultiWriter(hasher, &sz)), &reader.Snapshot); err != nil {
    92  		return nil, err
    93  	}
    94  
    95  	if setID == ExtractFnameSetID {
    96  		// set id from the filename has the authority and overrides the one from
    97  		// meta file.
    98  		var ok bool
    99  		ok, setID = isSnapshotFilename(fn)
   100  		if !ok {
   101  			return nil, fmt.Errorf("not a snapshot filename: %q", fn)
   102  		}
   103  	}
   104  
   105  	reader.SetID = setID
   106  
   107  	// OK, from here on we have a Snapshot
   108  
   109  	if !reader.IsValid() {
   110  		reader.Broken = "invalid snapshot"
   111  		return reader, errors.New(reader.Broken)
   112  	}
   113  
   114  	if sz.Size() != metaSize {
   115  		reader.Broken = fmt.Sprintf("declared metadata size (%d) does not match actual (%d)", metaSize, sz.Size())
   116  		return reader, errors.New(reader.Broken)
   117  	}
   118  
   119  	actualMetaHash := fmt.Sprintf("%x", hasher.Sum(nil))
   120  
   121  	// grab the metadata hash
   122  	sz.Reset()
   123  	metaHashReader, metaHashSize, err := zipMember(f, metaHashName)
   124  	if err != nil {
   125  		reader.Broken = err.Error()
   126  		return reader, err
   127  	}
   128  	metaHashBuf, err := ioutil.ReadAll(io.TeeReader(metaHashReader, &sz))
   129  	if err != nil {
   130  		reader.Broken = err.Error()
   131  		return reader, err
   132  	}
   133  	if sz.Size() != metaHashSize {
   134  		reader.Broken = fmt.Sprintf("declared hash size (%d) does not match actual (%d)", metaHashSize, sz.Size())
   135  		return reader, errors.New(reader.Broken)
   136  	}
   137  	if expectedMetaHash := string(bytes.TrimSpace(metaHashBuf)); actualMetaHash != expectedMetaHash {
   138  		reader.Broken = fmt.Sprintf("declared hash (%.7s…) does not match actual (%.7s…)", expectedMetaHash, actualMetaHash)
   139  		return reader, errors.New(reader.Broken)
   140  	}
   141  
   142  	return reader, nil
   143  }
   144  
   145  func (r *Reader) checkOne(ctx context.Context, entry string, hasher hash.Hash) error {
   146  	body, reportedSize, err := zipMember(r.File, entry)
   147  	if err != nil {
   148  		return err
   149  	}
   150  	defer body.Close()
   151  
   152  	expectedHash := r.SHA3_384[entry]
   153  	readSize, err := io.Copy(io.MultiWriter(osutil.ContextWriter(ctx), hasher), body)
   154  	if err != nil {
   155  		return err
   156  	}
   157  
   158  	if readSize != reportedSize {
   159  		return fmt.Errorf("snapshot entry %q size (%d) different from actual (%d)", entry, reportedSize, readSize)
   160  	}
   161  
   162  	if actualHash := fmt.Sprintf("%x", hasher.Sum(nil)); actualHash != expectedHash {
   163  		return fmt.Errorf("snapshot entry %q expected hash (%.7s…) does not match actual (%.7s…)", entry, expectedHash, actualHash)
   164  	}
   165  	return nil
   166  }
   167  
   168  // Check that the data contained in the snapshot matches its hashsums.
   169  func (r *Reader) Check(ctx context.Context, usernames []string) error {
   170  	sort.Strings(usernames)
   171  
   172  	hasher := crypto.SHA3_384.New()
   173  	for entry := range r.SHA3_384 {
   174  		if len(usernames) > 0 && isUserArchive(entry) {
   175  			username := entryUsername(entry)
   176  			if !strutil.SortedListContains(usernames, username) {
   177  				logger.Debugf("In checking snapshot %q, skipping entry %q by user request.", r.Name(), username)
   178  				continue
   179  			}
   180  		}
   181  
   182  		if err := r.checkOne(ctx, entry, hasher); err != nil {
   183  			return err
   184  		}
   185  		hasher.Reset()
   186  	}
   187  
   188  	return nil
   189  }
   190  
   191  // Logf is the type implemented by logging functions.
   192  type Logf func(format string, args ...interface{})
   193  
   194  // Restore the data from the snapshot.
   195  //
   196  // If successful this will replace the existing data (for the given revision,
   197  // or the one in the snapshot) with that contained in the snapshot. It keeps
   198  // track of the old data in the task so it can be undone (or cleaned up).
   199  func (r *Reader) Restore(ctx context.Context, current snap.Revision, usernames []string, logf Logf) (rs *RestoreState, e error) {
   200  	rs = &RestoreState{}
   201  	defer func() {
   202  		if e != nil {
   203  			logger.Noticef("Restore of snapshot %q failed (%v); undoing.", r.Name(), e)
   204  			rs.Revert()
   205  			rs = nil
   206  		}
   207  	}()
   208  
   209  	sort.Strings(usernames)
   210  	isRoot := sys.Geteuid() == 0
   211  	si := snap.MinimalPlaceInfo(r.Snap, r.Revision)
   212  	hasher := crypto.SHA3_384.New()
   213  	var sz osutil.Sizer
   214  
   215  	var curdir string
   216  	if !current.Unset() {
   217  		curdir = current.String()
   218  	}
   219  
   220  	for entry := range r.SHA3_384 {
   221  		if err := ctx.Err(); err != nil {
   222  			return rs, err
   223  		}
   224  
   225  		var dest string
   226  		isUser := isUserArchive(entry)
   227  		username := "root"
   228  		uid := sys.UserID(osutil.NoChown)
   229  		gid := sys.GroupID(osutil.NoChown)
   230  
   231  		if !isUser {
   232  			if entry != archiveName {
   233  				// hmmm
   234  				logf("Skipping restore of unknown entry %q.", entry)
   235  				continue
   236  			}
   237  			dest = si.DataDir()
   238  		} else {
   239  			username = entryUsername(entry)
   240  			if len(usernames) > 0 && !strutil.SortedListContains(usernames, username) {
   241  				logger.Debugf("In restoring snapshot %q, skipping entry %q by user request.", r.Name(), username)
   242  				continue
   243  			}
   244  			usr, err := userLookup(username)
   245  			if err != nil {
   246  				logf("Skipping restore of user %q: %v.", username, err)
   247  				continue
   248  			}
   249  
   250  			dest = si.UserDataDir(usr.HomeDir)
   251  			fi, err := os.Stat(usr.HomeDir)
   252  			if err != nil {
   253  				if osutil.IsDirNotExist(err) {
   254  					logf("Skipping restore of %q as %q doesn't exist.", dest, usr.HomeDir)
   255  				} else {
   256  					logf("Skipping restore of %q: %v.", dest, err)
   257  				}
   258  				continue
   259  			}
   260  
   261  			if !fi.IsDir() {
   262  				logf("Skipping restore of %q as %q is not a directory.", dest, usr.HomeDir)
   263  				continue
   264  			}
   265  
   266  			if st, ok := fi.Sys().(*syscall.Stat_t); ok && isRoot {
   267  				// the mkdir below will use the uid/gid of usr.HomeDir
   268  				if st.Uid > 0 {
   269  					uid = sys.UserID(st.Uid)
   270  				}
   271  				if st.Gid > 0 {
   272  					gid = sys.GroupID(st.Gid)
   273  				}
   274  			}
   275  		}
   276  		parent, revdir := filepath.Split(dest)
   277  
   278  		exists, isDir, err := osutil.DirExists(parent)
   279  		if err != nil {
   280  			return rs, err
   281  		}
   282  		if !exists {
   283  			// NOTE that the chown won't happen (it'll be NoChown)
   284  			// for the system path, and we won't be creating the
   285  			// user's home (as we skip restore in that case).
   286  			// Also no chown happens for root/root.
   287  			if err := osutil.MkdirAllChown(parent, 0755, uid, gid); err != nil {
   288  				return rs, err
   289  			}
   290  			rs.Created = append(rs.Created, parent)
   291  		} else if !isDir {
   292  			return rs, fmt.Errorf("Cannot restore snapshot into %q: not a directory.", parent)
   293  		}
   294  
   295  		// TODO: have something more atomic in osutil
   296  		tempdir, err := ioutil.TempDir(parent, ".snapshot")
   297  		if err != nil {
   298  			return rs, err
   299  		}
   300  		if err := sys.ChownPath(tempdir, uid, gid); err != nil {
   301  			return rs, err
   302  		}
   303  
   304  		// one way or another we want tempdir gone
   305  		defer func() {
   306  			if err := os.RemoveAll(tempdir); err != nil {
   307  				logf("Cannot clean up temporary directory %q: %v.", tempdir, err)
   308  			}
   309  		}()
   310  
   311  		logger.Debugf("Restoring %q from %q into %q.", entry, r.Name(), tempdir)
   312  
   313  		body, expectedSize, err := zipMember(r.File, entry)
   314  		if err != nil {
   315  			return rs, err
   316  		}
   317  
   318  		expectedHash := r.SHA3_384[entry]
   319  
   320  		tr := io.TeeReader(body, io.MultiWriter(hasher, &sz))
   321  
   322  		// resist the temptation of using archive/tar unless it's proven
   323  		// that calling out to tar has issues -- there are a lot of
   324  		// special cases we'd need to consider otherwise
   325  		cmd := tarAsUser(username,
   326  			"--extract",
   327  			"--preserve-permissions", "--preserve-order", "--gunzip",
   328  			"--directory", tempdir)
   329  		cmd.Env = []string{}
   330  		cmd.Stdin = tr
   331  		matchCounter := &strutil.MatchCounter{N: 1}
   332  		cmd.Stderr = matchCounter
   333  		cmd.Stdout = os.Stderr
   334  		if isTesting {
   335  			matchCounter.N = -1
   336  			cmd.Stderr = io.MultiWriter(os.Stderr, matchCounter)
   337  		}
   338  
   339  		if err = osutil.RunWithContext(ctx, cmd); err != nil {
   340  			matches, count := matchCounter.Matches()
   341  			if count > 0 {
   342  				return rs, fmt.Errorf("cannot unpack archive: %s (and %d more)", matches[0], count-1)
   343  			}
   344  			return rs, fmt.Errorf("tar failed: %v", err)
   345  		}
   346  
   347  		if sz.Size() != expectedSize {
   348  			return rs, fmt.Errorf("snapshot %q entry %q expected size (%d) does not match actual (%d)",
   349  				r.Name(), entry, expectedSize, sz.Size())
   350  		}
   351  
   352  		if actualHash := fmt.Sprintf("%x", hasher.Sum(nil)); actualHash != expectedHash {
   353  			return rs, fmt.Errorf("snapshot %q entry %q expected hash (%.7s…) does not match actual (%.7s…)",
   354  				r.Name(), entry, expectedHash, actualHash)
   355  		}
   356  
   357  		if curdir != "" && curdir != revdir {
   358  			// rename it in tempdir
   359  			// this is where we assume the current revision can read the snapshot revision's data
   360  			if err := os.Rename(filepath.Join(tempdir, revdir), filepath.Join(tempdir, curdir)); err != nil {
   361  				return rs, err
   362  			}
   363  			revdir = curdir
   364  		}
   365  
   366  		for _, dir := range []string{"common", revdir} {
   367  			source := filepath.Join(tempdir, dir)
   368  			if exists, _, err := osutil.DirExists(source); err != nil {
   369  				return rs, err
   370  			} else if !exists {
   371  				continue
   372  			}
   373  			target := filepath.Join(parent, dir)
   374  			exists, _, err := osutil.DirExists(target)
   375  			if err != nil {
   376  				return rs, err
   377  			}
   378  			if exists {
   379  				rsfn := restoreStateFilename(target)
   380  				if err := os.Rename(target, rsfn); err != nil {
   381  					return rs, err
   382  				}
   383  				rs.Moved = append(rs.Moved, rsfn)
   384  			}
   385  
   386  			if err := os.Rename(source, target); err != nil {
   387  				return rs, err
   388  			}
   389  			rs.Created = append(rs.Created, target)
   390  		}
   391  
   392  		sz.Reset()
   393  		hasher.Reset()
   394  	}
   395  
   396  	return rs, nil
   397  }