github.com/stolowski/snapd@v0.0.0-20210407085831-115137ce5a22/daemon/api_snapshots_test.go (about) 1 // -*- Mode: Go; indent-tabs-mode: t -*- 2 3 /* 4 * Copyright (C) 2018 Canonical Ltd 5 * 6 * This program is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU General Public License version 3 as 8 * published by the Free Software Foundation. 9 * 10 * This program is distributed in the hope that it will be useful, 11 * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 * GNU General Public License for more details. 14 * 15 * You should have received a copy of the GNU General Public License 16 * along with this program. If not, see <http://www.gnu.org/licenses/>. 17 * 18 */ 19 20 package daemon_test 21 22 import ( 23 "bytes" 24 "context" 25 "errors" 26 "fmt" 27 "io" 28 "io/ioutil" 29 "net/http" 30 "strconv" 31 "strings" 32 33 "gopkg.in/check.v1" 34 35 "github.com/snapcore/snapd/client" 36 "github.com/snapcore/snapd/daemon" 37 "github.com/snapcore/snapd/overlord/snapshotstate" 38 "github.com/snapcore/snapd/overlord/state" 39 ) 40 41 var _ = check.Suite(&snapshotSuite{}) 42 43 type snapshotSuite struct { 44 apiBaseSuite 45 } 46 47 func (s *snapshotSuite) SetUpTest(c *check.C) { 48 s.apiBaseSuite.SetUpTest(c) 49 s.daemonWithOverlordMock(c) 50 } 51 52 func (s *snapshotSuite) TestSnapshotMany(c *check.C) { 53 defer daemon.MockSnapshotSave(func(s *state.State, snaps, users []string) (uint64, []string, *state.TaskSet, error) { 54 c.Check(snaps, check.HasLen, 2) 55 t := s.NewTask("fake-snapshot-2", "Snapshot two") 56 return 1, snaps, state.NewTaskSet(t), nil 57 })() 58 59 inst := daemon.MustUnmarshalSnapInstruction(c, `{"action": "snapshot", "snaps": ["foo", "bar"]}`) 60 st := s.d.Overlord().State() 61 st.Lock() 62 res, err := inst.DispatchForMany()(inst, st) 63 st.Unlock() 64 c.Assert(err, check.IsNil) 65 c.Check(res.Summary, check.Equals, `Snapshot snaps "foo", "bar"`) 66 c.Check(res.Affected, check.DeepEquals, inst.Snaps) 67 } 68 69 func (s *snapshotSuite) TestListSnapshots(c *check.C) { 70 snapshots := []client.SnapshotSet{{ID: 1}, {ID: 42}} 71 72 defer daemon.MockSnapshotList(func(context.Context, *state.State, uint64, []string) ([]client.SnapshotSet, error) { 73 return snapshots, nil 74 })() 75 76 req, err := http.NewRequest("GET", "/v2/snapshots", nil) 77 c.Assert(err, check.IsNil) 78 79 rsp := s.syncReq(c, req, nil) 80 c.Check(rsp.Status, check.Equals, 200) 81 c.Check(rsp.Result, check.DeepEquals, snapshots) 82 } 83 84 func (s *snapshotSuite) TestListSnapshotsFiltering(c *check.C) { 85 snapshots := []client.SnapshotSet{{ID: 1}, {ID: 42}} 86 87 defer daemon.MockSnapshotList(func(_ context.Context, st *state.State, setID uint64, _ []string) ([]client.SnapshotSet, error) { 88 c.Assert(setID, check.Equals, uint64(42)) 89 return snapshots[1:], nil 90 })() 91 92 req, err := http.NewRequest("GET", "/v2/snapshots?set=42", nil) 93 c.Assert(err, check.IsNil) 94 95 rsp := s.syncReq(c, req, nil) 96 c.Check(rsp.Status, check.Equals, 200) 97 c.Check(rsp.Result, check.DeepEquals, []client.SnapshotSet{{ID: 42}}) 98 } 99 100 func (s *snapshotSuite) TestListSnapshotsBadFiltering(c *check.C) { 101 defer daemon.MockSnapshotList(func(_ context.Context, _ *state.State, setID uint64, _ []string) ([]client.SnapshotSet, error) { 102 c.Fatal("snapshotList should not be reached (should have been blocked by validation!)") 103 return nil, nil 104 })() 105 106 req, err := http.NewRequest("GET", "/v2/snapshots?set=no", nil) 107 c.Assert(err, check.IsNil) 108 109 rsp := s.errorReq(c, req, nil) 110 c.Check(rsp.Status, check.Equals, 400) 111 c.Check(rsp.ErrorResult().Message, check.Equals, `'set', if given, must be a positive base 10 number; got "no"`) 112 } 113 114 func (s *snapshotSuite) TestListSnapshotsListError(c *check.C) { 115 defer daemon.MockSnapshotList(func(_ context.Context, _ *state.State, setID uint64, _ []string) ([]client.SnapshotSet, error) { 116 return nil, errors.New("no") 117 })() 118 119 req, err := http.NewRequest("GET", "/v2/snapshots", nil) 120 c.Assert(err, check.IsNil) 121 122 rsp := s.errorReq(c, req, nil) 123 c.Check(rsp.Status, check.Equals, 500) 124 c.Check(rsp.ErrorResult().Message, check.Equals, "no") 125 } 126 127 func (s *snapshotSuite) TestFormatSnapshotAction(c *check.C) { 128 type table struct { 129 action string 130 expected string 131 } 132 tests := []table{ 133 { 134 `{"set": 2, "action": "verb"}`, 135 `Verb of snapshot set #2`, 136 }, { 137 `{"set": 2, "action": "verb", "snaps": ["foo"]}`, 138 `Verb of snapshot set #2 for snaps "foo"`, 139 }, { 140 `{"set": 2, "action": "verb", "snaps": ["foo", "bar"]}`, 141 `Verb of snapshot set #2 for snaps "foo", "bar"`, 142 }, { 143 `{"set": 2, "action": "verb", "users": ["meep"]}`, 144 `Verb of snapshot set #2 for users "meep"`, 145 }, { 146 `{"set": 2, "action": "verb", "users": ["meep", "quux"]}`, 147 `Verb of snapshot set #2 for users "meep", "quux"`, 148 }, { 149 `{"set": 2, "action": "verb", "users": ["meep", "quux"], "snaps": ["foo", "bar"]}`, 150 `Verb of snapshot set #2 for snaps "foo", "bar" for users "meep", "quux"`, 151 }, 152 } 153 154 for _, test := range tests { 155 action := daemon.MustUnmarshalSnapshotAction(c, test.action) 156 c.Check(action.String(), check.Equals, test.expected) 157 } 158 } 159 160 func (s *snapshotSuite) TestChangeSnapshots400(c *check.C) { 161 type table struct{ body, error string } 162 tests := []table{ 163 { 164 body: `"woodchucks`, 165 error: "cannot decode request body into snapshot operation:.*", 166 }, { 167 body: `{}"woodchucks`, 168 error: "extra content found after snapshot operation", 169 }, { 170 body: `{}`, 171 error: "snapshot operation requires snapshot set ID", 172 }, { 173 body: `{"set": 42}`, 174 error: "snapshot operation requires action", 175 }, { 176 body: `{"set": 42, "action": "carrots"}`, 177 error: `unknown snapshot operation "carrots"`, 178 }, { 179 body: `{"set": 42, "action": "forget", "users": ["foo"]}`, 180 error: `snapshot "forget" operation cannot specify users`, 181 }, 182 } 183 184 for i, test := range tests { 185 comm := check.Commentf("%d:%q", i, test.body) 186 req, err := http.NewRequest("POST", "/v2/snapshots", strings.NewReader(test.body)) 187 c.Assert(err, check.IsNil, comm) 188 189 rsp := s.errorReq(c, req, nil) 190 c.Check(rsp.Status, check.Equals, 400, comm) 191 c.Check(rsp.ErrorResult().Message, check.Matches, test.error, comm) 192 } 193 } 194 195 func (s *snapshotSuite) TestChangeSnapshots404(c *check.C) { 196 var done string 197 expectedError := errors.New("bzzt") 198 defer daemon.MockSnapshotCheck(func(*state.State, uint64, []string, []string) ([]string, *state.TaskSet, error) { 199 done = "check" 200 return nil, nil, expectedError 201 })() 202 defer daemon.MockSnapshotRestore(func(*state.State, uint64, []string, []string) ([]string, *state.TaskSet, error) { 203 done = "restore" 204 return nil, nil, expectedError 205 })() 206 defer daemon.MockSnapshotForget(func(*state.State, uint64, []string) ([]string, *state.TaskSet, error) { 207 done = "forget" 208 return nil, nil, expectedError 209 })() 210 for _, expectedError = range []error{client.ErrSnapshotSetNotFound, client.ErrSnapshotSnapsNotFound} { 211 for _, action := range []string{"check", "restore", "forget"} { 212 done = "" 213 comm := check.Commentf("%s/%s", action, expectedError) 214 body := fmt.Sprintf(`{"set": 42, "action": "%s"}`, action) 215 req, err := http.NewRequest("POST", "/v2/snapshots", strings.NewReader(body)) 216 c.Assert(err, check.IsNil, comm) 217 218 rsp := s.errorReq(c, req, nil) 219 c.Check(rsp.Status, check.Equals, 404, comm) 220 c.Check(rsp.ErrorResult().Message, check.Matches, expectedError.Error(), comm) 221 c.Check(done, check.Equals, action, comm) 222 } 223 } 224 } 225 226 func (s *snapshotSuite) TestChangeSnapshots500(c *check.C) { 227 var done string 228 expectedError := errors.New("bzzt") 229 defer daemon.MockSnapshotCheck(func(*state.State, uint64, []string, []string) ([]string, *state.TaskSet, error) { 230 done = "check" 231 return nil, nil, expectedError 232 })() 233 defer daemon.MockSnapshotRestore(func(*state.State, uint64, []string, []string) ([]string, *state.TaskSet, error) { 234 done = "restore" 235 return nil, nil, expectedError 236 })() 237 defer daemon.MockSnapshotForget(func(*state.State, uint64, []string) ([]string, *state.TaskSet, error) { 238 done = "forget" 239 return nil, nil, expectedError 240 })() 241 for _, action := range []string{"check", "restore", "forget"} { 242 comm := check.Commentf("%s", action) 243 body := fmt.Sprintf(`{"set": 42, "action": "%s"}`, action) 244 req, err := http.NewRequest("POST", "/v2/snapshots", strings.NewReader(body)) 245 c.Assert(err, check.IsNil, comm) 246 247 rsp := s.errorReq(c, req, nil) 248 c.Check(rsp.Status, check.Equals, 500, comm) 249 c.Check(rsp.ErrorResult().Message, check.Matches, expectedError.Error(), comm) 250 c.Check(done, check.Equals, action, comm) 251 } 252 } 253 254 func (s *snapshotSuite) TestChangeSnapshot(c *check.C) { 255 var done string 256 defer daemon.MockSnapshotCheck(func(*state.State, uint64, []string, []string) ([]string, *state.TaskSet, error) { 257 done = "check" 258 return []string{"foo"}, state.NewTaskSet(), nil 259 })() 260 defer daemon.MockSnapshotRestore(func(*state.State, uint64, []string, []string) ([]string, *state.TaskSet, error) { 261 done = "restore" 262 return []string{"foo"}, state.NewTaskSet(), nil 263 })() 264 defer daemon.MockSnapshotForget(func(*state.State, uint64, []string) ([]string, *state.TaskSet, error) { 265 done = "forget" 266 return []string{"foo"}, state.NewTaskSet(), nil 267 })() 268 269 st := s.d.Overlord().State() 270 st.Lock() 271 defer st.Unlock() 272 for _, action := range []string{"check", "restore", "forget"} { 273 comm := check.Commentf("%s", action) 274 body := fmt.Sprintf(`{"set": 42, "action": "%s"}`, action) 275 req, err := http.NewRequest("POST", "/v2/snapshots", strings.NewReader(body)) 276 277 c.Assert(err, check.IsNil, comm) 278 279 st.Unlock() 280 rsp := s.asyncReq(c, req, nil) 281 st.Lock() 282 283 c.Check(rsp.Status, check.Equals, 202, comm) 284 c.Check(done, check.Equals, action, comm) 285 286 chg := st.Change(rsp.Change) 287 c.Assert(chg, check.NotNil) 288 c.Assert(chg.Tasks(), check.HasLen, 0) 289 290 c.Check(chg.Kind(), check.Equals, action+"-snapshot") 291 var apiData map[string]interface{} 292 err = chg.Get("api-data", &apiData) 293 c.Assert(err, check.IsNil) 294 c.Check(apiData, check.DeepEquals, map[string]interface{}{ 295 "snap-names": []interface{}{"foo"}, 296 }) 297 298 } 299 } 300 301 func (s *snapshotSuite) TestExportSnapshots(c *check.C) { 302 var snapshotExportCalled int 303 304 defer daemon.MockSnapshotExport(func(ctx context.Context, st *state.State, setID uint64) (*snapshotstate.SnapshotExport, error) { 305 snapshotExportCalled++ 306 c.Check(setID, check.Equals, uint64(1)) 307 return &snapshotstate.SnapshotExport{}, nil 308 })() 309 310 req, err := http.NewRequest("GET", "/v2/snapshots/1/export", nil) 311 c.Assert(err, check.IsNil) 312 313 rsp := s.req(c, req, nil) 314 c.Check(rsp, check.FitsTypeOf, &daemon.SnapshotExportResponse{}) 315 c.Check(snapshotExportCalled, check.Equals, 1) 316 } 317 318 func (s *snapshotSuite) TestExportSnapshotsBadRequestOnNonNumericID(c *check.C) { 319 req, err := http.NewRequest("GET", "/v2/snapshots/xxx/export", nil) 320 c.Assert(err, check.IsNil) 321 322 rsp := s.errorReq(c, req, nil) 323 c.Check(rsp.Status, check.Equals, 400) 324 c.Check(rsp.Result, check.DeepEquals, &daemon.ErrorResult{Message: `'id' must be a positive base 10 number; got "xxx"`}) 325 } 326 327 func (s *snapshotSuite) TestExportSnapshotsBadRequestOnError(c *check.C) { 328 var snapshotExportCalled int 329 330 defer daemon.MockSnapshotExport(func(ctx context.Context, st *state.State, setID uint64) (*snapshotstate.SnapshotExport, error) { 331 snapshotExportCalled++ 332 return nil, fmt.Errorf("boom") 333 })() 334 335 req, err := http.NewRequest("GET", "/v2/snapshots/1/export", nil) 336 c.Assert(err, check.IsNil) 337 338 rsp := s.errorReq(c, req, nil) 339 c.Check(rsp.Status, check.Equals, 400) 340 c.Check(rsp.Result, check.DeepEquals, &daemon.ErrorResult{Message: `cannot export 1: boom`}) 341 c.Check(snapshotExportCalled, check.Equals, 1) 342 } 343 344 func (s *snapshotSuite) TestImportSnapshot(c *check.C) { 345 data := []byte("mocked snapshot export data file") 346 347 setID := uint64(3) 348 snapNames := []string{"baz", "bar", "foo"} 349 defer daemon.MockSnapshotImport(func(context.Context, *state.State, io.Reader) (uint64, []string, error) { 350 return setID, snapNames, nil 351 })() 352 353 req, err := http.NewRequest("POST", "/v2/snapshots", bytes.NewReader(data)) 354 req.Header.Add("Content-Length", strconv.Itoa(len(data))) 355 c.Assert(err, check.IsNil) 356 req.Header.Set("Content-Type", client.SnapshotExportMediaType) 357 358 rsp := s.syncReq(c, req, nil) 359 c.Check(rsp.Status, check.Equals, 200) 360 c.Check(rsp.Result, check.DeepEquals, map[string]interface{}{"set-id": setID, "snaps": snapNames}) 361 } 362 363 func (s *snapshotSuite) TestImportSnapshotError(c *check.C) { 364 defer daemon.MockSnapshotImport(func(context.Context, *state.State, io.Reader) (uint64, []string, error) { 365 return uint64(0), nil, errors.New("no") 366 })() 367 368 data := []byte("mocked snapshot export data file") 369 req, err := http.NewRequest("POST", "/v2/snapshots", bytes.NewReader(data)) 370 req.Header.Add("Content-Length", strconv.Itoa(len(data))) 371 c.Assert(err, check.IsNil) 372 req.Header.Set("Content-Type", client.SnapshotExportMediaType) 373 374 rsp := s.errorReq(c, req, nil) 375 c.Check(rsp.Status, check.Equals, 400) 376 c.Check(rsp.ErrorResult().Message, check.Equals, "no") 377 } 378 379 func (s *snapshotSuite) TestImportSnapshotNoContentLengthError(c *check.C) { 380 data := []byte("mocked snapshot export data file") 381 req, err := http.NewRequest("POST", "/v2/snapshots", bytes.NewReader(data)) 382 c.Assert(err, check.IsNil) 383 req.Header.Set("Content-Type", client.SnapshotExportMediaType) 384 385 rsp := s.errorReq(c, req, nil) 386 c.Check(rsp.Status, check.Equals, 400) 387 c.Check(rsp.ErrorResult().Message, check.Equals, `cannot parse Content-Length: strconv.ParseInt: parsing "": invalid syntax`) 388 } 389 390 func (s *snapshotSuite) TestImportSnapshotLimits(c *check.C) { 391 var dataRead int 392 393 defer daemon.MockSnapshotImport(func(ctx context.Context, st *state.State, r io.Reader) (uint64, []string, error) { 394 data, err := ioutil.ReadAll(r) 395 c.Assert(err, check.IsNil) 396 dataRead = len(data) 397 return uint64(0), nil, nil 398 })() 399 400 data := []byte("much more data than expected from Content-Length") 401 req, err := http.NewRequest("POST", "/v2/snapshots", bytes.NewReader(data)) 402 // limit to 10 and check that this is really all that is read 403 req.Header.Add("Content-Length", "10") 404 c.Assert(err, check.IsNil) 405 req.Header.Set("Content-Type", client.SnapshotExportMediaType) 406 407 rsp := s.syncReq(c, req, nil) 408 c.Check(rsp.Status, check.Equals, 200) 409 c.Check(dataRead, check.Equals, 10) 410 }