github.com/david-imola/snapd@v0.0.0-20210611180407-2de8ddeece6d/daemon/daemon_test.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2014-2021 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package daemon
    21  
    22  import (
    23  	"fmt"
    24  
    25  	"encoding/json"
    26  	"io/ioutil"
    27  	"net"
    28  	"net/http"
    29  	"net/http/httptest"
    30  	"os"
    31  	"path/filepath"
    32  	"sync"
    33  	"syscall"
    34  	"testing"
    35  	"time"
    36  
    37  	"github.com/gorilla/mux"
    38  	"gopkg.in/check.v1"
    39  
    40  	"github.com/snapcore/snapd/client"
    41  	"github.com/snapcore/snapd/dirs"
    42  	"github.com/snapcore/snapd/osutil"
    43  	"github.com/snapcore/snapd/overlord"
    44  	"github.com/snapcore/snapd/overlord/auth"
    45  	"github.com/snapcore/snapd/overlord/devicestate/devicestatetest"
    46  	"github.com/snapcore/snapd/overlord/ifacestate"
    47  	"github.com/snapcore/snapd/overlord/patch"
    48  	"github.com/snapcore/snapd/overlord/snapstate"
    49  	"github.com/snapcore/snapd/overlord/standby"
    50  	"github.com/snapcore/snapd/overlord/state"
    51  	"github.com/snapcore/snapd/polkit"
    52  	"github.com/snapcore/snapd/snap"
    53  	"github.com/snapcore/snapd/store"
    54  	"github.com/snapcore/snapd/systemd"
    55  	"github.com/snapcore/snapd/testutil"
    56  )
    57  
    58  // Hook up check.v1 into the "go test" runner
    59  func Test(t *testing.T) { check.TestingT(t) }
    60  
    61  type daemonSuite struct {
    62  	testutil.BaseTest
    63  
    64  	authorized      bool
    65  	err             error
    66  	lastPolkitFlags polkit.CheckFlags
    67  	notified        []string
    68  }
    69  
    70  var _ = check.Suite(&daemonSuite{})
    71  
    72  func (s *daemonSuite) SetUpTest(c *check.C) {
    73  	s.BaseTest.SetUpTest(c)
    74  
    75  	dirs.SetRootDir(c.MkDir())
    76  	s.AddCleanup(osutil.MockMountInfo(""))
    77  
    78  	err := os.MkdirAll(filepath.Dir(dirs.SnapStateFile), 0755)
    79  	c.Assert(err, check.IsNil)
    80  	systemdSdNotify = func(notif string) error {
    81  		s.notified = append(s.notified, notif)
    82  		return nil
    83  	}
    84  	s.notified = nil
    85  	s.AddCleanup(ifacestate.MockSecurityBackends(nil))
    86  }
    87  
    88  func (s *daemonSuite) TearDownTest(c *check.C) {
    89  	systemdSdNotify = systemd.SdNotify
    90  	dirs.SetRootDir("")
    91  	s.authorized = false
    92  	s.err = nil
    93  
    94  	s.BaseTest.TearDownTest(c)
    95  }
    96  
    97  // build a new daemon, with only a little of Init(), suitable for the tests
    98  func newTestDaemon(c *check.C) *Daemon {
    99  	d, err := New()
   100  	c.Assert(err, check.IsNil)
   101  	d.addRoutes()
   102  
   103  	// don't actually try to talk to the store on snapstate.Ensure
   104  	// needs doing after the call to devicestate.Manager (which
   105  	// happens in daemon.New via overlord.New)
   106  	snapstate.CanAutoRefresh = nil
   107  
   108  	return d
   109  }
   110  
   111  // a Response suitable for testing
   112  type mockHandler struct {
   113  	cmd        *Command
   114  	lastMethod string
   115  }
   116  
   117  func (mck *mockHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   118  	mck.lastMethod = r.Method
   119  }
   120  
   121  func (s *daemonSuite) TestCommandMethodDispatch(c *check.C) {
   122  	d := newTestDaemon(c)
   123  	st := d.Overlord().State()
   124  	st.Lock()
   125  	authUser, err := auth.NewUser(st, "username", "email@test.com", "macaroon", []string{"discharge"})
   126  	st.Unlock()
   127  	c.Assert(err, check.IsNil)
   128  
   129  	fakeUserAgent := "some-agent-talking-to-snapd/1.0"
   130  
   131  	cmd := &Command{d: d}
   132  	mck := &mockHandler{cmd: cmd}
   133  	rf := func(innerCmd *Command, req *http.Request, user *auth.UserState) Response {
   134  		c.Assert(cmd, check.Equals, innerCmd)
   135  		c.Check(store.ClientUserAgent(req.Context()), check.Equals, fakeUserAgent)
   136  		c.Check(user, check.DeepEquals, authUser)
   137  		return mck
   138  	}
   139  	cmd.GET = rf
   140  	cmd.PUT = rf
   141  	cmd.POST = rf
   142  	cmd.ReadAccess = authenticatedAccess{}
   143  	cmd.WriteAccess = authenticatedAccess{}
   144  
   145  	for _, method := range []string{"GET", "POST", "PUT"} {
   146  		req, err := http.NewRequest(method, "", nil)
   147  		req.Header.Add("User-Agent", fakeUserAgent)
   148  		c.Assert(err, check.IsNil)
   149  
   150  		rec := httptest.NewRecorder()
   151  		req.RemoteAddr = fmt.Sprintf("pid=100;uid=1001;socket=%s;", dirs.SnapdSocket)
   152  		cmd.ServeHTTP(rec, req)
   153  		c.Check(rec.Code, check.Equals, 401, check.Commentf(method))
   154  
   155  		rec = httptest.NewRecorder()
   156  		req.Header.Set("Authorization", fmt.Sprintf(`Macaroon root="%s"`, authUser.Macaroon))
   157  
   158  		cmd.ServeHTTP(rec, req)
   159  		c.Check(mck.lastMethod, check.Equals, method)
   160  		c.Check(rec.Code, check.Equals, 200)
   161  	}
   162  
   163  	req, err := http.NewRequest("POTATO", "", nil)
   164  	c.Assert(err, check.IsNil)
   165  	req.RemoteAddr = fmt.Sprintf("pid=100;uid=1001;socket=%s;", dirs.SnapdSocket)
   166  	req.Header.Set("Authorization", fmt.Sprintf(`Macaroon root="%s"`, authUser.Macaroon))
   167  	rec := httptest.NewRecorder()
   168  	cmd.ServeHTTP(rec, req)
   169  	c.Check(rec.Code, check.Equals, 405)
   170  }
   171  
   172  func (s *daemonSuite) TestCommandMethodDispatchRoot(c *check.C) {
   173  	fakeUserAgent := "some-agent-talking-to-snapd/1.0"
   174  
   175  	cmd := &Command{d: newTestDaemon(c)}
   176  	mck := &mockHandler{cmd: cmd}
   177  	rf := func(innerCmd *Command, req *http.Request, user *auth.UserState) Response {
   178  		c.Assert(cmd, check.Equals, innerCmd)
   179  		c.Check(store.ClientUserAgent(req.Context()), check.Equals, fakeUserAgent)
   180  		return mck
   181  	}
   182  	cmd.GET = rf
   183  	cmd.PUT = rf
   184  	cmd.POST = rf
   185  	cmd.ReadAccess = authenticatedAccess{}
   186  	cmd.WriteAccess = authenticatedAccess{}
   187  
   188  	for _, method := range []string{"GET", "POST", "PUT"} {
   189  		req, err := http.NewRequest(method, "", nil)
   190  		req.Header.Add("User-Agent", fakeUserAgent)
   191  		c.Assert(err, check.IsNil)
   192  
   193  		rec := httptest.NewRecorder()
   194  		// no ucred => forbidden
   195  		cmd.ServeHTTP(rec, req)
   196  		c.Check(rec.Code, check.Equals, 403, check.Commentf(method))
   197  
   198  		rec = httptest.NewRecorder()
   199  		req.RemoteAddr = fmt.Sprintf("pid=100;uid=0;socket=%s;", dirs.SnapdSocket)
   200  
   201  		cmd.ServeHTTP(rec, req)
   202  		c.Check(mck.lastMethod, check.Equals, method)
   203  		c.Check(rec.Code, check.Equals, 200)
   204  	}
   205  
   206  	req, err := http.NewRequest("POTATO", "", nil)
   207  	c.Assert(err, check.IsNil)
   208  	req.RemoteAddr = fmt.Sprintf("pid=100;uid=0;socket=%s;", dirs.SnapdSocket)
   209  
   210  	rec := httptest.NewRecorder()
   211  	cmd.ServeHTTP(rec, req)
   212  	c.Check(rec.Code, check.Equals, 405)
   213  }
   214  
   215  func (s *daemonSuite) TestCommandRestartingState(c *check.C) {
   216  	d := newTestDaemon(c)
   217  
   218  	cmd := &Command{d: d}
   219  	cmd.GET = func(*Command, *http.Request, *auth.UserState) Response {
   220  		return SyncResponse(nil)
   221  	}
   222  	cmd.ReadAccess = openAccess{}
   223  	req, err := http.NewRequest("GET", "", nil)
   224  	c.Assert(err, check.IsNil)
   225  	req.RemoteAddr = fmt.Sprintf("pid=100;uid=42;socket=%s;", dirs.SnapdSocket)
   226  
   227  	rec := httptest.NewRecorder()
   228  	cmd.ServeHTTP(rec, req)
   229  	c.Check(rec.Code, check.Equals, 200)
   230  	var rst struct {
   231  		Maintenance *errorResult `json:"maintenance"`
   232  	}
   233  	err = json.Unmarshal(rec.Body.Bytes(), &rst)
   234  	c.Assert(err, check.IsNil)
   235  	c.Check(rst.Maintenance, check.IsNil)
   236  
   237  	tests := []struct {
   238  		rst  state.RestartType
   239  		kind client.ErrorKind
   240  		msg  string
   241  		op   string
   242  	}{
   243  		{
   244  			rst:  state.RestartSystem,
   245  			kind: client.ErrorKindSystemRestart,
   246  			msg:  "system is restarting",
   247  			op:   "reboot",
   248  		}, {
   249  			rst:  state.RestartSystemNow,
   250  			kind: client.ErrorKindSystemRestart,
   251  			msg:  "system is restarting",
   252  			op:   "reboot",
   253  		}, {
   254  			rst:  state.RestartDaemon,
   255  			kind: client.ErrorKindDaemonRestart,
   256  			msg:  "daemon is restarting",
   257  		}, {
   258  			rst:  state.RestartSystemHaltNow,
   259  			kind: client.ErrorKindSystemRestart,
   260  			msg:  "system is halting",
   261  			op:   "halt",
   262  		}, {
   263  			rst:  state.RestartSystemPoweroffNow,
   264  			kind: client.ErrorKindSystemRestart,
   265  			msg:  "system is powering off",
   266  			op:   "poweroff",
   267  		}, {
   268  			rst:  state.RestartSocket,
   269  			kind: client.ErrorKindDaemonRestart,
   270  			msg:  "daemon is stopping to wait for socket activation",
   271  		},
   272  	}
   273  
   274  	for _, t := range tests {
   275  		state.MockRestarting(d.overlord.State(), t.rst)
   276  		rec = httptest.NewRecorder()
   277  		cmd.ServeHTTP(rec, req)
   278  		c.Check(rec.Code, check.Equals, 200)
   279  		var rst struct {
   280  			Maintenance *errorResult `json:"maintenance"`
   281  		}
   282  		err = json.Unmarshal(rec.Body.Bytes(), &rst)
   283  		c.Assert(err, check.IsNil)
   284  		var val errorValue
   285  		if t.op != "" {
   286  			val = map[string]interface{}{
   287  				"op": t.op,
   288  			}
   289  		}
   290  		c.Check(rst.Maintenance, check.DeepEquals, &errorResult{
   291  			Kind:    t.kind,
   292  			Message: t.msg,
   293  			Value:   val,
   294  		})
   295  	}
   296  }
   297  
   298  func (s *daemonSuite) TestMaintenanceJsonDeletedOnStart(c *check.C) {
   299  	// write a maintenance.json file that has that the system is restarting
   300  	maintErr := &errorResult{
   301  		Kind:    client.ErrorKindDaemonRestart,
   302  		Message: systemRestartMsg,
   303  	}
   304  
   305  	b, err := json.Marshal(maintErr)
   306  	c.Assert(err, check.IsNil)
   307  	c.Assert(os.MkdirAll(filepath.Dir(dirs.SnapdMaintenanceFile), 0755), check.IsNil)
   308  	c.Assert(ioutil.WriteFile(dirs.SnapdMaintenanceFile, b, 0644), check.IsNil)
   309  
   310  	d := newTestDaemon(c)
   311  	makeDaemonListeners(c, d)
   312  
   313  	s.markSeeded(d)
   314  
   315  	// after starting, maintenance.json should be removed
   316  	c.Assert(d.Start(), check.IsNil)
   317  	c.Assert(dirs.SnapdMaintenanceFile, testutil.FileAbsent)
   318  	d.Stop(nil)
   319  }
   320  
   321  func (s *daemonSuite) TestFillsWarnings(c *check.C) {
   322  	d := newTestDaemon(c)
   323  
   324  	cmd := &Command{d: d}
   325  	cmd.GET = func(*Command, *http.Request, *auth.UserState) Response {
   326  		return SyncResponse(nil)
   327  	}
   328  	cmd.ReadAccess = openAccess{}
   329  	req, err := http.NewRequest("GET", "", nil)
   330  	c.Assert(err, check.IsNil)
   331  	req.RemoteAddr = fmt.Sprintf("pid=100;uid=42;socket=%s;", dirs.SnapdSocket)
   332  
   333  	rec := httptest.NewRecorder()
   334  	cmd.ServeHTTP(rec, req)
   335  	c.Check(rec.Code, check.Equals, 200)
   336  	var rst struct {
   337  		WarningTimestamp *time.Time `json:"warning-timestamp,omitempty"`
   338  		WarningCount     int        `json:"warning-count,omitempty"`
   339  	}
   340  	err = json.Unmarshal(rec.Body.Bytes(), &rst)
   341  	c.Assert(err, check.IsNil)
   342  	c.Check(rst.WarningCount, check.Equals, 0)
   343  	c.Check(rst.WarningTimestamp, check.IsNil)
   344  
   345  	st := d.overlord.State()
   346  	st.Lock()
   347  	st.Warnf("hello world")
   348  	st.Unlock()
   349  
   350  	rec = httptest.NewRecorder()
   351  	cmd.ServeHTTP(rec, req)
   352  	c.Check(rec.Code, check.Equals, 200)
   353  	err = json.Unmarshal(rec.Body.Bytes(), &rst)
   354  	c.Assert(err, check.IsNil)
   355  	c.Check(rst.WarningCount, check.Equals, 1)
   356  	c.Check(rst.WarningTimestamp, check.NotNil)
   357  }
   358  
   359  type accessCheckFunc func(r *http.Request, ucred *ucrednet, user *auth.UserState) *apiError
   360  
   361  func (f accessCheckFunc) CheckAccess(r *http.Request, ucred *ucrednet, user *auth.UserState) *apiError {
   362  	return f(r, ucred, user)
   363  }
   364  
   365  func (s *daemonSuite) TestReadAccess(c *check.C) {
   366  	cmd := &Command{d: newTestDaemon(c)}
   367  	cmd.GET = func(*Command, *http.Request, *auth.UserState) Response {
   368  		return SyncResponse(nil)
   369  	}
   370  	var accessCalled bool
   371  	cmd.ReadAccess = accessCheckFunc(func(r *http.Request, ucred *ucrednet, user *auth.UserState) *apiError {
   372  		accessCalled = true
   373  		c.Check(r, check.NotNil)
   374  		c.Assert(ucred, check.NotNil)
   375  		c.Check(ucred.Uid, check.Equals, uint32(42))
   376  		c.Check(ucred.Pid, check.Equals, int32(100))
   377  		c.Check(ucred.Socket, check.Equals, "xyz")
   378  		c.Check(user, check.IsNil)
   379  		return nil
   380  	})
   381  	cmd.WriteAccess = accessCheckFunc(func(r *http.Request, ucred *ucrednet, user *auth.UserState) *apiError {
   382  		c.Fail()
   383  		return Forbidden("")
   384  	})
   385  
   386  	req := httptest.NewRequest("GET", "/", nil)
   387  	req.RemoteAddr = "pid=100;uid=42;socket=xyz;"
   388  	rec := httptest.NewRecorder()
   389  	cmd.ServeHTTP(rec, req)
   390  	c.Check(rec.Code, check.Equals, 200)
   391  	c.Check(accessCalled, check.Equals, true)
   392  }
   393  
   394  func (s *daemonSuite) TestWriteAccess(c *check.C) {
   395  	cmd := &Command{d: newTestDaemon(c)}
   396  	cmd.PUT = func(*Command, *http.Request, *auth.UserState) Response {
   397  		return SyncResponse(nil)
   398  	}
   399  	cmd.POST = func(*Command, *http.Request, *auth.UserState) Response {
   400  		return SyncResponse(nil)
   401  	}
   402  	cmd.ReadAccess = accessCheckFunc(func(r *http.Request, ucred *ucrednet, user *auth.UserState) *apiError {
   403  		c.Fail()
   404  		return Forbidden("")
   405  	})
   406  	var accessCalled bool
   407  	cmd.WriteAccess = accessCheckFunc(func(r *http.Request, ucred *ucrednet, user *auth.UserState) *apiError {
   408  		accessCalled = true
   409  		c.Check(r, check.NotNil)
   410  		c.Assert(ucred, check.NotNil)
   411  		c.Check(ucred.Uid, check.Equals, uint32(42))
   412  		c.Check(ucred.Pid, check.Equals, int32(100))
   413  		c.Check(ucred.Socket, check.Equals, "xyz")
   414  		c.Check(user, check.IsNil)
   415  		return nil
   416  	})
   417  
   418  	req := httptest.NewRequest("PUT", "/", nil)
   419  	req.RemoteAddr = "pid=100;uid=42;socket=xyz;"
   420  	rec := httptest.NewRecorder()
   421  	cmd.ServeHTTP(rec, req)
   422  	c.Check(rec.Code, check.Equals, 200)
   423  	c.Check(accessCalled, check.Equals, true)
   424  
   425  	accessCalled = false
   426  	req = httptest.NewRequest("POST", "/", nil)
   427  	req.RemoteAddr = "pid=100;uid=42;socket=xyz;"
   428  	rec = httptest.NewRecorder()
   429  	cmd.ServeHTTP(rec, req)
   430  	c.Check(rec.Code, check.Equals, 200)
   431  	c.Check(accessCalled, check.Equals, true)
   432  }
   433  
   434  func (s *daemonSuite) TestWriteAccessWithUser(c *check.C) {
   435  	d := newTestDaemon(c)
   436  	st := d.Overlord().State()
   437  	st.Lock()
   438  	authUser, err := auth.NewUser(st, "username", "email@test.com", "macaroon", []string{"discharge"})
   439  	st.Unlock()
   440  	c.Assert(err, check.IsNil)
   441  
   442  	cmd := &Command{d: d}
   443  	cmd.PUT = func(*Command, *http.Request, *auth.UserState) Response {
   444  		return SyncResponse(nil)
   445  	}
   446  	cmd.POST = func(*Command, *http.Request, *auth.UserState) Response {
   447  		return SyncResponse(nil)
   448  	}
   449  	cmd.ReadAccess = accessCheckFunc(func(r *http.Request, ucred *ucrednet, user *auth.UserState) *apiError {
   450  		c.Fail()
   451  		return Forbidden("")
   452  	})
   453  	var accessCalled bool
   454  	cmd.WriteAccess = accessCheckFunc(func(r *http.Request, ucred *ucrednet, user *auth.UserState) *apiError {
   455  		accessCalled = true
   456  		c.Check(r, check.NotNil)
   457  		c.Assert(ucred, check.NotNil)
   458  		c.Check(ucred.Uid, check.Equals, uint32(1001))
   459  		c.Check(ucred.Pid, check.Equals, int32(100))
   460  		c.Check(ucred.Socket, check.Equals, "xyz")
   461  		c.Check(user, check.DeepEquals, authUser)
   462  		return nil
   463  	})
   464  
   465  	req := httptest.NewRequest("PUT", "/", nil)
   466  	req.Header.Set("Authorization", fmt.Sprintf(`Macaroon root="%s"`, authUser.Macaroon))
   467  	req.RemoteAddr = "pid=100;uid=1001;socket=xyz;"
   468  	rec := httptest.NewRecorder()
   469  	cmd.ServeHTTP(rec, req)
   470  	c.Check(rec.Code, check.Equals, 200)
   471  	c.Check(accessCalled, check.Equals, true)
   472  
   473  	accessCalled = false
   474  	req = httptest.NewRequest("POST", "/", nil)
   475  	req.Header.Set("Authorization", fmt.Sprintf(`Macaroon root="%s"`, authUser.Macaroon))
   476  	req.RemoteAddr = "pid=100;uid=1001;socket=xyz;"
   477  	rec = httptest.NewRecorder()
   478  	cmd.ServeHTTP(rec, req)
   479  	c.Check(rec.Code, check.Equals, 200)
   480  	c.Check(accessCalled, check.Equals, true)
   481  }
   482  
   483  func (s *daemonSuite) TestPolkitAccessPath(c *check.C) {
   484  	cmd := &Command{d: newTestDaemon(c)}
   485  	cmd.POST = func(*Command, *http.Request, *auth.UserState) Response {
   486  		return SyncResponse(nil)
   487  	}
   488  	access := false
   489  	cmd.WriteAccess = authenticatedAccess{Polkit: "foo"}
   490  	checkPolkitAction = func(r *http.Request, ucred *ucrednet, action string) *apiError {
   491  		c.Check(action, check.Equals, "foo")
   492  		c.Check(ucred.Uid, check.Equals, uint32(1001))
   493  		if access {
   494  			return nil
   495  		}
   496  		return AuthCancelled("")
   497  	}
   498  
   499  	req := httptest.NewRequest("POST", "/", nil)
   500  	req.RemoteAddr = fmt.Sprintf("pid=100;uid=1001;socket=%s;", dirs.SnapdSocket)
   501  	rec := httptest.NewRecorder()
   502  	cmd.ServeHTTP(rec, req)
   503  	c.Check(rec.Code, check.Equals, 403)
   504  	c.Check(rec.Body.String(), testutil.Contains, `"kind":"auth-cancelled"`)
   505  
   506  	access = true
   507  	rec = httptest.NewRecorder()
   508  	cmd.ServeHTTP(rec, req)
   509  	c.Check(rec.Code, check.Equals, 200)
   510  }
   511  
   512  func (s *daemonSuite) TestCommandAccessSane(c *check.C) {
   513  	for _, cmd := range api {
   514  		// If Command.GET is set, ReadAccess must be set
   515  		c.Check(cmd.GET != nil, check.Equals, cmd.ReadAccess != nil, check.Commentf("%q ReadAccess", cmd.Path))
   516  		// If Command.PUT or POST are set, WriteAccess must be set
   517  		c.Check(cmd.PUT != nil || cmd.POST != nil, check.Equals, cmd.WriteAccess != nil, check.Commentf("%q WriteAccess", cmd.Path))
   518  	}
   519  }
   520  
   521  func (s *daemonSuite) TestAddRoutes(c *check.C) {
   522  	d := newTestDaemon(c)
   523  
   524  	expected := make([]string, len(api))
   525  	for i, v := range api {
   526  		if v.PathPrefix != "" {
   527  			expected[i] = v.PathPrefix
   528  			continue
   529  		}
   530  		expected[i] = v.Path
   531  	}
   532  
   533  	got := make([]string, 0, len(api))
   534  	c.Assert(d.router.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
   535  		got = append(got, route.GetName())
   536  		return nil
   537  	}), check.IsNil)
   538  
   539  	c.Check(got, check.DeepEquals, expected) // this'll stop being true if routes are added that aren't commands (e.g. for the favicon)
   540  
   541  	// XXX: still waiting to know how to check d.router.NotFoundHandler has been set to NotFound
   542  	//      the old test relied on undefined behaviour:
   543  	//      c.Check(fmt.Sprintf("%p", d.router.NotFoundHandler), check.Equals, fmt.Sprintf("%p", NotFound))
   544  }
   545  
   546  type witnessAcceptListener struct {
   547  	net.Listener
   548  
   549  	accept  chan struct{}
   550  	accept1 bool
   551  
   552  	idempotClose sync.Once
   553  	closeErr     error
   554  	closed       chan struct{}
   555  }
   556  
   557  func (l *witnessAcceptListener) Accept() (net.Conn, error) {
   558  	if !l.accept1 {
   559  		l.accept1 = true
   560  		close(l.accept)
   561  	}
   562  	return l.Listener.Accept()
   563  }
   564  
   565  func (l *witnessAcceptListener) Close() error {
   566  	l.idempotClose.Do(func() {
   567  		l.closeErr = l.Listener.Close()
   568  		if l.closed != nil {
   569  			close(l.closed)
   570  		}
   571  	})
   572  	return l.closeErr
   573  }
   574  
   575  func (s *daemonSuite) markSeeded(d *Daemon) {
   576  	st := d.overlord.State()
   577  	st.Lock()
   578  	st.Set("seeded", true)
   579  	devicestatetest.SetDevice(st, &auth.DeviceState{
   580  		Brand:  "canonical",
   581  		Model:  "pc",
   582  		Serial: "serialserial",
   583  	})
   584  	st.Unlock()
   585  }
   586  
   587  func (s *daemonSuite) TestStartStop(c *check.C) {
   588  	d := newTestDaemon(c)
   589  	// mark as already seeded
   590  	s.markSeeded(d)
   591  	// and pretend we have snaps
   592  	st := d.overlord.State()
   593  	st.Lock()
   594  	snapstate.Set(st, "core", &snapstate.SnapState{
   595  		Active: true,
   596  		Sequence: []*snap.SideInfo{
   597  			{RealName: "core", Revision: snap.R(1), SnapID: "core-snap-id"},
   598  		},
   599  		Current: snap.R(1),
   600  	})
   601  	st.Unlock()
   602  	// 1 snap => extended timeout 30s + 5s
   603  	const extendedTimeoutUSec = "EXTEND_TIMEOUT_USEC=35000000"
   604  
   605  	l1, err := net.Listen("tcp", "127.0.0.1:0")
   606  	c.Assert(err, check.IsNil)
   607  	l2, err := net.Listen("tcp", "127.0.0.1:0")
   608  	c.Assert(err, check.IsNil)
   609  
   610  	snapdAccept := make(chan struct{})
   611  	d.snapdListener = &witnessAcceptListener{Listener: l1, accept: snapdAccept}
   612  
   613  	snapAccept := make(chan struct{})
   614  	d.snapListener = &witnessAcceptListener{Listener: l2, accept: snapAccept}
   615  
   616  	c.Assert(d.Start(), check.IsNil)
   617  
   618  	c.Check(s.notified, check.DeepEquals, []string{extendedTimeoutUSec, "READY=1"})
   619  
   620  	snapdDone := make(chan struct{})
   621  	go func() {
   622  		select {
   623  		case <-snapdAccept:
   624  		case <-time.After(2 * time.Second):
   625  			c.Fatal("snapd accept was not called")
   626  		}
   627  		close(snapdDone)
   628  	}()
   629  
   630  	snapDone := make(chan struct{})
   631  	go func() {
   632  		select {
   633  		case <-snapAccept:
   634  		case <-time.After(2 * time.Second):
   635  			c.Fatal("snapd accept was not called")
   636  		}
   637  		close(snapDone)
   638  	}()
   639  
   640  	<-snapdDone
   641  	<-snapDone
   642  
   643  	err = d.Stop(nil)
   644  	c.Check(err, check.IsNil)
   645  
   646  	c.Check(s.notified, check.DeepEquals, []string{extendedTimeoutUSec, "READY=1", "STOPPING=1"})
   647  }
   648  
   649  func (s *daemonSuite) TestRestartWiring(c *check.C) {
   650  	d := newTestDaemon(c)
   651  	// mark as already seeded
   652  	s.markSeeded(d)
   653  
   654  	l, err := net.Listen("tcp", "127.0.0.1:0")
   655  	c.Assert(err, check.IsNil)
   656  
   657  	snapdAccept := make(chan struct{})
   658  	d.snapdListener = &witnessAcceptListener{Listener: l, accept: snapdAccept}
   659  
   660  	snapAccept := make(chan struct{})
   661  	d.snapListener = &witnessAcceptListener{Listener: l, accept: snapAccept}
   662  
   663  	c.Assert(d.Start(), check.IsNil)
   664  	stoppedYet := false
   665  	defer func() {
   666  		if !stoppedYet {
   667  			d.Stop(nil)
   668  		}
   669  	}()
   670  
   671  	snapdDone := make(chan struct{})
   672  	go func() {
   673  		select {
   674  		case <-snapdAccept:
   675  		case <-time.After(2 * time.Second):
   676  			c.Fatal("snapd accept was not called")
   677  		}
   678  		close(snapdDone)
   679  	}()
   680  
   681  	snapDone := make(chan struct{})
   682  	go func() {
   683  		select {
   684  		case <-snapAccept:
   685  		case <-time.After(2 * time.Second):
   686  			c.Fatal("snap accept was not called")
   687  		}
   688  		close(snapDone)
   689  	}()
   690  
   691  	<-snapdDone
   692  	<-snapDone
   693  
   694  	d.overlord.State().RequestRestart(state.RestartDaemon)
   695  
   696  	select {
   697  	case <-d.Dying():
   698  	case <-time.After(2 * time.Second):
   699  		c.Fatal("RequestRestart -> overlord -> Kill chain didn't work")
   700  	}
   701  
   702  	d.Stop(nil)
   703  	stoppedYet = true
   704  
   705  	c.Assert(s.notified, check.DeepEquals, []string{"EXTEND_TIMEOUT_USEC=30000000", "READY=1", "STOPPING=1"})
   706  }
   707  
   708  func (s *daemonSuite) TestGracefulStop(c *check.C) {
   709  	d := newTestDaemon(c)
   710  
   711  	responding := make(chan struct{})
   712  	doRespond := make(chan bool, 1)
   713  
   714  	d.router.HandleFunc("/endp", func(w http.ResponseWriter, r *http.Request) {
   715  		close(responding)
   716  		if <-doRespond {
   717  			w.Write([]byte("OKOK"))
   718  		} else {
   719  			w.Write([]byte("Gone"))
   720  		}
   721  	})
   722  
   723  	// mark as already seeded
   724  	s.markSeeded(d)
   725  	// and pretend we have snaps
   726  	st := d.overlord.State()
   727  	st.Lock()
   728  	snapstate.Set(st, "core", &snapstate.SnapState{
   729  		Active: true,
   730  		Sequence: []*snap.SideInfo{
   731  			{RealName: "core", Revision: snap.R(1), SnapID: "core-snap-id"},
   732  		},
   733  		Current: snap.R(1),
   734  	})
   735  	st.Unlock()
   736  
   737  	snapdL, err := net.Listen("tcp", "127.0.0.1:0")
   738  	c.Assert(err, check.IsNil)
   739  
   740  	snapL, err := net.Listen("tcp", "127.0.0.1:0")
   741  	c.Assert(err, check.IsNil)
   742  
   743  	snapdAccept := make(chan struct{})
   744  	snapdClosed := make(chan struct{})
   745  	d.snapdListener = &witnessAcceptListener{Listener: snapdL, accept: snapdAccept, closed: snapdClosed}
   746  
   747  	snapAccept := make(chan struct{})
   748  	d.snapListener = &witnessAcceptListener{Listener: snapL, accept: snapAccept}
   749  
   750  	c.Assert(d.Start(), check.IsNil)
   751  
   752  	snapdAccepting := make(chan struct{})
   753  	go func() {
   754  		select {
   755  		case <-snapdAccept:
   756  		case <-time.After(2 * time.Second):
   757  			c.Fatal("snapd accept was not called")
   758  		}
   759  		close(snapdAccepting)
   760  	}()
   761  
   762  	snapAccepting := make(chan struct{})
   763  	go func() {
   764  		select {
   765  		case <-snapAccept:
   766  		case <-time.After(2 * time.Second):
   767  			c.Fatal("snapd accept was not called")
   768  		}
   769  		close(snapAccepting)
   770  	}()
   771  
   772  	<-snapdAccepting
   773  	<-snapAccepting
   774  
   775  	alright := make(chan struct{})
   776  
   777  	go func() {
   778  		res, err := http.Get(fmt.Sprintf("http://%s/endp", snapdL.Addr()))
   779  		c.Assert(err, check.IsNil)
   780  		c.Check(res.StatusCode, check.Equals, 200)
   781  		body, err := ioutil.ReadAll(res.Body)
   782  		res.Body.Close()
   783  		c.Assert(err, check.IsNil)
   784  		c.Check(string(body), check.Equals, "OKOK")
   785  		close(alright)
   786  	}()
   787  	go func() {
   788  		<-snapdClosed
   789  		time.Sleep(200 * time.Millisecond)
   790  		doRespond <- true
   791  	}()
   792  
   793  	<-responding
   794  	err = d.Stop(nil)
   795  	doRespond <- false
   796  	c.Check(err, check.IsNil)
   797  
   798  	select {
   799  	case <-alright:
   800  	case <-time.After(2 * time.Second):
   801  		c.Fatal("never got proper response")
   802  	}
   803  }
   804  
   805  func (s *daemonSuite) TestGracefulStopHasLimits(c *check.C) {
   806  	d := newTestDaemon(c)
   807  
   808  	// mark as already seeded
   809  	s.markSeeded(d)
   810  
   811  	restore := MockShutdownTimeout(time.Second)
   812  	defer restore()
   813  
   814  	responding := make(chan struct{})
   815  	doRespond := make(chan bool, 1)
   816  
   817  	d.router.HandleFunc("/endp", func(w http.ResponseWriter, r *http.Request) {
   818  		close(responding)
   819  		if <-doRespond {
   820  			for {
   821  				// write in a loop to keep the handler running
   822  				if _, err := w.Write([]byte("OKOK")); err != nil {
   823  					break
   824  				}
   825  				time.Sleep(50 * time.Millisecond)
   826  			}
   827  		} else {
   828  			w.Write([]byte("Gone"))
   829  		}
   830  	})
   831  
   832  	snapdL, err := net.Listen("tcp", "127.0.0.1:0")
   833  	c.Assert(err, check.IsNil)
   834  
   835  	snapL, err := net.Listen("tcp", "127.0.0.1:0")
   836  	c.Assert(err, check.IsNil)
   837  
   838  	snapdAccept := make(chan struct{})
   839  	snapdClosed := make(chan struct{})
   840  	d.snapdListener = &witnessAcceptListener{Listener: snapdL, accept: snapdAccept, closed: snapdClosed}
   841  
   842  	snapAccept := make(chan struct{})
   843  	d.snapListener = &witnessAcceptListener{Listener: snapL, accept: snapAccept}
   844  
   845  	c.Assert(d.Start(), check.IsNil)
   846  
   847  	snapdAccepting := make(chan struct{})
   848  	go func() {
   849  		select {
   850  		case <-snapdAccept:
   851  		case <-time.After(2 * time.Second):
   852  			c.Fatal("snapd accept was not called")
   853  		}
   854  		close(snapdAccepting)
   855  	}()
   856  
   857  	snapAccepting := make(chan struct{})
   858  	go func() {
   859  		select {
   860  		case <-snapAccept:
   861  		case <-time.After(2 * time.Second):
   862  			c.Fatal("snapd accept was not called")
   863  		}
   864  		close(snapAccepting)
   865  	}()
   866  
   867  	<-snapdAccepting
   868  	<-snapAccepting
   869  
   870  	clientErr := make(chan error)
   871  
   872  	go func() {
   873  		_, err := http.Get(fmt.Sprintf("http://%s/endp", snapdL.Addr()))
   874  		c.Assert(err, check.NotNil)
   875  		clientErr <- err
   876  		close(clientErr)
   877  	}()
   878  	go func() {
   879  		<-snapdClosed
   880  		time.Sleep(200 * time.Millisecond)
   881  		doRespond <- true
   882  	}()
   883  
   884  	<-responding
   885  	err = d.Stop(nil)
   886  	doRespond <- false
   887  	c.Check(err, check.IsNil)
   888  
   889  	select {
   890  	case cErr := <-clientErr:
   891  		c.Check(cErr, check.ErrorMatches, ".*: EOF")
   892  	case <-time.After(5 * time.Second):
   893  		c.Fatal("never got proper response")
   894  	}
   895  }
   896  
   897  func (s *daemonSuite) testRestartSystemWiring(c *check.C, prep func(d *Daemon), restart func(*state.State, state.RestartType), restartKind state.RestartType, wait time.Duration) {
   898  	d := newTestDaemon(c)
   899  	// mark as already seeded
   900  	s.markSeeded(d)
   901  
   902  	if prep != nil {
   903  		prep(d)
   904  	}
   905  
   906  	l, err := net.Listen("tcp", "127.0.0.1:0")
   907  	c.Assert(err, check.IsNil)
   908  
   909  	snapdAccept := make(chan struct{})
   910  	d.snapdListener = &witnessAcceptListener{Listener: l, accept: snapdAccept}
   911  
   912  	snapAccept := make(chan struct{})
   913  	d.snapListener = &witnessAcceptListener{Listener: l, accept: snapAccept}
   914  
   915  	oldRebootNoticeWait := rebootNoticeWait
   916  	oldRebootWaitTimeout := rebootWaitTimeout
   917  	defer func() {
   918  		reboot = rebootImpl
   919  		rebootNoticeWait = oldRebootNoticeWait
   920  		rebootWaitTimeout = oldRebootWaitTimeout
   921  	}()
   922  	rebootWaitTimeout = 100 * time.Millisecond
   923  	rebootNoticeWait = 150 * time.Millisecond
   924  
   925  	expectedAction := rebootReboot
   926  	expectedOp := "reboot"
   927  	if restartKind == state.RestartSystemHaltNow {
   928  		expectedAction = rebootHalt
   929  		expectedOp = "halt"
   930  	} else if restartKind == state.RestartSystemPoweroffNow {
   931  		expectedAction = rebootPoweroff
   932  		expectedOp = "poweroff"
   933  	}
   934  	var delays []time.Duration
   935  	reboot = func(a rebootAction, d time.Duration) error {
   936  		c.Check(a, check.Equals, expectedAction)
   937  		delays = append(delays, d)
   938  		return nil
   939  	}
   940  
   941  	c.Assert(d.Start(), check.IsNil)
   942  	defer d.Stop(nil)
   943  
   944  	st := d.overlord.State()
   945  
   946  	snapdDone := make(chan struct{})
   947  	go func() {
   948  		select {
   949  		case <-snapdAccept:
   950  		case <-time.After(2 * time.Second):
   951  			c.Fatal("snapd accept was not called")
   952  		}
   953  		close(snapdDone)
   954  	}()
   955  
   956  	snapDone := make(chan struct{})
   957  	go func() {
   958  		select {
   959  		case <-snapAccept:
   960  		case <-time.After(2 * time.Second):
   961  			c.Fatal("snap accept was not called")
   962  		}
   963  		close(snapDone)
   964  	}()
   965  
   966  	<-snapdDone
   967  	<-snapDone
   968  
   969  	st.Lock()
   970  	restart(st, restartKind)
   971  	st.Unlock()
   972  
   973  	defer func() {
   974  		d.mu.Lock()
   975  		d.requestedRestart = state.RestartUnset
   976  		d.mu.Unlock()
   977  	}()
   978  
   979  	select {
   980  	case <-d.Dying():
   981  	case <-time.After(2 * time.Second):
   982  		c.Fatal("RequestRestart -> overlord -> Kill chain didn't work")
   983  	}
   984  
   985  	d.mu.Lock()
   986  	rs := d.requestedRestart
   987  	d.mu.Unlock()
   988  
   989  	c.Check(rs, check.Equals, restartKind)
   990  
   991  	c.Check(delays, check.HasLen, 1)
   992  	c.Check(delays[0], check.DeepEquals, rebootWaitTimeout)
   993  
   994  	now := time.Now()
   995  
   996  	err = d.Stop(nil)
   997  
   998  	c.Check(err, check.ErrorMatches, fmt.Sprintf("expected %s did not happen", expectedAction))
   999  
  1000  	c.Check(delays, check.HasLen, 2)
  1001  	c.Check(delays[1], check.DeepEquals, wait)
  1002  
  1003  	// we are not stopping, we wait for the reboot instead
  1004  	c.Check(s.notified, check.DeepEquals, []string{"EXTEND_TIMEOUT_USEC=30000000", "READY=1"})
  1005  
  1006  	st.Lock()
  1007  	defer st.Unlock()
  1008  	var rebootAt time.Time
  1009  	err = st.Get("daemon-system-restart-at", &rebootAt)
  1010  	c.Assert(err, check.IsNil)
  1011  	if wait > 0 {
  1012  		approxAt := now.Add(wait)
  1013  		c.Check(rebootAt.After(approxAt) || rebootAt.Equal(approxAt), check.Equals, true)
  1014  	} else {
  1015  		// should be good enough
  1016  		c.Check(rebootAt.Before(now.Add(10*time.Second)), check.Equals, true)
  1017  	}
  1018  
  1019  	// finally check that maintenance.json was written appropriate for this
  1020  	// restart reason
  1021  	b, err := ioutil.ReadFile(dirs.SnapdMaintenanceFile)
  1022  	c.Assert(err, check.IsNil)
  1023  
  1024  	maintErr := &errorResult{}
  1025  	c.Assert(json.Unmarshal(b, maintErr), check.IsNil)
  1026  	c.Check(maintErr.Kind, check.Equals, client.ErrorKindSystemRestart)
  1027  	c.Check(maintErr.Value, check.DeepEquals, map[string]interface{}{
  1028  		"op": expectedOp,
  1029  	})
  1030  
  1031  	exp := maintenanceForRestartType(restartKind)
  1032  	c.Assert(maintErr, check.DeepEquals, exp)
  1033  }
  1034  
  1035  func (s *daemonSuite) TestRestartSystemGracefulWiring(c *check.C) {
  1036  	s.testRestartSystemWiring(c, nil, (*state.State).RequestRestart, state.RestartSystem, 1*time.Minute)
  1037  }
  1038  
  1039  func (s *daemonSuite) TestRestartSystemImmediateWiring(c *check.C) {
  1040  	s.testRestartSystemWiring(c, nil, (*state.State).RequestRestart, state.RestartSystemNow, 0)
  1041  }
  1042  
  1043  func (s *daemonSuite) TestRestartSystemHaltImmediateWiring(c *check.C) {
  1044  	s.testRestartSystemWiring(c, nil, (*state.State).RequestRestart, state.RestartSystemHaltNow, 0)
  1045  }
  1046  
  1047  func (s *daemonSuite) TestRestartSystemPoweroffImmediateWiring(c *check.C) {
  1048  	s.testRestartSystemWiring(c, nil, (*state.State).RequestRestart, state.RestartSystemPoweroffNow, 0)
  1049  }
  1050  
  1051  type rstManager struct {
  1052  	st *state.State
  1053  }
  1054  
  1055  func (m *rstManager) Ensure() error {
  1056  	m.st.Lock()
  1057  	defer m.st.Unlock()
  1058  	m.st.RequestRestart(state.RestartSystemNow)
  1059  	return nil
  1060  }
  1061  
  1062  type witnessManager struct {
  1063  	ensureCalled int
  1064  }
  1065  
  1066  func (m *witnessManager) Ensure() error {
  1067  	m.ensureCalled++
  1068  	return nil
  1069  }
  1070  
  1071  func (s *daemonSuite) TestRestartSystemFromEnsure(c *check.C) {
  1072  	// Test that calling RequestRestart from inside the first
  1073  	// Ensure loop works.
  1074  	wm := &witnessManager{}
  1075  
  1076  	prep := func(d *Daemon) {
  1077  		st := d.overlord.State()
  1078  		hm := d.overlord.HookManager()
  1079  		o := overlord.MockWithStateAndRestartHandler(st, d.HandleRestart)
  1080  		d.overlord = o
  1081  		o.AddManager(hm)
  1082  		rm := &rstManager{st: st}
  1083  		o.AddManager(rm)
  1084  		o.AddManager(wm)
  1085  	}
  1086  
  1087  	nop := func(*state.State, state.RestartType) {}
  1088  
  1089  	s.testRestartSystemWiring(c, prep, nop, state.RestartSystemNow, 0)
  1090  
  1091  	c.Check(wm.ensureCalled, check.Equals, 1)
  1092  }
  1093  
  1094  func (s *daemonSuite) TestRebootHelper(c *check.C) {
  1095  	cmd := testutil.MockCommand(c, "shutdown", "")
  1096  	defer cmd.Restore()
  1097  
  1098  	tests := []struct {
  1099  		delay    time.Duration
  1100  		delayArg string
  1101  	}{
  1102  		{-1, "+0"},
  1103  		{0, "+0"},
  1104  		{time.Minute, "+1"},
  1105  		{10 * time.Minute, "+10"},
  1106  		{30 * time.Second, "+0"},
  1107  	}
  1108  
  1109  	args := []struct {
  1110  		a   rebootAction
  1111  		arg string
  1112  		msg string
  1113  	}{
  1114  		{rebootReboot, "-r", "reboot scheduled to update the system"},
  1115  		{rebootHalt, "--halt", "system halt scheduled"},
  1116  		{rebootPoweroff, "--poweroff", "system poweroff scheduled"},
  1117  	}
  1118  
  1119  	for _, arg := range args {
  1120  		for _, t := range tests {
  1121  			err := reboot(arg.a, t.delay)
  1122  			c.Assert(err, check.IsNil)
  1123  			c.Check(cmd.Calls(), check.DeepEquals, [][]string{
  1124  				{"shutdown", arg.arg, t.delayArg, arg.msg},
  1125  			})
  1126  
  1127  			cmd.ForgetCalls()
  1128  		}
  1129  	}
  1130  }
  1131  
  1132  func makeDaemonListeners(c *check.C, d *Daemon) {
  1133  	snapdL, err := net.Listen("tcp", "127.0.0.1:0")
  1134  	c.Assert(err, check.IsNil)
  1135  
  1136  	snapL, err := net.Listen("tcp", "127.0.0.1:0")
  1137  	c.Assert(err, check.IsNil)
  1138  
  1139  	snapdAccept := make(chan struct{})
  1140  	snapdClosed := make(chan struct{})
  1141  	d.snapdListener = &witnessAcceptListener{Listener: snapdL, accept: snapdAccept, closed: snapdClosed}
  1142  
  1143  	snapAccept := make(chan struct{})
  1144  	d.snapListener = &witnessAcceptListener{Listener: snapL, accept: snapAccept}
  1145  }
  1146  
  1147  // This test tests that when the snapd calls a restart of the system
  1148  // a sigterm (from e.g. systemd) is handled when it arrives before
  1149  // stop is fully done.
  1150  func (s *daemonSuite) TestRestartShutdownWithSigtermInBetween(c *check.C) {
  1151  	oldRebootNoticeWait := rebootNoticeWait
  1152  	defer func() {
  1153  		rebootNoticeWait = oldRebootNoticeWait
  1154  	}()
  1155  	rebootNoticeWait = 150 * time.Millisecond
  1156  
  1157  	cmd := testutil.MockCommand(c, "shutdown", "")
  1158  	defer cmd.Restore()
  1159  
  1160  	d := newTestDaemon(c)
  1161  	makeDaemonListeners(c, d)
  1162  	s.markSeeded(d)
  1163  
  1164  	c.Assert(d.Start(), check.IsNil)
  1165  	st := d.overlord.State()
  1166  
  1167  	st.Lock()
  1168  	st.RequestRestart(state.RestartSystem)
  1169  	st.Unlock()
  1170  
  1171  	ch := make(chan os.Signal, 2)
  1172  	ch <- syscall.SIGTERM
  1173  	// stop will check if we got a sigterm in between (which we did)
  1174  	err := d.Stop(ch)
  1175  	c.Assert(err, check.IsNil)
  1176  }
  1177  
  1178  // This test tests that when there is a shutdown we close the sigterm
  1179  // handler so that systemd can kill snapd.
  1180  func (s *daemonSuite) TestRestartShutdown(c *check.C) {
  1181  	oldRebootNoticeWait := rebootNoticeWait
  1182  	oldRebootWaitTimeout := rebootWaitTimeout
  1183  	defer func() {
  1184  		rebootNoticeWait = oldRebootNoticeWait
  1185  		rebootWaitTimeout = oldRebootWaitTimeout
  1186  	}()
  1187  	rebootWaitTimeout = 100 * time.Millisecond
  1188  	rebootNoticeWait = 150 * time.Millisecond
  1189  
  1190  	cmd := testutil.MockCommand(c, "shutdown", "")
  1191  	defer cmd.Restore()
  1192  
  1193  	d := newTestDaemon(c)
  1194  	makeDaemonListeners(c, d)
  1195  	s.markSeeded(d)
  1196  
  1197  	c.Assert(d.Start(), check.IsNil)
  1198  	st := d.overlord.State()
  1199  
  1200  	st.Lock()
  1201  	st.RequestRestart(state.RestartSystem)
  1202  	st.Unlock()
  1203  
  1204  	sigCh := make(chan os.Signal, 2)
  1205  	// stop (this will timeout but that's not relevant for this test)
  1206  	d.Stop(sigCh)
  1207  
  1208  	// ensure that the sigCh got closed as part of the stop
  1209  	_, chOpen := <-sigCh
  1210  	c.Assert(chOpen, check.Equals, false)
  1211  }
  1212  
  1213  func (s *daemonSuite) TestRestartExpectedRebootDidNotHappen(c *check.C) {
  1214  	curBootID, err := osutil.BootID()
  1215  	c.Assert(err, check.IsNil)
  1216  
  1217  	fakeState := []byte(fmt.Sprintf(`{"data":{"patch-level":%d,"patch-sublevel":%d,"some":"data","refresh-privacy-key":"0123456789ABCDEF","system-restart-from-boot-id":%q,"daemon-system-restart-at":"%s"},"changes":null,"tasks":null,"last-change-id":0,"last-task-id":0,"last-lane-id":0}`, patch.Level, patch.Sublevel, curBootID, time.Now().UTC().Format(time.RFC3339)))
  1218  	err = ioutil.WriteFile(dirs.SnapStateFile, fakeState, 0600)
  1219  	c.Assert(err, check.IsNil)
  1220  
  1221  	oldRebootNoticeWait := rebootNoticeWait
  1222  	oldRebootRetryWaitTimeout := rebootRetryWaitTimeout
  1223  	defer func() {
  1224  		rebootNoticeWait = oldRebootNoticeWait
  1225  		rebootRetryWaitTimeout = oldRebootRetryWaitTimeout
  1226  	}()
  1227  	rebootRetryWaitTimeout = 100 * time.Millisecond
  1228  	rebootNoticeWait = 150 * time.Millisecond
  1229  
  1230  	cmd := testutil.MockCommand(c, "shutdown", "")
  1231  	defer cmd.Restore()
  1232  
  1233  	d := newTestDaemon(c)
  1234  	c.Check(d.overlord, check.IsNil)
  1235  	c.Check(d.expectedRebootDidNotHappen, check.Equals, true)
  1236  
  1237  	var n int
  1238  	d.state.Lock()
  1239  	err = d.state.Get("daemon-system-restart-tentative", &n)
  1240  	d.state.Unlock()
  1241  	c.Check(err, check.IsNil)
  1242  	c.Check(n, check.Equals, 1)
  1243  
  1244  	c.Assert(d.Start(), check.IsNil)
  1245  
  1246  	c.Check(s.notified, check.DeepEquals, []string{"READY=1"})
  1247  
  1248  	select {
  1249  	case <-d.Dying():
  1250  	case <-time.After(2 * time.Second):
  1251  		c.Fatal("expected reboot not happening should proceed to try to shutdown again")
  1252  	}
  1253  
  1254  	sigCh := make(chan os.Signal, 2)
  1255  	// stop (this will timeout but thats not relevant for this test)
  1256  	d.Stop(sigCh)
  1257  
  1258  	// an immediate shutdown was scheduled again
  1259  	c.Check(cmd.Calls(), check.DeepEquals, [][]string{
  1260  		{"shutdown", "-r", "+0", "reboot scheduled to update the system"},
  1261  	})
  1262  }
  1263  
  1264  func (s *daemonSuite) TestRestartExpectedRebootOK(c *check.C) {
  1265  	fakeState := []byte(fmt.Sprintf(`{"data":{"patch-level":%d,"patch-sublevel":%d,"some":"data","refresh-privacy-key":"0123456789ABCDEF","system-restart-from-boot-id":%q,"daemon-system-restart-at":"%s"},"changes":null,"tasks":null,"last-change-id":0,"last-task-id":0,"last-lane-id":0}`, patch.Level, patch.Sublevel, "boot-id-0", time.Now().UTC().Format(time.RFC3339)))
  1266  	err := ioutil.WriteFile(dirs.SnapStateFile, fakeState, 0600)
  1267  	c.Assert(err, check.IsNil)
  1268  
  1269  	cmd := testutil.MockCommand(c, "shutdown", "")
  1270  	defer cmd.Restore()
  1271  
  1272  	d := newTestDaemon(c)
  1273  	c.Assert(d.overlord, check.NotNil)
  1274  
  1275  	st := d.overlord.State()
  1276  	st.Lock()
  1277  	defer st.Unlock()
  1278  	var v interface{}
  1279  	// these were cleared
  1280  	c.Check(st.Get("daemon-system-restart-at", &v), check.Equals, state.ErrNoState)
  1281  	c.Check(st.Get("system-restart-from-boot-id", &v), check.Equals, state.ErrNoState)
  1282  }
  1283  
  1284  func (s *daemonSuite) TestRestartExpectedRebootGiveUp(c *check.C) {
  1285  	// we give up trying to restart the system after 3 retry tentatives
  1286  	curBootID, err := osutil.BootID()
  1287  	c.Assert(err, check.IsNil)
  1288  
  1289  	fakeState := []byte(fmt.Sprintf(`{"data":{"patch-level":%d,"patch-sublevel":%d,"some":"data","refresh-privacy-key":"0123456789ABCDEF","system-restart-from-boot-id":%q,"daemon-system-restart-at":"%s","daemon-system-restart-tentative":3},"changes":null,"tasks":null,"last-change-id":0,"last-task-id":0,"last-lane-id":0}`, patch.Level, patch.Sublevel, curBootID, time.Now().UTC().Format(time.RFC3339)))
  1290  	err = ioutil.WriteFile(dirs.SnapStateFile, fakeState, 0600)
  1291  	c.Assert(err, check.IsNil)
  1292  
  1293  	cmd := testutil.MockCommand(c, "shutdown", "")
  1294  	defer cmd.Restore()
  1295  
  1296  	d := newTestDaemon(c)
  1297  	c.Assert(d.overlord, check.NotNil)
  1298  
  1299  	st := d.overlord.State()
  1300  	st.Lock()
  1301  	defer st.Unlock()
  1302  	var v interface{}
  1303  	// these were cleared
  1304  	c.Check(st.Get("daemon-system-restart-at", &v), check.Equals, state.ErrNoState)
  1305  	c.Check(st.Get("system-restart-from-boot-id", &v), check.Equals, state.ErrNoState)
  1306  	c.Check(st.Get("daemon-system-restart-tentative", &v), check.Equals, state.ErrNoState)
  1307  }
  1308  
  1309  func (s *daemonSuite) TestRestartIntoSocketModeNoNewChanges(c *check.C) {
  1310  	restore := standby.MockStandbyWait(5 * time.Millisecond)
  1311  	defer restore()
  1312  
  1313  	d := newTestDaemon(c)
  1314  	makeDaemonListeners(c, d)
  1315  
  1316  	// mark as already seeded, we also have no snaps so this will
  1317  	// go into socket activation mode
  1318  	s.markSeeded(d)
  1319  
  1320  	c.Assert(d.Start(), check.IsNil)
  1321  	// pretend some ensure happened
  1322  	for i := 0; i < 5; i++ {
  1323  		c.Check(d.overlord.StateEngine().Ensure(), check.IsNil)
  1324  		time.Sleep(5 * time.Millisecond)
  1325  	}
  1326  
  1327  	select {
  1328  	case <-d.Dying():
  1329  		// exit the loop
  1330  	case <-time.After(15 * time.Second):
  1331  		c.Errorf("daemon did not stop after 15s")
  1332  	}
  1333  	err := d.Stop(nil)
  1334  	c.Check(err, check.Equals, ErrRestartSocket)
  1335  	c.Check(d.restartSocket, check.Equals, true)
  1336  }
  1337  
  1338  func (s *daemonSuite) TestRestartIntoSocketModePendingChanges(c *check.C) {
  1339  	restore := standby.MockStandbyWait(5 * time.Millisecond)
  1340  	defer restore()
  1341  
  1342  	d := newTestDaemon(c)
  1343  	makeDaemonListeners(c, d)
  1344  
  1345  	// mark as already seeded, we also have no snaps so this will
  1346  	// go into socket activation mode
  1347  	s.markSeeded(d)
  1348  	st := d.overlord.State()
  1349  
  1350  	c.Assert(d.Start(), check.IsNil)
  1351  	// pretend some ensure happened
  1352  	for i := 0; i < 5; i++ {
  1353  		c.Check(d.overlord.StateEngine().Ensure(), check.IsNil)
  1354  		time.Sleep(5 * time.Millisecond)
  1355  	}
  1356  
  1357  	select {
  1358  	case <-d.Dying():
  1359  		// Pretend we got change while shutting down, this can
  1360  		// happen when e.g. the user requested a `snap install
  1361  		// foo` at the same time as the code in the overlord
  1362  		// checked that it can go into socket activated
  1363  		// mode. I.e. the daemon was processing the request
  1364  		// but no change was generated at the time yet.
  1365  		st.Lock()
  1366  		chg := st.NewChange("fake-install", "fake install some snap")
  1367  		chg.AddTask(st.NewTask("fake-install-task", "fake install task"))
  1368  		chgStatus := chg.Status()
  1369  		st.Unlock()
  1370  		// ensure our change is valid and ready
  1371  		c.Check(chgStatus, check.Equals, state.DoStatus)
  1372  	case <-time.After(5 * time.Second):
  1373  		c.Errorf("daemon did not stop after 5s")
  1374  	}
  1375  	// when the daemon got a pending change it just restarts
  1376  	err := d.Stop(nil)
  1377  	c.Check(err, check.IsNil)
  1378  	c.Check(d.restartSocket, check.Equals, false)
  1379  }
  1380  
  1381  func (s *daemonSuite) TestConnTrackerCanShutdown(c *check.C) {
  1382  	ct := &connTracker{conns: make(map[net.Conn]struct{})}
  1383  	c.Check(ct.CanStandby(), check.Equals, true)
  1384  
  1385  	con := &net.IPConn{}
  1386  	ct.trackConn(con, http.StateActive)
  1387  	c.Check(ct.CanStandby(), check.Equals, false)
  1388  
  1389  	ct.trackConn(con, http.StateIdle)
  1390  	c.Check(ct.CanStandby(), check.Equals, true)
  1391  }
  1392  
  1393  func doTestReq(c *check.C, cmd *Command, mth string) *httptest.ResponseRecorder {
  1394  	req, err := http.NewRequest(mth, "", nil)
  1395  	c.Assert(err, check.IsNil)
  1396  	req.RemoteAddr = fmt.Sprintf("pid=100;uid=0;socket=%s;", dirs.SnapdSocket)
  1397  	rec := httptest.NewRecorder()
  1398  	cmd.ServeHTTP(rec, req)
  1399  	return rec
  1400  }
  1401  
  1402  func (s *daemonSuite) TestDegradedModeReply(c *check.C) {
  1403  	d := newTestDaemon(c)
  1404  	cmd := &Command{d: d}
  1405  	cmd.GET = func(*Command, *http.Request, *auth.UserState) Response {
  1406  		return SyncResponse(nil)
  1407  	}
  1408  	cmd.POST = func(*Command, *http.Request, *auth.UserState) Response {
  1409  		return SyncResponse(nil)
  1410  	}
  1411  	cmd.ReadAccess = authenticatedAccess{}
  1412  	cmd.WriteAccess = authenticatedAccess{}
  1413  
  1414  	// pretend we are in degraded mode
  1415  	d.SetDegradedMode(fmt.Errorf("foo error"))
  1416  
  1417  	// GET is ok even in degraded mode
  1418  	rec := doTestReq(c, cmd, "GET")
  1419  	c.Check(rec.Code, check.Equals, 200)
  1420  	// POST is not allowed
  1421  	rec = doTestReq(c, cmd, "POST")
  1422  	c.Check(rec.Code, check.Equals, 500)
  1423  	// verify we get the error
  1424  	var v struct{ Result errorResult }
  1425  	c.Assert(json.NewDecoder(rec.Body).Decode(&v), check.IsNil)
  1426  	c.Check(v.Result.Message, check.Equals, "foo error")
  1427  
  1428  	// clean degraded mode
  1429  	d.SetDegradedMode(nil)
  1430  	rec = doTestReq(c, cmd, "POST")
  1431  	c.Check(rec.Code, check.Equals, 200)
  1432  }