github.com/stolowski/snapd@v0.0.0-20210407085831-115137ce5a22/daemon/api_snapshots_test.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2018 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package daemon_test
    21  
    22  import (
    23  	"bytes"
    24  	"context"
    25  	"errors"
    26  	"fmt"
    27  	"io"
    28  	"io/ioutil"
    29  	"net/http"
    30  	"strconv"
    31  	"strings"
    32  
    33  	"gopkg.in/check.v1"
    34  
    35  	"github.com/snapcore/snapd/client"
    36  	"github.com/snapcore/snapd/daemon"
    37  	"github.com/snapcore/snapd/overlord/snapshotstate"
    38  	"github.com/snapcore/snapd/overlord/state"
    39  )
    40  
    41  var _ = check.Suite(&snapshotSuite{})
    42  
    43  type snapshotSuite struct {
    44  	apiBaseSuite
    45  }
    46  
    47  func (s *snapshotSuite) SetUpTest(c *check.C) {
    48  	s.apiBaseSuite.SetUpTest(c)
    49  	s.daemonWithOverlordMock(c)
    50  }
    51  
    52  func (s *snapshotSuite) TestSnapshotMany(c *check.C) {
    53  	defer daemon.MockSnapshotSave(func(s *state.State, snaps, users []string) (uint64, []string, *state.TaskSet, error) {
    54  		c.Check(snaps, check.HasLen, 2)
    55  		t := s.NewTask("fake-snapshot-2", "Snapshot two")
    56  		return 1, snaps, state.NewTaskSet(t), nil
    57  	})()
    58  
    59  	inst := daemon.MustUnmarshalSnapInstruction(c, `{"action": "snapshot", "snaps": ["foo", "bar"]}`)
    60  	st := s.d.Overlord().State()
    61  	st.Lock()
    62  	res, err := inst.DispatchForMany()(inst, st)
    63  	st.Unlock()
    64  	c.Assert(err, check.IsNil)
    65  	c.Check(res.Summary, check.Equals, `Snapshot snaps "foo", "bar"`)
    66  	c.Check(res.Affected, check.DeepEquals, inst.Snaps)
    67  }
    68  
    69  func (s *snapshotSuite) TestListSnapshots(c *check.C) {
    70  	snapshots := []client.SnapshotSet{{ID: 1}, {ID: 42}}
    71  
    72  	defer daemon.MockSnapshotList(func(context.Context, *state.State, uint64, []string) ([]client.SnapshotSet, error) {
    73  		return snapshots, nil
    74  	})()
    75  
    76  	req, err := http.NewRequest("GET", "/v2/snapshots", nil)
    77  	c.Assert(err, check.IsNil)
    78  
    79  	rsp := s.syncReq(c, req, nil)
    80  	c.Check(rsp.Status, check.Equals, 200)
    81  	c.Check(rsp.Result, check.DeepEquals, snapshots)
    82  }
    83  
    84  func (s *snapshotSuite) TestListSnapshotsFiltering(c *check.C) {
    85  	snapshots := []client.SnapshotSet{{ID: 1}, {ID: 42}}
    86  
    87  	defer daemon.MockSnapshotList(func(_ context.Context, st *state.State, setID uint64, _ []string) ([]client.SnapshotSet, error) {
    88  		c.Assert(setID, check.Equals, uint64(42))
    89  		return snapshots[1:], nil
    90  	})()
    91  
    92  	req, err := http.NewRequest("GET", "/v2/snapshots?set=42", nil)
    93  	c.Assert(err, check.IsNil)
    94  
    95  	rsp := s.syncReq(c, req, nil)
    96  	c.Check(rsp.Status, check.Equals, 200)
    97  	c.Check(rsp.Result, check.DeepEquals, []client.SnapshotSet{{ID: 42}})
    98  }
    99  
   100  func (s *snapshotSuite) TestListSnapshotsBadFiltering(c *check.C) {
   101  	defer daemon.MockSnapshotList(func(_ context.Context, _ *state.State, setID uint64, _ []string) ([]client.SnapshotSet, error) {
   102  		c.Fatal("snapshotList should not be reached (should have been blocked by validation!)")
   103  		return nil, nil
   104  	})()
   105  
   106  	req, err := http.NewRequest("GET", "/v2/snapshots?set=no", nil)
   107  	c.Assert(err, check.IsNil)
   108  
   109  	rsp := s.errorReq(c, req, nil)
   110  	c.Check(rsp.Status, check.Equals, 400)
   111  	c.Check(rsp.ErrorResult().Message, check.Equals, `'set', if given, must be a positive base 10 number; got "no"`)
   112  }
   113  
   114  func (s *snapshotSuite) TestListSnapshotsListError(c *check.C) {
   115  	defer daemon.MockSnapshotList(func(_ context.Context, _ *state.State, setID uint64, _ []string) ([]client.SnapshotSet, error) {
   116  		return nil, errors.New("no")
   117  	})()
   118  
   119  	req, err := http.NewRequest("GET", "/v2/snapshots", nil)
   120  	c.Assert(err, check.IsNil)
   121  
   122  	rsp := s.errorReq(c, req, nil)
   123  	c.Check(rsp.Status, check.Equals, 500)
   124  	c.Check(rsp.ErrorResult().Message, check.Equals, "no")
   125  }
   126  
   127  func (s *snapshotSuite) TestFormatSnapshotAction(c *check.C) {
   128  	type table struct {
   129  		action   string
   130  		expected string
   131  	}
   132  	tests := []table{
   133  		{
   134  			`{"set": 2, "action": "verb"}`,
   135  			`Verb of snapshot set #2`,
   136  		}, {
   137  			`{"set": 2, "action": "verb", "snaps": ["foo"]}`,
   138  			`Verb of snapshot set #2 for snaps "foo"`,
   139  		}, {
   140  			`{"set": 2, "action": "verb", "snaps": ["foo", "bar"]}`,
   141  			`Verb of snapshot set #2 for snaps "foo", "bar"`,
   142  		}, {
   143  			`{"set": 2, "action": "verb", "users": ["meep"]}`,
   144  			`Verb of snapshot set #2 for users "meep"`,
   145  		}, {
   146  			`{"set": 2, "action": "verb", "users": ["meep", "quux"]}`,
   147  			`Verb of snapshot set #2 for users "meep", "quux"`,
   148  		}, {
   149  			`{"set": 2, "action": "verb", "users": ["meep", "quux"], "snaps": ["foo", "bar"]}`,
   150  			`Verb of snapshot set #2 for snaps "foo", "bar" for users "meep", "quux"`,
   151  		},
   152  	}
   153  
   154  	for _, test := range tests {
   155  		action := daemon.MustUnmarshalSnapshotAction(c, test.action)
   156  		c.Check(action.String(), check.Equals, test.expected)
   157  	}
   158  }
   159  
   160  func (s *snapshotSuite) TestChangeSnapshots400(c *check.C) {
   161  	type table struct{ body, error string }
   162  	tests := []table{
   163  		{
   164  			body:  `"woodchucks`,
   165  			error: "cannot decode request body into snapshot operation:.*",
   166  		}, {
   167  			body:  `{}"woodchucks`,
   168  			error: "extra content found after snapshot operation",
   169  		}, {
   170  			body:  `{}`,
   171  			error: "snapshot operation requires snapshot set ID",
   172  		}, {
   173  			body:  `{"set": 42}`,
   174  			error: "snapshot operation requires action",
   175  		}, {
   176  			body:  `{"set": 42, "action": "carrots"}`,
   177  			error: `unknown snapshot operation "carrots"`,
   178  		}, {
   179  			body:  `{"set": 42, "action": "forget", "users": ["foo"]}`,
   180  			error: `snapshot "forget" operation cannot specify users`,
   181  		},
   182  	}
   183  
   184  	for i, test := range tests {
   185  		comm := check.Commentf("%d:%q", i, test.body)
   186  		req, err := http.NewRequest("POST", "/v2/snapshots", strings.NewReader(test.body))
   187  		c.Assert(err, check.IsNil, comm)
   188  
   189  		rsp := s.errorReq(c, req, nil)
   190  		c.Check(rsp.Status, check.Equals, 400, comm)
   191  		c.Check(rsp.ErrorResult().Message, check.Matches, test.error, comm)
   192  	}
   193  }
   194  
   195  func (s *snapshotSuite) TestChangeSnapshots404(c *check.C) {
   196  	var done string
   197  	expectedError := errors.New("bzzt")
   198  	defer daemon.MockSnapshotCheck(func(*state.State, uint64, []string, []string) ([]string, *state.TaskSet, error) {
   199  		done = "check"
   200  		return nil, nil, expectedError
   201  	})()
   202  	defer daemon.MockSnapshotRestore(func(*state.State, uint64, []string, []string) ([]string, *state.TaskSet, error) {
   203  		done = "restore"
   204  		return nil, nil, expectedError
   205  	})()
   206  	defer daemon.MockSnapshotForget(func(*state.State, uint64, []string) ([]string, *state.TaskSet, error) {
   207  		done = "forget"
   208  		return nil, nil, expectedError
   209  	})()
   210  	for _, expectedError = range []error{client.ErrSnapshotSetNotFound, client.ErrSnapshotSnapsNotFound} {
   211  		for _, action := range []string{"check", "restore", "forget"} {
   212  			done = ""
   213  			comm := check.Commentf("%s/%s", action, expectedError)
   214  			body := fmt.Sprintf(`{"set": 42, "action": "%s"}`, action)
   215  			req, err := http.NewRequest("POST", "/v2/snapshots", strings.NewReader(body))
   216  			c.Assert(err, check.IsNil, comm)
   217  
   218  			rsp := s.errorReq(c, req, nil)
   219  			c.Check(rsp.Status, check.Equals, 404, comm)
   220  			c.Check(rsp.ErrorResult().Message, check.Matches, expectedError.Error(), comm)
   221  			c.Check(done, check.Equals, action, comm)
   222  		}
   223  	}
   224  }
   225  
   226  func (s *snapshotSuite) TestChangeSnapshots500(c *check.C) {
   227  	var done string
   228  	expectedError := errors.New("bzzt")
   229  	defer daemon.MockSnapshotCheck(func(*state.State, uint64, []string, []string) ([]string, *state.TaskSet, error) {
   230  		done = "check"
   231  		return nil, nil, expectedError
   232  	})()
   233  	defer daemon.MockSnapshotRestore(func(*state.State, uint64, []string, []string) ([]string, *state.TaskSet, error) {
   234  		done = "restore"
   235  		return nil, nil, expectedError
   236  	})()
   237  	defer daemon.MockSnapshotForget(func(*state.State, uint64, []string) ([]string, *state.TaskSet, error) {
   238  		done = "forget"
   239  		return nil, nil, expectedError
   240  	})()
   241  	for _, action := range []string{"check", "restore", "forget"} {
   242  		comm := check.Commentf("%s", action)
   243  		body := fmt.Sprintf(`{"set": 42, "action": "%s"}`, action)
   244  		req, err := http.NewRequest("POST", "/v2/snapshots", strings.NewReader(body))
   245  		c.Assert(err, check.IsNil, comm)
   246  
   247  		rsp := s.errorReq(c, req, nil)
   248  		c.Check(rsp.Status, check.Equals, 500, comm)
   249  		c.Check(rsp.ErrorResult().Message, check.Matches, expectedError.Error(), comm)
   250  		c.Check(done, check.Equals, action, comm)
   251  	}
   252  }
   253  
   254  func (s *snapshotSuite) TestChangeSnapshot(c *check.C) {
   255  	var done string
   256  	defer daemon.MockSnapshotCheck(func(*state.State, uint64, []string, []string) ([]string, *state.TaskSet, error) {
   257  		done = "check"
   258  		return []string{"foo"}, state.NewTaskSet(), nil
   259  	})()
   260  	defer daemon.MockSnapshotRestore(func(*state.State, uint64, []string, []string) ([]string, *state.TaskSet, error) {
   261  		done = "restore"
   262  		return []string{"foo"}, state.NewTaskSet(), nil
   263  	})()
   264  	defer daemon.MockSnapshotForget(func(*state.State, uint64, []string) ([]string, *state.TaskSet, error) {
   265  		done = "forget"
   266  		return []string{"foo"}, state.NewTaskSet(), nil
   267  	})()
   268  
   269  	st := s.d.Overlord().State()
   270  	st.Lock()
   271  	defer st.Unlock()
   272  	for _, action := range []string{"check", "restore", "forget"} {
   273  		comm := check.Commentf("%s", action)
   274  		body := fmt.Sprintf(`{"set": 42, "action": "%s"}`, action)
   275  		req, err := http.NewRequest("POST", "/v2/snapshots", strings.NewReader(body))
   276  
   277  		c.Assert(err, check.IsNil, comm)
   278  
   279  		st.Unlock()
   280  		rsp := s.asyncReq(c, req, nil)
   281  		st.Lock()
   282  
   283  		c.Check(rsp.Status, check.Equals, 202, comm)
   284  		c.Check(done, check.Equals, action, comm)
   285  
   286  		chg := st.Change(rsp.Change)
   287  		c.Assert(chg, check.NotNil)
   288  		c.Assert(chg.Tasks(), check.HasLen, 0)
   289  
   290  		c.Check(chg.Kind(), check.Equals, action+"-snapshot")
   291  		var apiData map[string]interface{}
   292  		err = chg.Get("api-data", &apiData)
   293  		c.Assert(err, check.IsNil)
   294  		c.Check(apiData, check.DeepEquals, map[string]interface{}{
   295  			"snap-names": []interface{}{"foo"},
   296  		})
   297  
   298  	}
   299  }
   300  
   301  func (s *snapshotSuite) TestExportSnapshots(c *check.C) {
   302  	var snapshotExportCalled int
   303  
   304  	defer daemon.MockSnapshotExport(func(ctx context.Context, st *state.State, setID uint64) (*snapshotstate.SnapshotExport, error) {
   305  		snapshotExportCalled++
   306  		c.Check(setID, check.Equals, uint64(1))
   307  		return &snapshotstate.SnapshotExport{}, nil
   308  	})()
   309  
   310  	req, err := http.NewRequest("GET", "/v2/snapshots/1/export", nil)
   311  	c.Assert(err, check.IsNil)
   312  
   313  	rsp := s.req(c, req, nil)
   314  	c.Check(rsp, check.FitsTypeOf, &daemon.SnapshotExportResponse{})
   315  	c.Check(snapshotExportCalled, check.Equals, 1)
   316  }
   317  
   318  func (s *snapshotSuite) TestExportSnapshotsBadRequestOnNonNumericID(c *check.C) {
   319  	req, err := http.NewRequest("GET", "/v2/snapshots/xxx/export", nil)
   320  	c.Assert(err, check.IsNil)
   321  
   322  	rsp := s.errorReq(c, req, nil)
   323  	c.Check(rsp.Status, check.Equals, 400)
   324  	c.Check(rsp.Result, check.DeepEquals, &daemon.ErrorResult{Message: `'id' must be a positive base 10 number; got "xxx"`})
   325  }
   326  
   327  func (s *snapshotSuite) TestExportSnapshotsBadRequestOnError(c *check.C) {
   328  	var snapshotExportCalled int
   329  
   330  	defer daemon.MockSnapshotExport(func(ctx context.Context, st *state.State, setID uint64) (*snapshotstate.SnapshotExport, error) {
   331  		snapshotExportCalled++
   332  		return nil, fmt.Errorf("boom")
   333  	})()
   334  
   335  	req, err := http.NewRequest("GET", "/v2/snapshots/1/export", nil)
   336  	c.Assert(err, check.IsNil)
   337  
   338  	rsp := s.errorReq(c, req, nil)
   339  	c.Check(rsp.Status, check.Equals, 400)
   340  	c.Check(rsp.Result, check.DeepEquals, &daemon.ErrorResult{Message: `cannot export 1: boom`})
   341  	c.Check(snapshotExportCalled, check.Equals, 1)
   342  }
   343  
   344  func (s *snapshotSuite) TestImportSnapshot(c *check.C) {
   345  	data := []byte("mocked snapshot export data file")
   346  
   347  	setID := uint64(3)
   348  	snapNames := []string{"baz", "bar", "foo"}
   349  	defer daemon.MockSnapshotImport(func(context.Context, *state.State, io.Reader) (uint64, []string, error) {
   350  		return setID, snapNames, nil
   351  	})()
   352  
   353  	req, err := http.NewRequest("POST", "/v2/snapshots", bytes.NewReader(data))
   354  	req.Header.Add("Content-Length", strconv.Itoa(len(data)))
   355  	c.Assert(err, check.IsNil)
   356  	req.Header.Set("Content-Type", client.SnapshotExportMediaType)
   357  
   358  	rsp := s.syncReq(c, req, nil)
   359  	c.Check(rsp.Status, check.Equals, 200)
   360  	c.Check(rsp.Result, check.DeepEquals, map[string]interface{}{"set-id": setID, "snaps": snapNames})
   361  }
   362  
   363  func (s *snapshotSuite) TestImportSnapshotError(c *check.C) {
   364  	defer daemon.MockSnapshotImport(func(context.Context, *state.State, io.Reader) (uint64, []string, error) {
   365  		return uint64(0), nil, errors.New("no")
   366  	})()
   367  
   368  	data := []byte("mocked snapshot export data file")
   369  	req, err := http.NewRequest("POST", "/v2/snapshots", bytes.NewReader(data))
   370  	req.Header.Add("Content-Length", strconv.Itoa(len(data)))
   371  	c.Assert(err, check.IsNil)
   372  	req.Header.Set("Content-Type", client.SnapshotExportMediaType)
   373  
   374  	rsp := s.errorReq(c, req, nil)
   375  	c.Check(rsp.Status, check.Equals, 400)
   376  	c.Check(rsp.ErrorResult().Message, check.Equals, "no")
   377  }
   378  
   379  func (s *snapshotSuite) TestImportSnapshotNoContentLengthError(c *check.C) {
   380  	data := []byte("mocked snapshot export data file")
   381  	req, err := http.NewRequest("POST", "/v2/snapshots", bytes.NewReader(data))
   382  	c.Assert(err, check.IsNil)
   383  	req.Header.Set("Content-Type", client.SnapshotExportMediaType)
   384  
   385  	rsp := s.errorReq(c, req, nil)
   386  	c.Check(rsp.Status, check.Equals, 400)
   387  	c.Check(rsp.ErrorResult().Message, check.Equals, `cannot parse Content-Length: strconv.ParseInt: parsing "": invalid syntax`)
   388  }
   389  
   390  func (s *snapshotSuite) TestImportSnapshotLimits(c *check.C) {
   391  	var dataRead int
   392  
   393  	defer daemon.MockSnapshotImport(func(ctx context.Context, st *state.State, r io.Reader) (uint64, []string, error) {
   394  		data, err := ioutil.ReadAll(r)
   395  		c.Assert(err, check.IsNil)
   396  		dataRead = len(data)
   397  		return uint64(0), nil, nil
   398  	})()
   399  
   400  	data := []byte("much more data than expected from Content-Length")
   401  	req, err := http.NewRequest("POST", "/v2/snapshots", bytes.NewReader(data))
   402  	// limit to 10 and check that this is really all that is read
   403  	req.Header.Add("Content-Length", "10")
   404  	c.Assert(err, check.IsNil)
   405  	req.Header.Set("Content-Type", client.SnapshotExportMediaType)
   406  
   407  	rsp := s.syncReq(c, req, nil)
   408  	c.Check(rsp.Status, check.Equals, 200)
   409  	c.Check(dataRead, check.Equals, 10)
   410  }