github.com/hugh712/snapd@v0.0.0-20200910133618-1a99902bd583/overlord/snapshotstate/snapshotmgr_test.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2018 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package snapshotstate_test
    21  
    22  import (
    23  	"context"
    24  	"encoding/json"
    25  	"errors"
    26  	"os"
    27  	"path/filepath"
    28  	"sort"
    29  	"time"
    30  
    31  	"gopkg.in/check.v1"
    32  	"gopkg.in/tomb.v2"
    33  
    34  	"github.com/snapcore/snapd/client"
    35  	"github.com/snapcore/snapd/dirs"
    36  	"github.com/snapcore/snapd/overlord/snapshotstate"
    37  	"github.com/snapcore/snapd/overlord/snapshotstate/backend"
    38  	"github.com/snapcore/snapd/overlord/state"
    39  	"github.com/snapcore/snapd/snap"
    40  )
    41  
    42  func (snapshotSuite) TestManager(c *check.C) {
    43  	st := state.New(nil)
    44  	st.Lock()
    45  	defer st.Unlock()
    46  	runner := state.NewTaskRunner(st)
    47  	mgr := snapshotstate.Manager(st, runner)
    48  	c.Assert(mgr, check.NotNil)
    49  	kinds := runner.KnownTaskKinds()
    50  	sort.Strings(kinds)
    51  	c.Check(kinds, check.DeepEquals, []string{
    52  		"check-snapshot",
    53  		"forget-snapshot",
    54  		"restore-snapshot",
    55  		"save-snapshot",
    56  	})
    57  }
    58  
    59  func mockDummySnapshot(c *check.C) (restore func()) {
    60  	shotfile, err := os.Create(filepath.Join(c.MkDir(), "foo.zip"))
    61  	c.Assert(err, check.IsNil)
    62  
    63  	fakeIter := func(_ context.Context, f func(*backend.Reader) error) error {
    64  		c.Assert(f(&backend.Reader{
    65  			Snapshot: client.Snapshot{SetID: 1, Snap: "a-snap", SnapID: "a-id", Epoch: snap.Epoch{Read: []uint32{42}, Write: []uint32{17}}},
    66  			File:     shotfile,
    67  		}), check.IsNil)
    68  		return nil
    69  	}
    70  
    71  	restoreBackendIter := snapshotstate.MockBackendIter(fakeIter)
    72  
    73  	return func() {
    74  		shotfile.Close()
    75  		restoreBackendIter()
    76  	}
    77  }
    78  
    79  func (snapshotSuite) TestEnsureForgetsSnapshots(c *check.C) {
    80  	var removedSnapshot string
    81  	restoreOsRemove := snapshotstate.MockOsRemove(func(fileName string) error {
    82  		removedSnapshot = fileName
    83  		return nil
    84  	})
    85  	defer restoreOsRemove()
    86  
    87  	restore := mockDummySnapshot(c)
    88  	defer restore()
    89  
    90  	st := state.New(nil)
    91  	runner := state.NewTaskRunner(st)
    92  	mgr := snapshotstate.Manager(st, runner)
    93  	c.Assert(mgr, check.NotNil)
    94  
    95  	st.Lock()
    96  	defer st.Unlock()
    97  
    98  	st.Set("snapshots", map[uint64]interface{}{
    99  		1: map[string]interface{}{"expiry-time": "2001-03-11T11:24:00Z"},
   100  		2: map[string]interface{}{"expiry-time": "2037-02-12T12:50:00Z"},
   101  	})
   102  
   103  	st.Unlock()
   104  	c.Assert(mgr.Ensure(), check.IsNil)
   105  	st.Lock()
   106  
   107  	// verify expired snapshots were removed
   108  	var expirations map[uint64]interface{}
   109  	c.Assert(st.Get("snapshots", &expirations), check.IsNil)
   110  	c.Check(expirations, check.DeepEquals, map[uint64]interface{}{
   111  		2: map[string]interface{}{"expiry-time": "2037-02-12T12:50:00Z"}})
   112  	c.Check(removedSnapshot, check.Matches, ".*/foo.zip")
   113  }
   114  
   115  func (snapshotSuite) TestEnsureForgetsSnapshotsRunsRegularly(c *check.C) {
   116  	var backendIterCalls int
   117  	shotfile, err := os.Create(filepath.Join(c.MkDir(), "foo.zip"))
   118  	c.Assert(err, check.IsNil)
   119  	fakeIter := func(_ context.Context, f func(*backend.Reader) error) error {
   120  		c.Assert(f(&backend.Reader{
   121  			Snapshot: client.Snapshot{SetID: 1, Snap: "a-snap", SnapID: "a-id", Epoch: snap.Epoch{Read: []uint32{42}, Write: []uint32{17}}},
   122  			File:     shotfile,
   123  		}), check.IsNil)
   124  		backendIterCalls++
   125  		return nil
   126  	}
   127  	restoreBackendIter := snapshotstate.MockBackendIter(fakeIter)
   128  	defer restoreBackendIter()
   129  
   130  	restoreOsRemove := snapshotstate.MockOsRemove(func(fileName string) error {
   131  		return nil
   132  	})
   133  	defer restoreOsRemove()
   134  
   135  	st := state.New(nil)
   136  	runner := state.NewTaskRunner(st)
   137  	mgr := snapshotstate.Manager(st, runner)
   138  	c.Assert(mgr, check.NotNil)
   139  
   140  	storeExpiredSnapshot := func() {
   141  		st.Lock()
   142  		// we need at least one snapshot set in the state for forgetExpiredSnapshots to do any work
   143  		st.Set("snapshots", map[uint64]interface{}{
   144  			1: map[string]interface{}{"expiry-time": "2001-03-11T11:24:00Z"},
   145  		})
   146  		st.Unlock()
   147  	}
   148  
   149  	// consecutive runs of Ensure call the backend just once because of the snapshotExpirationLoopInterval
   150  	for i := 0; i < 3; i++ {
   151  		storeExpiredSnapshot()
   152  		c.Assert(mgr.Ensure(), check.IsNil)
   153  		c.Check(backendIterCalls, check.Equals, 1)
   154  	}
   155  
   156  	// pretend we haven't run for a while
   157  	t, err := time.Parse(time.RFC3339, "2002-03-11T11:24:00Z")
   158  	c.Assert(err, check.IsNil)
   159  	mgr.SetLastForgetExpiredSnapshotTime(t)
   160  	c.Assert(mgr.Ensure(), check.IsNil)
   161  	c.Check(backendIterCalls, check.Equals, 2)
   162  
   163  	c.Assert(mgr.Ensure(), check.IsNil)
   164  	c.Check(backendIterCalls, check.Equals, 2)
   165  }
   166  
   167  func (snapshotSuite) testEnsureForgetSnapshotsConflict(c *check.C, snapshotTaskKind string) {
   168  	removeCalled := 0
   169  	restoreOsRemove := snapshotstate.MockOsRemove(func(string) error {
   170  		removeCalled++
   171  		return nil
   172  	})
   173  	defer restoreOsRemove()
   174  
   175  	restore := mockDummySnapshot(c)
   176  	defer restore()
   177  
   178  	st := state.New(nil)
   179  	runner := state.NewTaskRunner(st)
   180  	mgr := snapshotstate.Manager(st, runner)
   181  	c.Assert(mgr, check.NotNil)
   182  
   183  	st.Lock()
   184  	defer st.Unlock()
   185  
   186  	st.Set("snapshots", map[uint64]interface{}{
   187  		1: map[string]interface{}{"expiry-time": "2001-03-11T11:24:00Z"},
   188  	})
   189  
   190  	chg := st.NewChange("snapshot-change", "...")
   191  	tsk := st.NewTask(snapshotTaskKind, "...")
   192  	tsk.SetStatus(state.DoingStatus)
   193  	tsk.Set("snapshot-setup", map[string]int{"set-id": 1})
   194  	chg.AddTask(tsk)
   195  
   196  	st.Unlock()
   197  	c.Assert(mgr.Ensure(), check.IsNil)
   198  	st.Lock()
   199  
   200  	var expirations map[uint64]interface{}
   201  	c.Assert(st.Get("snapshots", &expirations), check.IsNil)
   202  	c.Check(expirations, check.DeepEquals, map[uint64]interface{}{
   203  		1: map[string]interface{}{"expiry-time": "2001-03-11T11:24:00Z"},
   204  	})
   205  	c.Check(removeCalled, check.Equals, 0)
   206  
   207  	// sanity check of the test setup: snapshot gets removed once conflict goes away
   208  	tsk.SetStatus(state.DoneStatus)
   209  
   210  	// pretend we haven't run for a while
   211  	t, err := time.Parse(time.RFC3339, "2002-03-11T11:24:00Z")
   212  	c.Assert(err, check.IsNil)
   213  	mgr.SetLastForgetExpiredSnapshotTime(t)
   214  
   215  	st.Unlock()
   216  	c.Assert(mgr.Ensure(), check.IsNil)
   217  	st.Lock()
   218  
   219  	expirations = nil
   220  	c.Assert(st.Get("snapshots", &expirations), check.IsNil)
   221  	c.Check(removeCalled, check.Equals, 1)
   222  	c.Check(expirations, check.HasLen, 0)
   223  }
   224  
   225  func (s *snapshotSuite) TestEnsureForgetSnapshotsConflictWithCheckSnapshot(c *check.C) {
   226  	s.testEnsureForgetSnapshotsConflict(c, "check-snapshot")
   227  }
   228  
   229  func (s *snapshotSuite) TestEnsureForgetSnapshotsConflictWithRestoreSnapshot(c *check.C) {
   230  	s.testEnsureForgetSnapshotsConflict(c, "restore-snapshot")
   231  }
   232  
   233  func (snapshotSuite) TestFilename(c *check.C) {
   234  	si := &snap.Info{
   235  		SideInfo: snap.SideInfo{
   236  			RealName: "a-snap",
   237  			Revision: snap.R(-1),
   238  		},
   239  		Version: "1.33",
   240  	}
   241  	filename := snapshotstate.Filename(42, si)
   242  	c.Check(filepath.Dir(filename), check.Equals, dirs.SnapshotsDir)
   243  	c.Check(filepath.Base(filename), check.Equals, "42_a-snap_1.33_x1.zip")
   244  }
   245  
   246  func (snapshotSuite) TestDoSave(c *check.C) {
   247  	snapInfo := snap.Info{
   248  		SideInfo: snap.SideInfo{
   249  			RealName: "a-snap",
   250  			Revision: snap.R(-1),
   251  		},
   252  		Version: "1.33",
   253  	}
   254  	defer snapshotstate.MockSnapstateCurrentInfo(func(_ *state.State, snapname string) (*snap.Info, error) {
   255  		c.Check(snapname, check.Equals, "a-snap")
   256  		return &snapInfo, nil
   257  	})()
   258  	defer snapshotstate.MockConfigGetSnapConfig(func(_ *state.State, snapname string) (*json.RawMessage, error) {
   259  		c.Check(snapname, check.Equals, "a-snap")
   260  		buf := json.RawMessage(`{"hello": "there"}`)
   261  		return &buf, nil
   262  	})()
   263  	defer snapshotstate.MockBackendSave(func(_ context.Context, id uint64, si *snap.Info, cfg map[string]interface{}, usernames []string, flags *backend.Flags) (*client.Snapshot, error) {
   264  		c.Check(id, check.Equals, uint64(42))
   265  		c.Check(si, check.DeepEquals, &snapInfo)
   266  		c.Check(cfg, check.DeepEquals, map[string]interface{}{"hello": "there"})
   267  		c.Check(usernames, check.DeepEquals, []string{"a-user", "b-user"})
   268  		c.Check(flags.Auto, check.Equals, false)
   269  		return nil, nil
   270  	})()
   271  
   272  	st := state.New(nil)
   273  	st.Lock()
   274  	task := st.NewTask("save-snapshot", "...")
   275  	task.Set("snapshot-setup", map[string]interface{}{
   276  		"set-id": 42,
   277  		"snap":   "a-snap",
   278  		"users":  []string{"a-user", "b-user"},
   279  	})
   280  	st.Unlock()
   281  	err := snapshotstate.DoSave(task, &tomb.Tomb{})
   282  	c.Assert(err, check.IsNil)
   283  }
   284  
   285  func (snapshotSuite) TestDoSaveFailsWithNoSnap(c *check.C) {
   286  	defer snapshotstate.MockSnapstateCurrentInfo(func(*state.State, string) (*snap.Info, error) {
   287  		return nil, errors.New("bzzt")
   288  	})()
   289  	defer snapshotstate.MockConfigGetSnapConfig(func(*state.State, string) (*json.RawMessage, error) { return nil, nil })()
   290  	defer snapshotstate.MockBackendSave(func(_ context.Context, id uint64, si *snap.Info, cfg map[string]interface{}, usernames []string, flags *backend.Flags) (*client.Snapshot, error) {
   291  		return nil, nil
   292  	})()
   293  
   294  	st := state.New(nil)
   295  	st.Lock()
   296  	task := st.NewTask("save-snapshot", "...")
   297  	task.Set("snapshot-setup", map[string]interface{}{
   298  		"set-id": 42,
   299  		"snap":   "a-snap",
   300  		"users":  []string{"a-user", "b-user"},
   301  	})
   302  	st.Unlock()
   303  	err := snapshotstate.DoSave(task, &tomb.Tomb{})
   304  	c.Assert(err, check.ErrorMatches, "bzzt")
   305  }
   306  
   307  func (snapshotSuite) TestDoSaveFailsWithNoSnapshot(c *check.C) {
   308  	snapInfo := snap.Info{
   309  		SideInfo: snap.SideInfo{
   310  			RealName: "a-snap",
   311  			Revision: snap.R(-1),
   312  		},
   313  		Version: "1.33",
   314  	}
   315  	defer snapshotstate.MockSnapstateCurrentInfo(func(*state.State, string) (*snap.Info, error) { return &snapInfo, nil })()
   316  	defer snapshotstate.MockConfigGetSnapConfig(func(*state.State, string) (*json.RawMessage, error) { return nil, nil })()
   317  	defer snapshotstate.MockBackendSave(func(_ context.Context, id uint64, si *snap.Info, cfg map[string]interface{}, usernames []string, flags *backend.Flags) (*client.Snapshot, error) {
   318  		return nil, nil
   319  	})()
   320  
   321  	st := state.New(nil)
   322  	st.Lock()
   323  	task := st.NewTask("save-snapshot", "...")
   324  	// NOTE no task.Set("snapshot-setup", ...)
   325  	st.Unlock()
   326  	err := snapshotstate.DoSave(task, &tomb.Tomb{})
   327  	c.Assert(err, check.NotNil)
   328  	c.Assert(err.Error(), check.Equals, "internal error: task 1 (save-snapshot) is missing snapshot information")
   329  }
   330  
   331  func (snapshotSuite) TestDoSaveFailsBackendError(c *check.C) {
   332  	snapInfo := snap.Info{
   333  		SideInfo: snap.SideInfo{
   334  			RealName: "a-snap",
   335  			Revision: snap.R(-1),
   336  		},
   337  		Version: "1.33",
   338  	}
   339  	defer snapshotstate.MockSnapstateCurrentInfo(func(*state.State, string) (*snap.Info, error) { return &snapInfo, nil })()
   340  	defer snapshotstate.MockConfigGetSnapConfig(func(*state.State, string) (*json.RawMessage, error) { return nil, nil })()
   341  	defer snapshotstate.MockBackendSave(func(_ context.Context, id uint64, si *snap.Info, cfg map[string]interface{}, usernames []string, flags *backend.Flags) (*client.Snapshot, error) {
   342  		return nil, errors.New("bzzt")
   343  	})()
   344  
   345  	st := state.New(nil)
   346  	st.Lock()
   347  	task := st.NewTask("save-snapshot", "...")
   348  	task.Set("snapshot-setup", map[string]interface{}{
   349  		"set-id": 42,
   350  		"snap":   "a-snap",
   351  		"users":  []string{"a-user", "b-user"},
   352  	})
   353  	st.Unlock()
   354  	err := snapshotstate.DoSave(task, &tomb.Tomb{})
   355  	c.Assert(err, check.ErrorMatches, "bzzt")
   356  }
   357  
   358  func (snapshotSuite) TestDoSaveFailsConfigError(c *check.C) {
   359  	snapInfo := snap.Info{
   360  		SideInfo: snap.SideInfo{
   361  			RealName: "a-snap",
   362  			Revision: snap.R(-1),
   363  		},
   364  		Version: "1.33",
   365  	}
   366  	defer snapshotstate.MockSnapstateCurrentInfo(func(*state.State, string) (*snap.Info, error) { return &snapInfo, nil })()
   367  	defer snapshotstate.MockConfigGetSnapConfig(func(*state.State, string) (*json.RawMessage, error) {
   368  		return nil, errors.New("bzzt")
   369  	})()
   370  	defer snapshotstate.MockBackendSave(func(_ context.Context, id uint64, si *snap.Info, cfg map[string]interface{}, usernames []string, flags *backend.Flags) (*client.Snapshot, error) {
   371  		return nil, nil
   372  	})()
   373  
   374  	st := state.New(nil)
   375  	st.Lock()
   376  	task := st.NewTask("save-snapshot", "...")
   377  	task.Set("snapshot-setup", map[string]interface{}{
   378  		"set-id": 42,
   379  		"snap":   "a-snap",
   380  		"users":  []string{"a-user", "b-user"},
   381  	})
   382  	st.Unlock()
   383  	err := snapshotstate.DoSave(task, &tomb.Tomb{})
   384  	c.Assert(err, check.ErrorMatches, "bzzt")
   385  }
   386  
   387  func (snapshotSuite) TestDoSaveFailsBadConfig(c *check.C) {
   388  	snapInfo := snap.Info{
   389  		SideInfo: snap.SideInfo{
   390  			RealName: "a-snap",
   391  			Revision: snap.R(-1),
   392  		},
   393  		Version: "1.33",
   394  	}
   395  	defer snapshotstate.MockSnapstateCurrentInfo(func(*state.State, string) (*snap.Info, error) { return &snapInfo, nil })()
   396  	defer snapshotstate.MockConfigGetSnapConfig(func(*state.State, string) (*json.RawMessage, error) {
   397  		// returns something that's not a JSON object
   398  		buf := json.RawMessage(`"hello-there"`)
   399  		return &buf, nil
   400  	})()
   401  	defer snapshotstate.MockBackendSave(func(_ context.Context, id uint64, si *snap.Info, cfg map[string]interface{}, usernames []string, flags *backend.Flags) (*client.Snapshot, error) {
   402  		return nil, nil
   403  	})()
   404  
   405  	st := state.New(nil)
   406  	st.Lock()
   407  	task := st.NewTask("save-snapshot", "...")
   408  	task.Set("snapshot-setup", map[string]interface{}{
   409  		"set-id": 42,
   410  		"snap":   "a-snap",
   411  		"users":  []string{"a-user", "b-user"},
   412  	})
   413  	st.Unlock()
   414  	err := snapshotstate.DoSave(task, &tomb.Tomb{})
   415  	c.Assert(err, check.ErrorMatches, ".* cannot unmarshal .*")
   416  }
   417  
   418  func (snapshotSuite) TestDoSaveFailureRemovesStateEntry(c *check.C) {
   419  	st := state.New(nil)
   420  
   421  	snapInfo := snap.Info{
   422  		SideInfo: snap.SideInfo{
   423  			RealName: "a-snap",
   424  			Revision: snap.R(-1),
   425  		},
   426  		Version: "1.33",
   427  	}
   428  	defer snapshotstate.MockSnapstateCurrentInfo(func(_ *state.State, snapname string) (*snap.Info, error) {
   429  		return &snapInfo, nil
   430  	})()
   431  	defer snapshotstate.MockConfigGetSnapConfig(func(_ *state.State, snapname string) (*json.RawMessage, error) {
   432  		return nil, nil
   433  	})()
   434  	defer snapshotstate.MockBackendSave(func(_ context.Context, id uint64, si *snap.Info, cfg map[string]interface{}, usernames []string, flags *backend.Flags) (*client.Snapshot, error) {
   435  		var expirations map[uint64]interface{}
   436  		st.Lock()
   437  		defer st.Unlock()
   438  		// verify that prepareSave stored expiration in the state
   439  		c.Assert(st.Get("snapshots", &expirations), check.IsNil)
   440  		c.Assert(expirations, check.HasLen, 1)
   441  		c.Check(expirations[42], check.NotNil)
   442  		return nil, errors.New("error")
   443  	})()
   444  
   445  	st.Lock()
   446  
   447  	task := st.NewTask("save-snapshot", "...")
   448  	task.Set("snapshot-setup", map[string]interface{}{
   449  		"set-id": 42,
   450  		"snap":   "a-snap",
   451  		"auto":   true,
   452  	})
   453  	st.Unlock()
   454  	err := snapshotstate.DoSave(task, &tomb.Tomb{})
   455  	c.Assert(err, check.ErrorMatches, "error")
   456  
   457  	st.Lock()
   458  	defer st.Unlock()
   459  
   460  	// verify that after backend.Save failure expiration was removed from the state
   461  	var expirations map[uint64]interface{}
   462  	c.Assert(st.Get("snapshots", &expirations), check.IsNil)
   463  	c.Check(expirations, check.HasLen, 0)
   464  }
   465  
   466  type readerSuite struct {
   467  	task     *state.Task
   468  	calls    []string
   469  	restores []func()
   470  }
   471  
   472  var _ = check.Suite(&readerSuite{})
   473  
   474  func (rs *readerSuite) SetUpTest(c *check.C) {
   475  	st := state.New(nil)
   476  	st.Lock()
   477  	rs.task = st.NewTask("restore-snapshot", "...")
   478  	rs.task.Set("snapshot-setup", map[string]interface{}{
   479  		// interestingly restore doesn't use the set-id
   480  		"snap":     "a-snap",
   481  		"filename": "/some/file.zip",
   482  		"users":    []string{"a-user", "b-user"},
   483  	})
   484  	st.Unlock()
   485  
   486  	rs.calls = nil
   487  	rs.restores = []func(){
   488  		snapshotstate.MockOsRemove(func(string) error {
   489  			rs.calls = append(rs.calls, "remove")
   490  			return nil
   491  		}),
   492  		snapshotstate.MockConfigGetSnapConfig(func(*state.State, string) (*json.RawMessage, error) {
   493  			rs.calls = append(rs.calls, "get config")
   494  			return nil, nil
   495  		}),
   496  		snapshotstate.MockConfigSetSnapConfig(func(*state.State, string, *json.RawMessage) error {
   497  			rs.calls = append(rs.calls, "set config")
   498  			return nil
   499  		}),
   500  		snapshotstate.MockBackendOpen(func(string) (*backend.Reader, error) {
   501  			rs.calls = append(rs.calls, "open")
   502  			return &backend.Reader{}, nil
   503  		}),
   504  		snapshotstate.MockBackendRestore(func(*backend.Reader, context.Context, snap.Revision, []string, backend.Logf) (*backend.RestoreState, error) {
   505  			rs.calls = append(rs.calls, "restore")
   506  			return &backend.RestoreState{}, nil
   507  		}),
   508  		snapshotstate.MockBackendCheck(func(*backend.Reader, context.Context, []string) error {
   509  			rs.calls = append(rs.calls, "check")
   510  			return nil
   511  		}),
   512  		snapshotstate.MockBackendRevert(func(*backend.RestoreState) {
   513  			rs.calls = append(rs.calls, "revert")
   514  		}),
   515  		snapshotstate.MockBackendCleanup(func(*backend.RestoreState) {
   516  			rs.calls = append(rs.calls, "cleanup")
   517  		}),
   518  	}
   519  }
   520  
   521  func (rs *readerSuite) TearDownTest(c *check.C) {
   522  	for _, restore := range rs.restores {
   523  		restore()
   524  	}
   525  }
   526  
   527  func (rs *readerSuite) TestDoRestore(c *check.C) {
   528  	defer snapshotstate.MockConfigGetSnapConfig(func(_ *state.State, snapname string) (*json.RawMessage, error) {
   529  		rs.calls = append(rs.calls, "get config")
   530  		c.Check(snapname, check.Equals, "a-snap")
   531  		buf := json.RawMessage(`{"old": "conf"}`)
   532  		return &buf, nil
   533  	})()
   534  	defer snapshotstate.MockBackendOpen(func(filename string) (*backend.Reader, error) {
   535  		rs.calls = append(rs.calls, "open")
   536  		c.Check(filename, check.Equals, "/some/file.zip")
   537  		return &backend.Reader{
   538  			Snapshot: client.Snapshot{Conf: map[string]interface{}{"hello": "there"}},
   539  		}, nil
   540  	})()
   541  	defer snapshotstate.MockBackendRestore(func(_ *backend.Reader, _ context.Context, _ snap.Revision, users []string, _ backend.Logf) (*backend.RestoreState, error) {
   542  		rs.calls = append(rs.calls, "restore")
   543  		c.Check(users, check.DeepEquals, []string{"a-user", "b-user"})
   544  		return &backend.RestoreState{}, nil
   545  	})()
   546  	defer snapshotstate.MockConfigSetSnapConfig(func(_ *state.State, snapname string, conf *json.RawMessage) error {
   547  		rs.calls = append(rs.calls, "set config")
   548  		c.Check(snapname, check.Equals, "a-snap")
   549  		c.Check(string(*conf), check.Equals, `{"hello":"there"}`)
   550  		return nil
   551  	})()
   552  
   553  	err := snapshotstate.DoRestore(rs.task, &tomb.Tomb{})
   554  	c.Assert(err, check.IsNil)
   555  	c.Check(rs.calls, check.DeepEquals, []string{"get config", "open", "restore", "set config"})
   556  
   557  	st := rs.task.State()
   558  	st.Lock()
   559  	var v map[string]interface{}
   560  	rs.task.Get("restore-state", &v)
   561  	st.Unlock()
   562  	c.Check(v, check.DeepEquals, map[string]interface{}{"config": map[string]interface{}{"old": "conf"}})
   563  }
   564  
   565  func (rs *readerSuite) TestDoRestoreFailsNoTaskSnapshot(c *check.C) {
   566  	rs.task.State().Lock()
   567  	rs.task.Clear("snapshot-setup")
   568  	rs.task.State().Unlock()
   569  
   570  	err := snapshotstate.DoRestore(rs.task, &tomb.Tomb{})
   571  	c.Assert(err, check.NotNil)
   572  	c.Assert(err.Error(), check.Equals, "internal error: task 1 (restore-snapshot) is missing snapshot information")
   573  	c.Check(rs.calls, check.HasLen, 0)
   574  }
   575  
   576  func (rs *readerSuite) TestDoRestoreFailsOnGetConfigError(c *check.C) {
   577  	defer snapshotstate.MockConfigGetSnapConfig(func(*state.State, string) (*json.RawMessage, error) {
   578  		rs.calls = append(rs.calls, "get config")
   579  		return nil, errors.New("bzzt")
   580  	})()
   581  
   582  	err := snapshotstate.DoRestore(rs.task, &tomb.Tomb{})
   583  	c.Assert(err, check.ErrorMatches, "internal error: cannot obtain current snap config for snapshot restore: bzzt")
   584  	c.Check(rs.calls, check.DeepEquals, []string{"get config"})
   585  }
   586  
   587  func (rs *readerSuite) TestDoRestoreFailsOnBadConfig(c *check.C) {
   588  	defer snapshotstate.MockConfigGetSnapConfig(func(*state.State, string) (*json.RawMessage, error) {
   589  		rs.calls = append(rs.calls, "get config")
   590  		buf := json.RawMessage(`42`)
   591  		return &buf, nil
   592  	})()
   593  
   594  	err := snapshotstate.DoRestore(rs.task, &tomb.Tomb{})
   595  	c.Assert(err, check.ErrorMatches, ".* cannot unmarshal .*")
   596  	c.Check(rs.calls, check.DeepEquals, []string{"get config"})
   597  }
   598  
   599  func (rs *readerSuite) TestDoRestoreFailsOpenError(c *check.C) {
   600  	defer snapshotstate.MockBackendOpen(func(string) (*backend.Reader, error) {
   601  		rs.calls = append(rs.calls, "open")
   602  		return nil, errors.New("bzzt")
   603  	})()
   604  
   605  	err := snapshotstate.DoRestore(rs.task, &tomb.Tomb{})
   606  	c.Assert(err, check.ErrorMatches, "cannot open snapshot: bzzt")
   607  	c.Check(rs.calls, check.DeepEquals, []string{"get config", "open"})
   608  }
   609  
   610  func (rs *readerSuite) TestDoRestoreFailsUnserialisableSnapshotConfigError(c *check.C) {
   611  	defer snapshotstate.MockBackendOpen(func(string) (*backend.Reader, error) {
   612  		rs.calls = append(rs.calls, "open")
   613  		return &backend.Reader{
   614  			Snapshot: client.Snapshot{Conf: map[string]interface{}{"hello": func() {}}},
   615  		}, nil
   616  	})()
   617  
   618  	err := snapshotstate.DoRestore(rs.task, &tomb.Tomb{})
   619  	c.Assert(err, check.ErrorMatches, "cannot marshal saved config: json.*")
   620  	c.Check(rs.calls, check.DeepEquals, []string{"get config", "open", "restore", "revert"})
   621  }
   622  
   623  func (rs *readerSuite) TestDoRestoreFailsOnRestoreError(c *check.C) {
   624  	defer snapshotstate.MockBackendRestore(func(*backend.Reader, context.Context, snap.Revision, []string, backend.Logf) (*backend.RestoreState, error) {
   625  		rs.calls = append(rs.calls, "restore")
   626  		return nil, errors.New("bzzt")
   627  	})()
   628  
   629  	err := snapshotstate.DoRestore(rs.task, &tomb.Tomb{})
   630  	c.Assert(err, check.ErrorMatches, "bzzt")
   631  	c.Check(rs.calls, check.DeepEquals, []string{"get config", "open", "restore"})
   632  }
   633  
   634  func (rs *readerSuite) TestDoRestoreFailsAndRevertsOnSetConfigError(c *check.C) {
   635  	defer snapshotstate.MockConfigSetSnapConfig(func(*state.State, string, *json.RawMessage) error {
   636  		rs.calls = append(rs.calls, "set config")
   637  		return errors.New("bzzt")
   638  	})()
   639  
   640  	err := snapshotstate.DoRestore(rs.task, &tomb.Tomb{})
   641  	c.Assert(err, check.ErrorMatches, "cannot set snap config: bzzt")
   642  	c.Check(rs.calls, check.DeepEquals, []string{"get config", "open", "restore", "set config", "revert"})
   643  }
   644  
   645  func (rs *readerSuite) TestUndoRestore(c *check.C) {
   646  	st := rs.task.State()
   647  	st.Lock()
   648  	var v map[string]interface{}
   649  	rs.task.Set("restore-state", &v)
   650  	st.Unlock()
   651  
   652  	err := snapshotstate.UndoRestore(rs.task, &tomb.Tomb{})
   653  	c.Assert(err, check.IsNil)
   654  	c.Check(rs.calls, check.DeepEquals, []string{"set config", "revert"})
   655  }
   656  
   657  func (rs *readerSuite) TestCleanupRestore(c *check.C) {
   658  	st := rs.task.State()
   659  	st.Lock()
   660  	var v map[string]interface{}
   661  	rs.task.Set("restore-state", &v)
   662  	st.Unlock()
   663  
   664  	err := snapshotstate.CleanupRestore(rs.task, &tomb.Tomb{})
   665  	c.Assert(err, check.IsNil)
   666  	c.Check(rs.calls, check.HasLen, 0)
   667  
   668  	st.Lock()
   669  	rs.task.SetStatus(state.DoneStatus)
   670  	st.Unlock()
   671  
   672  	err = snapshotstate.CleanupRestore(rs.task, &tomb.Tomb{})
   673  	c.Assert(err, check.IsNil)
   674  	c.Check(rs.calls, check.DeepEquals, []string{"cleanup"})
   675  }
   676  
   677  func (rs *readerSuite) TestDoCheck(c *check.C) {
   678  	defer snapshotstate.MockBackendOpen(func(filename string) (*backend.Reader, error) {
   679  		rs.calls = append(rs.calls, "open")
   680  		c.Check(filename, check.Equals, "/some/file.zip")
   681  		return &backend.Reader{
   682  			Snapshot: client.Snapshot{Conf: map[string]interface{}{"hello": "there"}},
   683  		}, nil
   684  	})()
   685  	defer snapshotstate.MockBackendCheck(func(_ *backend.Reader, _ context.Context, users []string) error {
   686  		rs.calls = append(rs.calls, "check")
   687  		c.Check(users, check.DeepEquals, []string{"a-user", "b-user"})
   688  		return nil
   689  	})()
   690  
   691  	err := snapshotstate.DoCheck(rs.task, &tomb.Tomb{})
   692  	c.Assert(err, check.IsNil)
   693  	c.Check(rs.calls, check.DeepEquals, []string{"open", "check"})
   694  
   695  }
   696  
   697  func (rs *readerSuite) TestDoRemove(c *check.C) {
   698  	defer snapshotstate.MockOsRemove(func(filename string) error {
   699  		c.Check(filename, check.Equals, "/some/file.zip")
   700  		rs.calls = append(rs.calls, "remove")
   701  		return nil
   702  	})()
   703  	err := snapshotstate.DoForget(rs.task, &tomb.Tomb{})
   704  	c.Assert(err, check.IsNil)
   705  	c.Check(rs.calls, check.DeepEquals, []string{"remove"})
   706  }
   707  
   708  func (rs *readerSuite) TestDoForgetRemovesAutomaticSnapshotExpiry(c *check.C) {
   709  	defer snapshotstate.MockOsRemove(func(filename string) error {
   710  		return nil
   711  	})()
   712  
   713  	st := state.New(nil)
   714  	st.Lock()
   715  	defer st.Unlock()
   716  
   717  	task := st.NewTask("forget-snapshot", "...")
   718  	task.Set("snapshot-setup", map[string]interface{}{
   719  		"set-id":   1,
   720  		"filename": "a-file",
   721  		"snap":     "a-snap",
   722  	})
   723  
   724  	st.Set("snapshots", map[uint64]interface{}{
   725  		1: map[string]interface{}{
   726  			"expiry-time": "2001-03-11T11:24:00Z",
   727  		},
   728  		2: map[string]interface{}{
   729  			"expiry-time": "2037-02-12T12:50:00Z",
   730  		},
   731  	})
   732  
   733  	st.Unlock()
   734  	c.Assert(snapshotstate.DoForget(task, &tomb.Tomb{}), check.IsNil)
   735  
   736  	st.Lock()
   737  	var expirations map[uint64]interface{}
   738  	c.Assert(st.Get("snapshots", &expirations), check.IsNil)
   739  	c.Check(expirations, check.DeepEquals, map[uint64]interface{}{
   740  		2: map[string]interface{}{
   741  			"expiry-time": "2037-02-12T12:50:00Z",
   742  		}})
   743  }