github.com/anonymouse64/snapd@v0.0.0-20210824153203-04c4c42d842d/daemon/api_snapshots_test.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2018 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package daemon_test
    21  
    22  import (
    23  	"bytes"
    24  	"context"
    25  	"errors"
    26  	"fmt"
    27  	"io"
    28  	"io/ioutil"
    29  	"net/http"
    30  	"strconv"
    31  	"strings"
    32  
    33  	"gopkg.in/check.v1"
    34  
    35  	"github.com/snapcore/snapd/client"
    36  	"github.com/snapcore/snapd/daemon"
    37  	"github.com/snapcore/snapd/overlord/snapshotstate"
    38  	"github.com/snapcore/snapd/overlord/state"
    39  )
    40  
    41  var _ = check.Suite(&snapshotSuite{})
    42  
    43  type snapshotSuite struct {
    44  	apiBaseSuite
    45  }
    46  
    47  func (s *snapshotSuite) SetUpTest(c *check.C) {
    48  	s.apiBaseSuite.SetUpTest(c)
    49  	s.daemonWithOverlordMock(c)
    50  	s.expectAuthenticatedAccess()
    51  	s.expectWriteAccess(daemon.AuthenticatedAccess{Polkit: "io.snapcraft.snapd.manage"})
    52  }
    53  
    54  func (s *snapshotSuite) TestSnapshotMany(c *check.C) {
    55  	defer daemon.MockSnapshotSave(func(s *state.State, snaps, users []string) (uint64, []string, *state.TaskSet, error) {
    56  		c.Check(snaps, check.HasLen, 2)
    57  		t := s.NewTask("fake-snapshot-2", "Snapshot two")
    58  		return 1, snaps, state.NewTaskSet(t), nil
    59  	})()
    60  
    61  	inst := daemon.MustUnmarshalSnapInstruction(c, `{"action": "snapshot", "snaps": ["foo", "bar"]}`)
    62  	st := s.d.Overlord().State()
    63  	st.Lock()
    64  	res, err := inst.DispatchForMany()(inst, st)
    65  	st.Unlock()
    66  	c.Assert(err, check.IsNil)
    67  	c.Check(res.Summary, check.Equals, `Snapshot snaps "foo", "bar"`)
    68  	c.Check(res.Affected, check.DeepEquals, inst.Snaps)
    69  }
    70  
    71  func (s *snapshotSuite) TestListSnapshots(c *check.C) {
    72  	s.expectOpenAccess()
    73  
    74  	snapshots := []client.SnapshotSet{{ID: 1}, {ID: 42}}
    75  
    76  	defer daemon.MockSnapshotList(func(context.Context, *state.State, uint64, []string) ([]client.SnapshotSet, error) {
    77  		return snapshots, nil
    78  	})()
    79  
    80  	req, err := http.NewRequest("GET", "/v2/snapshots", nil)
    81  	c.Assert(err, check.IsNil)
    82  
    83  	rsp := s.syncReq(c, req, nil)
    84  	c.Check(rsp.Status, check.Equals, 200)
    85  	c.Check(rsp.Result, check.DeepEquals, snapshots)
    86  }
    87  
    88  func (s *snapshotSuite) TestListSnapshotsFiltering(c *check.C) {
    89  	s.expectOpenAccess()
    90  
    91  	snapshots := []client.SnapshotSet{{ID: 1}, {ID: 42}}
    92  
    93  	defer daemon.MockSnapshotList(func(_ context.Context, st *state.State, setID uint64, _ []string) ([]client.SnapshotSet, error) {
    94  		c.Assert(setID, check.Equals, uint64(42))
    95  		return snapshots[1:], nil
    96  	})()
    97  
    98  	req, err := http.NewRequest("GET", "/v2/snapshots?set=42", nil)
    99  	c.Assert(err, check.IsNil)
   100  
   101  	rsp := s.syncReq(c, req, nil)
   102  	c.Check(rsp.Status, check.Equals, 200)
   103  	c.Check(rsp.Result, check.DeepEquals, []client.SnapshotSet{{ID: 42}})
   104  }
   105  
   106  func (s *snapshotSuite) TestListSnapshotsBadFiltering(c *check.C) {
   107  	s.expectOpenAccess()
   108  
   109  	defer daemon.MockSnapshotList(func(_ context.Context, _ *state.State, setID uint64, _ []string) ([]client.SnapshotSet, error) {
   110  		c.Fatal("snapshotList should not be reached (should have been blocked by validation!)")
   111  		return nil, nil
   112  	})()
   113  
   114  	req, err := http.NewRequest("GET", "/v2/snapshots?set=no", nil)
   115  	c.Assert(err, check.IsNil)
   116  
   117  	rspe := s.errorReq(c, req, nil)
   118  	c.Check(rspe.Status, check.Equals, 400)
   119  	c.Check(rspe.Message, check.Equals, `'set', if given, must be a positive base 10 number; got "no"`)
   120  }
   121  
   122  func (s *snapshotSuite) TestListSnapshotsListError(c *check.C) {
   123  	s.expectOpenAccess()
   124  
   125  	defer daemon.MockSnapshotList(func(_ context.Context, _ *state.State, setID uint64, _ []string) ([]client.SnapshotSet, error) {
   126  		return nil, errors.New("no")
   127  	})()
   128  
   129  	req, err := http.NewRequest("GET", "/v2/snapshots", nil)
   130  	c.Assert(err, check.IsNil)
   131  
   132  	rspe := s.errorReq(c, req, nil)
   133  	c.Check(rspe.Status, check.Equals, 500)
   134  	c.Check(rspe.Message, check.Equals, "no")
   135  }
   136  
   137  func (s *snapshotSuite) TestFormatSnapshotAction(c *check.C) {
   138  	type table struct {
   139  		action   string
   140  		expected string
   141  	}
   142  	tests := []table{
   143  		{
   144  			`{"set": 2, "action": "verb"}`,
   145  			`Verb of snapshot set #2`,
   146  		}, {
   147  			`{"set": 2, "action": "verb", "snaps": ["foo"]}`,
   148  			`Verb of snapshot set #2 for snaps "foo"`,
   149  		}, {
   150  			`{"set": 2, "action": "verb", "snaps": ["foo", "bar"]}`,
   151  			`Verb of snapshot set #2 for snaps "foo", "bar"`,
   152  		}, {
   153  			`{"set": 2, "action": "verb", "users": ["meep"]}`,
   154  			`Verb of snapshot set #2 for users "meep"`,
   155  		}, {
   156  			`{"set": 2, "action": "verb", "users": ["meep", "quux"]}`,
   157  			`Verb of snapshot set #2 for users "meep", "quux"`,
   158  		}, {
   159  			`{"set": 2, "action": "verb", "users": ["meep", "quux"], "snaps": ["foo", "bar"]}`,
   160  			`Verb of snapshot set #2 for snaps "foo", "bar" for users "meep", "quux"`,
   161  		},
   162  	}
   163  
   164  	for _, test := range tests {
   165  		action := daemon.MustUnmarshalSnapshotAction(c, test.action)
   166  		c.Check(action.String(), check.Equals, test.expected)
   167  	}
   168  }
   169  
   170  func (s *snapshotSuite) TestChangeSnapshots400(c *check.C) {
   171  	type table struct{ body, error string }
   172  	tests := []table{
   173  		{
   174  			body:  `"woodchucks`,
   175  			error: "cannot decode request body into snapshot operation:.*",
   176  		}, {
   177  			body:  `{}"woodchucks`,
   178  			error: "extra content found after snapshot operation",
   179  		}, {
   180  			body:  `{}`,
   181  			error: "snapshot operation requires snapshot set ID",
   182  		}, {
   183  			body:  `{"set": 42}`,
   184  			error: "snapshot operation requires action",
   185  		}, {
   186  			body:  `{"set": 42, "action": "carrots"}`,
   187  			error: `unknown snapshot operation "carrots"`,
   188  		}, {
   189  			body:  `{"set": 42, "action": "forget", "users": ["foo"]}`,
   190  			error: `snapshot "forget" operation cannot specify users`,
   191  		},
   192  	}
   193  
   194  	for i, test := range tests {
   195  		comm := check.Commentf("%d:%q", i, test.body)
   196  		req, err := http.NewRequest("POST", "/v2/snapshots", strings.NewReader(test.body))
   197  		c.Assert(err, check.IsNil, comm)
   198  
   199  		rspe := s.errorReq(c, req, nil)
   200  		c.Check(rspe.Status, check.Equals, 400, comm)
   201  		c.Check(rspe.Message, check.Matches, test.error, comm)
   202  	}
   203  }
   204  
   205  func (s *snapshotSuite) TestChangeSnapshots404(c *check.C) {
   206  	var done string
   207  	expectedError := errors.New("bzzt")
   208  	defer daemon.MockSnapshotCheck(func(*state.State, uint64, []string, []string) ([]string, *state.TaskSet, error) {
   209  		done = "check"
   210  		return nil, nil, expectedError
   211  	})()
   212  	defer daemon.MockSnapshotRestore(func(*state.State, uint64, []string, []string) ([]string, *state.TaskSet, error) {
   213  		done = "restore"
   214  		return nil, nil, expectedError
   215  	})()
   216  	defer daemon.MockSnapshotForget(func(*state.State, uint64, []string) ([]string, *state.TaskSet, error) {
   217  		done = "forget"
   218  		return nil, nil, expectedError
   219  	})()
   220  	for _, expectedError = range []error{client.ErrSnapshotSetNotFound, client.ErrSnapshotSnapsNotFound} {
   221  		for _, action := range []string{"check", "restore", "forget"} {
   222  			done = ""
   223  			comm := check.Commentf("%s/%s", action, expectedError)
   224  			body := fmt.Sprintf(`{"set": 42, "action": "%s"}`, action)
   225  			req, err := http.NewRequest("POST", "/v2/snapshots", strings.NewReader(body))
   226  			c.Assert(err, check.IsNil, comm)
   227  
   228  			rspe := s.errorReq(c, req, nil)
   229  			c.Check(rspe.Status, check.Equals, 404, comm)
   230  			c.Check(rspe.Message, check.Matches, expectedError.Error(), comm)
   231  			c.Check(done, check.Equals, action, comm)
   232  		}
   233  	}
   234  }
   235  
   236  func (s *snapshotSuite) TestChangeSnapshots500(c *check.C) {
   237  	var done string
   238  	expectedError := errors.New("bzzt")
   239  	defer daemon.MockSnapshotCheck(func(*state.State, uint64, []string, []string) ([]string, *state.TaskSet, error) {
   240  		done = "check"
   241  		return nil, nil, expectedError
   242  	})()
   243  	defer daemon.MockSnapshotRestore(func(*state.State, uint64, []string, []string) ([]string, *state.TaskSet, error) {
   244  		done = "restore"
   245  		return nil, nil, expectedError
   246  	})()
   247  	defer daemon.MockSnapshotForget(func(*state.State, uint64, []string) ([]string, *state.TaskSet, error) {
   248  		done = "forget"
   249  		return nil, nil, expectedError
   250  	})()
   251  	for _, action := range []string{"check", "restore", "forget"} {
   252  		comm := check.Commentf("%s", action)
   253  		body := fmt.Sprintf(`{"set": 42, "action": "%s"}`, action)
   254  		req, err := http.NewRequest("POST", "/v2/snapshots", strings.NewReader(body))
   255  		c.Assert(err, check.IsNil, comm)
   256  
   257  		rspe := s.errorReq(c, req, nil)
   258  		c.Check(rspe.Status, check.Equals, 500, comm)
   259  		c.Check(rspe.Message, check.Matches, expectedError.Error(), comm)
   260  		c.Check(done, check.Equals, action, comm)
   261  	}
   262  }
   263  
   264  func (s *snapshotSuite) TestChangeSnapshot(c *check.C) {
   265  	var done string
   266  	defer daemon.MockSnapshotCheck(func(*state.State, uint64, []string, []string) ([]string, *state.TaskSet, error) {
   267  		done = "check"
   268  		return []string{"foo"}, state.NewTaskSet(), nil
   269  	})()
   270  	defer daemon.MockSnapshotRestore(func(*state.State, uint64, []string, []string) ([]string, *state.TaskSet, error) {
   271  		done = "restore"
   272  		return []string{"foo"}, state.NewTaskSet(), nil
   273  	})()
   274  	defer daemon.MockSnapshotForget(func(*state.State, uint64, []string) ([]string, *state.TaskSet, error) {
   275  		done = "forget"
   276  		return []string{"foo"}, state.NewTaskSet(), nil
   277  	})()
   278  
   279  	st := s.d.Overlord().State()
   280  	st.Lock()
   281  	defer st.Unlock()
   282  	for _, action := range []string{"check", "restore", "forget"} {
   283  		comm := check.Commentf("%s", action)
   284  		body := fmt.Sprintf(`{"set": 42, "action": "%s"}`, action)
   285  		req, err := http.NewRequest("POST", "/v2/snapshots", strings.NewReader(body))
   286  
   287  		c.Assert(err, check.IsNil, comm)
   288  
   289  		st.Unlock()
   290  		rsp := s.asyncReq(c, req, nil)
   291  		st.Lock()
   292  
   293  		c.Check(rsp.Status, check.Equals, 202, comm)
   294  		c.Check(done, check.Equals, action, comm)
   295  
   296  		chg := st.Change(rsp.Change)
   297  		c.Assert(chg, check.NotNil)
   298  		c.Assert(chg.Tasks(), check.HasLen, 0)
   299  
   300  		c.Check(chg.Kind(), check.Equals, action+"-snapshot")
   301  		var apiData map[string]interface{}
   302  		err = chg.Get("api-data", &apiData)
   303  		c.Assert(err, check.IsNil)
   304  		c.Check(apiData, check.DeepEquals, map[string]interface{}{
   305  			"snap-names": []interface{}{"foo"},
   306  		})
   307  
   308  	}
   309  }
   310  
   311  func (s *snapshotSuite) TestExportSnapshots(c *check.C) {
   312  	var snapshotExportCalled int
   313  
   314  	defer daemon.MockSnapshotExport(func(ctx context.Context, st *state.State, setID uint64) (*snapshotstate.SnapshotExport, error) {
   315  		snapshotExportCalled++
   316  		c.Check(setID, check.Equals, uint64(1))
   317  		return &snapshotstate.SnapshotExport{}, nil
   318  	})()
   319  
   320  	req, err := http.NewRequest("GET", "/v2/snapshots/1/export", nil)
   321  	c.Assert(err, check.IsNil)
   322  
   323  	rsp := s.req(c, req, nil)
   324  	c.Check(rsp, check.FitsTypeOf, &daemon.SnapshotExportResponse{})
   325  	c.Check(snapshotExportCalled, check.Equals, 1)
   326  }
   327  
   328  func (s *snapshotSuite) TestExportSnapshotsBadRequestOnNonNumericID(c *check.C) {
   329  	req, err := http.NewRequest("GET", "/v2/snapshots/xxx/export", nil)
   330  	c.Assert(err, check.IsNil)
   331  
   332  	rspe := s.errorReq(c, req, nil)
   333  	c.Check(rspe.Status, check.Equals, 400)
   334  	c.Check(rspe.Message, check.Equals, `'id' must be a positive base 10 number; got "xxx"`)
   335  }
   336  
   337  func (s *snapshotSuite) TestExportSnapshotsBadRequestOnError(c *check.C) {
   338  	var snapshotExportCalled int
   339  
   340  	defer daemon.MockSnapshotExport(func(ctx context.Context, st *state.State, setID uint64) (*snapshotstate.SnapshotExport, error) {
   341  		snapshotExportCalled++
   342  		return nil, fmt.Errorf("boom")
   343  	})()
   344  
   345  	req, err := http.NewRequest("GET", "/v2/snapshots/1/export", nil)
   346  	c.Assert(err, check.IsNil)
   347  
   348  	rspe := s.errorReq(c, req, nil)
   349  	c.Check(rspe.Status, check.Equals, 400)
   350  	c.Check(rspe.Message, check.Equals, `cannot export 1: boom`)
   351  	c.Check(snapshotExportCalled, check.Equals, 1)
   352  }
   353  
   354  func (s *snapshotSuite) TestImportSnapshot(c *check.C) {
   355  	data := []byte("mocked snapshot export data file")
   356  
   357  	setID := uint64(3)
   358  	snapNames := []string{"baz", "bar", "foo"}
   359  	defer daemon.MockSnapshotImport(func(context.Context, *state.State, io.Reader) (uint64, []string, error) {
   360  		return setID, snapNames, nil
   361  	})()
   362  
   363  	req, err := http.NewRequest("POST", "/v2/snapshots", bytes.NewReader(data))
   364  	req.Header.Add("Content-Length", strconv.Itoa(len(data)))
   365  	c.Assert(err, check.IsNil)
   366  	req.Header.Set("Content-Type", client.SnapshotExportMediaType)
   367  
   368  	rsp := s.syncReq(c, req, nil)
   369  	c.Check(rsp.Status, check.Equals, 200)
   370  	c.Check(rsp.Result, check.DeepEquals, map[string]interface{}{"set-id": setID, "snaps": snapNames})
   371  }
   372  
   373  func (s *snapshotSuite) TestImportSnapshotError(c *check.C) {
   374  	defer daemon.MockSnapshotImport(func(context.Context, *state.State, io.Reader) (uint64, []string, error) {
   375  		return uint64(0), nil, errors.New("no")
   376  	})()
   377  
   378  	data := []byte("mocked snapshot export data file")
   379  	req, err := http.NewRequest("POST", "/v2/snapshots", bytes.NewReader(data))
   380  	req.Header.Add("Content-Length", strconv.Itoa(len(data)))
   381  	c.Assert(err, check.IsNil)
   382  	req.Header.Set("Content-Type", client.SnapshotExportMediaType)
   383  
   384  	rspe := s.errorReq(c, req, nil)
   385  	c.Check(rspe.Status, check.Equals, 400)
   386  	c.Check(rspe.Message, check.Equals, "no")
   387  }
   388  
   389  func (s *snapshotSuite) TestImportSnapshotNoContentLengthError(c *check.C) {
   390  	data := []byte("mocked snapshot export data file")
   391  	req, err := http.NewRequest("POST", "/v2/snapshots", bytes.NewReader(data))
   392  	c.Assert(err, check.IsNil)
   393  	req.Header.Set("Content-Type", client.SnapshotExportMediaType)
   394  
   395  	rspe := s.errorReq(c, req, nil)
   396  	c.Check(rspe.Status, check.Equals, 400)
   397  	c.Check(rspe.Message, check.Equals, `cannot parse Content-Length: strconv.ParseInt: parsing "": invalid syntax`)
   398  }
   399  
   400  func (s *snapshotSuite) TestImportSnapshotLimits(c *check.C) {
   401  	var dataRead int
   402  
   403  	defer daemon.MockSnapshotImport(func(ctx context.Context, st *state.State, r io.Reader) (uint64, []string, error) {
   404  		data, err := ioutil.ReadAll(r)
   405  		c.Assert(err, check.IsNil)
   406  		dataRead = len(data)
   407  		return uint64(0), nil, nil
   408  	})()
   409  
   410  	data := []byte("much more data than expected from Content-Length")
   411  	req, err := http.NewRequest("POST", "/v2/snapshots", bytes.NewReader(data))
   412  	// limit to 10 and check that this is really all that is read
   413  	req.Header.Add("Content-Length", "10")
   414  	c.Assert(err, check.IsNil)
   415  	req.Header.Set("Content-Type", client.SnapshotExportMediaType)
   416  
   417  	rsp := s.syncReq(c, req, nil)
   418  	c.Check(rsp.Status, check.Equals, 200)
   419  	c.Check(dataRead, check.Equals, 10)
   420  }