github.com/david-imola/snapd@v0.0.0-20210611180407-2de8ddeece6d/daemon/daemon_test.go (about) 1 // -*- Mode: Go; indent-tabs-mode: t -*- 2 3 /* 4 * Copyright (C) 2014-2021 Canonical Ltd 5 * 6 * This program is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU General Public License version 3 as 8 * published by the Free Software Foundation. 9 * 10 * This program is distributed in the hope that it will be useful, 11 * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 * GNU General Public License for more details. 14 * 15 * You should have received a copy of the GNU General Public License 16 * along with this program. If not, see <http://www.gnu.org/licenses/>. 17 * 18 */ 19 20 package daemon 21 22 import ( 23 "fmt" 24 25 "encoding/json" 26 "io/ioutil" 27 "net" 28 "net/http" 29 "net/http/httptest" 30 "os" 31 "path/filepath" 32 "sync" 33 "syscall" 34 "testing" 35 "time" 36 37 "github.com/gorilla/mux" 38 "gopkg.in/check.v1" 39 40 "github.com/snapcore/snapd/client" 41 "github.com/snapcore/snapd/dirs" 42 "github.com/snapcore/snapd/osutil" 43 "github.com/snapcore/snapd/overlord" 44 "github.com/snapcore/snapd/overlord/auth" 45 "github.com/snapcore/snapd/overlord/devicestate/devicestatetest" 46 "github.com/snapcore/snapd/overlord/ifacestate" 47 "github.com/snapcore/snapd/overlord/patch" 48 "github.com/snapcore/snapd/overlord/snapstate" 49 "github.com/snapcore/snapd/overlord/standby" 50 "github.com/snapcore/snapd/overlord/state" 51 "github.com/snapcore/snapd/polkit" 52 "github.com/snapcore/snapd/snap" 53 "github.com/snapcore/snapd/store" 54 "github.com/snapcore/snapd/systemd" 55 "github.com/snapcore/snapd/testutil" 56 ) 57 58 // Hook up check.v1 into the "go test" runner 59 func Test(t *testing.T) { check.TestingT(t) } 60 61 type daemonSuite struct { 62 testutil.BaseTest 63 64 authorized bool 65 err error 66 lastPolkitFlags polkit.CheckFlags 67 notified []string 68 } 69 70 var _ = check.Suite(&daemonSuite{}) 71 72 func (s *daemonSuite) SetUpTest(c *check.C) { 73 s.BaseTest.SetUpTest(c) 74 75 dirs.SetRootDir(c.MkDir()) 76 s.AddCleanup(osutil.MockMountInfo("")) 77 78 err := os.MkdirAll(filepath.Dir(dirs.SnapStateFile), 0755) 79 c.Assert(err, check.IsNil) 80 systemdSdNotify = func(notif string) error { 81 s.notified = append(s.notified, notif) 82 return nil 83 } 84 s.notified = nil 85 s.AddCleanup(ifacestate.MockSecurityBackends(nil)) 86 } 87 88 func (s *daemonSuite) TearDownTest(c *check.C) { 89 systemdSdNotify = systemd.SdNotify 90 dirs.SetRootDir("") 91 s.authorized = false 92 s.err = nil 93 94 s.BaseTest.TearDownTest(c) 95 } 96 97 // build a new daemon, with only a little of Init(), suitable for the tests 98 func newTestDaemon(c *check.C) *Daemon { 99 d, err := New() 100 c.Assert(err, check.IsNil) 101 d.addRoutes() 102 103 // don't actually try to talk to the store on snapstate.Ensure 104 // needs doing after the call to devicestate.Manager (which 105 // happens in daemon.New via overlord.New) 106 snapstate.CanAutoRefresh = nil 107 108 return d 109 } 110 111 // a Response suitable for testing 112 type mockHandler struct { 113 cmd *Command 114 lastMethod string 115 } 116 117 func (mck *mockHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 118 mck.lastMethod = r.Method 119 } 120 121 func (s *daemonSuite) TestCommandMethodDispatch(c *check.C) { 122 d := newTestDaemon(c) 123 st := d.Overlord().State() 124 st.Lock() 125 authUser, err := auth.NewUser(st, "username", "email@test.com", "macaroon", []string{"discharge"}) 126 st.Unlock() 127 c.Assert(err, check.IsNil) 128 129 fakeUserAgent := "some-agent-talking-to-snapd/1.0" 130 131 cmd := &Command{d: d} 132 mck := &mockHandler{cmd: cmd} 133 rf := func(innerCmd *Command, req *http.Request, user *auth.UserState) Response { 134 c.Assert(cmd, check.Equals, innerCmd) 135 c.Check(store.ClientUserAgent(req.Context()), check.Equals, fakeUserAgent) 136 c.Check(user, check.DeepEquals, authUser) 137 return mck 138 } 139 cmd.GET = rf 140 cmd.PUT = rf 141 cmd.POST = rf 142 cmd.ReadAccess = authenticatedAccess{} 143 cmd.WriteAccess = authenticatedAccess{} 144 145 for _, method := range []string{"GET", "POST", "PUT"} { 146 req, err := http.NewRequest(method, "", nil) 147 req.Header.Add("User-Agent", fakeUserAgent) 148 c.Assert(err, check.IsNil) 149 150 rec := httptest.NewRecorder() 151 req.RemoteAddr = fmt.Sprintf("pid=100;uid=1001;socket=%s;", dirs.SnapdSocket) 152 cmd.ServeHTTP(rec, req) 153 c.Check(rec.Code, check.Equals, 401, check.Commentf(method)) 154 155 rec = httptest.NewRecorder() 156 req.Header.Set("Authorization", fmt.Sprintf(`Macaroon root="%s"`, authUser.Macaroon)) 157 158 cmd.ServeHTTP(rec, req) 159 c.Check(mck.lastMethod, check.Equals, method) 160 c.Check(rec.Code, check.Equals, 200) 161 } 162 163 req, err := http.NewRequest("POTATO", "", nil) 164 c.Assert(err, check.IsNil) 165 req.RemoteAddr = fmt.Sprintf("pid=100;uid=1001;socket=%s;", dirs.SnapdSocket) 166 req.Header.Set("Authorization", fmt.Sprintf(`Macaroon root="%s"`, authUser.Macaroon)) 167 rec := httptest.NewRecorder() 168 cmd.ServeHTTP(rec, req) 169 c.Check(rec.Code, check.Equals, 405) 170 } 171 172 func (s *daemonSuite) TestCommandMethodDispatchRoot(c *check.C) { 173 fakeUserAgent := "some-agent-talking-to-snapd/1.0" 174 175 cmd := &Command{d: newTestDaemon(c)} 176 mck := &mockHandler{cmd: cmd} 177 rf := func(innerCmd *Command, req *http.Request, user *auth.UserState) Response { 178 c.Assert(cmd, check.Equals, innerCmd) 179 c.Check(store.ClientUserAgent(req.Context()), check.Equals, fakeUserAgent) 180 return mck 181 } 182 cmd.GET = rf 183 cmd.PUT = rf 184 cmd.POST = rf 185 cmd.ReadAccess = authenticatedAccess{} 186 cmd.WriteAccess = authenticatedAccess{} 187 188 for _, method := range []string{"GET", "POST", "PUT"} { 189 req, err := http.NewRequest(method, "", nil) 190 req.Header.Add("User-Agent", fakeUserAgent) 191 c.Assert(err, check.IsNil) 192 193 rec := httptest.NewRecorder() 194 // no ucred => forbidden 195 cmd.ServeHTTP(rec, req) 196 c.Check(rec.Code, check.Equals, 403, check.Commentf(method)) 197 198 rec = httptest.NewRecorder() 199 req.RemoteAddr = fmt.Sprintf("pid=100;uid=0;socket=%s;", dirs.SnapdSocket) 200 201 cmd.ServeHTTP(rec, req) 202 c.Check(mck.lastMethod, check.Equals, method) 203 c.Check(rec.Code, check.Equals, 200) 204 } 205 206 req, err := http.NewRequest("POTATO", "", nil) 207 c.Assert(err, check.IsNil) 208 req.RemoteAddr = fmt.Sprintf("pid=100;uid=0;socket=%s;", dirs.SnapdSocket) 209 210 rec := httptest.NewRecorder() 211 cmd.ServeHTTP(rec, req) 212 c.Check(rec.Code, check.Equals, 405) 213 } 214 215 func (s *daemonSuite) TestCommandRestartingState(c *check.C) { 216 d := newTestDaemon(c) 217 218 cmd := &Command{d: d} 219 cmd.GET = func(*Command, *http.Request, *auth.UserState) Response { 220 return SyncResponse(nil) 221 } 222 cmd.ReadAccess = openAccess{} 223 req, err := http.NewRequest("GET", "", nil) 224 c.Assert(err, check.IsNil) 225 req.RemoteAddr = fmt.Sprintf("pid=100;uid=42;socket=%s;", dirs.SnapdSocket) 226 227 rec := httptest.NewRecorder() 228 cmd.ServeHTTP(rec, req) 229 c.Check(rec.Code, check.Equals, 200) 230 var rst struct { 231 Maintenance *errorResult `json:"maintenance"` 232 } 233 err = json.Unmarshal(rec.Body.Bytes(), &rst) 234 c.Assert(err, check.IsNil) 235 c.Check(rst.Maintenance, check.IsNil) 236 237 tests := []struct { 238 rst state.RestartType 239 kind client.ErrorKind 240 msg string 241 op string 242 }{ 243 { 244 rst: state.RestartSystem, 245 kind: client.ErrorKindSystemRestart, 246 msg: "system is restarting", 247 op: "reboot", 248 }, { 249 rst: state.RestartSystemNow, 250 kind: client.ErrorKindSystemRestart, 251 msg: "system is restarting", 252 op: "reboot", 253 }, { 254 rst: state.RestartDaemon, 255 kind: client.ErrorKindDaemonRestart, 256 msg: "daemon is restarting", 257 }, { 258 rst: state.RestartSystemHaltNow, 259 kind: client.ErrorKindSystemRestart, 260 msg: "system is halting", 261 op: "halt", 262 }, { 263 rst: state.RestartSystemPoweroffNow, 264 kind: client.ErrorKindSystemRestart, 265 msg: "system is powering off", 266 op: "poweroff", 267 }, { 268 rst: state.RestartSocket, 269 kind: client.ErrorKindDaemonRestart, 270 msg: "daemon is stopping to wait for socket activation", 271 }, 272 } 273 274 for _, t := range tests { 275 state.MockRestarting(d.overlord.State(), t.rst) 276 rec = httptest.NewRecorder() 277 cmd.ServeHTTP(rec, req) 278 c.Check(rec.Code, check.Equals, 200) 279 var rst struct { 280 Maintenance *errorResult `json:"maintenance"` 281 } 282 err = json.Unmarshal(rec.Body.Bytes(), &rst) 283 c.Assert(err, check.IsNil) 284 var val errorValue 285 if t.op != "" { 286 val = map[string]interface{}{ 287 "op": t.op, 288 } 289 } 290 c.Check(rst.Maintenance, check.DeepEquals, &errorResult{ 291 Kind: t.kind, 292 Message: t.msg, 293 Value: val, 294 }) 295 } 296 } 297 298 func (s *daemonSuite) TestMaintenanceJsonDeletedOnStart(c *check.C) { 299 // write a maintenance.json file that has that the system is restarting 300 maintErr := &errorResult{ 301 Kind: client.ErrorKindDaemonRestart, 302 Message: systemRestartMsg, 303 } 304 305 b, err := json.Marshal(maintErr) 306 c.Assert(err, check.IsNil) 307 c.Assert(os.MkdirAll(filepath.Dir(dirs.SnapdMaintenanceFile), 0755), check.IsNil) 308 c.Assert(ioutil.WriteFile(dirs.SnapdMaintenanceFile, b, 0644), check.IsNil) 309 310 d := newTestDaemon(c) 311 makeDaemonListeners(c, d) 312 313 s.markSeeded(d) 314 315 // after starting, maintenance.json should be removed 316 c.Assert(d.Start(), check.IsNil) 317 c.Assert(dirs.SnapdMaintenanceFile, testutil.FileAbsent) 318 d.Stop(nil) 319 } 320 321 func (s *daemonSuite) TestFillsWarnings(c *check.C) { 322 d := newTestDaemon(c) 323 324 cmd := &Command{d: d} 325 cmd.GET = func(*Command, *http.Request, *auth.UserState) Response { 326 return SyncResponse(nil) 327 } 328 cmd.ReadAccess = openAccess{} 329 req, err := http.NewRequest("GET", "", nil) 330 c.Assert(err, check.IsNil) 331 req.RemoteAddr = fmt.Sprintf("pid=100;uid=42;socket=%s;", dirs.SnapdSocket) 332 333 rec := httptest.NewRecorder() 334 cmd.ServeHTTP(rec, req) 335 c.Check(rec.Code, check.Equals, 200) 336 var rst struct { 337 WarningTimestamp *time.Time `json:"warning-timestamp,omitempty"` 338 WarningCount int `json:"warning-count,omitempty"` 339 } 340 err = json.Unmarshal(rec.Body.Bytes(), &rst) 341 c.Assert(err, check.IsNil) 342 c.Check(rst.WarningCount, check.Equals, 0) 343 c.Check(rst.WarningTimestamp, check.IsNil) 344 345 st := d.overlord.State() 346 st.Lock() 347 st.Warnf("hello world") 348 st.Unlock() 349 350 rec = httptest.NewRecorder() 351 cmd.ServeHTTP(rec, req) 352 c.Check(rec.Code, check.Equals, 200) 353 err = json.Unmarshal(rec.Body.Bytes(), &rst) 354 c.Assert(err, check.IsNil) 355 c.Check(rst.WarningCount, check.Equals, 1) 356 c.Check(rst.WarningTimestamp, check.NotNil) 357 } 358 359 type accessCheckFunc func(r *http.Request, ucred *ucrednet, user *auth.UserState) *apiError 360 361 func (f accessCheckFunc) CheckAccess(r *http.Request, ucred *ucrednet, user *auth.UserState) *apiError { 362 return f(r, ucred, user) 363 } 364 365 func (s *daemonSuite) TestReadAccess(c *check.C) { 366 cmd := &Command{d: newTestDaemon(c)} 367 cmd.GET = func(*Command, *http.Request, *auth.UserState) Response { 368 return SyncResponse(nil) 369 } 370 var accessCalled bool 371 cmd.ReadAccess = accessCheckFunc(func(r *http.Request, ucred *ucrednet, user *auth.UserState) *apiError { 372 accessCalled = true 373 c.Check(r, check.NotNil) 374 c.Assert(ucred, check.NotNil) 375 c.Check(ucred.Uid, check.Equals, uint32(42)) 376 c.Check(ucred.Pid, check.Equals, int32(100)) 377 c.Check(ucred.Socket, check.Equals, "xyz") 378 c.Check(user, check.IsNil) 379 return nil 380 }) 381 cmd.WriteAccess = accessCheckFunc(func(r *http.Request, ucred *ucrednet, user *auth.UserState) *apiError { 382 c.Fail() 383 return Forbidden("") 384 }) 385 386 req := httptest.NewRequest("GET", "/", nil) 387 req.RemoteAddr = "pid=100;uid=42;socket=xyz;" 388 rec := httptest.NewRecorder() 389 cmd.ServeHTTP(rec, req) 390 c.Check(rec.Code, check.Equals, 200) 391 c.Check(accessCalled, check.Equals, true) 392 } 393 394 func (s *daemonSuite) TestWriteAccess(c *check.C) { 395 cmd := &Command{d: newTestDaemon(c)} 396 cmd.PUT = func(*Command, *http.Request, *auth.UserState) Response { 397 return SyncResponse(nil) 398 } 399 cmd.POST = func(*Command, *http.Request, *auth.UserState) Response { 400 return SyncResponse(nil) 401 } 402 cmd.ReadAccess = accessCheckFunc(func(r *http.Request, ucred *ucrednet, user *auth.UserState) *apiError { 403 c.Fail() 404 return Forbidden("") 405 }) 406 var accessCalled bool 407 cmd.WriteAccess = accessCheckFunc(func(r *http.Request, ucred *ucrednet, user *auth.UserState) *apiError { 408 accessCalled = true 409 c.Check(r, check.NotNil) 410 c.Assert(ucred, check.NotNil) 411 c.Check(ucred.Uid, check.Equals, uint32(42)) 412 c.Check(ucred.Pid, check.Equals, int32(100)) 413 c.Check(ucred.Socket, check.Equals, "xyz") 414 c.Check(user, check.IsNil) 415 return nil 416 }) 417 418 req := httptest.NewRequest("PUT", "/", nil) 419 req.RemoteAddr = "pid=100;uid=42;socket=xyz;" 420 rec := httptest.NewRecorder() 421 cmd.ServeHTTP(rec, req) 422 c.Check(rec.Code, check.Equals, 200) 423 c.Check(accessCalled, check.Equals, true) 424 425 accessCalled = false 426 req = httptest.NewRequest("POST", "/", nil) 427 req.RemoteAddr = "pid=100;uid=42;socket=xyz;" 428 rec = httptest.NewRecorder() 429 cmd.ServeHTTP(rec, req) 430 c.Check(rec.Code, check.Equals, 200) 431 c.Check(accessCalled, check.Equals, true) 432 } 433 434 func (s *daemonSuite) TestWriteAccessWithUser(c *check.C) { 435 d := newTestDaemon(c) 436 st := d.Overlord().State() 437 st.Lock() 438 authUser, err := auth.NewUser(st, "username", "email@test.com", "macaroon", []string{"discharge"}) 439 st.Unlock() 440 c.Assert(err, check.IsNil) 441 442 cmd := &Command{d: d} 443 cmd.PUT = func(*Command, *http.Request, *auth.UserState) Response { 444 return SyncResponse(nil) 445 } 446 cmd.POST = func(*Command, *http.Request, *auth.UserState) Response { 447 return SyncResponse(nil) 448 } 449 cmd.ReadAccess = accessCheckFunc(func(r *http.Request, ucred *ucrednet, user *auth.UserState) *apiError { 450 c.Fail() 451 return Forbidden("") 452 }) 453 var accessCalled bool 454 cmd.WriteAccess = accessCheckFunc(func(r *http.Request, ucred *ucrednet, user *auth.UserState) *apiError { 455 accessCalled = true 456 c.Check(r, check.NotNil) 457 c.Assert(ucred, check.NotNil) 458 c.Check(ucred.Uid, check.Equals, uint32(1001)) 459 c.Check(ucred.Pid, check.Equals, int32(100)) 460 c.Check(ucred.Socket, check.Equals, "xyz") 461 c.Check(user, check.DeepEquals, authUser) 462 return nil 463 }) 464 465 req := httptest.NewRequest("PUT", "/", nil) 466 req.Header.Set("Authorization", fmt.Sprintf(`Macaroon root="%s"`, authUser.Macaroon)) 467 req.RemoteAddr = "pid=100;uid=1001;socket=xyz;" 468 rec := httptest.NewRecorder() 469 cmd.ServeHTTP(rec, req) 470 c.Check(rec.Code, check.Equals, 200) 471 c.Check(accessCalled, check.Equals, true) 472 473 accessCalled = false 474 req = httptest.NewRequest("POST", "/", nil) 475 req.Header.Set("Authorization", fmt.Sprintf(`Macaroon root="%s"`, authUser.Macaroon)) 476 req.RemoteAddr = "pid=100;uid=1001;socket=xyz;" 477 rec = httptest.NewRecorder() 478 cmd.ServeHTTP(rec, req) 479 c.Check(rec.Code, check.Equals, 200) 480 c.Check(accessCalled, check.Equals, true) 481 } 482 483 func (s *daemonSuite) TestPolkitAccessPath(c *check.C) { 484 cmd := &Command{d: newTestDaemon(c)} 485 cmd.POST = func(*Command, *http.Request, *auth.UserState) Response { 486 return SyncResponse(nil) 487 } 488 access := false 489 cmd.WriteAccess = authenticatedAccess{Polkit: "foo"} 490 checkPolkitAction = func(r *http.Request, ucred *ucrednet, action string) *apiError { 491 c.Check(action, check.Equals, "foo") 492 c.Check(ucred.Uid, check.Equals, uint32(1001)) 493 if access { 494 return nil 495 } 496 return AuthCancelled("") 497 } 498 499 req := httptest.NewRequest("POST", "/", nil) 500 req.RemoteAddr = fmt.Sprintf("pid=100;uid=1001;socket=%s;", dirs.SnapdSocket) 501 rec := httptest.NewRecorder() 502 cmd.ServeHTTP(rec, req) 503 c.Check(rec.Code, check.Equals, 403) 504 c.Check(rec.Body.String(), testutil.Contains, `"kind":"auth-cancelled"`) 505 506 access = true 507 rec = httptest.NewRecorder() 508 cmd.ServeHTTP(rec, req) 509 c.Check(rec.Code, check.Equals, 200) 510 } 511 512 func (s *daemonSuite) TestCommandAccessSane(c *check.C) { 513 for _, cmd := range api { 514 // If Command.GET is set, ReadAccess must be set 515 c.Check(cmd.GET != nil, check.Equals, cmd.ReadAccess != nil, check.Commentf("%q ReadAccess", cmd.Path)) 516 // If Command.PUT or POST are set, WriteAccess must be set 517 c.Check(cmd.PUT != nil || cmd.POST != nil, check.Equals, cmd.WriteAccess != nil, check.Commentf("%q WriteAccess", cmd.Path)) 518 } 519 } 520 521 func (s *daemonSuite) TestAddRoutes(c *check.C) { 522 d := newTestDaemon(c) 523 524 expected := make([]string, len(api)) 525 for i, v := range api { 526 if v.PathPrefix != "" { 527 expected[i] = v.PathPrefix 528 continue 529 } 530 expected[i] = v.Path 531 } 532 533 got := make([]string, 0, len(api)) 534 c.Assert(d.router.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error { 535 got = append(got, route.GetName()) 536 return nil 537 }), check.IsNil) 538 539 c.Check(got, check.DeepEquals, expected) // this'll stop being true if routes are added that aren't commands (e.g. for the favicon) 540 541 // XXX: still waiting to know how to check d.router.NotFoundHandler has been set to NotFound 542 // the old test relied on undefined behaviour: 543 // c.Check(fmt.Sprintf("%p", d.router.NotFoundHandler), check.Equals, fmt.Sprintf("%p", NotFound)) 544 } 545 546 type witnessAcceptListener struct { 547 net.Listener 548 549 accept chan struct{} 550 accept1 bool 551 552 idempotClose sync.Once 553 closeErr error 554 closed chan struct{} 555 } 556 557 func (l *witnessAcceptListener) Accept() (net.Conn, error) { 558 if !l.accept1 { 559 l.accept1 = true 560 close(l.accept) 561 } 562 return l.Listener.Accept() 563 } 564 565 func (l *witnessAcceptListener) Close() error { 566 l.idempotClose.Do(func() { 567 l.closeErr = l.Listener.Close() 568 if l.closed != nil { 569 close(l.closed) 570 } 571 }) 572 return l.closeErr 573 } 574 575 func (s *daemonSuite) markSeeded(d *Daemon) { 576 st := d.overlord.State() 577 st.Lock() 578 st.Set("seeded", true) 579 devicestatetest.SetDevice(st, &auth.DeviceState{ 580 Brand: "canonical", 581 Model: "pc", 582 Serial: "serialserial", 583 }) 584 st.Unlock() 585 } 586 587 func (s *daemonSuite) TestStartStop(c *check.C) { 588 d := newTestDaemon(c) 589 // mark as already seeded 590 s.markSeeded(d) 591 // and pretend we have snaps 592 st := d.overlord.State() 593 st.Lock() 594 snapstate.Set(st, "core", &snapstate.SnapState{ 595 Active: true, 596 Sequence: []*snap.SideInfo{ 597 {RealName: "core", Revision: snap.R(1), SnapID: "core-snap-id"}, 598 }, 599 Current: snap.R(1), 600 }) 601 st.Unlock() 602 // 1 snap => extended timeout 30s + 5s 603 const extendedTimeoutUSec = "EXTEND_TIMEOUT_USEC=35000000" 604 605 l1, err := net.Listen("tcp", "127.0.0.1:0") 606 c.Assert(err, check.IsNil) 607 l2, err := net.Listen("tcp", "127.0.0.1:0") 608 c.Assert(err, check.IsNil) 609 610 snapdAccept := make(chan struct{}) 611 d.snapdListener = &witnessAcceptListener{Listener: l1, accept: snapdAccept} 612 613 snapAccept := make(chan struct{}) 614 d.snapListener = &witnessAcceptListener{Listener: l2, accept: snapAccept} 615 616 c.Assert(d.Start(), check.IsNil) 617 618 c.Check(s.notified, check.DeepEquals, []string{extendedTimeoutUSec, "READY=1"}) 619 620 snapdDone := make(chan struct{}) 621 go func() { 622 select { 623 case <-snapdAccept: 624 case <-time.After(2 * time.Second): 625 c.Fatal("snapd accept was not called") 626 } 627 close(snapdDone) 628 }() 629 630 snapDone := make(chan struct{}) 631 go func() { 632 select { 633 case <-snapAccept: 634 case <-time.After(2 * time.Second): 635 c.Fatal("snapd accept was not called") 636 } 637 close(snapDone) 638 }() 639 640 <-snapdDone 641 <-snapDone 642 643 err = d.Stop(nil) 644 c.Check(err, check.IsNil) 645 646 c.Check(s.notified, check.DeepEquals, []string{extendedTimeoutUSec, "READY=1", "STOPPING=1"}) 647 } 648 649 func (s *daemonSuite) TestRestartWiring(c *check.C) { 650 d := newTestDaemon(c) 651 // mark as already seeded 652 s.markSeeded(d) 653 654 l, err := net.Listen("tcp", "127.0.0.1:0") 655 c.Assert(err, check.IsNil) 656 657 snapdAccept := make(chan struct{}) 658 d.snapdListener = &witnessAcceptListener{Listener: l, accept: snapdAccept} 659 660 snapAccept := make(chan struct{}) 661 d.snapListener = &witnessAcceptListener{Listener: l, accept: snapAccept} 662 663 c.Assert(d.Start(), check.IsNil) 664 stoppedYet := false 665 defer func() { 666 if !stoppedYet { 667 d.Stop(nil) 668 } 669 }() 670 671 snapdDone := make(chan struct{}) 672 go func() { 673 select { 674 case <-snapdAccept: 675 case <-time.After(2 * time.Second): 676 c.Fatal("snapd accept was not called") 677 } 678 close(snapdDone) 679 }() 680 681 snapDone := make(chan struct{}) 682 go func() { 683 select { 684 case <-snapAccept: 685 case <-time.After(2 * time.Second): 686 c.Fatal("snap accept was not called") 687 } 688 close(snapDone) 689 }() 690 691 <-snapdDone 692 <-snapDone 693 694 d.overlord.State().RequestRestart(state.RestartDaemon) 695 696 select { 697 case <-d.Dying(): 698 case <-time.After(2 * time.Second): 699 c.Fatal("RequestRestart -> overlord -> Kill chain didn't work") 700 } 701 702 d.Stop(nil) 703 stoppedYet = true 704 705 c.Assert(s.notified, check.DeepEquals, []string{"EXTEND_TIMEOUT_USEC=30000000", "READY=1", "STOPPING=1"}) 706 } 707 708 func (s *daemonSuite) TestGracefulStop(c *check.C) { 709 d := newTestDaemon(c) 710 711 responding := make(chan struct{}) 712 doRespond := make(chan bool, 1) 713 714 d.router.HandleFunc("/endp", func(w http.ResponseWriter, r *http.Request) { 715 close(responding) 716 if <-doRespond { 717 w.Write([]byte("OKOK")) 718 } else { 719 w.Write([]byte("Gone")) 720 } 721 }) 722 723 // mark as already seeded 724 s.markSeeded(d) 725 // and pretend we have snaps 726 st := d.overlord.State() 727 st.Lock() 728 snapstate.Set(st, "core", &snapstate.SnapState{ 729 Active: true, 730 Sequence: []*snap.SideInfo{ 731 {RealName: "core", Revision: snap.R(1), SnapID: "core-snap-id"}, 732 }, 733 Current: snap.R(1), 734 }) 735 st.Unlock() 736 737 snapdL, err := net.Listen("tcp", "127.0.0.1:0") 738 c.Assert(err, check.IsNil) 739 740 snapL, err := net.Listen("tcp", "127.0.0.1:0") 741 c.Assert(err, check.IsNil) 742 743 snapdAccept := make(chan struct{}) 744 snapdClosed := make(chan struct{}) 745 d.snapdListener = &witnessAcceptListener{Listener: snapdL, accept: snapdAccept, closed: snapdClosed} 746 747 snapAccept := make(chan struct{}) 748 d.snapListener = &witnessAcceptListener{Listener: snapL, accept: snapAccept} 749 750 c.Assert(d.Start(), check.IsNil) 751 752 snapdAccepting := make(chan struct{}) 753 go func() { 754 select { 755 case <-snapdAccept: 756 case <-time.After(2 * time.Second): 757 c.Fatal("snapd accept was not called") 758 } 759 close(snapdAccepting) 760 }() 761 762 snapAccepting := make(chan struct{}) 763 go func() { 764 select { 765 case <-snapAccept: 766 case <-time.After(2 * time.Second): 767 c.Fatal("snapd accept was not called") 768 } 769 close(snapAccepting) 770 }() 771 772 <-snapdAccepting 773 <-snapAccepting 774 775 alright := make(chan struct{}) 776 777 go func() { 778 res, err := http.Get(fmt.Sprintf("http://%s/endp", snapdL.Addr())) 779 c.Assert(err, check.IsNil) 780 c.Check(res.StatusCode, check.Equals, 200) 781 body, err := ioutil.ReadAll(res.Body) 782 res.Body.Close() 783 c.Assert(err, check.IsNil) 784 c.Check(string(body), check.Equals, "OKOK") 785 close(alright) 786 }() 787 go func() { 788 <-snapdClosed 789 time.Sleep(200 * time.Millisecond) 790 doRespond <- true 791 }() 792 793 <-responding 794 err = d.Stop(nil) 795 doRespond <- false 796 c.Check(err, check.IsNil) 797 798 select { 799 case <-alright: 800 case <-time.After(2 * time.Second): 801 c.Fatal("never got proper response") 802 } 803 } 804 805 func (s *daemonSuite) TestGracefulStopHasLimits(c *check.C) { 806 d := newTestDaemon(c) 807 808 // mark as already seeded 809 s.markSeeded(d) 810 811 restore := MockShutdownTimeout(time.Second) 812 defer restore() 813 814 responding := make(chan struct{}) 815 doRespond := make(chan bool, 1) 816 817 d.router.HandleFunc("/endp", func(w http.ResponseWriter, r *http.Request) { 818 close(responding) 819 if <-doRespond { 820 for { 821 // write in a loop to keep the handler running 822 if _, err := w.Write([]byte("OKOK")); err != nil { 823 break 824 } 825 time.Sleep(50 * time.Millisecond) 826 } 827 } else { 828 w.Write([]byte("Gone")) 829 } 830 }) 831 832 snapdL, err := net.Listen("tcp", "127.0.0.1:0") 833 c.Assert(err, check.IsNil) 834 835 snapL, err := net.Listen("tcp", "127.0.0.1:0") 836 c.Assert(err, check.IsNil) 837 838 snapdAccept := make(chan struct{}) 839 snapdClosed := make(chan struct{}) 840 d.snapdListener = &witnessAcceptListener{Listener: snapdL, accept: snapdAccept, closed: snapdClosed} 841 842 snapAccept := make(chan struct{}) 843 d.snapListener = &witnessAcceptListener{Listener: snapL, accept: snapAccept} 844 845 c.Assert(d.Start(), check.IsNil) 846 847 snapdAccepting := make(chan struct{}) 848 go func() { 849 select { 850 case <-snapdAccept: 851 case <-time.After(2 * time.Second): 852 c.Fatal("snapd accept was not called") 853 } 854 close(snapdAccepting) 855 }() 856 857 snapAccepting := make(chan struct{}) 858 go func() { 859 select { 860 case <-snapAccept: 861 case <-time.After(2 * time.Second): 862 c.Fatal("snapd accept was not called") 863 } 864 close(snapAccepting) 865 }() 866 867 <-snapdAccepting 868 <-snapAccepting 869 870 clientErr := make(chan error) 871 872 go func() { 873 _, err := http.Get(fmt.Sprintf("http://%s/endp", snapdL.Addr())) 874 c.Assert(err, check.NotNil) 875 clientErr <- err 876 close(clientErr) 877 }() 878 go func() { 879 <-snapdClosed 880 time.Sleep(200 * time.Millisecond) 881 doRespond <- true 882 }() 883 884 <-responding 885 err = d.Stop(nil) 886 doRespond <- false 887 c.Check(err, check.IsNil) 888 889 select { 890 case cErr := <-clientErr: 891 c.Check(cErr, check.ErrorMatches, ".*: EOF") 892 case <-time.After(5 * time.Second): 893 c.Fatal("never got proper response") 894 } 895 } 896 897 func (s *daemonSuite) testRestartSystemWiring(c *check.C, prep func(d *Daemon), restart func(*state.State, state.RestartType), restartKind state.RestartType, wait time.Duration) { 898 d := newTestDaemon(c) 899 // mark as already seeded 900 s.markSeeded(d) 901 902 if prep != nil { 903 prep(d) 904 } 905 906 l, err := net.Listen("tcp", "127.0.0.1:0") 907 c.Assert(err, check.IsNil) 908 909 snapdAccept := make(chan struct{}) 910 d.snapdListener = &witnessAcceptListener{Listener: l, accept: snapdAccept} 911 912 snapAccept := make(chan struct{}) 913 d.snapListener = &witnessAcceptListener{Listener: l, accept: snapAccept} 914 915 oldRebootNoticeWait := rebootNoticeWait 916 oldRebootWaitTimeout := rebootWaitTimeout 917 defer func() { 918 reboot = rebootImpl 919 rebootNoticeWait = oldRebootNoticeWait 920 rebootWaitTimeout = oldRebootWaitTimeout 921 }() 922 rebootWaitTimeout = 100 * time.Millisecond 923 rebootNoticeWait = 150 * time.Millisecond 924 925 expectedAction := rebootReboot 926 expectedOp := "reboot" 927 if restartKind == state.RestartSystemHaltNow { 928 expectedAction = rebootHalt 929 expectedOp = "halt" 930 } else if restartKind == state.RestartSystemPoweroffNow { 931 expectedAction = rebootPoweroff 932 expectedOp = "poweroff" 933 } 934 var delays []time.Duration 935 reboot = func(a rebootAction, d time.Duration) error { 936 c.Check(a, check.Equals, expectedAction) 937 delays = append(delays, d) 938 return nil 939 } 940 941 c.Assert(d.Start(), check.IsNil) 942 defer d.Stop(nil) 943 944 st := d.overlord.State() 945 946 snapdDone := make(chan struct{}) 947 go func() { 948 select { 949 case <-snapdAccept: 950 case <-time.After(2 * time.Second): 951 c.Fatal("snapd accept was not called") 952 } 953 close(snapdDone) 954 }() 955 956 snapDone := make(chan struct{}) 957 go func() { 958 select { 959 case <-snapAccept: 960 case <-time.After(2 * time.Second): 961 c.Fatal("snap accept was not called") 962 } 963 close(snapDone) 964 }() 965 966 <-snapdDone 967 <-snapDone 968 969 st.Lock() 970 restart(st, restartKind) 971 st.Unlock() 972 973 defer func() { 974 d.mu.Lock() 975 d.requestedRestart = state.RestartUnset 976 d.mu.Unlock() 977 }() 978 979 select { 980 case <-d.Dying(): 981 case <-time.After(2 * time.Second): 982 c.Fatal("RequestRestart -> overlord -> Kill chain didn't work") 983 } 984 985 d.mu.Lock() 986 rs := d.requestedRestart 987 d.mu.Unlock() 988 989 c.Check(rs, check.Equals, restartKind) 990 991 c.Check(delays, check.HasLen, 1) 992 c.Check(delays[0], check.DeepEquals, rebootWaitTimeout) 993 994 now := time.Now() 995 996 err = d.Stop(nil) 997 998 c.Check(err, check.ErrorMatches, fmt.Sprintf("expected %s did not happen", expectedAction)) 999 1000 c.Check(delays, check.HasLen, 2) 1001 c.Check(delays[1], check.DeepEquals, wait) 1002 1003 // we are not stopping, we wait for the reboot instead 1004 c.Check(s.notified, check.DeepEquals, []string{"EXTEND_TIMEOUT_USEC=30000000", "READY=1"}) 1005 1006 st.Lock() 1007 defer st.Unlock() 1008 var rebootAt time.Time 1009 err = st.Get("daemon-system-restart-at", &rebootAt) 1010 c.Assert(err, check.IsNil) 1011 if wait > 0 { 1012 approxAt := now.Add(wait) 1013 c.Check(rebootAt.After(approxAt) || rebootAt.Equal(approxAt), check.Equals, true) 1014 } else { 1015 // should be good enough 1016 c.Check(rebootAt.Before(now.Add(10*time.Second)), check.Equals, true) 1017 } 1018 1019 // finally check that maintenance.json was written appropriate for this 1020 // restart reason 1021 b, err := ioutil.ReadFile(dirs.SnapdMaintenanceFile) 1022 c.Assert(err, check.IsNil) 1023 1024 maintErr := &errorResult{} 1025 c.Assert(json.Unmarshal(b, maintErr), check.IsNil) 1026 c.Check(maintErr.Kind, check.Equals, client.ErrorKindSystemRestart) 1027 c.Check(maintErr.Value, check.DeepEquals, map[string]interface{}{ 1028 "op": expectedOp, 1029 }) 1030 1031 exp := maintenanceForRestartType(restartKind) 1032 c.Assert(maintErr, check.DeepEquals, exp) 1033 } 1034 1035 func (s *daemonSuite) TestRestartSystemGracefulWiring(c *check.C) { 1036 s.testRestartSystemWiring(c, nil, (*state.State).RequestRestart, state.RestartSystem, 1*time.Minute) 1037 } 1038 1039 func (s *daemonSuite) TestRestartSystemImmediateWiring(c *check.C) { 1040 s.testRestartSystemWiring(c, nil, (*state.State).RequestRestart, state.RestartSystemNow, 0) 1041 } 1042 1043 func (s *daemonSuite) TestRestartSystemHaltImmediateWiring(c *check.C) { 1044 s.testRestartSystemWiring(c, nil, (*state.State).RequestRestart, state.RestartSystemHaltNow, 0) 1045 } 1046 1047 func (s *daemonSuite) TestRestartSystemPoweroffImmediateWiring(c *check.C) { 1048 s.testRestartSystemWiring(c, nil, (*state.State).RequestRestart, state.RestartSystemPoweroffNow, 0) 1049 } 1050 1051 type rstManager struct { 1052 st *state.State 1053 } 1054 1055 func (m *rstManager) Ensure() error { 1056 m.st.Lock() 1057 defer m.st.Unlock() 1058 m.st.RequestRestart(state.RestartSystemNow) 1059 return nil 1060 } 1061 1062 type witnessManager struct { 1063 ensureCalled int 1064 } 1065 1066 func (m *witnessManager) Ensure() error { 1067 m.ensureCalled++ 1068 return nil 1069 } 1070 1071 func (s *daemonSuite) TestRestartSystemFromEnsure(c *check.C) { 1072 // Test that calling RequestRestart from inside the first 1073 // Ensure loop works. 1074 wm := &witnessManager{} 1075 1076 prep := func(d *Daemon) { 1077 st := d.overlord.State() 1078 hm := d.overlord.HookManager() 1079 o := overlord.MockWithStateAndRestartHandler(st, d.HandleRestart) 1080 d.overlord = o 1081 o.AddManager(hm) 1082 rm := &rstManager{st: st} 1083 o.AddManager(rm) 1084 o.AddManager(wm) 1085 } 1086 1087 nop := func(*state.State, state.RestartType) {} 1088 1089 s.testRestartSystemWiring(c, prep, nop, state.RestartSystemNow, 0) 1090 1091 c.Check(wm.ensureCalled, check.Equals, 1) 1092 } 1093 1094 func (s *daemonSuite) TestRebootHelper(c *check.C) { 1095 cmd := testutil.MockCommand(c, "shutdown", "") 1096 defer cmd.Restore() 1097 1098 tests := []struct { 1099 delay time.Duration 1100 delayArg string 1101 }{ 1102 {-1, "+0"}, 1103 {0, "+0"}, 1104 {time.Minute, "+1"}, 1105 {10 * time.Minute, "+10"}, 1106 {30 * time.Second, "+0"}, 1107 } 1108 1109 args := []struct { 1110 a rebootAction 1111 arg string 1112 msg string 1113 }{ 1114 {rebootReboot, "-r", "reboot scheduled to update the system"}, 1115 {rebootHalt, "--halt", "system halt scheduled"}, 1116 {rebootPoweroff, "--poweroff", "system poweroff scheduled"}, 1117 } 1118 1119 for _, arg := range args { 1120 for _, t := range tests { 1121 err := reboot(arg.a, t.delay) 1122 c.Assert(err, check.IsNil) 1123 c.Check(cmd.Calls(), check.DeepEquals, [][]string{ 1124 {"shutdown", arg.arg, t.delayArg, arg.msg}, 1125 }) 1126 1127 cmd.ForgetCalls() 1128 } 1129 } 1130 } 1131 1132 func makeDaemonListeners(c *check.C, d *Daemon) { 1133 snapdL, err := net.Listen("tcp", "127.0.0.1:0") 1134 c.Assert(err, check.IsNil) 1135 1136 snapL, err := net.Listen("tcp", "127.0.0.1:0") 1137 c.Assert(err, check.IsNil) 1138 1139 snapdAccept := make(chan struct{}) 1140 snapdClosed := make(chan struct{}) 1141 d.snapdListener = &witnessAcceptListener{Listener: snapdL, accept: snapdAccept, closed: snapdClosed} 1142 1143 snapAccept := make(chan struct{}) 1144 d.snapListener = &witnessAcceptListener{Listener: snapL, accept: snapAccept} 1145 } 1146 1147 // This test tests that when the snapd calls a restart of the system 1148 // a sigterm (from e.g. systemd) is handled when it arrives before 1149 // stop is fully done. 1150 func (s *daemonSuite) TestRestartShutdownWithSigtermInBetween(c *check.C) { 1151 oldRebootNoticeWait := rebootNoticeWait 1152 defer func() { 1153 rebootNoticeWait = oldRebootNoticeWait 1154 }() 1155 rebootNoticeWait = 150 * time.Millisecond 1156 1157 cmd := testutil.MockCommand(c, "shutdown", "") 1158 defer cmd.Restore() 1159 1160 d := newTestDaemon(c) 1161 makeDaemonListeners(c, d) 1162 s.markSeeded(d) 1163 1164 c.Assert(d.Start(), check.IsNil) 1165 st := d.overlord.State() 1166 1167 st.Lock() 1168 st.RequestRestart(state.RestartSystem) 1169 st.Unlock() 1170 1171 ch := make(chan os.Signal, 2) 1172 ch <- syscall.SIGTERM 1173 // stop will check if we got a sigterm in between (which we did) 1174 err := d.Stop(ch) 1175 c.Assert(err, check.IsNil) 1176 } 1177 1178 // This test tests that when there is a shutdown we close the sigterm 1179 // handler so that systemd can kill snapd. 1180 func (s *daemonSuite) TestRestartShutdown(c *check.C) { 1181 oldRebootNoticeWait := rebootNoticeWait 1182 oldRebootWaitTimeout := rebootWaitTimeout 1183 defer func() { 1184 rebootNoticeWait = oldRebootNoticeWait 1185 rebootWaitTimeout = oldRebootWaitTimeout 1186 }() 1187 rebootWaitTimeout = 100 * time.Millisecond 1188 rebootNoticeWait = 150 * time.Millisecond 1189 1190 cmd := testutil.MockCommand(c, "shutdown", "") 1191 defer cmd.Restore() 1192 1193 d := newTestDaemon(c) 1194 makeDaemonListeners(c, d) 1195 s.markSeeded(d) 1196 1197 c.Assert(d.Start(), check.IsNil) 1198 st := d.overlord.State() 1199 1200 st.Lock() 1201 st.RequestRestart(state.RestartSystem) 1202 st.Unlock() 1203 1204 sigCh := make(chan os.Signal, 2) 1205 // stop (this will timeout but that's not relevant for this test) 1206 d.Stop(sigCh) 1207 1208 // ensure that the sigCh got closed as part of the stop 1209 _, chOpen := <-sigCh 1210 c.Assert(chOpen, check.Equals, false) 1211 } 1212 1213 func (s *daemonSuite) TestRestartExpectedRebootDidNotHappen(c *check.C) { 1214 curBootID, err := osutil.BootID() 1215 c.Assert(err, check.IsNil) 1216 1217 fakeState := []byte(fmt.Sprintf(`{"data":{"patch-level":%d,"patch-sublevel":%d,"some":"data","refresh-privacy-key":"0123456789ABCDEF","system-restart-from-boot-id":%q,"daemon-system-restart-at":"%s"},"changes":null,"tasks":null,"last-change-id":0,"last-task-id":0,"last-lane-id":0}`, patch.Level, patch.Sublevel, curBootID, time.Now().UTC().Format(time.RFC3339))) 1218 err = ioutil.WriteFile(dirs.SnapStateFile, fakeState, 0600) 1219 c.Assert(err, check.IsNil) 1220 1221 oldRebootNoticeWait := rebootNoticeWait 1222 oldRebootRetryWaitTimeout := rebootRetryWaitTimeout 1223 defer func() { 1224 rebootNoticeWait = oldRebootNoticeWait 1225 rebootRetryWaitTimeout = oldRebootRetryWaitTimeout 1226 }() 1227 rebootRetryWaitTimeout = 100 * time.Millisecond 1228 rebootNoticeWait = 150 * time.Millisecond 1229 1230 cmd := testutil.MockCommand(c, "shutdown", "") 1231 defer cmd.Restore() 1232 1233 d := newTestDaemon(c) 1234 c.Check(d.overlord, check.IsNil) 1235 c.Check(d.expectedRebootDidNotHappen, check.Equals, true) 1236 1237 var n int 1238 d.state.Lock() 1239 err = d.state.Get("daemon-system-restart-tentative", &n) 1240 d.state.Unlock() 1241 c.Check(err, check.IsNil) 1242 c.Check(n, check.Equals, 1) 1243 1244 c.Assert(d.Start(), check.IsNil) 1245 1246 c.Check(s.notified, check.DeepEquals, []string{"READY=1"}) 1247 1248 select { 1249 case <-d.Dying(): 1250 case <-time.After(2 * time.Second): 1251 c.Fatal("expected reboot not happening should proceed to try to shutdown again") 1252 } 1253 1254 sigCh := make(chan os.Signal, 2) 1255 // stop (this will timeout but thats not relevant for this test) 1256 d.Stop(sigCh) 1257 1258 // an immediate shutdown was scheduled again 1259 c.Check(cmd.Calls(), check.DeepEquals, [][]string{ 1260 {"shutdown", "-r", "+0", "reboot scheduled to update the system"}, 1261 }) 1262 } 1263 1264 func (s *daemonSuite) TestRestartExpectedRebootOK(c *check.C) { 1265 fakeState := []byte(fmt.Sprintf(`{"data":{"patch-level":%d,"patch-sublevel":%d,"some":"data","refresh-privacy-key":"0123456789ABCDEF","system-restart-from-boot-id":%q,"daemon-system-restart-at":"%s"},"changes":null,"tasks":null,"last-change-id":0,"last-task-id":0,"last-lane-id":0}`, patch.Level, patch.Sublevel, "boot-id-0", time.Now().UTC().Format(time.RFC3339))) 1266 err := ioutil.WriteFile(dirs.SnapStateFile, fakeState, 0600) 1267 c.Assert(err, check.IsNil) 1268 1269 cmd := testutil.MockCommand(c, "shutdown", "") 1270 defer cmd.Restore() 1271 1272 d := newTestDaemon(c) 1273 c.Assert(d.overlord, check.NotNil) 1274 1275 st := d.overlord.State() 1276 st.Lock() 1277 defer st.Unlock() 1278 var v interface{} 1279 // these were cleared 1280 c.Check(st.Get("daemon-system-restart-at", &v), check.Equals, state.ErrNoState) 1281 c.Check(st.Get("system-restart-from-boot-id", &v), check.Equals, state.ErrNoState) 1282 } 1283 1284 func (s *daemonSuite) TestRestartExpectedRebootGiveUp(c *check.C) { 1285 // we give up trying to restart the system after 3 retry tentatives 1286 curBootID, err := osutil.BootID() 1287 c.Assert(err, check.IsNil) 1288 1289 fakeState := []byte(fmt.Sprintf(`{"data":{"patch-level":%d,"patch-sublevel":%d,"some":"data","refresh-privacy-key":"0123456789ABCDEF","system-restart-from-boot-id":%q,"daemon-system-restart-at":"%s","daemon-system-restart-tentative":3},"changes":null,"tasks":null,"last-change-id":0,"last-task-id":0,"last-lane-id":0}`, patch.Level, patch.Sublevel, curBootID, time.Now().UTC().Format(time.RFC3339))) 1290 err = ioutil.WriteFile(dirs.SnapStateFile, fakeState, 0600) 1291 c.Assert(err, check.IsNil) 1292 1293 cmd := testutil.MockCommand(c, "shutdown", "") 1294 defer cmd.Restore() 1295 1296 d := newTestDaemon(c) 1297 c.Assert(d.overlord, check.NotNil) 1298 1299 st := d.overlord.State() 1300 st.Lock() 1301 defer st.Unlock() 1302 var v interface{} 1303 // these were cleared 1304 c.Check(st.Get("daemon-system-restart-at", &v), check.Equals, state.ErrNoState) 1305 c.Check(st.Get("system-restart-from-boot-id", &v), check.Equals, state.ErrNoState) 1306 c.Check(st.Get("daemon-system-restart-tentative", &v), check.Equals, state.ErrNoState) 1307 } 1308 1309 func (s *daemonSuite) TestRestartIntoSocketModeNoNewChanges(c *check.C) { 1310 restore := standby.MockStandbyWait(5 * time.Millisecond) 1311 defer restore() 1312 1313 d := newTestDaemon(c) 1314 makeDaemonListeners(c, d) 1315 1316 // mark as already seeded, we also have no snaps so this will 1317 // go into socket activation mode 1318 s.markSeeded(d) 1319 1320 c.Assert(d.Start(), check.IsNil) 1321 // pretend some ensure happened 1322 for i := 0; i < 5; i++ { 1323 c.Check(d.overlord.StateEngine().Ensure(), check.IsNil) 1324 time.Sleep(5 * time.Millisecond) 1325 } 1326 1327 select { 1328 case <-d.Dying(): 1329 // exit the loop 1330 case <-time.After(15 * time.Second): 1331 c.Errorf("daemon did not stop after 15s") 1332 } 1333 err := d.Stop(nil) 1334 c.Check(err, check.Equals, ErrRestartSocket) 1335 c.Check(d.restartSocket, check.Equals, true) 1336 } 1337 1338 func (s *daemonSuite) TestRestartIntoSocketModePendingChanges(c *check.C) { 1339 restore := standby.MockStandbyWait(5 * time.Millisecond) 1340 defer restore() 1341 1342 d := newTestDaemon(c) 1343 makeDaemonListeners(c, d) 1344 1345 // mark as already seeded, we also have no snaps so this will 1346 // go into socket activation mode 1347 s.markSeeded(d) 1348 st := d.overlord.State() 1349 1350 c.Assert(d.Start(), check.IsNil) 1351 // pretend some ensure happened 1352 for i := 0; i < 5; i++ { 1353 c.Check(d.overlord.StateEngine().Ensure(), check.IsNil) 1354 time.Sleep(5 * time.Millisecond) 1355 } 1356 1357 select { 1358 case <-d.Dying(): 1359 // Pretend we got change while shutting down, this can 1360 // happen when e.g. the user requested a `snap install 1361 // foo` at the same time as the code in the overlord 1362 // checked that it can go into socket activated 1363 // mode. I.e. the daemon was processing the request 1364 // but no change was generated at the time yet. 1365 st.Lock() 1366 chg := st.NewChange("fake-install", "fake install some snap") 1367 chg.AddTask(st.NewTask("fake-install-task", "fake install task")) 1368 chgStatus := chg.Status() 1369 st.Unlock() 1370 // ensure our change is valid and ready 1371 c.Check(chgStatus, check.Equals, state.DoStatus) 1372 case <-time.After(5 * time.Second): 1373 c.Errorf("daemon did not stop after 5s") 1374 } 1375 // when the daemon got a pending change it just restarts 1376 err := d.Stop(nil) 1377 c.Check(err, check.IsNil) 1378 c.Check(d.restartSocket, check.Equals, false) 1379 } 1380 1381 func (s *daemonSuite) TestConnTrackerCanShutdown(c *check.C) { 1382 ct := &connTracker{conns: make(map[net.Conn]struct{})} 1383 c.Check(ct.CanStandby(), check.Equals, true) 1384 1385 con := &net.IPConn{} 1386 ct.trackConn(con, http.StateActive) 1387 c.Check(ct.CanStandby(), check.Equals, false) 1388 1389 ct.trackConn(con, http.StateIdle) 1390 c.Check(ct.CanStandby(), check.Equals, true) 1391 } 1392 1393 func doTestReq(c *check.C, cmd *Command, mth string) *httptest.ResponseRecorder { 1394 req, err := http.NewRequest(mth, "", nil) 1395 c.Assert(err, check.IsNil) 1396 req.RemoteAddr = fmt.Sprintf("pid=100;uid=0;socket=%s;", dirs.SnapdSocket) 1397 rec := httptest.NewRecorder() 1398 cmd.ServeHTTP(rec, req) 1399 return rec 1400 } 1401 1402 func (s *daemonSuite) TestDegradedModeReply(c *check.C) { 1403 d := newTestDaemon(c) 1404 cmd := &Command{d: d} 1405 cmd.GET = func(*Command, *http.Request, *auth.UserState) Response { 1406 return SyncResponse(nil) 1407 } 1408 cmd.POST = func(*Command, *http.Request, *auth.UserState) Response { 1409 return SyncResponse(nil) 1410 } 1411 cmd.ReadAccess = authenticatedAccess{} 1412 cmd.WriteAccess = authenticatedAccess{} 1413 1414 // pretend we are in degraded mode 1415 d.SetDegradedMode(fmt.Errorf("foo error")) 1416 1417 // GET is ok even in degraded mode 1418 rec := doTestReq(c, cmd, "GET") 1419 c.Check(rec.Code, check.Equals, 200) 1420 // POST is not allowed 1421 rec = doTestReq(c, cmd, "POST") 1422 c.Check(rec.Code, check.Equals, 500) 1423 // verify we get the error 1424 var v struct{ Result errorResult } 1425 c.Assert(json.NewDecoder(rec.Body).Decode(&v), check.IsNil) 1426 c.Check(v.Result.Message, check.Equals, "foo error") 1427 1428 // clean degraded mode 1429 d.SetDegradedMode(nil) 1430 rec = doTestReq(c, cmd, "POST") 1431 c.Check(rec.Code, check.Equals, 200) 1432 }