github.com/kubiko/snapd@v0.0.0-20201013125620-d4f3094d9ddf/overlord/snapshotstate/snapshotmgr_test.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2018 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package snapshotstate_test
    21  
    22  import (
    23  	"context"
    24  	"encoding/json"
    25  	"errors"
    26  	"os"
    27  	"path/filepath"
    28  	"sort"
    29  	"time"
    30  
    31  	"gopkg.in/check.v1"
    32  	"gopkg.in/tomb.v2"
    33  
    34  	"github.com/snapcore/snapd/client"
    35  	"github.com/snapcore/snapd/dirs"
    36  	"github.com/snapcore/snapd/overlord/snapshotstate"
    37  	"github.com/snapcore/snapd/overlord/snapshotstate/backend"
    38  	"github.com/snapcore/snapd/overlord/state"
    39  	"github.com/snapcore/snapd/snap"
    40  )
    41  
    42  func (snapshotSuite) TestManager(c *check.C) {
    43  	st := state.New(nil)
    44  	st.Lock()
    45  	defer st.Unlock()
    46  	runner := state.NewTaskRunner(st)
    47  	mgr := snapshotstate.Manager(st, runner)
    48  	c.Assert(mgr, check.NotNil)
    49  	kinds := runner.KnownTaskKinds()
    50  	sort.Strings(kinds)
    51  	c.Check(kinds, check.DeepEquals, []string{
    52  		"check-snapshot",
    53  		"forget-snapshot",
    54  		"restore-snapshot",
    55  		"save-snapshot",
    56  	})
    57  }
    58  
    59  func mockDummySnapshot(c *check.C) (restore func()) {
    60  	shotfile, err := os.Create(filepath.Join(c.MkDir(), "foo.zip"))
    61  	c.Assert(err, check.IsNil)
    62  
    63  	fakeIter := func(_ context.Context, f func(*backend.Reader) error) error {
    64  		c.Assert(f(&backend.Reader{
    65  			Snapshot: client.Snapshot{SetID: 1, Snap: "a-snap", SnapID: "a-id", Epoch: snap.Epoch{Read: []uint32{42}, Write: []uint32{17}}},
    66  			File:     shotfile,
    67  		}), check.IsNil)
    68  		return nil
    69  	}
    70  
    71  	restoreBackendIter := snapshotstate.MockBackendIter(fakeIter)
    72  
    73  	return func() {
    74  		shotfile.Close()
    75  		restoreBackendIter()
    76  	}
    77  }
    78  
    79  func (snapshotSuite) TestEnsureForgetsSnapshots(c *check.C) {
    80  	var removedSnapshot string
    81  	restoreOsRemove := snapshotstate.MockOsRemove(func(fileName string) error {
    82  		removedSnapshot = fileName
    83  		return nil
    84  	})
    85  	defer restoreOsRemove()
    86  
    87  	restore := mockDummySnapshot(c)
    88  	defer restore()
    89  
    90  	st := state.New(nil)
    91  	runner := state.NewTaskRunner(st)
    92  	mgr := snapshotstate.Manager(st, runner)
    93  	c.Assert(mgr, check.NotNil)
    94  
    95  	st.Lock()
    96  	defer st.Unlock()
    97  
    98  	st.Set("snapshots", map[uint64]interface{}{
    99  		1: map[string]interface{}{"expiry-time": "2001-03-11T11:24:00Z"},
   100  		2: map[string]interface{}{"expiry-time": "2037-02-12T12:50:00Z"},
   101  	})
   102  
   103  	st.Unlock()
   104  	c.Assert(mgr.Ensure(), check.IsNil)
   105  	st.Lock()
   106  
   107  	// verify expired snapshots were removed
   108  	var expirations map[uint64]interface{}
   109  	c.Assert(st.Get("snapshots", &expirations), check.IsNil)
   110  	c.Check(expirations, check.DeepEquals, map[uint64]interface{}{
   111  		2: map[string]interface{}{"expiry-time": "2037-02-12T12:50:00Z"}})
   112  	c.Check(removedSnapshot, check.Matches, ".*/foo.zip")
   113  }
   114  
   115  func (snapshotSuite) TestEnsureForgetsSnapshotsRunsRegularly(c *check.C) {
   116  	var backendIterCalls int
   117  	shotfile, err := os.Create(filepath.Join(c.MkDir(), "foo.zip"))
   118  	c.Assert(err, check.IsNil)
   119  	fakeIter := func(_ context.Context, f func(*backend.Reader) error) error {
   120  		c.Assert(f(&backend.Reader{
   121  			Snapshot: client.Snapshot{SetID: 1, Snap: "a-snap", SnapID: "a-id", Epoch: snap.Epoch{Read: []uint32{42}, Write: []uint32{17}}},
   122  			File:     shotfile,
   123  		}), check.IsNil)
   124  		backendIterCalls++
   125  		return nil
   126  	}
   127  	restoreBackendIter := snapshotstate.MockBackendIter(fakeIter)
   128  	defer restoreBackendIter()
   129  
   130  	restoreOsRemove := snapshotstate.MockOsRemove(func(fileName string) error {
   131  		return nil
   132  	})
   133  	defer restoreOsRemove()
   134  
   135  	st := state.New(nil)
   136  	runner := state.NewTaskRunner(st)
   137  	mgr := snapshotstate.Manager(st, runner)
   138  	c.Assert(mgr, check.NotNil)
   139  
   140  	storeExpiredSnapshot := func() {
   141  		st.Lock()
   142  		// we need at least one snapshot set in the state for forgetExpiredSnapshots to do any work
   143  		st.Set("snapshots", map[uint64]interface{}{
   144  			1: map[string]interface{}{"expiry-time": "2001-03-11T11:24:00Z"},
   145  		})
   146  		st.Unlock()
   147  	}
   148  
   149  	// consecutive runs of Ensure call the backend just once because of the snapshotExpirationLoopInterval
   150  	for i := 0; i < 3; i++ {
   151  		storeExpiredSnapshot()
   152  		c.Assert(mgr.Ensure(), check.IsNil)
   153  		c.Check(backendIterCalls, check.Equals, 1)
   154  	}
   155  
   156  	// pretend we haven't run for a while
   157  	t, err := time.Parse(time.RFC3339, "2002-03-11T11:24:00Z")
   158  	c.Assert(err, check.IsNil)
   159  	mgr.SetLastForgetExpiredSnapshotTime(t)
   160  	c.Assert(mgr.Ensure(), check.IsNil)
   161  	c.Check(backendIterCalls, check.Equals, 2)
   162  
   163  	c.Assert(mgr.Ensure(), check.IsNil)
   164  	c.Check(backendIterCalls, check.Equals, 2)
   165  }
   166  
   167  func (snapshotSuite) testEnsureForgetSnapshotsConflict(c *check.C, snapshotTaskKind string) {
   168  	removeCalled := 0
   169  	restoreOsRemove := snapshotstate.MockOsRemove(func(string) error {
   170  		removeCalled++
   171  		return nil
   172  	})
   173  	defer restoreOsRemove()
   174  
   175  	restore := mockDummySnapshot(c)
   176  	defer restore()
   177  
   178  	st := state.New(nil)
   179  	runner := state.NewTaskRunner(st)
   180  	mgr := snapshotstate.Manager(st, runner)
   181  	c.Assert(mgr, check.NotNil)
   182  
   183  	st.Lock()
   184  	defer st.Unlock()
   185  
   186  	st.Set("snapshots", map[uint64]interface{}{
   187  		1: map[string]interface{}{"expiry-time": "2001-03-11T11:24:00Z"},
   188  	})
   189  
   190  	chg := st.NewChange("snapshot-change", "...")
   191  	tsk := st.NewTask(snapshotTaskKind, "...")
   192  	tsk.SetStatus(state.DoingStatus)
   193  	tsk.Set("snapshot-setup", map[string]int{"set-id": 1})
   194  	chg.AddTask(tsk)
   195  
   196  	st.Unlock()
   197  	c.Assert(mgr.Ensure(), check.IsNil)
   198  	st.Lock()
   199  
   200  	var expirations map[uint64]interface{}
   201  	c.Assert(st.Get("snapshots", &expirations), check.IsNil)
   202  	c.Check(expirations, check.DeepEquals, map[uint64]interface{}{
   203  		1: map[string]interface{}{"expiry-time": "2001-03-11T11:24:00Z"},
   204  	})
   205  	c.Check(removeCalled, check.Equals, 0)
   206  
   207  	// sanity check of the test setup: snapshot gets removed once conflict goes away
   208  	tsk.SetStatus(state.DoneStatus)
   209  
   210  	// pretend we haven't run for a while
   211  	t, err := time.Parse(time.RFC3339, "2002-03-11T11:24:00Z")
   212  	c.Assert(err, check.IsNil)
   213  	mgr.SetLastForgetExpiredSnapshotTime(t)
   214  
   215  	st.Unlock()
   216  	c.Assert(mgr.Ensure(), check.IsNil)
   217  	st.Lock()
   218  
   219  	expirations = nil
   220  	c.Assert(st.Get("snapshots", &expirations), check.IsNil)
   221  	c.Check(removeCalled, check.Equals, 1)
   222  	c.Check(expirations, check.HasLen, 0)
   223  }
   224  
   225  func (s *snapshotSuite) TestEnsureForgetSnapshotsConflictWithCheckSnapshot(c *check.C) {
   226  	s.testEnsureForgetSnapshotsConflict(c, "check-snapshot")
   227  }
   228  
   229  func (s *snapshotSuite) TestEnsureForgetSnapshotsConflictWithRestoreSnapshot(c *check.C) {
   230  	s.testEnsureForgetSnapshotsConflict(c, "restore-snapshot")
   231  }
   232  
   233  func (snapshotSuite) TestFilename(c *check.C) {
   234  	si := &snap.Info{
   235  		SideInfo: snap.SideInfo{
   236  			RealName: "a-snap",
   237  			Revision: snap.R(-1),
   238  		},
   239  		Version: "1.33",
   240  	}
   241  	filename := snapshotstate.Filename(42, si)
   242  	c.Check(filepath.Dir(filename), check.Equals, dirs.SnapshotsDir)
   243  	c.Check(filepath.Base(filename), check.Equals, "42_a-snap_1.33_x1.zip")
   244  }
   245  
   246  func (snapshotSuite) TestDoSave(c *check.C) {
   247  	snapInfo := snap.Info{
   248  		SideInfo: snap.SideInfo{
   249  			RealName: "a-snap",
   250  			Revision: snap.R(-1),
   251  		},
   252  		Version: "1.33",
   253  	}
   254  	defer snapshotstate.MockSnapstateCurrentInfo(func(_ *state.State, snapname string) (*snap.Info, error) {
   255  		c.Check(snapname, check.Equals, "a-snap")
   256  		return &snapInfo, nil
   257  	})()
   258  	defer snapshotstate.MockConfigGetSnapConfig(func(_ *state.State, snapname string) (*json.RawMessage, error) {
   259  		c.Check(snapname, check.Equals, "a-snap")
   260  		buf := json.RawMessage(`{"hello": "there"}`)
   261  		return &buf, nil
   262  	})()
   263  	defer snapshotstate.MockBackendSave(func(_ context.Context, id uint64, si *snap.Info, cfg map[string]interface{}, usernames []string, flags *backend.Flags) (*client.Snapshot, error) {
   264  		c.Check(id, check.Equals, uint64(42))
   265  		c.Check(si, check.DeepEquals, &snapInfo)
   266  		c.Check(cfg, check.DeepEquals, map[string]interface{}{"hello": "there"})
   267  		c.Check(usernames, check.DeepEquals, []string{"a-user", "b-user"})
   268  		c.Check(flags.Auto, check.Equals, false)
   269  		return nil, nil
   270  	})()
   271  
   272  	st := state.New(nil)
   273  	st.Lock()
   274  	task := st.NewTask("save-snapshot", "...")
   275  	task.Set("snapshot-setup", map[string]interface{}{
   276  		"set-id": 42,
   277  		"snap":   "a-snap",
   278  		"users":  []string{"a-user", "b-user"},
   279  	})
   280  	st.Unlock()
   281  	err := snapshotstate.DoSave(task, &tomb.Tomb{})
   282  	c.Assert(err, check.IsNil)
   283  }
   284  
   285  func (snapshotSuite) TestDoSaveFailsWithNoSnap(c *check.C) {
   286  	defer snapshotstate.MockSnapstateCurrentInfo(func(*state.State, string) (*snap.Info, error) {
   287  		return nil, errors.New("bzzt")
   288  	})()
   289  	defer snapshotstate.MockConfigGetSnapConfig(func(*state.State, string) (*json.RawMessage, error) { return nil, nil })()
   290  	defer snapshotstate.MockBackendSave(func(_ context.Context, id uint64, si *snap.Info, cfg map[string]interface{}, usernames []string, flags *backend.Flags) (*client.Snapshot, error) {
   291  		return nil, nil
   292  	})()
   293  
   294  	st := state.New(nil)
   295  	st.Lock()
   296  	task := st.NewTask("save-snapshot", "...")
   297  	task.Set("snapshot-setup", map[string]interface{}{
   298  		"set-id": 42,
   299  		"snap":   "a-snap",
   300  		"users":  []string{"a-user", "b-user"},
   301  	})
   302  	st.Unlock()
   303  	err := snapshotstate.DoSave(task, &tomb.Tomb{})
   304  	c.Assert(err, check.ErrorMatches, "bzzt")
   305  }
   306  
   307  func (snapshotSuite) TestDoSaveFailsWithNoSnapshot(c *check.C) {
   308  	snapInfo := snap.Info{
   309  		SideInfo: snap.SideInfo{
   310  			RealName: "a-snap",
   311  			Revision: snap.R(-1),
   312  		},
   313  		Version: "1.33",
   314  	}
   315  	defer snapshotstate.MockSnapstateCurrentInfo(func(*state.State, string) (*snap.Info, error) { return &snapInfo, nil })()
   316  	defer snapshotstate.MockConfigGetSnapConfig(func(*state.State, string) (*json.RawMessage, error) { return nil, nil })()
   317  	defer snapshotstate.MockBackendSave(func(_ context.Context, id uint64, si *snap.Info, cfg map[string]interface{}, usernames []string, flags *backend.Flags) (*client.Snapshot, error) {
   318  		return nil, nil
   319  	})()
   320  
   321  	st := state.New(nil)
   322  	st.Lock()
   323  	task := st.NewTask("save-snapshot", "...")
   324  	// NOTE no task.Set("snapshot-setup", ...)
   325  	st.Unlock()
   326  	err := snapshotstate.DoSave(task, &tomb.Tomb{})
   327  	c.Assert(err, check.NotNil)
   328  	c.Assert(err.Error(), check.Equals, "internal error: task 1 (save-snapshot) is missing snapshot information")
   329  }
   330  
   331  func (snapshotSuite) TestDoSaveFailsBackendError(c *check.C) {
   332  	snapInfo := snap.Info{
   333  		SideInfo: snap.SideInfo{
   334  			RealName: "a-snap",
   335  			Revision: snap.R(-1),
   336  		},
   337  		Version: "1.33",
   338  	}
   339  	defer snapshotstate.MockSnapstateCurrentInfo(func(*state.State, string) (*snap.Info, error) { return &snapInfo, nil })()
   340  	defer snapshotstate.MockConfigGetSnapConfig(func(*state.State, string) (*json.RawMessage, error) { return nil, nil })()
   341  	defer snapshotstate.MockBackendSave(func(_ context.Context, id uint64, si *snap.Info, cfg map[string]interface{}, usernames []string, flags *backend.Flags) (*client.Snapshot, error) {
   342  		return nil, errors.New("bzzt")
   343  	})()
   344  
   345  	st := state.New(nil)
   346  	st.Lock()
   347  	task := st.NewTask("save-snapshot", "...")
   348  	task.Set("snapshot-setup", map[string]interface{}{
   349  		"set-id": 42,
   350  		"snap":   "a-snap",
   351  		"users":  []string{"a-user", "b-user"},
   352  	})
   353  	st.Unlock()
   354  	err := snapshotstate.DoSave(task, &tomb.Tomb{})
   355  	c.Assert(err, check.ErrorMatches, "bzzt")
   356  }
   357  
   358  func (snapshotSuite) TestDoSaveFailsConfigError(c *check.C) {
   359  	snapInfo := snap.Info{
   360  		SideInfo: snap.SideInfo{
   361  			RealName: "a-snap",
   362  			Revision: snap.R(-1),
   363  		},
   364  		Version: "1.33",
   365  	}
   366  	defer snapshotstate.MockSnapstateCurrentInfo(func(*state.State, string) (*snap.Info, error) { return &snapInfo, nil })()
   367  	defer snapshotstate.MockConfigGetSnapConfig(func(*state.State, string) (*json.RawMessage, error) {
   368  		return nil, errors.New("bzzt")
   369  	})()
   370  	defer snapshotstate.MockBackendSave(func(_ context.Context, id uint64, si *snap.Info, cfg map[string]interface{}, usernames []string, flags *backend.Flags) (*client.Snapshot, error) {
   371  		return nil, nil
   372  	})()
   373  
   374  	st := state.New(nil)
   375  	st.Lock()
   376  	task := st.NewTask("save-snapshot", "...")
   377  	task.Set("snapshot-setup", map[string]interface{}{
   378  		"set-id": 42,
   379  		"snap":   "a-snap",
   380  		"users":  []string{"a-user", "b-user"},
   381  	})
   382  	st.Unlock()
   383  	err := snapshotstate.DoSave(task, &tomb.Tomb{})
   384  	c.Assert(err, check.ErrorMatches, "bzzt")
   385  }
   386  
   387  func (snapshotSuite) TestDoSaveFailsBadConfig(c *check.C) {
   388  	snapInfo := snap.Info{
   389  		SideInfo: snap.SideInfo{
   390  			RealName: "a-snap",
   391  			Revision: snap.R(-1),
   392  		},
   393  		Version: "1.33",
   394  	}
   395  	defer snapshotstate.MockSnapstateCurrentInfo(func(*state.State, string) (*snap.Info, error) { return &snapInfo, nil })()
   396  	defer snapshotstate.MockConfigGetSnapConfig(func(*state.State, string) (*json.RawMessage, error) {
   397  		// returns something that's not a JSON object
   398  		buf := json.RawMessage(`"hello-there"`)
   399  		return &buf, nil
   400  	})()
   401  	defer snapshotstate.MockBackendSave(func(_ context.Context, id uint64, si *snap.Info, cfg map[string]interface{}, usernames []string, flags *backend.Flags) (*client.Snapshot, error) {
   402  		return nil, nil
   403  	})()
   404  
   405  	st := state.New(nil)
   406  	st.Lock()
   407  	task := st.NewTask("save-snapshot", "...")
   408  	task.Set("snapshot-setup", map[string]interface{}{
   409  		"set-id": 42,
   410  		"snap":   "a-snap",
   411  		"users":  []string{"a-user", "b-user"},
   412  	})
   413  	st.Unlock()
   414  	err := snapshotstate.DoSave(task, &tomb.Tomb{})
   415  	c.Assert(err, check.ErrorMatches, ".* cannot unmarshal .*")
   416  }
   417  
   418  func (snapshotSuite) TestDoSaveFailureRemovesStateEntry(c *check.C) {
   419  	st := state.New(nil)
   420  
   421  	snapInfo := snap.Info{
   422  		SideInfo: snap.SideInfo{
   423  			RealName: "a-snap",
   424  			Revision: snap.R(-1),
   425  		},
   426  		Version: "1.33",
   427  	}
   428  	defer snapshotstate.MockSnapstateCurrentInfo(func(_ *state.State, snapname string) (*snap.Info, error) {
   429  		return &snapInfo, nil
   430  	})()
   431  	defer snapshotstate.MockConfigGetSnapConfig(func(_ *state.State, snapname string) (*json.RawMessage, error) {
   432  		return nil, nil
   433  	})()
   434  	defer snapshotstate.MockBackendSave(func(_ context.Context, id uint64, si *snap.Info, cfg map[string]interface{}, usernames []string, flags *backend.Flags) (*client.Snapshot, error) {
   435  		var expirations map[uint64]interface{}
   436  		st.Lock()
   437  		defer st.Unlock()
   438  		// verify that prepareSave stored expiration in the state
   439  		c.Assert(st.Get("snapshots", &expirations), check.IsNil)
   440  		c.Assert(expirations, check.HasLen, 1)
   441  		c.Check(expirations[42], check.NotNil)
   442  		return nil, errors.New("error")
   443  	})()
   444  
   445  	st.Lock()
   446  
   447  	task := st.NewTask("save-snapshot", "...")
   448  	task.Set("snapshot-setup", map[string]interface{}{
   449  		"set-id": 42,
   450  		"snap":   "a-snap",
   451  		"auto":   true,
   452  	})
   453  	st.Unlock()
   454  	err := snapshotstate.DoSave(task, &tomb.Tomb{})
   455  	c.Assert(err, check.ErrorMatches, "error")
   456  
   457  	st.Lock()
   458  	defer st.Unlock()
   459  
   460  	// verify that after backend.Save failure expiration was removed from the state
   461  	var expirations map[uint64]interface{}
   462  	c.Assert(st.Get("snapshots", &expirations), check.IsNil)
   463  	c.Check(expirations, check.HasLen, 0)
   464  }
   465  
   466  type readerSuite struct {
   467  	task     *state.Task
   468  	calls    []string
   469  	restores []func()
   470  }
   471  
   472  var _ = check.Suite(&readerSuite{})
   473  
   474  func (rs *readerSuite) SetUpTest(c *check.C) {
   475  	st := state.New(nil)
   476  	st.Lock()
   477  	rs.task = st.NewTask("restore-snapshot", "...")
   478  	rs.task.Set("snapshot-setup", map[string]interface{}{
   479  		// interestingly restore doesn't use the set-id
   480  		"snap":     "a-snap",
   481  		"filename": "/some/1_file.zip",
   482  		"users":    []string{"a-user", "b-user"},
   483  	})
   484  	st.Unlock()
   485  
   486  	rs.calls = nil
   487  	rs.restores = []func(){
   488  		snapshotstate.MockOsRemove(func(string) error {
   489  			rs.calls = append(rs.calls, "remove")
   490  			return nil
   491  		}),
   492  		snapshotstate.MockConfigGetSnapConfig(func(*state.State, string) (*json.RawMessage, error) {
   493  			rs.calls = append(rs.calls, "get config")
   494  			return nil, nil
   495  		}),
   496  		snapshotstate.MockConfigSetSnapConfig(func(*state.State, string, *json.RawMessage) error {
   497  			rs.calls = append(rs.calls, "set config")
   498  			return nil
   499  		}),
   500  		snapshotstate.MockBackendOpen(func(string, uint64) (*backend.Reader, error) {
   501  			rs.calls = append(rs.calls, "open")
   502  			return &backend.Reader{}, nil
   503  		}),
   504  		snapshotstate.MockBackendRestore(func(*backend.Reader, context.Context, snap.Revision, []string, backend.Logf) (*backend.RestoreState, error) {
   505  			rs.calls = append(rs.calls, "restore")
   506  			return &backend.RestoreState{}, nil
   507  		}),
   508  		snapshotstate.MockBackendCheck(func(*backend.Reader, context.Context, []string) error {
   509  			rs.calls = append(rs.calls, "check")
   510  			return nil
   511  		}),
   512  		snapshotstate.MockBackendRevert(func(*backend.RestoreState) {
   513  			rs.calls = append(rs.calls, "revert")
   514  		}),
   515  		snapshotstate.MockBackendCleanup(func(*backend.RestoreState) {
   516  			rs.calls = append(rs.calls, "cleanup")
   517  		}),
   518  	}
   519  }
   520  
   521  func (rs *readerSuite) TearDownTest(c *check.C) {
   522  	for _, restore := range rs.restores {
   523  		restore()
   524  	}
   525  }
   526  
   527  func (rs *readerSuite) TestDoRestore(c *check.C) {
   528  	defer snapshotstate.MockConfigGetSnapConfig(func(_ *state.State, snapname string) (*json.RawMessage, error) {
   529  		rs.calls = append(rs.calls, "get config")
   530  		c.Check(snapname, check.Equals, "a-snap")
   531  		buf := json.RawMessage(`{"old": "conf"}`)
   532  		return &buf, nil
   533  	})()
   534  	defer snapshotstate.MockBackendOpen(func(filename string, setID uint64) (*backend.Reader, error) {
   535  		rs.calls = append(rs.calls, "open")
   536  		// set id 0 tells backend.Open to use set id from the filename
   537  		c.Check(setID, check.Equals, uint64(0))
   538  		c.Check(filename, check.Equals, "/some/1_file.zip")
   539  		return &backend.Reader{
   540  			Snapshot: client.Snapshot{Conf: map[string]interface{}{"hello": "there"}},
   541  		}, nil
   542  	})()
   543  	defer snapshotstate.MockBackendRestore(func(_ *backend.Reader, _ context.Context, _ snap.Revision, users []string, _ backend.Logf) (*backend.RestoreState, error) {
   544  		rs.calls = append(rs.calls, "restore")
   545  		c.Check(users, check.DeepEquals, []string{"a-user", "b-user"})
   546  		return &backend.RestoreState{}, nil
   547  	})()
   548  	defer snapshotstate.MockConfigSetSnapConfig(func(_ *state.State, snapname string, conf *json.RawMessage) error {
   549  		rs.calls = append(rs.calls, "set config")
   550  		c.Check(snapname, check.Equals, "a-snap")
   551  		c.Check(string(*conf), check.Equals, `{"hello":"there"}`)
   552  		return nil
   553  	})()
   554  
   555  	err := snapshotstate.DoRestore(rs.task, &tomb.Tomb{})
   556  	c.Assert(err, check.IsNil)
   557  	c.Check(rs.calls, check.DeepEquals, []string{"get config", "open", "restore", "set config"})
   558  
   559  	st := rs.task.State()
   560  	st.Lock()
   561  	var v map[string]interface{}
   562  	rs.task.Get("restore-state", &v)
   563  	st.Unlock()
   564  	c.Check(v, check.DeepEquals, map[string]interface{}{"config": map[string]interface{}{"old": "conf"}})
   565  }
   566  
   567  func (rs *readerSuite) TestDoRestoreFailsNoTaskSnapshot(c *check.C) {
   568  	rs.task.State().Lock()
   569  	rs.task.Clear("snapshot-setup")
   570  	rs.task.State().Unlock()
   571  
   572  	err := snapshotstate.DoRestore(rs.task, &tomb.Tomb{})
   573  	c.Assert(err, check.NotNil)
   574  	c.Assert(err.Error(), check.Equals, "internal error: task 1 (restore-snapshot) is missing snapshot information")
   575  	c.Check(rs.calls, check.HasLen, 0)
   576  }
   577  
   578  func (rs *readerSuite) TestDoRestoreFailsOnGetConfigError(c *check.C) {
   579  	defer snapshotstate.MockConfigGetSnapConfig(func(*state.State, string) (*json.RawMessage, error) {
   580  		rs.calls = append(rs.calls, "get config")
   581  		return nil, errors.New("bzzt")
   582  	})()
   583  
   584  	err := snapshotstate.DoRestore(rs.task, &tomb.Tomb{})
   585  	c.Assert(err, check.ErrorMatches, "internal error: cannot obtain current snap config for snapshot restore: bzzt")
   586  	c.Check(rs.calls, check.DeepEquals, []string{"get config"})
   587  }
   588  
   589  func (rs *readerSuite) TestDoRestoreFailsOnBadConfig(c *check.C) {
   590  	defer snapshotstate.MockConfigGetSnapConfig(func(*state.State, string) (*json.RawMessage, error) {
   591  		rs.calls = append(rs.calls, "get config")
   592  		buf := json.RawMessage(`42`)
   593  		return &buf, nil
   594  	})()
   595  
   596  	err := snapshotstate.DoRestore(rs.task, &tomb.Tomb{})
   597  	c.Assert(err, check.ErrorMatches, ".* cannot unmarshal .*")
   598  	c.Check(rs.calls, check.DeepEquals, []string{"get config"})
   599  }
   600  
   601  func (rs *readerSuite) TestDoRestoreFailsOpenError(c *check.C) {
   602  	defer snapshotstate.MockBackendOpen(func(string, uint64) (*backend.Reader, error) {
   603  		rs.calls = append(rs.calls, "open")
   604  		return nil, errors.New("bzzt")
   605  	})()
   606  
   607  	err := snapshotstate.DoRestore(rs.task, &tomb.Tomb{})
   608  	c.Assert(err, check.ErrorMatches, "cannot open snapshot: bzzt")
   609  	c.Check(rs.calls, check.DeepEquals, []string{"get config", "open"})
   610  }
   611  
   612  func (rs *readerSuite) TestDoRestoreFailsUnserialisableSnapshotConfigError(c *check.C) {
   613  	defer snapshotstate.MockBackendOpen(func(string, uint64) (*backend.Reader, error) {
   614  		rs.calls = append(rs.calls, "open")
   615  		return &backend.Reader{
   616  			Snapshot: client.Snapshot{Conf: map[string]interface{}{"hello": func() {}}},
   617  		}, nil
   618  	})()
   619  
   620  	err := snapshotstate.DoRestore(rs.task, &tomb.Tomb{})
   621  	c.Assert(err, check.ErrorMatches, "cannot marshal saved config: json.*")
   622  	c.Check(rs.calls, check.DeepEquals, []string{"get config", "open", "restore", "revert"})
   623  }
   624  
   625  func (rs *readerSuite) TestDoRestoreFailsOnRestoreError(c *check.C) {
   626  	defer snapshotstate.MockBackendRestore(func(*backend.Reader, context.Context, snap.Revision, []string, backend.Logf) (*backend.RestoreState, error) {
   627  		rs.calls = append(rs.calls, "restore")
   628  		return nil, errors.New("bzzt")
   629  	})()
   630  
   631  	err := snapshotstate.DoRestore(rs.task, &tomb.Tomb{})
   632  	c.Assert(err, check.ErrorMatches, "bzzt")
   633  	c.Check(rs.calls, check.DeepEquals, []string{"get config", "open", "restore"})
   634  }
   635  
   636  func (rs *readerSuite) TestDoRestoreFailsAndRevertsOnSetConfigError(c *check.C) {
   637  	defer snapshotstate.MockConfigSetSnapConfig(func(*state.State, string, *json.RawMessage) error {
   638  		rs.calls = append(rs.calls, "set config")
   639  		return errors.New("bzzt")
   640  	})()
   641  
   642  	err := snapshotstate.DoRestore(rs.task, &tomb.Tomb{})
   643  	c.Assert(err, check.ErrorMatches, "cannot set snap config: bzzt")
   644  	c.Check(rs.calls, check.DeepEquals, []string{"get config", "open", "restore", "set config", "revert"})
   645  }
   646  
   647  func (rs *readerSuite) TestUndoRestore(c *check.C) {
   648  	st := rs.task.State()
   649  	st.Lock()
   650  	var v map[string]interface{}
   651  	rs.task.Set("restore-state", &v)
   652  	st.Unlock()
   653  
   654  	err := snapshotstate.UndoRestore(rs.task, &tomb.Tomb{})
   655  	c.Assert(err, check.IsNil)
   656  	c.Check(rs.calls, check.DeepEquals, []string{"set config", "revert"})
   657  }
   658  
   659  func (rs *readerSuite) TestCleanupRestore(c *check.C) {
   660  	st := rs.task.State()
   661  	st.Lock()
   662  	var v map[string]interface{}
   663  	rs.task.Set("restore-state", &v)
   664  	st.Unlock()
   665  
   666  	err := snapshotstate.CleanupRestore(rs.task, &tomb.Tomb{})
   667  	c.Assert(err, check.IsNil)
   668  	c.Check(rs.calls, check.HasLen, 0)
   669  
   670  	st.Lock()
   671  	rs.task.SetStatus(state.DoneStatus)
   672  	st.Unlock()
   673  
   674  	err = snapshotstate.CleanupRestore(rs.task, &tomb.Tomb{})
   675  	c.Assert(err, check.IsNil)
   676  	c.Check(rs.calls, check.DeepEquals, []string{"cleanup"})
   677  }
   678  
   679  func (rs *readerSuite) TestDoCheck(c *check.C) {
   680  	defer snapshotstate.MockBackendOpen(func(filename string, setID uint64) (*backend.Reader, error) {
   681  		rs.calls = append(rs.calls, "open")
   682  		c.Check(filename, check.Equals, "/some/1_file.zip")
   683  		// set id 0 tells backend.Open to use set id from the filename
   684  		c.Check(setID, check.Equals, uint64(0))
   685  		return &backend.Reader{
   686  			Snapshot: client.Snapshot{Conf: map[string]interface{}{"hello": "there"}},
   687  		}, nil
   688  	})()
   689  	defer snapshotstate.MockBackendCheck(func(_ *backend.Reader, _ context.Context, users []string) error {
   690  		rs.calls = append(rs.calls, "check")
   691  		c.Check(users, check.DeepEquals, []string{"a-user", "b-user"})
   692  		return nil
   693  	})()
   694  
   695  	err := snapshotstate.DoCheck(rs.task, &tomb.Tomb{})
   696  	c.Assert(err, check.IsNil)
   697  	c.Check(rs.calls, check.DeepEquals, []string{"open", "check"})
   698  }
   699  
   700  func (rs *readerSuite) TestDoRemove(c *check.C) {
   701  	defer snapshotstate.MockOsRemove(func(filename string) error {
   702  		c.Check(filename, check.Equals, "/some/1_file.zip")
   703  		rs.calls = append(rs.calls, "remove")
   704  		return nil
   705  	})()
   706  	err := snapshotstate.DoForget(rs.task, &tomb.Tomb{})
   707  	c.Assert(err, check.IsNil)
   708  	c.Check(rs.calls, check.DeepEquals, []string{"remove"})
   709  }
   710  
   711  func (rs *readerSuite) TestDoForgetRemovesAutomaticSnapshotExpiry(c *check.C) {
   712  	defer snapshotstate.MockOsRemove(func(filename string) error {
   713  		return nil
   714  	})()
   715  
   716  	st := state.New(nil)
   717  	st.Lock()
   718  	defer st.Unlock()
   719  
   720  	task := st.NewTask("forget-snapshot", "...")
   721  	task.Set("snapshot-setup", map[string]interface{}{
   722  		"set-id":   1,
   723  		"filename": "a-file",
   724  		"snap":     "a-snap",
   725  	})
   726  
   727  	st.Set("snapshots", map[uint64]interface{}{
   728  		1: map[string]interface{}{
   729  			"expiry-time": "2001-03-11T11:24:00Z",
   730  		},
   731  		2: map[string]interface{}{
   732  			"expiry-time": "2037-02-12T12:50:00Z",
   733  		},
   734  	})
   735  
   736  	st.Unlock()
   737  	c.Assert(snapshotstate.DoForget(task, &tomb.Tomb{}), check.IsNil)
   738  
   739  	st.Lock()
   740  	var expirations map[uint64]interface{}
   741  	c.Assert(st.Get("snapshots", &expirations), check.IsNil)
   742  	c.Check(expirations, check.DeepEquals, map[uint64]interface{}{
   743  		2: map[string]interface{}{
   744  			"expiry-time": "2037-02-12T12:50:00Z",
   745  		}})
   746  }