github.com/freetocompute/snapd@v0.0.0-20210618182524-2fb355d72fd9/daemon/daemon_test.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2014-2021 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package daemon
    21  
    22  import (
    23  	"encoding/json"
    24  	"fmt"
    25  	"io/ioutil"
    26  	"net"
    27  	"net/http"
    28  	"net/http/httptest"
    29  	"os"
    30  	"path/filepath"
    31  	"sync"
    32  	"syscall"
    33  	"testing"
    34  	"time"
    35  
    36  	"github.com/gorilla/mux"
    37  	"gopkg.in/check.v1"
    38  
    39  	"github.com/snapcore/snapd/client"
    40  	"github.com/snapcore/snapd/dirs"
    41  	"github.com/snapcore/snapd/osutil"
    42  	"github.com/snapcore/snapd/overlord"
    43  	"github.com/snapcore/snapd/overlord/auth"
    44  	"github.com/snapcore/snapd/overlord/devicestate/devicestatetest"
    45  	"github.com/snapcore/snapd/overlord/ifacestate"
    46  	"github.com/snapcore/snapd/overlord/patch"
    47  	"github.com/snapcore/snapd/overlord/snapstate"
    48  	"github.com/snapcore/snapd/overlord/standby"
    49  	"github.com/snapcore/snapd/overlord/state"
    50  	"github.com/snapcore/snapd/polkit"
    51  	"github.com/snapcore/snapd/snap"
    52  	"github.com/snapcore/snapd/store"
    53  	"github.com/snapcore/snapd/systemd"
    54  	"github.com/snapcore/snapd/testutil"
    55  )
    56  
    57  // Hook up check.v1 into the "go test" runner
    58  func Test(t *testing.T) { check.TestingT(t) }
    59  
    60  type daemonSuite struct {
    61  	testutil.BaseTest
    62  
    63  	authorized      bool
    64  	err             error
    65  	lastPolkitFlags polkit.CheckFlags
    66  	notified        []string
    67  }
    68  
    69  var _ = check.Suite(&daemonSuite{})
    70  
    71  func (s *daemonSuite) SetUpTest(c *check.C) {
    72  	s.BaseTest.SetUpTest(c)
    73  
    74  	dirs.SetRootDir(c.MkDir())
    75  	s.AddCleanup(osutil.MockMountInfo(""))
    76  
    77  	err := os.MkdirAll(filepath.Dir(dirs.SnapStateFile), 0755)
    78  	c.Assert(err, check.IsNil)
    79  	systemdSdNotify = func(notif string) error {
    80  		s.notified = append(s.notified, notif)
    81  		return nil
    82  	}
    83  	s.notified = nil
    84  	s.AddCleanup(ifacestate.MockSecurityBackends(nil))
    85  }
    86  
    87  func (s *daemonSuite) TearDownTest(c *check.C) {
    88  	systemdSdNotify = systemd.SdNotify
    89  	dirs.SetRootDir("")
    90  	s.authorized = false
    91  	s.err = nil
    92  
    93  	s.BaseTest.TearDownTest(c)
    94  }
    95  
    96  // build a new daemon, with only a little of Init(), suitable for the tests
    97  func newTestDaemon(c *check.C) *Daemon {
    98  	d, err := New()
    99  	c.Assert(err, check.IsNil)
   100  	d.addRoutes()
   101  
   102  	// don't actually try to talk to the store on snapstate.Ensure
   103  	// needs doing after the call to devicestate.Manager (which
   104  	// happens in daemon.New via overlord.New)
   105  	snapstate.CanAutoRefresh = nil
   106  
   107  	return d
   108  }
   109  
   110  // a Response suitable for testing
   111  type mockHandler struct {
   112  	cmd        *Command
   113  	lastMethod string
   114  }
   115  
   116  func (mck *mockHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   117  	mck.lastMethod = r.Method
   118  }
   119  
   120  func (s *daemonSuite) TestCommandMethodDispatch(c *check.C) {
   121  	d := newTestDaemon(c)
   122  	st := d.Overlord().State()
   123  	st.Lock()
   124  	authUser, err := auth.NewUser(st, "username", "email@test.com", "macaroon", []string{"discharge"})
   125  	st.Unlock()
   126  	c.Assert(err, check.IsNil)
   127  
   128  	fakeUserAgent := "some-agent-talking-to-snapd/1.0"
   129  
   130  	cmd := &Command{d: d}
   131  	mck := &mockHandler{cmd: cmd}
   132  	rf := func(innerCmd *Command, req *http.Request, user *auth.UserState) Response {
   133  		c.Assert(cmd, check.Equals, innerCmd)
   134  		c.Check(store.ClientUserAgent(req.Context()), check.Equals, fakeUserAgent)
   135  		c.Check(user, check.DeepEquals, authUser)
   136  		return mck
   137  	}
   138  	cmd.GET = rf
   139  	cmd.PUT = rf
   140  	cmd.POST = rf
   141  	cmd.ReadAccess = authenticatedAccess{}
   142  	cmd.WriteAccess = authenticatedAccess{}
   143  
   144  	for _, method := range []string{"GET", "POST", "PUT"} {
   145  		req, err := http.NewRequest(method, "", nil)
   146  		req.Header.Add("User-Agent", fakeUserAgent)
   147  		c.Assert(err, check.IsNil)
   148  
   149  		rec := httptest.NewRecorder()
   150  		req.RemoteAddr = fmt.Sprintf("pid=100;uid=1001;socket=%s;", dirs.SnapdSocket)
   151  		cmd.ServeHTTP(rec, req)
   152  		c.Check(rec.Code, check.Equals, 401, check.Commentf(method))
   153  
   154  		rec = httptest.NewRecorder()
   155  		req.Header.Set("Authorization", fmt.Sprintf(`Macaroon root="%s"`, authUser.Macaroon))
   156  
   157  		cmd.ServeHTTP(rec, req)
   158  		c.Check(mck.lastMethod, check.Equals, method)
   159  		c.Check(rec.Code, check.Equals, 200)
   160  	}
   161  
   162  	req, err := http.NewRequest("POTATO", "", nil)
   163  	c.Assert(err, check.IsNil)
   164  	req.RemoteAddr = fmt.Sprintf("pid=100;uid=1001;socket=%s;", dirs.SnapdSocket)
   165  	req.Header.Set("Authorization", fmt.Sprintf(`Macaroon root="%s"`, authUser.Macaroon))
   166  	rec := httptest.NewRecorder()
   167  	cmd.ServeHTTP(rec, req)
   168  	c.Check(rec.Code, check.Equals, 405)
   169  }
   170  
   171  func (s *daemonSuite) TestCommandMethodDispatchRoot(c *check.C) {
   172  	fakeUserAgent := "some-agent-talking-to-snapd/1.0"
   173  
   174  	cmd := &Command{d: newTestDaemon(c)}
   175  	mck := &mockHandler{cmd: cmd}
   176  	rf := func(innerCmd *Command, req *http.Request, user *auth.UserState) Response {
   177  		c.Assert(cmd, check.Equals, innerCmd)
   178  		c.Check(store.ClientUserAgent(req.Context()), check.Equals, fakeUserAgent)
   179  		return mck
   180  	}
   181  	cmd.GET = rf
   182  	cmd.PUT = rf
   183  	cmd.POST = rf
   184  	cmd.ReadAccess = authenticatedAccess{}
   185  	cmd.WriteAccess = authenticatedAccess{}
   186  
   187  	for _, method := range []string{"GET", "POST", "PUT"} {
   188  		req, err := http.NewRequest(method, "", nil)
   189  		req.Header.Add("User-Agent", fakeUserAgent)
   190  		c.Assert(err, check.IsNil)
   191  
   192  		rec := httptest.NewRecorder()
   193  		// no ucred => forbidden
   194  		cmd.ServeHTTP(rec, req)
   195  		c.Check(rec.Code, check.Equals, 403, check.Commentf(method))
   196  
   197  		rec = httptest.NewRecorder()
   198  		req.RemoteAddr = fmt.Sprintf("pid=100;uid=0;socket=%s;", dirs.SnapdSocket)
   199  
   200  		cmd.ServeHTTP(rec, req)
   201  		c.Check(mck.lastMethod, check.Equals, method)
   202  		c.Check(rec.Code, check.Equals, 200)
   203  	}
   204  
   205  	req, err := http.NewRequest("POTATO", "", nil)
   206  	c.Assert(err, check.IsNil)
   207  	req.RemoteAddr = fmt.Sprintf("pid=100;uid=0;socket=%s;", dirs.SnapdSocket)
   208  
   209  	rec := httptest.NewRecorder()
   210  	cmd.ServeHTTP(rec, req)
   211  	c.Check(rec.Code, check.Equals, 405)
   212  }
   213  
   214  func (s *daemonSuite) TestCommandRestartingState(c *check.C) {
   215  	d := newTestDaemon(c)
   216  
   217  	cmd := &Command{d: d}
   218  	cmd.GET = func(*Command, *http.Request, *auth.UserState) Response {
   219  		return SyncResponse(nil)
   220  	}
   221  	cmd.ReadAccess = openAccess{}
   222  	req, err := http.NewRequest("GET", "", nil)
   223  	c.Assert(err, check.IsNil)
   224  	req.RemoteAddr = fmt.Sprintf("pid=100;uid=42;socket=%s;", dirs.SnapdSocket)
   225  
   226  	rec := httptest.NewRecorder()
   227  	cmd.ServeHTTP(rec, req)
   228  	c.Check(rec.Code, check.Equals, 200)
   229  	var rst struct {
   230  		Maintenance *errorResult `json:"maintenance"`
   231  	}
   232  	err = json.Unmarshal(rec.Body.Bytes(), &rst)
   233  	c.Assert(err, check.IsNil)
   234  	c.Check(rst.Maintenance, check.IsNil)
   235  
   236  	tests := []struct {
   237  		rst  state.RestartType
   238  		kind client.ErrorKind
   239  		msg  string
   240  		op   string
   241  	}{
   242  		{
   243  			rst:  state.RestartSystem,
   244  			kind: client.ErrorKindSystemRestart,
   245  			msg:  "system is restarting",
   246  			op:   "reboot",
   247  		}, {
   248  			rst:  state.RestartSystemNow,
   249  			kind: client.ErrorKindSystemRestart,
   250  			msg:  "system is restarting",
   251  			op:   "reboot",
   252  		}, {
   253  			rst:  state.RestartDaemon,
   254  			kind: client.ErrorKindDaemonRestart,
   255  			msg:  "daemon is restarting",
   256  		}, {
   257  			rst:  state.RestartSystemHaltNow,
   258  			kind: client.ErrorKindSystemRestart,
   259  			msg:  "system is halting",
   260  			op:   "halt",
   261  		}, {
   262  			rst:  state.RestartSystemPoweroffNow,
   263  			kind: client.ErrorKindSystemRestart,
   264  			msg:  "system is powering off",
   265  			op:   "poweroff",
   266  		}, {
   267  			rst:  state.RestartSocket,
   268  			kind: client.ErrorKindDaemonRestart,
   269  			msg:  "daemon is stopping to wait for socket activation",
   270  		},
   271  	}
   272  
   273  	for _, t := range tests {
   274  		state.MockRestarting(d.overlord.State(), t.rst)
   275  		rec = httptest.NewRecorder()
   276  		cmd.ServeHTTP(rec, req)
   277  		c.Check(rec.Code, check.Equals, 200)
   278  		var rst struct {
   279  			Maintenance *errorResult `json:"maintenance"`
   280  		}
   281  		err = json.Unmarshal(rec.Body.Bytes(), &rst)
   282  		c.Assert(err, check.IsNil)
   283  		var val errorValue
   284  		if t.op != "" {
   285  			val = map[string]interface{}{
   286  				"op": t.op,
   287  			}
   288  		}
   289  		c.Check(rst.Maintenance, check.DeepEquals, &errorResult{
   290  			Kind:    t.kind,
   291  			Message: t.msg,
   292  			Value:   val,
   293  		})
   294  	}
   295  }
   296  
   297  func (s *daemonSuite) TestMaintenanceJsonDeletedOnStart(c *check.C) {
   298  	// write a maintenance.json file that has that the system is restarting
   299  	maintErr := &errorResult{
   300  		Kind:    client.ErrorKindDaemonRestart,
   301  		Message: systemRestartMsg,
   302  	}
   303  
   304  	b, err := json.Marshal(maintErr)
   305  	c.Assert(err, check.IsNil)
   306  	c.Assert(os.MkdirAll(filepath.Dir(dirs.SnapdMaintenanceFile), 0755), check.IsNil)
   307  	c.Assert(ioutil.WriteFile(dirs.SnapdMaintenanceFile, b, 0644), check.IsNil)
   308  
   309  	d := newTestDaemon(c)
   310  	makeDaemonListeners(c, d)
   311  
   312  	s.markSeeded(d)
   313  
   314  	// after starting, maintenance.json should be removed
   315  	c.Assert(d.Start(), check.IsNil)
   316  	c.Assert(dirs.SnapdMaintenanceFile, testutil.FileAbsent)
   317  	d.Stop(nil)
   318  }
   319  
   320  func (s *daemonSuite) TestFillsWarnings(c *check.C) {
   321  	d := newTestDaemon(c)
   322  
   323  	cmd := &Command{d: d}
   324  	cmd.GET = func(*Command, *http.Request, *auth.UserState) Response {
   325  		return SyncResponse(nil)
   326  	}
   327  	cmd.ReadAccess = openAccess{}
   328  	req, err := http.NewRequest("GET", "", nil)
   329  	c.Assert(err, check.IsNil)
   330  	req.RemoteAddr = fmt.Sprintf("pid=100;uid=42;socket=%s;", dirs.SnapdSocket)
   331  
   332  	rec := httptest.NewRecorder()
   333  	cmd.ServeHTTP(rec, req)
   334  	c.Check(rec.Code, check.Equals, 200)
   335  	var rst struct {
   336  		WarningTimestamp *time.Time `json:"warning-timestamp,omitempty"`
   337  		WarningCount     int        `json:"warning-count,omitempty"`
   338  	}
   339  	err = json.Unmarshal(rec.Body.Bytes(), &rst)
   340  	c.Assert(err, check.IsNil)
   341  	c.Check(rst.WarningCount, check.Equals, 0)
   342  	c.Check(rst.WarningTimestamp, check.IsNil)
   343  
   344  	st := d.overlord.State()
   345  	st.Lock()
   346  	st.Warnf("hello world")
   347  	st.Unlock()
   348  
   349  	rec = httptest.NewRecorder()
   350  	cmd.ServeHTTP(rec, req)
   351  	c.Check(rec.Code, check.Equals, 200)
   352  	err = json.Unmarshal(rec.Body.Bytes(), &rst)
   353  	c.Assert(err, check.IsNil)
   354  	c.Check(rst.WarningCount, check.Equals, 1)
   355  	c.Check(rst.WarningTimestamp, check.NotNil)
   356  }
   357  
   358  type accessCheckFunc func(d *Daemon, r *http.Request, ucred *ucrednet, user *auth.UserState) *apiError
   359  
   360  func (f accessCheckFunc) CheckAccess(d *Daemon, r *http.Request, ucred *ucrednet, user *auth.UserState) *apiError {
   361  	return f(d, r, ucred, user)
   362  }
   363  
   364  func (s *daemonSuite) TestReadAccess(c *check.C) {
   365  	cmd := &Command{d: newTestDaemon(c)}
   366  	cmd.GET = func(*Command, *http.Request, *auth.UserState) Response {
   367  		return SyncResponse(nil)
   368  	}
   369  	var accessCalled bool
   370  	cmd.ReadAccess = accessCheckFunc(func(d *Daemon, r *http.Request, ucred *ucrednet, user *auth.UserState) *apiError {
   371  		accessCalled = true
   372  		c.Check(d, check.Equals, cmd.d)
   373  		c.Check(r, check.NotNil)
   374  		c.Assert(ucred, check.NotNil)
   375  		c.Check(ucred.Uid, check.Equals, uint32(42))
   376  		c.Check(ucred.Pid, check.Equals, int32(100))
   377  		c.Check(ucred.Socket, check.Equals, "xyz")
   378  		c.Check(user, check.IsNil)
   379  		return nil
   380  	})
   381  	cmd.WriteAccess = accessCheckFunc(func(d *Daemon, r *http.Request, ucred *ucrednet, user *auth.UserState) *apiError {
   382  		c.Fail()
   383  		return Forbidden("")
   384  	})
   385  
   386  	req := httptest.NewRequest("GET", "/", nil)
   387  	req.RemoteAddr = "pid=100;uid=42;socket=xyz;"
   388  	rec := httptest.NewRecorder()
   389  	cmd.ServeHTTP(rec, req)
   390  	c.Check(rec.Code, check.Equals, 200)
   391  	c.Check(accessCalled, check.Equals, true)
   392  }
   393  
   394  func (s *daemonSuite) TestWriteAccess(c *check.C) {
   395  	cmd := &Command{d: newTestDaemon(c)}
   396  	cmd.PUT = func(*Command, *http.Request, *auth.UserState) Response {
   397  		return SyncResponse(nil)
   398  	}
   399  	cmd.POST = func(*Command, *http.Request, *auth.UserState) Response {
   400  		return SyncResponse(nil)
   401  	}
   402  	cmd.ReadAccess = accessCheckFunc(func(d *Daemon, r *http.Request, ucred *ucrednet, user *auth.UserState) *apiError {
   403  		c.Fail()
   404  		return Forbidden("")
   405  	})
   406  	var accessCalled bool
   407  	cmd.WriteAccess = accessCheckFunc(func(d *Daemon, r *http.Request, ucred *ucrednet, user *auth.UserState) *apiError {
   408  		accessCalled = true
   409  		c.Check(d, check.Equals, cmd.d)
   410  		c.Check(r, check.NotNil)
   411  		c.Assert(ucred, check.NotNil)
   412  		c.Check(ucred.Uid, check.Equals, uint32(42))
   413  		c.Check(ucred.Pid, check.Equals, int32(100))
   414  		c.Check(ucred.Socket, check.Equals, "xyz")
   415  		c.Check(user, check.IsNil)
   416  		return nil
   417  	})
   418  
   419  	req := httptest.NewRequest("PUT", "/", nil)
   420  	req.RemoteAddr = "pid=100;uid=42;socket=xyz;"
   421  	rec := httptest.NewRecorder()
   422  	cmd.ServeHTTP(rec, req)
   423  	c.Check(rec.Code, check.Equals, 200)
   424  	c.Check(accessCalled, check.Equals, true)
   425  
   426  	accessCalled = false
   427  	req = httptest.NewRequest("POST", "/", nil)
   428  	req.RemoteAddr = "pid=100;uid=42;socket=xyz;"
   429  	rec = httptest.NewRecorder()
   430  	cmd.ServeHTTP(rec, req)
   431  	c.Check(rec.Code, check.Equals, 200)
   432  	c.Check(accessCalled, check.Equals, true)
   433  }
   434  
   435  func (s *daemonSuite) TestWriteAccessWithUser(c *check.C) {
   436  	d := newTestDaemon(c)
   437  	st := d.Overlord().State()
   438  	st.Lock()
   439  	authUser, err := auth.NewUser(st, "username", "email@test.com", "macaroon", []string{"discharge"})
   440  	st.Unlock()
   441  	c.Assert(err, check.IsNil)
   442  
   443  	cmd := &Command{d: d}
   444  	cmd.PUT = func(*Command, *http.Request, *auth.UserState) Response {
   445  		return SyncResponse(nil)
   446  	}
   447  	cmd.POST = func(*Command, *http.Request, *auth.UserState) Response {
   448  		return SyncResponse(nil)
   449  	}
   450  	cmd.ReadAccess = accessCheckFunc(func(d *Daemon, r *http.Request, ucred *ucrednet, user *auth.UserState) *apiError {
   451  		c.Fail()
   452  		return Forbidden("")
   453  	})
   454  	var accessCalled bool
   455  	cmd.WriteAccess = accessCheckFunc(func(d *Daemon, r *http.Request, ucred *ucrednet, user *auth.UserState) *apiError {
   456  		accessCalled = true
   457  		c.Check(d, check.Equals, cmd.d)
   458  		c.Check(r, check.NotNil)
   459  		c.Assert(ucred, check.NotNil)
   460  		c.Check(ucred.Uid, check.Equals, uint32(1001))
   461  		c.Check(ucred.Pid, check.Equals, int32(100))
   462  		c.Check(ucred.Socket, check.Equals, "xyz")
   463  		c.Check(user, check.DeepEquals, authUser)
   464  		return nil
   465  	})
   466  
   467  	req := httptest.NewRequest("PUT", "/", nil)
   468  	req.Header.Set("Authorization", fmt.Sprintf(`Macaroon root="%s"`, authUser.Macaroon))
   469  	req.RemoteAddr = "pid=100;uid=1001;socket=xyz;"
   470  	rec := httptest.NewRecorder()
   471  	cmd.ServeHTTP(rec, req)
   472  	c.Check(rec.Code, check.Equals, 200)
   473  	c.Check(accessCalled, check.Equals, true)
   474  
   475  	accessCalled = false
   476  	req = httptest.NewRequest("POST", "/", nil)
   477  	req.Header.Set("Authorization", fmt.Sprintf(`Macaroon root="%s"`, authUser.Macaroon))
   478  	req.RemoteAddr = "pid=100;uid=1001;socket=xyz;"
   479  	rec = httptest.NewRecorder()
   480  	cmd.ServeHTTP(rec, req)
   481  	c.Check(rec.Code, check.Equals, 200)
   482  	c.Check(accessCalled, check.Equals, true)
   483  }
   484  
   485  func (s *daemonSuite) TestPolkitAccessPath(c *check.C) {
   486  	cmd := &Command{d: newTestDaemon(c)}
   487  	cmd.POST = func(*Command, *http.Request, *auth.UserState) Response {
   488  		return SyncResponse(nil)
   489  	}
   490  	access := false
   491  	cmd.WriteAccess = authenticatedAccess{Polkit: "foo"}
   492  	checkPolkitAction = func(r *http.Request, ucred *ucrednet, action string) *apiError {
   493  		c.Check(action, check.Equals, "foo")
   494  		c.Check(ucred.Uid, check.Equals, uint32(1001))
   495  		if access {
   496  			return nil
   497  		}
   498  		return AuthCancelled("")
   499  	}
   500  
   501  	req := httptest.NewRequest("POST", "/", nil)
   502  	req.RemoteAddr = fmt.Sprintf("pid=100;uid=1001;socket=%s;", dirs.SnapdSocket)
   503  	rec := httptest.NewRecorder()
   504  	cmd.ServeHTTP(rec, req)
   505  	c.Check(rec.Code, check.Equals, 403)
   506  	c.Check(rec.Body.String(), testutil.Contains, `"kind":"auth-cancelled"`)
   507  
   508  	access = true
   509  	rec = httptest.NewRecorder()
   510  	cmd.ServeHTTP(rec, req)
   511  	c.Check(rec.Code, check.Equals, 200)
   512  }
   513  
   514  func (s *daemonSuite) TestCommandAccessSane(c *check.C) {
   515  	for _, cmd := range api {
   516  		// If Command.GET is set, ReadAccess must be set
   517  		c.Check(cmd.GET != nil, check.Equals, cmd.ReadAccess != nil, check.Commentf("%q ReadAccess", cmd.Path))
   518  		// If Command.PUT or POST are set, WriteAccess must be set
   519  		c.Check(cmd.PUT != nil || cmd.POST != nil, check.Equals, cmd.WriteAccess != nil, check.Commentf("%q WriteAccess", cmd.Path))
   520  	}
   521  }
   522  
   523  func (s *daemonSuite) TestAddRoutes(c *check.C) {
   524  	d := newTestDaemon(c)
   525  
   526  	expected := make([]string, len(api))
   527  	for i, v := range api {
   528  		if v.PathPrefix != "" {
   529  			expected[i] = v.PathPrefix
   530  			continue
   531  		}
   532  		expected[i] = v.Path
   533  	}
   534  
   535  	got := make([]string, 0, len(api))
   536  	c.Assert(d.router.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
   537  		got = append(got, route.GetName())
   538  		return nil
   539  	}), check.IsNil)
   540  
   541  	c.Check(got, check.DeepEquals, expected) // this'll stop being true if routes are added that aren't commands (e.g. for the favicon)
   542  
   543  	// XXX: still waiting to know how to check d.router.NotFoundHandler has been set to NotFound
   544  	//      the old test relied on undefined behaviour:
   545  	//      c.Check(fmt.Sprintf("%p", d.router.NotFoundHandler), check.Equals, fmt.Sprintf("%p", NotFound))
   546  }
   547  
   548  type witnessAcceptListener struct {
   549  	net.Listener
   550  
   551  	accept  chan struct{}
   552  	accept1 bool
   553  
   554  	idempotClose sync.Once
   555  	closeErr     error
   556  	closed       chan struct{}
   557  }
   558  
   559  func (l *witnessAcceptListener) Accept() (net.Conn, error) {
   560  	if !l.accept1 {
   561  		l.accept1 = true
   562  		close(l.accept)
   563  	}
   564  	return l.Listener.Accept()
   565  }
   566  
   567  func (l *witnessAcceptListener) Close() error {
   568  	l.idempotClose.Do(func() {
   569  		l.closeErr = l.Listener.Close()
   570  		if l.closed != nil {
   571  			close(l.closed)
   572  		}
   573  	})
   574  	return l.closeErr
   575  }
   576  
   577  func (s *daemonSuite) markSeeded(d *Daemon) {
   578  	st := d.overlord.State()
   579  	st.Lock()
   580  	st.Set("seeded", true)
   581  	devicestatetest.SetDevice(st, &auth.DeviceState{
   582  		Brand:  "canonical",
   583  		Model:  "pc",
   584  		Serial: "serialserial",
   585  	})
   586  	st.Unlock()
   587  }
   588  
   589  func (s *daemonSuite) TestStartStop(c *check.C) {
   590  	d := newTestDaemon(c)
   591  	// mark as already seeded
   592  	s.markSeeded(d)
   593  	// and pretend we have snaps
   594  	st := d.overlord.State()
   595  	st.Lock()
   596  	snapstate.Set(st, "core", &snapstate.SnapState{
   597  		Active: true,
   598  		Sequence: []*snap.SideInfo{
   599  			{RealName: "core", Revision: snap.R(1), SnapID: "core-snap-id"},
   600  		},
   601  		Current: snap.R(1),
   602  	})
   603  	st.Unlock()
   604  	// 1 snap => extended timeout 30s + 5s
   605  	const extendedTimeoutUSec = "EXTEND_TIMEOUT_USEC=35000000"
   606  
   607  	l1, err := net.Listen("tcp", "127.0.0.1:0")
   608  	c.Assert(err, check.IsNil)
   609  	l2, err := net.Listen("tcp", "127.0.0.1:0")
   610  	c.Assert(err, check.IsNil)
   611  
   612  	snapdAccept := make(chan struct{})
   613  	d.snapdListener = &witnessAcceptListener{Listener: l1, accept: snapdAccept}
   614  
   615  	snapAccept := make(chan struct{})
   616  	d.snapListener = &witnessAcceptListener{Listener: l2, accept: snapAccept}
   617  
   618  	c.Assert(d.Start(), check.IsNil)
   619  
   620  	c.Check(s.notified, check.DeepEquals, []string{extendedTimeoutUSec, "READY=1"})
   621  
   622  	snapdDone := make(chan struct{})
   623  	go func() {
   624  		select {
   625  		case <-snapdAccept:
   626  		case <-time.After(2 * time.Second):
   627  			c.Fatal("snapd accept was not called")
   628  		}
   629  		close(snapdDone)
   630  	}()
   631  
   632  	snapDone := make(chan struct{})
   633  	go func() {
   634  		select {
   635  		case <-snapAccept:
   636  		case <-time.After(2 * time.Second):
   637  			c.Fatal("snapd accept was not called")
   638  		}
   639  		close(snapDone)
   640  	}()
   641  
   642  	<-snapdDone
   643  	<-snapDone
   644  
   645  	err = d.Stop(nil)
   646  	c.Check(err, check.IsNil)
   647  
   648  	c.Check(s.notified, check.DeepEquals, []string{extendedTimeoutUSec, "READY=1", "STOPPING=1"})
   649  }
   650  
   651  func (s *daemonSuite) TestRestartWiring(c *check.C) {
   652  	d := newTestDaemon(c)
   653  	// mark as already seeded
   654  	s.markSeeded(d)
   655  
   656  	l, err := net.Listen("tcp", "127.0.0.1:0")
   657  	c.Assert(err, check.IsNil)
   658  
   659  	snapdAccept := make(chan struct{})
   660  	d.snapdListener = &witnessAcceptListener{Listener: l, accept: snapdAccept}
   661  
   662  	snapAccept := make(chan struct{})
   663  	d.snapListener = &witnessAcceptListener{Listener: l, accept: snapAccept}
   664  
   665  	c.Assert(d.Start(), check.IsNil)
   666  	stoppedYet := false
   667  	defer func() {
   668  		if !stoppedYet {
   669  			d.Stop(nil)
   670  		}
   671  	}()
   672  
   673  	snapdDone := make(chan struct{})
   674  	go func() {
   675  		select {
   676  		case <-snapdAccept:
   677  		case <-time.After(2 * time.Second):
   678  			c.Fatal("snapd accept was not called")
   679  		}
   680  		close(snapdDone)
   681  	}()
   682  
   683  	snapDone := make(chan struct{})
   684  	go func() {
   685  		select {
   686  		case <-snapAccept:
   687  		case <-time.After(2 * time.Second):
   688  			c.Fatal("snap accept was not called")
   689  		}
   690  		close(snapDone)
   691  	}()
   692  
   693  	<-snapdDone
   694  	<-snapDone
   695  
   696  	d.overlord.State().RequestRestart(state.RestartDaemon)
   697  
   698  	select {
   699  	case <-d.Dying():
   700  	case <-time.After(2 * time.Second):
   701  		c.Fatal("RequestRestart -> overlord -> Kill chain didn't work")
   702  	}
   703  
   704  	d.Stop(nil)
   705  	stoppedYet = true
   706  
   707  	c.Assert(s.notified, check.DeepEquals, []string{"EXTEND_TIMEOUT_USEC=30000000", "READY=1", "STOPPING=1"})
   708  }
   709  
   710  func (s *daemonSuite) TestGracefulStop(c *check.C) {
   711  	d := newTestDaemon(c)
   712  
   713  	responding := make(chan struct{})
   714  	doRespond := make(chan bool, 1)
   715  
   716  	d.router.HandleFunc("/endp", func(w http.ResponseWriter, r *http.Request) {
   717  		close(responding)
   718  		if <-doRespond {
   719  			w.Write([]byte("OKOK"))
   720  		} else {
   721  			w.Write([]byte("Gone"))
   722  		}
   723  	})
   724  
   725  	// mark as already seeded
   726  	s.markSeeded(d)
   727  	// and pretend we have snaps
   728  	st := d.overlord.State()
   729  	st.Lock()
   730  	snapstate.Set(st, "core", &snapstate.SnapState{
   731  		Active: true,
   732  		Sequence: []*snap.SideInfo{
   733  			{RealName: "core", Revision: snap.R(1), SnapID: "core-snap-id"},
   734  		},
   735  		Current: snap.R(1),
   736  	})
   737  	st.Unlock()
   738  
   739  	snapdL, err := net.Listen("tcp", "127.0.0.1:0")
   740  	c.Assert(err, check.IsNil)
   741  
   742  	snapL, err := net.Listen("tcp", "127.0.0.1:0")
   743  	c.Assert(err, check.IsNil)
   744  
   745  	snapdAccept := make(chan struct{})
   746  	snapdClosed := make(chan struct{})
   747  	d.snapdListener = &witnessAcceptListener{Listener: snapdL, accept: snapdAccept, closed: snapdClosed}
   748  
   749  	snapAccept := make(chan struct{})
   750  	d.snapListener = &witnessAcceptListener{Listener: snapL, accept: snapAccept}
   751  
   752  	c.Assert(d.Start(), check.IsNil)
   753  
   754  	snapdAccepting := make(chan struct{})
   755  	go func() {
   756  		select {
   757  		case <-snapdAccept:
   758  		case <-time.After(2 * time.Second):
   759  			c.Fatal("snapd accept was not called")
   760  		}
   761  		close(snapdAccepting)
   762  	}()
   763  
   764  	snapAccepting := make(chan struct{})
   765  	go func() {
   766  		select {
   767  		case <-snapAccept:
   768  		case <-time.After(2 * time.Second):
   769  			c.Fatal("snapd accept was not called")
   770  		}
   771  		close(snapAccepting)
   772  	}()
   773  
   774  	<-snapdAccepting
   775  	<-snapAccepting
   776  
   777  	alright := make(chan struct{})
   778  
   779  	go func() {
   780  		res, err := http.Get(fmt.Sprintf("http://%s/endp", snapdL.Addr()))
   781  		c.Assert(err, check.IsNil)
   782  		c.Check(res.StatusCode, check.Equals, 200)
   783  		body, err := ioutil.ReadAll(res.Body)
   784  		res.Body.Close()
   785  		c.Assert(err, check.IsNil)
   786  		c.Check(string(body), check.Equals, "OKOK")
   787  		close(alright)
   788  	}()
   789  	go func() {
   790  		<-snapdClosed
   791  		time.Sleep(200 * time.Millisecond)
   792  		doRespond <- true
   793  	}()
   794  
   795  	<-responding
   796  	err = d.Stop(nil)
   797  	doRespond <- false
   798  	c.Check(err, check.IsNil)
   799  
   800  	select {
   801  	case <-alright:
   802  	case <-time.After(2 * time.Second):
   803  		c.Fatal("never got proper response")
   804  	}
   805  }
   806  
   807  func (s *daemonSuite) TestGracefulStopHasLimits(c *check.C) {
   808  	d := newTestDaemon(c)
   809  
   810  	// mark as already seeded
   811  	s.markSeeded(d)
   812  
   813  	restore := MockShutdownTimeout(time.Second)
   814  	defer restore()
   815  
   816  	responding := make(chan struct{})
   817  	doRespond := make(chan bool, 1)
   818  
   819  	d.router.HandleFunc("/endp", func(w http.ResponseWriter, r *http.Request) {
   820  		close(responding)
   821  		if <-doRespond {
   822  			for {
   823  				// write in a loop to keep the handler running
   824  				if _, err := w.Write([]byte("OKOK")); err != nil {
   825  					break
   826  				}
   827  				time.Sleep(50 * time.Millisecond)
   828  			}
   829  		} else {
   830  			w.Write([]byte("Gone"))
   831  		}
   832  	})
   833  
   834  	snapdL, err := net.Listen("tcp", "127.0.0.1:0")
   835  	c.Assert(err, check.IsNil)
   836  
   837  	snapL, err := net.Listen("tcp", "127.0.0.1:0")
   838  	c.Assert(err, check.IsNil)
   839  
   840  	snapdAccept := make(chan struct{})
   841  	snapdClosed := make(chan struct{})
   842  	d.snapdListener = &witnessAcceptListener{Listener: snapdL, accept: snapdAccept, closed: snapdClosed}
   843  
   844  	snapAccept := make(chan struct{})
   845  	d.snapListener = &witnessAcceptListener{Listener: snapL, accept: snapAccept}
   846  
   847  	c.Assert(d.Start(), check.IsNil)
   848  
   849  	snapdAccepting := make(chan struct{})
   850  	go func() {
   851  		select {
   852  		case <-snapdAccept:
   853  		case <-time.After(2 * time.Second):
   854  			c.Fatal("snapd accept was not called")
   855  		}
   856  		close(snapdAccepting)
   857  	}()
   858  
   859  	snapAccepting := make(chan struct{})
   860  	go func() {
   861  		select {
   862  		case <-snapAccept:
   863  		case <-time.After(2 * time.Second):
   864  			c.Fatal("snapd accept was not called")
   865  		}
   866  		close(snapAccepting)
   867  	}()
   868  
   869  	<-snapdAccepting
   870  	<-snapAccepting
   871  
   872  	clientErr := make(chan error)
   873  
   874  	go func() {
   875  		_, err := http.Get(fmt.Sprintf("http://%s/endp", snapdL.Addr()))
   876  		c.Assert(err, check.NotNil)
   877  		clientErr <- err
   878  		close(clientErr)
   879  	}()
   880  	go func() {
   881  		<-snapdClosed
   882  		time.Sleep(200 * time.Millisecond)
   883  		doRespond <- true
   884  	}()
   885  
   886  	<-responding
   887  	err = d.Stop(nil)
   888  	doRespond <- false
   889  	c.Check(err, check.IsNil)
   890  
   891  	select {
   892  	case cErr := <-clientErr:
   893  		c.Check(cErr, check.ErrorMatches, ".*: EOF")
   894  	case <-time.After(5 * time.Second):
   895  		c.Fatal("never got proper response")
   896  	}
   897  }
   898  
   899  func (s *daemonSuite) testRestartSystemWiring(c *check.C, prep func(d *Daemon), restart func(*state.State, state.RestartType), restartKind state.RestartType, wait time.Duration) {
   900  	d := newTestDaemon(c)
   901  	// mark as already seeded
   902  	s.markSeeded(d)
   903  
   904  	if prep != nil {
   905  		prep(d)
   906  	}
   907  
   908  	l, err := net.Listen("tcp", "127.0.0.1:0")
   909  	c.Assert(err, check.IsNil)
   910  
   911  	snapdAccept := make(chan struct{})
   912  	d.snapdListener = &witnessAcceptListener{Listener: l, accept: snapdAccept}
   913  
   914  	snapAccept := make(chan struct{})
   915  	d.snapListener = &witnessAcceptListener{Listener: l, accept: snapAccept}
   916  
   917  	oldRebootNoticeWait := rebootNoticeWait
   918  	oldRebootWaitTimeout := rebootWaitTimeout
   919  	defer func() {
   920  		reboot = rebootImpl
   921  		rebootNoticeWait = oldRebootNoticeWait
   922  		rebootWaitTimeout = oldRebootWaitTimeout
   923  	}()
   924  	rebootWaitTimeout = 100 * time.Millisecond
   925  	rebootNoticeWait = 150 * time.Millisecond
   926  
   927  	expectedAction := rebootReboot
   928  	expectedOp := "reboot"
   929  	if restartKind == state.RestartSystemHaltNow {
   930  		expectedAction = rebootHalt
   931  		expectedOp = "halt"
   932  	} else if restartKind == state.RestartSystemPoweroffNow {
   933  		expectedAction = rebootPoweroff
   934  		expectedOp = "poweroff"
   935  	}
   936  	var delays []time.Duration
   937  	reboot = func(a rebootAction, d time.Duration) error {
   938  		c.Check(a, check.Equals, expectedAction)
   939  		delays = append(delays, d)
   940  		return nil
   941  	}
   942  
   943  	c.Assert(d.Start(), check.IsNil)
   944  	defer d.Stop(nil)
   945  
   946  	st := d.overlord.State()
   947  
   948  	snapdDone := make(chan struct{})
   949  	go func() {
   950  		select {
   951  		case <-snapdAccept:
   952  		case <-time.After(2 * time.Second):
   953  			c.Fatal("snapd accept was not called")
   954  		}
   955  		close(snapdDone)
   956  	}()
   957  
   958  	snapDone := make(chan struct{})
   959  	go func() {
   960  		select {
   961  		case <-snapAccept:
   962  		case <-time.After(2 * time.Second):
   963  			c.Fatal("snap accept was not called")
   964  		}
   965  		close(snapDone)
   966  	}()
   967  
   968  	<-snapdDone
   969  	<-snapDone
   970  
   971  	st.Lock()
   972  	restart(st, restartKind)
   973  	st.Unlock()
   974  
   975  	defer func() {
   976  		d.mu.Lock()
   977  		d.requestedRestart = state.RestartUnset
   978  		d.mu.Unlock()
   979  	}()
   980  
   981  	select {
   982  	case <-d.Dying():
   983  	case <-time.After(2 * time.Second):
   984  		c.Fatal("RequestRestart -> overlord -> Kill chain didn't work")
   985  	}
   986  
   987  	d.mu.Lock()
   988  	rs := d.requestedRestart
   989  	d.mu.Unlock()
   990  
   991  	c.Check(rs, check.Equals, restartKind)
   992  
   993  	c.Check(delays, check.HasLen, 1)
   994  	c.Check(delays[0], check.DeepEquals, rebootWaitTimeout)
   995  
   996  	now := time.Now()
   997  
   998  	err = d.Stop(nil)
   999  
  1000  	c.Check(err, check.ErrorMatches, fmt.Sprintf("expected %s did not happen", expectedAction))
  1001  
  1002  	c.Check(delays, check.HasLen, 2)
  1003  	c.Check(delays[1], check.DeepEquals, wait)
  1004  
  1005  	// we are not stopping, we wait for the reboot instead
  1006  	c.Check(s.notified, check.DeepEquals, []string{"EXTEND_TIMEOUT_USEC=30000000", "READY=1"})
  1007  
  1008  	st.Lock()
  1009  	defer st.Unlock()
  1010  	var rebootAt time.Time
  1011  	err = st.Get("daemon-system-restart-at", &rebootAt)
  1012  	c.Assert(err, check.IsNil)
  1013  	if wait > 0 {
  1014  		approxAt := now.Add(wait)
  1015  		c.Check(rebootAt.After(approxAt) || rebootAt.Equal(approxAt), check.Equals, true)
  1016  	} else {
  1017  		// should be good enough
  1018  		c.Check(rebootAt.Before(now.Add(10*time.Second)), check.Equals, true)
  1019  	}
  1020  
  1021  	// finally check that maintenance.json was written appropriate for this
  1022  	// restart reason
  1023  	b, err := ioutil.ReadFile(dirs.SnapdMaintenanceFile)
  1024  	c.Assert(err, check.IsNil)
  1025  
  1026  	maintErr := &errorResult{}
  1027  	c.Assert(json.Unmarshal(b, maintErr), check.IsNil)
  1028  	c.Check(maintErr.Kind, check.Equals, client.ErrorKindSystemRestart)
  1029  	c.Check(maintErr.Value, check.DeepEquals, map[string]interface{}{
  1030  		"op": expectedOp,
  1031  	})
  1032  
  1033  	exp := maintenanceForRestartType(restartKind)
  1034  	c.Assert(maintErr, check.DeepEquals, exp)
  1035  }
  1036  
  1037  func (s *daemonSuite) TestRestartSystemGracefulWiring(c *check.C) {
  1038  	s.testRestartSystemWiring(c, nil, (*state.State).RequestRestart, state.RestartSystem, 1*time.Minute)
  1039  }
  1040  
  1041  func (s *daemonSuite) TestRestartSystemImmediateWiring(c *check.C) {
  1042  	s.testRestartSystemWiring(c, nil, (*state.State).RequestRestart, state.RestartSystemNow, 0)
  1043  }
  1044  
  1045  func (s *daemonSuite) TestRestartSystemHaltImmediateWiring(c *check.C) {
  1046  	s.testRestartSystemWiring(c, nil, (*state.State).RequestRestart, state.RestartSystemHaltNow, 0)
  1047  }
  1048  
  1049  func (s *daemonSuite) TestRestartSystemPoweroffImmediateWiring(c *check.C) {
  1050  	s.testRestartSystemWiring(c, nil, (*state.State).RequestRestart, state.RestartSystemPoweroffNow, 0)
  1051  }
  1052  
  1053  type rstManager struct {
  1054  	st *state.State
  1055  }
  1056  
  1057  func (m *rstManager) Ensure() error {
  1058  	m.st.Lock()
  1059  	defer m.st.Unlock()
  1060  	m.st.RequestRestart(state.RestartSystemNow)
  1061  	return nil
  1062  }
  1063  
  1064  type witnessManager struct {
  1065  	ensureCalled int
  1066  }
  1067  
  1068  func (m *witnessManager) Ensure() error {
  1069  	m.ensureCalled++
  1070  	return nil
  1071  }
  1072  
  1073  func (s *daemonSuite) TestRestartSystemFromEnsure(c *check.C) {
  1074  	// Test that calling RequestRestart from inside the first
  1075  	// Ensure loop works.
  1076  	wm := &witnessManager{}
  1077  
  1078  	prep := func(d *Daemon) {
  1079  		st := d.overlord.State()
  1080  		hm := d.overlord.HookManager()
  1081  		o := overlord.MockWithStateAndRestartHandler(st, d.HandleRestart)
  1082  		d.overlord = o
  1083  		o.AddManager(hm)
  1084  		rm := &rstManager{st: st}
  1085  		o.AddManager(rm)
  1086  		o.AddManager(wm)
  1087  	}
  1088  
  1089  	nop := func(*state.State, state.RestartType) {}
  1090  
  1091  	s.testRestartSystemWiring(c, prep, nop, state.RestartSystemNow, 0)
  1092  
  1093  	c.Check(wm.ensureCalled, check.Equals, 1)
  1094  }
  1095  
  1096  func (s *daemonSuite) TestRebootHelper(c *check.C) {
  1097  	cmd := testutil.MockCommand(c, "shutdown", "")
  1098  	defer cmd.Restore()
  1099  
  1100  	tests := []struct {
  1101  		delay    time.Duration
  1102  		delayArg string
  1103  	}{
  1104  		{-1, "+0"},
  1105  		{0, "+0"},
  1106  		{time.Minute, "+1"},
  1107  		{10 * time.Minute, "+10"},
  1108  		{30 * time.Second, "+0"},
  1109  	}
  1110  
  1111  	args := []struct {
  1112  		a   rebootAction
  1113  		arg string
  1114  		msg string
  1115  	}{
  1116  		{rebootReboot, "-r", "reboot scheduled to update the system"},
  1117  		{rebootHalt, "--halt", "system halt scheduled"},
  1118  		{rebootPoweroff, "--poweroff", "system poweroff scheduled"},
  1119  	}
  1120  
  1121  	for _, arg := range args {
  1122  		for _, t := range tests {
  1123  			err := reboot(arg.a, t.delay)
  1124  			c.Assert(err, check.IsNil)
  1125  			c.Check(cmd.Calls(), check.DeepEquals, [][]string{
  1126  				{"shutdown", arg.arg, t.delayArg, arg.msg},
  1127  			})
  1128  
  1129  			cmd.ForgetCalls()
  1130  		}
  1131  	}
  1132  }
  1133  
  1134  func makeDaemonListeners(c *check.C, d *Daemon) {
  1135  	snapdL, err := net.Listen("tcp", "127.0.0.1:0")
  1136  	c.Assert(err, check.IsNil)
  1137  
  1138  	snapL, err := net.Listen("tcp", "127.0.0.1:0")
  1139  	c.Assert(err, check.IsNil)
  1140  
  1141  	snapdAccept := make(chan struct{})
  1142  	snapdClosed := make(chan struct{})
  1143  	d.snapdListener = &witnessAcceptListener{Listener: snapdL, accept: snapdAccept, closed: snapdClosed}
  1144  
  1145  	snapAccept := make(chan struct{})
  1146  	d.snapListener = &witnessAcceptListener{Listener: snapL, accept: snapAccept}
  1147  }
  1148  
  1149  // This test tests that when the snapd calls a restart of the system
  1150  // a sigterm (from e.g. systemd) is handled when it arrives before
  1151  // stop is fully done.
  1152  func (s *daemonSuite) TestRestartShutdownWithSigtermInBetween(c *check.C) {
  1153  	oldRebootNoticeWait := rebootNoticeWait
  1154  	defer func() {
  1155  		rebootNoticeWait = oldRebootNoticeWait
  1156  	}()
  1157  	rebootNoticeWait = 150 * time.Millisecond
  1158  
  1159  	cmd := testutil.MockCommand(c, "shutdown", "")
  1160  	defer cmd.Restore()
  1161  
  1162  	d := newTestDaemon(c)
  1163  	makeDaemonListeners(c, d)
  1164  	s.markSeeded(d)
  1165  
  1166  	c.Assert(d.Start(), check.IsNil)
  1167  	st := d.overlord.State()
  1168  
  1169  	st.Lock()
  1170  	st.RequestRestart(state.RestartSystem)
  1171  	st.Unlock()
  1172  
  1173  	ch := make(chan os.Signal, 2)
  1174  	ch <- syscall.SIGTERM
  1175  	// stop will check if we got a sigterm in between (which we did)
  1176  	err := d.Stop(ch)
  1177  	c.Assert(err, check.IsNil)
  1178  }
  1179  
  1180  // This test tests that when there is a shutdown we close the sigterm
  1181  // handler so that systemd can kill snapd.
  1182  func (s *daemonSuite) TestRestartShutdown(c *check.C) {
  1183  	oldRebootNoticeWait := rebootNoticeWait
  1184  	oldRebootWaitTimeout := rebootWaitTimeout
  1185  	defer func() {
  1186  		rebootNoticeWait = oldRebootNoticeWait
  1187  		rebootWaitTimeout = oldRebootWaitTimeout
  1188  	}()
  1189  	rebootWaitTimeout = 100 * time.Millisecond
  1190  	rebootNoticeWait = 150 * time.Millisecond
  1191  
  1192  	cmd := testutil.MockCommand(c, "shutdown", "")
  1193  	defer cmd.Restore()
  1194  
  1195  	d := newTestDaemon(c)
  1196  	makeDaemonListeners(c, d)
  1197  	s.markSeeded(d)
  1198  
  1199  	c.Assert(d.Start(), check.IsNil)
  1200  	st := d.overlord.State()
  1201  
  1202  	st.Lock()
  1203  	st.RequestRestart(state.RestartSystem)
  1204  	st.Unlock()
  1205  
  1206  	sigCh := make(chan os.Signal, 2)
  1207  	// stop (this will timeout but that's not relevant for this test)
  1208  	d.Stop(sigCh)
  1209  
  1210  	// ensure that the sigCh got closed as part of the stop
  1211  	_, chOpen := <-sigCh
  1212  	c.Assert(chOpen, check.Equals, false)
  1213  }
  1214  
  1215  func (s *daemonSuite) TestRestartExpectedRebootDidNotHappen(c *check.C) {
  1216  	curBootID, err := osutil.BootID()
  1217  	c.Assert(err, check.IsNil)
  1218  
  1219  	fakeState := []byte(fmt.Sprintf(`{"data":{"patch-level":%d,"patch-sublevel":%d,"some":"data","refresh-privacy-key":"0123456789ABCDEF","system-restart-from-boot-id":%q,"daemon-system-restart-at":"%s"},"changes":null,"tasks":null,"last-change-id":0,"last-task-id":0,"last-lane-id":0}`, patch.Level, patch.Sublevel, curBootID, time.Now().UTC().Format(time.RFC3339)))
  1220  	err = ioutil.WriteFile(dirs.SnapStateFile, fakeState, 0600)
  1221  	c.Assert(err, check.IsNil)
  1222  
  1223  	oldRebootNoticeWait := rebootNoticeWait
  1224  	oldRebootRetryWaitTimeout := rebootRetryWaitTimeout
  1225  	defer func() {
  1226  		rebootNoticeWait = oldRebootNoticeWait
  1227  		rebootRetryWaitTimeout = oldRebootRetryWaitTimeout
  1228  	}()
  1229  	rebootRetryWaitTimeout = 100 * time.Millisecond
  1230  	rebootNoticeWait = 150 * time.Millisecond
  1231  
  1232  	cmd := testutil.MockCommand(c, "shutdown", "")
  1233  	defer cmd.Restore()
  1234  
  1235  	d := newTestDaemon(c)
  1236  	c.Check(d.overlord, check.IsNil)
  1237  	c.Check(d.expectedRebootDidNotHappen, check.Equals, true)
  1238  
  1239  	var n int
  1240  	d.state.Lock()
  1241  	err = d.state.Get("daemon-system-restart-tentative", &n)
  1242  	d.state.Unlock()
  1243  	c.Check(err, check.IsNil)
  1244  	c.Check(n, check.Equals, 1)
  1245  
  1246  	c.Assert(d.Start(), check.IsNil)
  1247  
  1248  	c.Check(s.notified, check.DeepEquals, []string{"READY=1"})
  1249  
  1250  	select {
  1251  	case <-d.Dying():
  1252  	case <-time.After(2 * time.Second):
  1253  		c.Fatal("expected reboot not happening should proceed to try to shutdown again")
  1254  	}
  1255  
  1256  	sigCh := make(chan os.Signal, 2)
  1257  	// stop (this will timeout but thats not relevant for this test)
  1258  	d.Stop(sigCh)
  1259  
  1260  	// an immediate shutdown was scheduled again
  1261  	c.Check(cmd.Calls(), check.DeepEquals, [][]string{
  1262  		{"shutdown", "-r", "+0", "reboot scheduled to update the system"},
  1263  	})
  1264  }
  1265  
  1266  func (s *daemonSuite) TestRestartExpectedRebootOK(c *check.C) {
  1267  	fakeState := []byte(fmt.Sprintf(`{"data":{"patch-level":%d,"patch-sublevel":%d,"some":"data","refresh-privacy-key":"0123456789ABCDEF","system-restart-from-boot-id":%q,"daemon-system-restart-at":"%s"},"changes":null,"tasks":null,"last-change-id":0,"last-task-id":0,"last-lane-id":0}`, patch.Level, patch.Sublevel, "boot-id-0", time.Now().UTC().Format(time.RFC3339)))
  1268  	err := ioutil.WriteFile(dirs.SnapStateFile, fakeState, 0600)
  1269  	c.Assert(err, check.IsNil)
  1270  
  1271  	cmd := testutil.MockCommand(c, "shutdown", "")
  1272  	defer cmd.Restore()
  1273  
  1274  	d := newTestDaemon(c)
  1275  	c.Assert(d.overlord, check.NotNil)
  1276  
  1277  	st := d.overlord.State()
  1278  	st.Lock()
  1279  	defer st.Unlock()
  1280  	var v interface{}
  1281  	// these were cleared
  1282  	c.Check(st.Get("daemon-system-restart-at", &v), check.Equals, state.ErrNoState)
  1283  	c.Check(st.Get("system-restart-from-boot-id", &v), check.Equals, state.ErrNoState)
  1284  }
  1285  
  1286  func (s *daemonSuite) TestRestartExpectedRebootGiveUp(c *check.C) {
  1287  	// we give up trying to restart the system after 3 retry tentatives
  1288  	curBootID, err := osutil.BootID()
  1289  	c.Assert(err, check.IsNil)
  1290  
  1291  	fakeState := []byte(fmt.Sprintf(`{"data":{"patch-level":%d,"patch-sublevel":%d,"some":"data","refresh-privacy-key":"0123456789ABCDEF","system-restart-from-boot-id":%q,"daemon-system-restart-at":"%s","daemon-system-restart-tentative":3},"changes":null,"tasks":null,"last-change-id":0,"last-task-id":0,"last-lane-id":0}`, patch.Level, patch.Sublevel, curBootID, time.Now().UTC().Format(time.RFC3339)))
  1292  	err = ioutil.WriteFile(dirs.SnapStateFile, fakeState, 0600)
  1293  	c.Assert(err, check.IsNil)
  1294  
  1295  	cmd := testutil.MockCommand(c, "shutdown", "")
  1296  	defer cmd.Restore()
  1297  
  1298  	d := newTestDaemon(c)
  1299  	c.Assert(d.overlord, check.NotNil)
  1300  
  1301  	st := d.overlord.State()
  1302  	st.Lock()
  1303  	defer st.Unlock()
  1304  	var v interface{}
  1305  	// these were cleared
  1306  	c.Check(st.Get("daemon-system-restart-at", &v), check.Equals, state.ErrNoState)
  1307  	c.Check(st.Get("system-restart-from-boot-id", &v), check.Equals, state.ErrNoState)
  1308  	c.Check(st.Get("daemon-system-restart-tentative", &v), check.Equals, state.ErrNoState)
  1309  }
  1310  
  1311  func (s *daemonSuite) TestRestartIntoSocketModeNoNewChanges(c *check.C) {
  1312  	restore := standby.MockStandbyWait(5 * time.Millisecond)
  1313  	defer restore()
  1314  
  1315  	d := newTestDaemon(c)
  1316  	makeDaemonListeners(c, d)
  1317  
  1318  	// mark as already seeded, we also have no snaps so this will
  1319  	// go into socket activation mode
  1320  	s.markSeeded(d)
  1321  
  1322  	c.Assert(d.Start(), check.IsNil)
  1323  	// pretend some ensure happened
  1324  	for i := 0; i < 5; i++ {
  1325  		c.Check(d.overlord.StateEngine().Ensure(), check.IsNil)
  1326  		time.Sleep(5 * time.Millisecond)
  1327  	}
  1328  
  1329  	select {
  1330  	case <-d.Dying():
  1331  		// exit the loop
  1332  	case <-time.After(15 * time.Second):
  1333  		c.Errorf("daemon did not stop after 15s")
  1334  	}
  1335  	err := d.Stop(nil)
  1336  	c.Check(err, check.Equals, ErrRestartSocket)
  1337  	c.Check(d.restartSocket, check.Equals, true)
  1338  }
  1339  
  1340  func (s *daemonSuite) TestRestartIntoSocketModePendingChanges(c *check.C) {
  1341  	restore := standby.MockStandbyWait(5 * time.Millisecond)
  1342  	defer restore()
  1343  
  1344  	d := newTestDaemon(c)
  1345  	makeDaemonListeners(c, d)
  1346  
  1347  	// mark as already seeded, we also have no snaps so this will
  1348  	// go into socket activation mode
  1349  	s.markSeeded(d)
  1350  	st := d.overlord.State()
  1351  
  1352  	c.Assert(d.Start(), check.IsNil)
  1353  	// pretend some ensure happened
  1354  	for i := 0; i < 5; i++ {
  1355  		c.Check(d.overlord.StateEngine().Ensure(), check.IsNil)
  1356  		time.Sleep(5 * time.Millisecond)
  1357  	}
  1358  
  1359  	select {
  1360  	case <-d.Dying():
  1361  		// Pretend we got change while shutting down, this can
  1362  		// happen when e.g. the user requested a `snap install
  1363  		// foo` at the same time as the code in the overlord
  1364  		// checked that it can go into socket activated
  1365  		// mode. I.e. the daemon was processing the request
  1366  		// but no change was generated at the time yet.
  1367  		st.Lock()
  1368  		chg := st.NewChange("fake-install", "fake install some snap")
  1369  		chg.AddTask(st.NewTask("fake-install-task", "fake install task"))
  1370  		chgStatus := chg.Status()
  1371  		st.Unlock()
  1372  		// ensure our change is valid and ready
  1373  		c.Check(chgStatus, check.Equals, state.DoStatus)
  1374  	case <-time.After(5 * time.Second):
  1375  		c.Errorf("daemon did not stop after 5s")
  1376  	}
  1377  	// when the daemon got a pending change it just restarts
  1378  	err := d.Stop(nil)
  1379  	c.Check(err, check.IsNil)
  1380  	c.Check(d.restartSocket, check.Equals, false)
  1381  }
  1382  
  1383  func (s *daemonSuite) TestConnTrackerCanShutdown(c *check.C) {
  1384  	ct := &connTracker{conns: make(map[net.Conn]struct{})}
  1385  	c.Check(ct.CanStandby(), check.Equals, true)
  1386  
  1387  	con := &net.IPConn{}
  1388  	ct.trackConn(con, http.StateActive)
  1389  	c.Check(ct.CanStandby(), check.Equals, false)
  1390  
  1391  	ct.trackConn(con, http.StateIdle)
  1392  	c.Check(ct.CanStandby(), check.Equals, true)
  1393  }
  1394  
  1395  func doTestReq(c *check.C, cmd *Command, mth string) *httptest.ResponseRecorder {
  1396  	req, err := http.NewRequest(mth, "", nil)
  1397  	c.Assert(err, check.IsNil)
  1398  	req.RemoteAddr = fmt.Sprintf("pid=100;uid=0;socket=%s;", dirs.SnapdSocket)
  1399  	rec := httptest.NewRecorder()
  1400  	cmd.ServeHTTP(rec, req)
  1401  	return rec
  1402  }
  1403  
  1404  func (s *daemonSuite) TestDegradedModeReply(c *check.C) {
  1405  	d := newTestDaemon(c)
  1406  	cmd := &Command{d: d}
  1407  	cmd.GET = func(*Command, *http.Request, *auth.UserState) Response {
  1408  		return SyncResponse(nil)
  1409  	}
  1410  	cmd.POST = func(*Command, *http.Request, *auth.UserState) Response {
  1411  		return SyncResponse(nil)
  1412  	}
  1413  	cmd.ReadAccess = authenticatedAccess{}
  1414  	cmd.WriteAccess = authenticatedAccess{}
  1415  
  1416  	// pretend we are in degraded mode
  1417  	d.SetDegradedMode(fmt.Errorf("foo error"))
  1418  
  1419  	// GET is ok even in degraded mode
  1420  	rec := doTestReq(c, cmd, "GET")
  1421  	c.Check(rec.Code, check.Equals, 200)
  1422  	// POST is not allowed
  1423  	rec = doTestReq(c, cmd, "POST")
  1424  	c.Check(rec.Code, check.Equals, 500)
  1425  	// verify we get the error
  1426  	var v struct{ Result errorResult }
  1427  	c.Assert(json.NewDecoder(rec.Body).Decode(&v), check.IsNil)
  1428  	c.Check(v.Result.Message, check.Equals, "foo error")
  1429  
  1430  	// clean degraded mode
  1431  	d.SetDegradedMode(nil)
  1432  	rec = doTestReq(c, cmd, "POST")
  1433  	c.Check(rec.Code, check.Equals, 200)
  1434  }