github.com/hugh712/snapd@v0.0.0-20200910133618-1a99902bd583/daemon/api_snapshots_test.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2018 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package daemon_test
    21  
    22  import (
    23  	"context"
    24  	"errors"
    25  	"fmt"
    26  	"net/http"
    27  	"strings"
    28  
    29  	"gopkg.in/check.v1"
    30  
    31  	"github.com/snapcore/snapd/client"
    32  	"github.com/snapcore/snapd/daemon"
    33  	"github.com/snapcore/snapd/dirs"
    34  	"github.com/snapcore/snapd/overlord"
    35  	"github.com/snapcore/snapd/overlord/assertstate"
    36  	"github.com/snapcore/snapd/overlord/snapshotstate"
    37  	"github.com/snapcore/snapd/overlord/snapstate"
    38  	"github.com/snapcore/snapd/overlord/state"
    39  	"github.com/snapcore/snapd/store/storetest"
    40  )
    41  
    42  var _ = check.Suite(&snapshotSuite{})
    43  
    44  type snapshotSuite struct {
    45  	d *daemon.Daemon
    46  	o *overlord.Overlord
    47  }
    48  
    49  func (s *snapshotSuite) SetUpTest(c *check.C) {
    50  	s.o = overlord.Mock()
    51  	s.d = daemon.NewWithOverlord(s.o)
    52  
    53  	st := s.o.State()
    54  	// adds an assertion db
    55  	assertstate.Manager(st, s.o.TaskRunner())
    56  	st.Lock()
    57  	defer st.Unlock()
    58  	snapstate.ReplaceStore(st, storetest.Store{})
    59  	dirs.SetRootDir(c.MkDir())
    60  }
    61  
    62  func (s *snapshotSuite) TearDownTest(c *check.C) {
    63  	s.o = nil
    64  	s.d = nil
    65  }
    66  
    67  func (s *snapshotSuite) TestSnapshotMany(c *check.C) {
    68  	defer daemon.MockSnapshotSave(func(s *state.State, snaps, users []string) (uint64, []string, *state.TaskSet, error) {
    69  		c.Check(snaps, check.HasLen, 2)
    70  		t := s.NewTask("fake-snapshot-2", "Snapshot two")
    71  		return 1, snaps, state.NewTaskSet(t), nil
    72  	})()
    73  
    74  	inst := daemon.MustUnmarshalSnapInstruction(c, `{"action": "snapshot", "snaps": ["foo", "bar"]}`)
    75  	st := s.o.State()
    76  	st.Lock()
    77  	res, err := daemon.SnapshotMany(inst, st)
    78  	st.Unlock()
    79  	c.Assert(err, check.IsNil)
    80  	c.Check(res.Summary, check.Equals, `Snapshot snaps "foo", "bar"`)
    81  	c.Check(res.Affected, check.DeepEquals, inst.Snaps)
    82  }
    83  
    84  func (s *snapshotSuite) TestListSnapshots(c *check.C) {
    85  	snapshots := []client.SnapshotSet{{ID: 1}, {ID: 42}}
    86  
    87  	defer daemon.MockSnapshotList(func(context.Context, uint64, []string) ([]client.SnapshotSet, error) {
    88  		return snapshots, nil
    89  	})()
    90  
    91  	c.Check(daemon.SnapshotCmd.Path, check.Equals, "/v2/snapshots")
    92  	req, err := http.NewRequest("GET", "/v2/snapshots", nil)
    93  	c.Assert(err, check.IsNil)
    94  
    95  	rsp := daemon.ListSnapshots(daemon.SnapshotCmd, req, nil)
    96  	c.Check(rsp.Type, check.Equals, daemon.ResponseTypeSync)
    97  	c.Check(rsp.Status, check.Equals, 200)
    98  	c.Check(rsp.Result, check.DeepEquals, snapshots)
    99  }
   100  
   101  func (s *snapshotSuite) TestListSnapshotsFiltering(c *check.C) {
   102  	snapshots := []client.SnapshotSet{{ID: 1}, {ID: 42}}
   103  
   104  	defer daemon.MockSnapshotList(func(_ context.Context, setID uint64, _ []string) ([]client.SnapshotSet, error) {
   105  		c.Assert(setID, check.Equals, uint64(42))
   106  		return snapshots[1:], nil
   107  	})()
   108  
   109  	req, err := http.NewRequest("GET", "/v2/snapshots?set=42", nil)
   110  	c.Assert(err, check.IsNil)
   111  
   112  	rsp := daemon.ListSnapshots(daemon.SnapshotCmd, req, nil)
   113  	c.Check(rsp.Type, check.Equals, daemon.ResponseTypeSync)
   114  	c.Check(rsp.Status, check.Equals, 200)
   115  	c.Check(rsp.Result, check.DeepEquals, []client.SnapshotSet{{ID: 42}})
   116  }
   117  
   118  func (s *snapshotSuite) TestListSnapshotsBadFiltering(c *check.C) {
   119  	defer daemon.MockSnapshotList(func(_ context.Context, setID uint64, _ []string) ([]client.SnapshotSet, error) {
   120  		c.Fatal("snapshotList should not be reached (should have been blocked by validation!)")
   121  		return nil, nil
   122  	})()
   123  
   124  	req, err := http.NewRequest("GET", "/v2/snapshots?set=no", nil)
   125  	c.Assert(err, check.IsNil)
   126  
   127  	rsp := daemon.ListSnapshots(daemon.SnapshotCmd, req, nil)
   128  	c.Assert(rsp.Type, check.Equals, daemon.ResponseTypeError)
   129  	c.Check(rsp.Status, check.Equals, 400)
   130  	c.Check(rsp.ErrorResult().Message, check.Equals, `'set', if given, must be a positive base 10 number; got "no"`)
   131  }
   132  
   133  func (s *snapshotSuite) TestListSnapshotsListError(c *check.C) {
   134  	defer daemon.MockSnapshotList(func(_ context.Context, setID uint64, _ []string) ([]client.SnapshotSet, error) {
   135  		return nil, errors.New("no")
   136  	})()
   137  
   138  	c.Check(daemon.SnapshotCmd.Path, check.Equals, "/v2/snapshots")
   139  	req, err := http.NewRequest("GET", "/v2/snapshots", nil)
   140  	c.Assert(err, check.IsNil)
   141  
   142  	rsp := daemon.ListSnapshots(daemon.SnapshotCmd, req, nil)
   143  	c.Assert(rsp.Type, check.Equals, daemon.ResponseTypeError)
   144  	c.Check(rsp.Status, check.Equals, 500)
   145  	c.Check(rsp.ErrorResult().Message, check.Equals, "no")
   146  }
   147  
   148  func (s *snapshotSuite) TestFormatSnapshotAction(c *check.C) {
   149  	type table struct {
   150  		action   string
   151  		expected string
   152  	}
   153  	tests := []table{
   154  		{
   155  			`{"set": 2, "action": "verb"}`,
   156  			`Verb of snapshot set #2`,
   157  		}, {
   158  			`{"set": 2, "action": "verb", "snaps": ["foo"]}`,
   159  			`Verb of snapshot set #2 for snaps "foo"`,
   160  		}, {
   161  			`{"set": 2, "action": "verb", "snaps": ["foo", "bar"]}`,
   162  			`Verb of snapshot set #2 for snaps "foo", "bar"`,
   163  		}, {
   164  			`{"set": 2, "action": "verb", "users": ["meep"]}`,
   165  			`Verb of snapshot set #2 for users "meep"`,
   166  		}, {
   167  			`{"set": 2, "action": "verb", "users": ["meep", "quux"]}`,
   168  			`Verb of snapshot set #2 for users "meep", "quux"`,
   169  		}, {
   170  			`{"set": 2, "action": "verb", "users": ["meep", "quux"], "snaps": ["foo", "bar"]}`,
   171  			`Verb of snapshot set #2 for snaps "foo", "bar" for users "meep", "quux"`,
   172  		},
   173  	}
   174  
   175  	for _, test := range tests {
   176  		action := daemon.MustUnmarshalSnapshotAction(c, test.action)
   177  		c.Check(action.String(), check.Equals, test.expected)
   178  	}
   179  }
   180  
   181  func (s *snapshotSuite) TestChangeSnapshots400(c *check.C) {
   182  	type table struct{ body, error string }
   183  	tests := []table{
   184  		{
   185  			body:  `"woodchucks`,
   186  			error: "cannot decode request body into snapshot operation:.*",
   187  		}, {
   188  			body:  `{}"woodchucks`,
   189  			error: "extra content found after snapshot operation",
   190  		}, {
   191  			body:  `{}`,
   192  			error: "snapshot operation requires snapshot set ID",
   193  		}, {
   194  			body:  `{"set": 42}`,
   195  			error: "snapshot operation requires action",
   196  		}, {
   197  			body:  `{"set": 42, "action": "carrots"}`,
   198  			error: `unknown snapshot operation "carrots"`,
   199  		}, {
   200  			body:  `{"set": 42, "action": "forget", "users": ["foo"]}`,
   201  			error: `snapshot "forget" operation cannot specify users`,
   202  		},
   203  	}
   204  
   205  	for i, test := range tests {
   206  		comm := check.Commentf("%d:%q", i, test.body)
   207  		req, err := http.NewRequest("POST", "/v2/snapshots", strings.NewReader(test.body))
   208  		c.Assert(err, check.IsNil, comm)
   209  
   210  		rsp := daemon.ChangeSnapshots(daemon.SnapshotCmd, req, nil)
   211  		c.Check(rsp.Type, check.Equals, daemon.ResponseTypeError, comm)
   212  		c.Check(rsp.Status, check.Equals, 400, comm)
   213  		c.Check(rsp.ErrorResult().Message, check.Matches, test.error, comm)
   214  	}
   215  }
   216  
   217  func (s *snapshotSuite) TestChangeSnapshots404(c *check.C) {
   218  	var done string
   219  	expectedError := errors.New("bzzt")
   220  	defer daemon.MockSnapshotCheck(func(*state.State, uint64, []string, []string) ([]string, *state.TaskSet, error) {
   221  		done = "check"
   222  		return nil, nil, expectedError
   223  	})()
   224  	defer daemon.MockSnapshotRestore(func(*state.State, uint64, []string, []string) ([]string, *state.TaskSet, error) {
   225  		done = "restore"
   226  		return nil, nil, expectedError
   227  	})()
   228  	defer daemon.MockSnapshotForget(func(*state.State, uint64, []string) ([]string, *state.TaskSet, error) {
   229  		done = "forget"
   230  		return nil, nil, expectedError
   231  	})()
   232  	for _, expectedError = range []error{client.ErrSnapshotSetNotFound, client.ErrSnapshotSnapsNotFound} {
   233  		for _, action := range []string{"check", "restore", "forget"} {
   234  			done = ""
   235  			comm := check.Commentf("%s/%s", action, expectedError)
   236  			body := fmt.Sprintf(`{"set": 42, "action": "%s"}`, action)
   237  			req, err := http.NewRequest("POST", "/v2/snapshots", strings.NewReader(body))
   238  			c.Assert(err, check.IsNil, comm)
   239  
   240  			rsp := daemon.ChangeSnapshots(daemon.SnapshotCmd, req, nil)
   241  			c.Check(rsp.Type, check.Equals, daemon.ResponseTypeError, comm)
   242  			c.Check(rsp.Status, check.Equals, 404, comm)
   243  			c.Check(rsp.ErrorResult().Message, check.Matches, expectedError.Error(), comm)
   244  			c.Check(done, check.Equals, action, comm)
   245  		}
   246  	}
   247  }
   248  
   249  func (s *snapshotSuite) TestChangeSnapshots500(c *check.C) {
   250  	var done string
   251  	expectedError := errors.New("bzzt")
   252  	defer daemon.MockSnapshotCheck(func(*state.State, uint64, []string, []string) ([]string, *state.TaskSet, error) {
   253  		done = "check"
   254  		return nil, nil, expectedError
   255  	})()
   256  	defer daemon.MockSnapshotRestore(func(*state.State, uint64, []string, []string) ([]string, *state.TaskSet, error) {
   257  		done = "restore"
   258  		return nil, nil, expectedError
   259  	})()
   260  	defer daemon.MockSnapshotForget(func(*state.State, uint64, []string) ([]string, *state.TaskSet, error) {
   261  		done = "forget"
   262  		return nil, nil, expectedError
   263  	})()
   264  	for _, action := range []string{"check", "restore", "forget"} {
   265  		comm := check.Commentf("%s", action)
   266  		body := fmt.Sprintf(`{"set": 42, "action": "%s"}`, action)
   267  		req, err := http.NewRequest("POST", "/v2/snapshots", strings.NewReader(body))
   268  		c.Assert(err, check.IsNil, comm)
   269  
   270  		rsp := daemon.ChangeSnapshots(daemon.SnapshotCmd, req, nil)
   271  		c.Check(rsp.Type, check.Equals, daemon.ResponseTypeError, comm)
   272  		c.Check(rsp.Status, check.Equals, 500, comm)
   273  		c.Check(rsp.ErrorResult().Message, check.Matches, expectedError.Error(), comm)
   274  		c.Check(done, check.Equals, action, comm)
   275  	}
   276  }
   277  
   278  func (s *snapshotSuite) TestChangeSnapshot(c *check.C) {
   279  	var done string
   280  	defer daemon.MockSnapshotCheck(func(*state.State, uint64, []string, []string) ([]string, *state.TaskSet, error) {
   281  		done = "check"
   282  		return []string{"foo"}, state.NewTaskSet(), nil
   283  	})()
   284  	defer daemon.MockSnapshotRestore(func(*state.State, uint64, []string, []string) ([]string, *state.TaskSet, error) {
   285  		done = "restore"
   286  		return []string{"foo"}, state.NewTaskSet(), nil
   287  	})()
   288  	defer daemon.MockSnapshotForget(func(*state.State, uint64, []string) ([]string, *state.TaskSet, error) {
   289  		done = "forget"
   290  		return []string{"foo"}, state.NewTaskSet(), nil
   291  	})()
   292  
   293  	st := s.o.State()
   294  	st.Lock()
   295  	defer st.Unlock()
   296  	for _, action := range []string{"check", "restore", "forget"} {
   297  		comm := check.Commentf("%s", action)
   298  		body := fmt.Sprintf(`{"set": 42, "action": "%s"}`, action)
   299  		req, err := http.NewRequest("POST", "/v2/snapshots", strings.NewReader(body))
   300  
   301  		c.Assert(err, check.IsNil, comm)
   302  
   303  		st.Unlock()
   304  		rsp := daemon.ChangeSnapshots(daemon.SnapshotCmd, req, nil)
   305  		st.Lock()
   306  
   307  		c.Check(rsp.Type, check.Equals, daemon.ResponseTypeAsync, comm)
   308  		c.Check(rsp.Status, check.Equals, 202, comm)
   309  		c.Check(done, check.Equals, action, comm)
   310  
   311  		chg := st.Change(rsp.Change)
   312  		c.Assert(chg, check.NotNil)
   313  		c.Assert(chg.Tasks(), check.HasLen, 0)
   314  
   315  		c.Check(chg.Kind(), check.Equals, action+"-snapshot")
   316  		var apiData map[string]interface{}
   317  		err = chg.Get("api-data", &apiData)
   318  		c.Assert(err, check.IsNil)
   319  		c.Check(apiData, check.DeepEquals, map[string]interface{}{
   320  			"snap-names": []interface{}{"foo"},
   321  		})
   322  
   323  	}
   324  }
   325  
   326  func (s *snapshotSuite) TestExportSnapshots(c *check.C) {
   327  	var snapshotExportCalled int
   328  
   329  	defer daemon.MockMuxVars(func(*http.Request) map[string]string {
   330  		return map[string]string{"id": "1"}
   331  	})()
   332  	defer daemon.MockSnapshotExport(func(ctx context.Context, setID uint64) (*snapshotstate.SnapshotExport, error) {
   333  		snapshotExportCalled++
   334  		c.Check(setID, check.Equals, uint64(1))
   335  		return &snapshotstate.SnapshotExport{}, nil
   336  	})()
   337  
   338  	c.Check(daemon.SnapshotExportCmd.Path, check.Equals, "/v2/snapshots/{id}/export")
   339  	req, err := http.NewRequest("GET", "/v2/snapshots/1/export", nil)
   340  	c.Assert(err, check.IsNil)
   341  
   342  	rsp := daemon.ExportSnapshot(daemon.SnapshotExportCmd, req, nil)
   343  	c.Check(rsp, check.FitsTypeOf, &daemon.SnapshotExportResponse{})
   344  	c.Check(snapshotExportCalled, check.Equals, 1)
   345  }
   346  
   347  func (s *snapshotSuite) TestExportSnapshotsBadRequestOnNonNumericID(c *check.C) {
   348  	defer daemon.MockMuxVars(func(*http.Request) map[string]string {
   349  		return map[string]string{"id": "xxx"}
   350  	})()
   351  
   352  	c.Check(daemon.SnapshotExportCmd.Path, check.Equals, "/v2/snapshots/{id}/export")
   353  	req, err := http.NewRequest("GET", "/v2/snapshots/xxx/export", nil)
   354  	c.Assert(err, check.IsNil)
   355  
   356  	rsp := daemon.ExportSnapshot(daemon.SnapshotExportCmd, req, nil).(*daemon.Resp)
   357  	c.Check(rsp.Type, check.Equals, daemon.ResponseTypeError)
   358  	c.Check(rsp.Status, check.Equals, 400)
   359  	c.Check(rsp.Result, check.DeepEquals, &daemon.ErrorResult{Message: `'id' must be a positive base 10 number; got "xxx"`})
   360  }
   361  
   362  func (s *snapshotSuite) TestExportSnapshotsBadRequestOnError(c *check.C) {
   363  	var snapshotExportCalled int
   364  
   365  	defer daemon.MockMuxVars(func(*http.Request) map[string]string {
   366  		return map[string]string{"id": "1"}
   367  	})()
   368  	defer daemon.MockSnapshotExport(func(ctx context.Context, setID uint64) (*snapshotstate.SnapshotExport, error) {
   369  		snapshotExportCalled++
   370  		return nil, fmt.Errorf("boom")
   371  	})()
   372  
   373  	c.Check(daemon.SnapshotExportCmd.Path, check.Equals, "/v2/snapshots/{id}/export")
   374  	req, err := http.NewRequest("GET", "/v2/snapshots/1/export", nil)
   375  	c.Assert(err, check.IsNil)
   376  
   377  	rsp := daemon.ExportSnapshot(daemon.SnapshotExportCmd, req, nil).(*daemon.Resp)
   378  	c.Check(rsp.Type, check.Equals, daemon.ResponseTypeError)
   379  	c.Check(rsp.Status, check.Equals, 400)
   380  	c.Check(rsp.Result, check.DeepEquals, &daemon.ErrorResult{Message: `cannot export 1: boom`})
   381  	c.Check(snapshotExportCalled, check.Equals, 1)
   382  }