github.com/hugh712/snapd@v0.0.0-20200910133618-1a99902bd583/daemon/daemon_test.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2014-2015 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package daemon
    21  
    22  import (
    23  	"fmt"
    24  
    25  	"bytes"
    26  	"encoding/json"
    27  	"errors"
    28  	"io/ioutil"
    29  	"net"
    30  	"net/http"
    31  	"net/http/httptest"
    32  	"os"
    33  	"path/filepath"
    34  	"sync"
    35  	"syscall"
    36  	"testing"
    37  	"time"
    38  
    39  	"github.com/gorilla/mux"
    40  	"gopkg.in/check.v1"
    41  
    42  	"github.com/snapcore/snapd/client"
    43  	"github.com/snapcore/snapd/dirs"
    44  	"github.com/snapcore/snapd/logger"
    45  	"github.com/snapcore/snapd/osutil"
    46  	"github.com/snapcore/snapd/overlord/auth"
    47  	"github.com/snapcore/snapd/overlord/devicestate/devicestatetest"
    48  	"github.com/snapcore/snapd/overlord/ifacestate"
    49  	"github.com/snapcore/snapd/overlord/patch"
    50  	"github.com/snapcore/snapd/overlord/snapstate"
    51  	"github.com/snapcore/snapd/overlord/standby"
    52  	"github.com/snapcore/snapd/overlord/state"
    53  	"github.com/snapcore/snapd/polkit"
    54  	"github.com/snapcore/snapd/snap"
    55  	"github.com/snapcore/snapd/store"
    56  	"github.com/snapcore/snapd/systemd"
    57  	"github.com/snapcore/snapd/testutil"
    58  )
    59  
    60  // Hook up check.v1 into the "go test" runner
    61  func Test(t *testing.T) { check.TestingT(t) }
    62  
    63  type daemonSuite struct {
    64  	authorized      bool
    65  	err             error
    66  	lastPolkitFlags polkit.CheckFlags
    67  	notified        []string
    68  	restoreBackends func()
    69  }
    70  
    71  var _ = check.Suite(&daemonSuite{})
    72  
    73  func (s *daemonSuite) checkAuthorization(pid int32, uid uint32, actionId string, details map[string]string, flags polkit.CheckFlags) (bool, error) {
    74  	s.lastPolkitFlags = flags
    75  	return s.authorized, s.err
    76  }
    77  
    78  func (s *daemonSuite) SetUpTest(c *check.C) {
    79  	dirs.SetRootDir(c.MkDir())
    80  	err := os.MkdirAll(filepath.Dir(dirs.SnapStateFile), 0755)
    81  	c.Assert(err, check.IsNil)
    82  	systemdSdNotify = func(notif string) error {
    83  		s.notified = append(s.notified, notif)
    84  		return nil
    85  	}
    86  	s.notified = nil
    87  	polkitCheckAuthorization = s.checkAuthorization
    88  	s.restoreBackends = ifacestate.MockSecurityBackends(nil)
    89  }
    90  
    91  func (s *daemonSuite) TearDownTest(c *check.C) {
    92  	systemdSdNotify = systemd.SdNotify
    93  	dirs.SetRootDir("")
    94  	s.authorized = false
    95  	s.err = nil
    96  	logger.SetLogger(logger.NullLogger)
    97  	s.restoreBackends()
    98  }
    99  
   100  func (s *daemonSuite) TearDownSuite(c *check.C) {
   101  	polkitCheckAuthorization = polkit.CheckAuthorization
   102  }
   103  
   104  // build a new daemon, with only a little of Init(), suitable for the tests
   105  func newTestDaemon(c *check.C) *Daemon {
   106  	d, err := New()
   107  	c.Assert(err, check.IsNil)
   108  	d.addRoutes()
   109  
   110  	// don't actually try to talk to the store on snapstate.Ensure
   111  	// needs doing after the call to devicestate.Manager (which
   112  	// happens in daemon.New via overlord.New)
   113  	snapstate.CanAutoRefresh = nil
   114  
   115  	return d
   116  }
   117  
   118  // a Response suitable for testing
   119  type mockHandler struct {
   120  	cmd        *Command
   121  	lastMethod string
   122  }
   123  
   124  func (mck *mockHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   125  	mck.lastMethod = r.Method
   126  }
   127  
   128  func (s *daemonSuite) TestCommandMethodDispatch(c *check.C) {
   129  	fakeUserAgent := "some-agent-talking-to-snapd/1.0"
   130  
   131  	cmd := &Command{d: newTestDaemon(c)}
   132  	mck := &mockHandler{cmd: cmd}
   133  	rf := func(innerCmd *Command, req *http.Request, user *auth.UserState) Response {
   134  		c.Assert(cmd, check.Equals, innerCmd)
   135  		c.Check(store.ClientUserAgent(req.Context()), check.Equals, fakeUserAgent)
   136  		return mck
   137  	}
   138  	cmd.GET = rf
   139  	cmd.PUT = rf
   140  	cmd.POST = rf
   141  
   142  	for _, method := range []string{"GET", "POST", "PUT"} {
   143  		req, err := http.NewRequest(method, "", nil)
   144  		req.Header.Add("User-Agent", fakeUserAgent)
   145  		c.Assert(err, check.IsNil)
   146  
   147  		rec := httptest.NewRecorder()
   148  		cmd.ServeHTTP(rec, req)
   149  		c.Check(rec.Code, check.Equals, 401, check.Commentf(method))
   150  
   151  		rec = httptest.NewRecorder()
   152  		req.RemoteAddr = "pid=100;uid=0;socket=;"
   153  
   154  		cmd.ServeHTTP(rec, req)
   155  		c.Check(mck.lastMethod, check.Equals, method)
   156  		c.Check(rec.Code, check.Equals, 200)
   157  	}
   158  
   159  	req, err := http.NewRequest("POTATO", "", nil)
   160  	c.Assert(err, check.IsNil)
   161  	req.RemoteAddr = "pid=100;uid=0;socket=;"
   162  
   163  	rec := httptest.NewRecorder()
   164  	cmd.ServeHTTP(rec, req)
   165  	c.Check(rec.Code, check.Equals, 405)
   166  }
   167  
   168  func (s *daemonSuite) TestCommandRestartingState(c *check.C) {
   169  	d := newTestDaemon(c)
   170  
   171  	cmd := &Command{d: d}
   172  	cmd.GET = func(*Command, *http.Request, *auth.UserState) Response {
   173  		return SyncResponse(nil, nil)
   174  	}
   175  	req, err := http.NewRequest("GET", "", nil)
   176  	c.Assert(err, check.IsNil)
   177  	req.RemoteAddr = "pid=100;uid=0;socket=;"
   178  
   179  	rec := httptest.NewRecorder()
   180  	cmd.ServeHTTP(rec, req)
   181  	c.Check(rec.Code, check.Equals, 200)
   182  	var rst struct {
   183  		Maintenance *errorResult `json:"maintenance"`
   184  	}
   185  	err = json.Unmarshal(rec.Body.Bytes(), &rst)
   186  	c.Assert(err, check.IsNil)
   187  	c.Check(rst.Maintenance, check.IsNil)
   188  
   189  	state.MockRestarting(d.overlord.State(), state.RestartSystem)
   190  	rec = httptest.NewRecorder()
   191  	cmd.ServeHTTP(rec, req)
   192  	c.Check(rec.Code, check.Equals, 200)
   193  	err = json.Unmarshal(rec.Body.Bytes(), &rst)
   194  	c.Assert(err, check.IsNil)
   195  	c.Check(rst.Maintenance, check.DeepEquals, &errorResult{
   196  		Kind:    client.ErrorKindSystemRestart,
   197  		Message: "system is restarting",
   198  	})
   199  
   200  	state.MockRestarting(d.overlord.State(), state.RestartDaemon)
   201  	rec = httptest.NewRecorder()
   202  	cmd.ServeHTTP(rec, req)
   203  	c.Check(rec.Code, check.Equals, 200)
   204  	err = json.Unmarshal(rec.Body.Bytes(), &rst)
   205  	c.Assert(err, check.IsNil)
   206  	c.Check(rst.Maintenance, check.DeepEquals, &errorResult{
   207  		Kind:    client.ErrorKindDaemonRestart,
   208  		Message: "daemon is restarting",
   209  	})
   210  }
   211  
   212  func (s *daemonSuite) TestFillsWarnings(c *check.C) {
   213  	d := newTestDaemon(c)
   214  
   215  	cmd := &Command{d: d}
   216  	cmd.GET = func(*Command, *http.Request, *auth.UserState) Response {
   217  		return SyncResponse(nil, nil)
   218  	}
   219  	req, err := http.NewRequest("GET", "", nil)
   220  	c.Assert(err, check.IsNil)
   221  	req.RemoteAddr = "pid=100;uid=0;socket=;"
   222  
   223  	rec := httptest.NewRecorder()
   224  	cmd.ServeHTTP(rec, req)
   225  	c.Check(rec.Code, check.Equals, 200)
   226  	var rst struct {
   227  		WarningTimestamp *time.Time `json:"warning-timestamp,omitempty"`
   228  		WarningCount     int        `json:"warning-count,omitempty"`
   229  	}
   230  	err = json.Unmarshal(rec.Body.Bytes(), &rst)
   231  	c.Assert(err, check.IsNil)
   232  	c.Check(rst.WarningCount, check.Equals, 0)
   233  	c.Check(rst.WarningTimestamp, check.IsNil)
   234  
   235  	st := d.overlord.State()
   236  	st.Lock()
   237  	st.Warnf("hello world")
   238  	st.Unlock()
   239  
   240  	rec = httptest.NewRecorder()
   241  	cmd.ServeHTTP(rec, req)
   242  	c.Check(rec.Code, check.Equals, 200)
   243  	err = json.Unmarshal(rec.Body.Bytes(), &rst)
   244  	c.Assert(err, check.IsNil)
   245  	c.Check(rst.WarningCount, check.Equals, 1)
   246  	c.Check(rst.WarningTimestamp, check.NotNil)
   247  }
   248  
   249  func (s *daemonSuite) TestGuestAccess(c *check.C) {
   250  	get := &http.Request{Method: "GET"}
   251  	put := &http.Request{Method: "PUT"}
   252  	pst := &http.Request{Method: "POST"}
   253  	del := &http.Request{Method: "DELETE"}
   254  
   255  	cmd := &Command{d: newTestDaemon(c)}
   256  	c.Check(cmd.canAccess(get, nil), check.Equals, accessUnauthorized)
   257  	c.Check(cmd.canAccess(put, nil), check.Equals, accessUnauthorized)
   258  	c.Check(cmd.canAccess(pst, nil), check.Equals, accessUnauthorized)
   259  	c.Check(cmd.canAccess(del, nil), check.Equals, accessUnauthorized)
   260  
   261  	cmd = &Command{d: newTestDaemon(c), RootOnly: true}
   262  	c.Check(cmd.canAccess(get, nil), check.Equals, accessUnauthorized)
   263  	c.Check(cmd.canAccess(put, nil), check.Equals, accessUnauthorized)
   264  	c.Check(cmd.canAccess(pst, nil), check.Equals, accessUnauthorized)
   265  	c.Check(cmd.canAccess(del, nil), check.Equals, accessUnauthorized)
   266  
   267  	cmd = &Command{d: newTestDaemon(c), UserOK: true}
   268  	c.Check(cmd.canAccess(get, nil), check.Equals, accessUnauthorized)
   269  	c.Check(cmd.canAccess(put, nil), check.Equals, accessUnauthorized)
   270  	c.Check(cmd.canAccess(pst, nil), check.Equals, accessUnauthorized)
   271  	c.Check(cmd.canAccess(del, nil), check.Equals, accessUnauthorized)
   272  
   273  	cmd = &Command{d: newTestDaemon(c), GuestOK: true}
   274  	c.Check(cmd.canAccess(get, nil), check.Equals, accessOK)
   275  	c.Check(cmd.canAccess(put, nil), check.Equals, accessUnauthorized)
   276  	c.Check(cmd.canAccess(pst, nil), check.Equals, accessUnauthorized)
   277  	c.Check(cmd.canAccess(del, nil), check.Equals, accessUnauthorized)
   278  }
   279  
   280  func (s *daemonSuite) TestSnapctlAccessSnapOKWithUser(c *check.C) {
   281  	remoteAddr := "pid=100;uid=1000;socket=" + dirs.SnapSocket + ";"
   282  	get := &http.Request{Method: "GET", RemoteAddr: remoteAddr}
   283  	put := &http.Request{Method: "PUT", RemoteAddr: remoteAddr}
   284  	pst := &http.Request{Method: "POST", RemoteAddr: remoteAddr}
   285  	del := &http.Request{Method: "DELETE", RemoteAddr: remoteAddr}
   286  
   287  	cmd := &Command{d: newTestDaemon(c), SnapOK: true}
   288  	c.Check(cmd.canAccess(get, nil), check.Equals, accessOK)
   289  	c.Check(cmd.canAccess(put, nil), check.Equals, accessOK)
   290  	c.Check(cmd.canAccess(pst, nil), check.Equals, accessOK)
   291  	c.Check(cmd.canAccess(del, nil), check.Equals, accessOK)
   292  }
   293  
   294  func (s *daemonSuite) TestSnapctlAccessSnapOKWithRoot(c *check.C) {
   295  	remoteAddr := "pid=100;uid=0;socket=" + dirs.SnapSocket + ";"
   296  	get := &http.Request{Method: "GET", RemoteAddr: remoteAddr}
   297  	put := &http.Request{Method: "PUT", RemoteAddr: remoteAddr}
   298  	pst := &http.Request{Method: "POST", RemoteAddr: remoteAddr}
   299  	del := &http.Request{Method: "DELETE", RemoteAddr: remoteAddr}
   300  
   301  	cmd := &Command{d: newTestDaemon(c), SnapOK: true}
   302  	c.Check(cmd.canAccess(get, nil), check.Equals, accessOK)
   303  	c.Check(cmd.canAccess(put, nil), check.Equals, accessOK)
   304  	c.Check(cmd.canAccess(pst, nil), check.Equals, accessOK)
   305  	c.Check(cmd.canAccess(del, nil), check.Equals, accessOK)
   306  }
   307  
   308  func (s *daemonSuite) TestUserAccess(c *check.C) {
   309  	get := &http.Request{Method: "GET", RemoteAddr: "pid=100;uid=42;socket=;"}
   310  	put := &http.Request{Method: "PUT", RemoteAddr: "pid=100;uid=42;socket=;"}
   311  
   312  	cmd := &Command{d: newTestDaemon(c)}
   313  	c.Check(cmd.canAccess(get, nil), check.Equals, accessUnauthorized)
   314  	c.Check(cmd.canAccess(put, nil), check.Equals, accessUnauthorized)
   315  
   316  	cmd = &Command{d: newTestDaemon(c), RootOnly: true}
   317  	c.Check(cmd.canAccess(get, nil), check.Equals, accessUnauthorized)
   318  	c.Check(cmd.canAccess(put, nil), check.Equals, accessUnauthorized)
   319  
   320  	cmd = &Command{d: newTestDaemon(c), UserOK: true}
   321  	c.Check(cmd.canAccess(get, nil), check.Equals, accessOK)
   322  	c.Check(cmd.canAccess(put, nil), check.Equals, accessUnauthorized)
   323  
   324  	cmd = &Command{d: newTestDaemon(c), GuestOK: true}
   325  	c.Check(cmd.canAccess(get, nil), check.Equals, accessOK)
   326  	c.Check(cmd.canAccess(put, nil), check.Equals, accessUnauthorized)
   327  
   328  	// Since this request has a RemoteAddr, it must be coming from the snapd
   329  	// socket instead of the snap one. In that case, SnapOK should have no
   330  	// bearing on the default behavior, which is to deny access.
   331  	cmd = &Command{d: newTestDaemon(c), SnapOK: true}
   332  	c.Check(cmd.canAccess(get, nil), check.Equals, accessUnauthorized)
   333  	c.Check(cmd.canAccess(put, nil), check.Equals, accessUnauthorized)
   334  }
   335  
   336  func (s *daemonSuite) TestLoggedInUserAccess(c *check.C) {
   337  	user := &auth.UserState{}
   338  	get := &http.Request{Method: "GET", RemoteAddr: "pid=100;uid=42;socket=;"}
   339  	put := &http.Request{Method: "PUT", RemoteAddr: "pid=100;uid=42;socket=;"}
   340  
   341  	cmd := &Command{d: newTestDaemon(c)}
   342  	c.Check(cmd.canAccess(get, user), check.Equals, accessOK)
   343  	c.Check(cmd.canAccess(put, user), check.Equals, accessOK)
   344  
   345  	cmd = &Command{d: newTestDaemon(c), RootOnly: true}
   346  	c.Check(cmd.canAccess(get, user), check.Equals, accessUnauthorized)
   347  	c.Check(cmd.canAccess(put, user), check.Equals, accessUnauthorized)
   348  
   349  	cmd = &Command{d: newTestDaemon(c), UserOK: true}
   350  	c.Check(cmd.canAccess(get, user), check.Equals, accessOK)
   351  	c.Check(cmd.canAccess(put, user), check.Equals, accessOK)
   352  
   353  	cmd = &Command{d: newTestDaemon(c), GuestOK: true}
   354  	c.Check(cmd.canAccess(get, user), check.Equals, accessOK)
   355  	c.Check(cmd.canAccess(put, user), check.Equals, accessOK)
   356  
   357  	cmd = &Command{d: newTestDaemon(c), SnapOK: true}
   358  	c.Check(cmd.canAccess(get, user), check.Equals, accessOK)
   359  	c.Check(cmd.canAccess(put, user), check.Equals, accessOK)
   360  }
   361  
   362  func (s *daemonSuite) TestSuperAccess(c *check.C) {
   363  	get := &http.Request{Method: "GET", RemoteAddr: "pid=100;uid=0;socket=;"}
   364  	put := &http.Request{Method: "PUT", RemoteAddr: "pid=100;uid=0;socket=;"}
   365  
   366  	cmd := &Command{d: newTestDaemon(c)}
   367  	c.Check(cmd.canAccess(get, nil), check.Equals, accessOK)
   368  	c.Check(cmd.canAccess(put, nil), check.Equals, accessOK)
   369  
   370  	cmd = &Command{d: newTestDaemon(c), RootOnly: true}
   371  	c.Check(cmd.canAccess(get, nil), check.Equals, accessOK)
   372  	c.Check(cmd.canAccess(put, nil), check.Equals, accessOK)
   373  
   374  	cmd = &Command{d: newTestDaemon(c), UserOK: true}
   375  	c.Check(cmd.canAccess(get, nil), check.Equals, accessOK)
   376  	c.Check(cmd.canAccess(put, nil), check.Equals, accessOK)
   377  
   378  	cmd = &Command{d: newTestDaemon(c), GuestOK: true}
   379  	c.Check(cmd.canAccess(get, nil), check.Equals, accessOK)
   380  	c.Check(cmd.canAccess(put, nil), check.Equals, accessOK)
   381  
   382  	cmd = &Command{d: newTestDaemon(c), SnapOK: true}
   383  	c.Check(cmd.canAccess(get, nil), check.Equals, accessOK)
   384  	c.Check(cmd.canAccess(put, nil), check.Equals, accessOK)
   385  }
   386  
   387  func (s *daemonSuite) TestPolkitAccess(c *check.C) {
   388  	put := &http.Request{Method: "PUT", RemoteAddr: "pid=100;uid=42;socket=;"}
   389  	cmd := &Command{d: newTestDaemon(c), PolkitOK: "polkit.action"}
   390  
   391  	// polkit says user is not authorised
   392  	s.authorized = false
   393  	c.Check(cmd.canAccess(put, nil), check.Equals, accessUnauthorized)
   394  
   395  	// polkit grants authorisation
   396  	s.authorized = true
   397  	c.Check(cmd.canAccess(put, nil), check.Equals, accessOK)
   398  
   399  	// an error occurs communicating with polkit
   400  	s.err = errors.New("error")
   401  	c.Check(cmd.canAccess(put, nil), check.Equals, accessUnauthorized)
   402  
   403  	// if the user dismisses the auth request, forbid access
   404  	s.err = polkit.ErrDismissed
   405  	c.Check(cmd.canAccess(put, nil), check.Equals, accessCancelled)
   406  }
   407  
   408  func (s *daemonSuite) TestPolkitAccessForGet(c *check.C) {
   409  	get := &http.Request{Method: "GET", RemoteAddr: "pid=100;uid=42;socket=;"}
   410  	cmd := &Command{d: newTestDaemon(c), PolkitOK: "polkit.action"}
   411  
   412  	// polkit can grant authorisation for GET requests
   413  	s.authorized = true
   414  	c.Check(cmd.canAccess(get, nil), check.Equals, accessOK)
   415  
   416  	// for UserOK commands, polkit is not consulted
   417  	cmd.UserOK = true
   418  	polkitCheckAuthorization = func(pid int32, uid uint32, actionId string, details map[string]string, flags polkit.CheckFlags) (bool, error) {
   419  		panic("polkit.CheckAuthorization called")
   420  	}
   421  	c.Check(cmd.canAccess(get, nil), check.Equals, accessOK)
   422  }
   423  
   424  func (s *daemonSuite) TestPolkitInteractivity(c *check.C) {
   425  	put := &http.Request{Method: "PUT", RemoteAddr: "pid=100;uid=42;socket=;", Header: make(http.Header)}
   426  	cmd := &Command{d: newTestDaemon(c), PolkitOK: "polkit.action"}
   427  	s.authorized = true
   428  
   429  	var logbuf bytes.Buffer
   430  	log, err := logger.New(&logbuf, logger.DefaultFlags)
   431  	c.Assert(err, check.IsNil)
   432  	logger.SetLogger(log)
   433  
   434  	c.Check(cmd.canAccess(put, nil), check.Equals, accessOK)
   435  	c.Check(s.lastPolkitFlags, check.Equals, polkit.CheckNone)
   436  	c.Check(logbuf.String(), check.Equals, "")
   437  
   438  	put.Header.Set(client.AllowInteractionHeader, "true")
   439  	c.Check(cmd.canAccess(put, nil), check.Equals, accessOK)
   440  	c.Check(s.lastPolkitFlags, check.Equals, polkit.CheckAllowInteraction)
   441  	c.Check(logbuf.String(), check.Equals, "")
   442  
   443  	// bad values are logged and treated as false
   444  	put.Header.Set(client.AllowInteractionHeader, "garbage")
   445  	c.Check(cmd.canAccess(put, nil), check.Equals, accessOK)
   446  	c.Check(s.lastPolkitFlags, check.Equals, polkit.CheckNone)
   447  	c.Check(logbuf.String(), testutil.Contains, "error parsing X-Allow-Interaction header:")
   448  }
   449  
   450  func (s *daemonSuite) TestAddRoutes(c *check.C) {
   451  	d := newTestDaemon(c)
   452  
   453  	expected := make([]string, len(api))
   454  	for i, v := range api {
   455  		if v.PathPrefix != "" {
   456  			expected[i] = v.PathPrefix
   457  			continue
   458  		}
   459  		expected[i] = v.Path
   460  	}
   461  
   462  	got := make([]string, 0, len(api))
   463  	c.Assert(d.router.Walk(func(route *mux.Route, router *mux.Router, ancestors []*mux.Route) error {
   464  		got = append(got, route.GetName())
   465  		return nil
   466  	}), check.IsNil)
   467  
   468  	c.Check(got, check.DeepEquals, expected) // this'll stop being true if routes are added that aren't commands (e.g. for the favicon)
   469  
   470  	// XXX: still waiting to know how to check d.router.NotFoundHandler has been set to NotFound
   471  	//      the old test relied on undefined behaviour:
   472  	//      c.Check(fmt.Sprintf("%p", d.router.NotFoundHandler), check.Equals, fmt.Sprintf("%p", NotFound))
   473  }
   474  
   475  type witnessAcceptListener struct {
   476  	net.Listener
   477  
   478  	accept  chan struct{}
   479  	accept1 bool
   480  
   481  	idempotClose sync.Once
   482  	closeErr     error
   483  	closed       chan struct{}
   484  }
   485  
   486  func (l *witnessAcceptListener) Accept() (net.Conn, error) {
   487  	if !l.accept1 {
   488  		l.accept1 = true
   489  		close(l.accept)
   490  	}
   491  	return l.Listener.Accept()
   492  }
   493  
   494  func (l *witnessAcceptListener) Close() error {
   495  	l.idempotClose.Do(func() {
   496  		l.closeErr = l.Listener.Close()
   497  		if l.closed != nil {
   498  			close(l.closed)
   499  		}
   500  	})
   501  	return l.closeErr
   502  }
   503  
   504  func (s *daemonSuite) markSeeded(d *Daemon) {
   505  	st := d.overlord.State()
   506  	st.Lock()
   507  	st.Set("seeded", true)
   508  	devicestatetest.SetDevice(st, &auth.DeviceState{
   509  		Brand:  "canonical",
   510  		Model:  "pc",
   511  		Serial: "serialserial",
   512  	})
   513  	st.Unlock()
   514  }
   515  
   516  func (s *daemonSuite) TestStartStop(c *check.C) {
   517  	d := newTestDaemon(c)
   518  	// mark as already seeded
   519  	s.markSeeded(d)
   520  	// and pretend we have snaps
   521  	st := d.overlord.State()
   522  	st.Lock()
   523  	snapstate.Set(st, "core", &snapstate.SnapState{
   524  		Active: true,
   525  		Sequence: []*snap.SideInfo{
   526  			{RealName: "core", Revision: snap.R(1), SnapID: "core-snap-id"},
   527  		},
   528  		Current: snap.R(1),
   529  	})
   530  	st.Unlock()
   531  	// 1 snap => extended timeout 30s + 5s
   532  	const extendedTimeoutUSec = "EXTEND_TIMEOUT_USEC=35000000"
   533  
   534  	l1, err := net.Listen("tcp", "127.0.0.1:0")
   535  	c.Assert(err, check.IsNil)
   536  	l2, err := net.Listen("tcp", "127.0.0.1:0")
   537  	c.Assert(err, check.IsNil)
   538  
   539  	snapdAccept := make(chan struct{})
   540  	d.snapdListener = &witnessAcceptListener{Listener: l1, accept: snapdAccept}
   541  
   542  	snapAccept := make(chan struct{})
   543  	d.snapListener = &witnessAcceptListener{Listener: l2, accept: snapAccept}
   544  
   545  	c.Assert(d.Start(), check.IsNil)
   546  
   547  	c.Check(s.notified, check.DeepEquals, []string{extendedTimeoutUSec, "READY=1"})
   548  
   549  	snapdDone := make(chan struct{})
   550  	go func() {
   551  		select {
   552  		case <-snapdAccept:
   553  		case <-time.After(2 * time.Second):
   554  			c.Fatal("snapd accept was not called")
   555  		}
   556  		close(snapdDone)
   557  	}()
   558  
   559  	snapDone := make(chan struct{})
   560  	go func() {
   561  		select {
   562  		case <-snapAccept:
   563  		case <-time.After(2 * time.Second):
   564  			c.Fatal("snapd accept was not called")
   565  		}
   566  		close(snapDone)
   567  	}()
   568  
   569  	<-snapdDone
   570  	<-snapDone
   571  
   572  	err = d.Stop(nil)
   573  	c.Check(err, check.IsNil)
   574  
   575  	c.Check(s.notified, check.DeepEquals, []string{extendedTimeoutUSec, "READY=1", "STOPPING=1"})
   576  }
   577  
   578  func (s *daemonSuite) TestRestartWiring(c *check.C) {
   579  	d := newTestDaemon(c)
   580  	// mark as already seeded
   581  	s.markSeeded(d)
   582  
   583  	l, err := net.Listen("tcp", "127.0.0.1:0")
   584  	c.Assert(err, check.IsNil)
   585  
   586  	snapdAccept := make(chan struct{})
   587  	d.snapdListener = &witnessAcceptListener{Listener: l, accept: snapdAccept}
   588  
   589  	snapAccept := make(chan struct{})
   590  	d.snapListener = &witnessAcceptListener{Listener: l, accept: snapAccept}
   591  
   592  	c.Assert(d.Start(), check.IsNil)
   593  	defer d.Stop(nil)
   594  
   595  	snapdDone := make(chan struct{})
   596  	go func() {
   597  		select {
   598  		case <-snapdAccept:
   599  		case <-time.After(2 * time.Second):
   600  			c.Fatal("snapd accept was not called")
   601  		}
   602  		close(snapdDone)
   603  	}()
   604  
   605  	snapDone := make(chan struct{})
   606  	go func() {
   607  		select {
   608  		case <-snapAccept:
   609  		case <-time.After(2 * time.Second):
   610  			c.Fatal("snap accept was not called")
   611  		}
   612  		close(snapDone)
   613  	}()
   614  
   615  	<-snapdDone
   616  	<-snapDone
   617  
   618  	d.overlord.State().RequestRestart(state.RestartDaemon)
   619  
   620  	select {
   621  	case <-d.Dying():
   622  	case <-time.After(2 * time.Second):
   623  		c.Fatal("RequestRestart -> overlord -> Kill chain didn't work")
   624  	}
   625  }
   626  
   627  func (s *daemonSuite) TestGracefulStop(c *check.C) {
   628  	d := newTestDaemon(c)
   629  
   630  	responding := make(chan struct{})
   631  	doRespond := make(chan bool, 1)
   632  
   633  	d.router.HandleFunc("/endp", func(w http.ResponseWriter, r *http.Request) {
   634  		close(responding)
   635  		if <-doRespond {
   636  			w.Write([]byte("OKOK"))
   637  		} else {
   638  			w.Write([]byte("Gone"))
   639  		}
   640  	})
   641  
   642  	// mark as already seeded
   643  	s.markSeeded(d)
   644  	// and pretend we have snaps
   645  	st := d.overlord.State()
   646  	st.Lock()
   647  	snapstate.Set(st, "core", &snapstate.SnapState{
   648  		Active: true,
   649  		Sequence: []*snap.SideInfo{
   650  			{RealName: "core", Revision: snap.R(1), SnapID: "core-snap-id"},
   651  		},
   652  		Current: snap.R(1),
   653  	})
   654  	st.Unlock()
   655  
   656  	snapdL, err := net.Listen("tcp", "127.0.0.1:0")
   657  	c.Assert(err, check.IsNil)
   658  
   659  	snapL, err := net.Listen("tcp", "127.0.0.1:0")
   660  	c.Assert(err, check.IsNil)
   661  
   662  	snapdAccept := make(chan struct{})
   663  	snapdClosed := make(chan struct{})
   664  	d.snapdListener = &witnessAcceptListener{Listener: snapdL, accept: snapdAccept, closed: snapdClosed}
   665  
   666  	snapAccept := make(chan struct{})
   667  	d.snapListener = &witnessAcceptListener{Listener: snapL, accept: snapAccept}
   668  
   669  	c.Assert(d.Start(), check.IsNil)
   670  
   671  	snapdAccepting := make(chan struct{})
   672  	go func() {
   673  		select {
   674  		case <-snapdAccept:
   675  		case <-time.After(2 * time.Second):
   676  			c.Fatal("snapd accept was not called")
   677  		}
   678  		close(snapdAccepting)
   679  	}()
   680  
   681  	snapAccepting := make(chan struct{})
   682  	go func() {
   683  		select {
   684  		case <-snapAccept:
   685  		case <-time.After(2 * time.Second):
   686  			c.Fatal("snapd accept was not called")
   687  		}
   688  		close(snapAccepting)
   689  	}()
   690  
   691  	<-snapdAccepting
   692  	<-snapAccepting
   693  
   694  	alright := make(chan struct{})
   695  
   696  	go func() {
   697  		res, err := http.Get(fmt.Sprintf("http://%s/endp", snapdL.Addr()))
   698  		c.Assert(err, check.IsNil)
   699  		c.Check(res.StatusCode, check.Equals, 200)
   700  		body, err := ioutil.ReadAll(res.Body)
   701  		res.Body.Close()
   702  		c.Assert(err, check.IsNil)
   703  		c.Check(string(body), check.Equals, "OKOK")
   704  		close(alright)
   705  	}()
   706  	go func() {
   707  		<-snapdClosed
   708  		time.Sleep(200 * time.Millisecond)
   709  		doRespond <- true
   710  	}()
   711  
   712  	<-responding
   713  	err = d.Stop(nil)
   714  	doRespond <- false
   715  	c.Check(err, check.IsNil)
   716  
   717  	select {
   718  	case <-alright:
   719  	case <-time.After(2 * time.Second):
   720  		c.Fatal("never got proper response")
   721  	}
   722  }
   723  
   724  func (s *daemonSuite) TestGracefulStopHasLimits(c *check.C) {
   725  	d := newTestDaemon(c)
   726  
   727  	// mark as already seeded
   728  	s.markSeeded(d)
   729  
   730  	restore := MockShutdownTimeout(time.Second)
   731  	defer restore()
   732  
   733  	responding := make(chan struct{})
   734  	doRespond := make(chan bool, 1)
   735  
   736  	d.router.HandleFunc("/endp", func(w http.ResponseWriter, r *http.Request) {
   737  		close(responding)
   738  		if <-doRespond {
   739  			for {
   740  				// write in a loop to keep the handler running
   741  				if _, err := w.Write([]byte("OKOK")); err != nil {
   742  					break
   743  				}
   744  				time.Sleep(50 * time.Millisecond)
   745  			}
   746  		} else {
   747  			w.Write([]byte("Gone"))
   748  		}
   749  	})
   750  
   751  	snapdL, err := net.Listen("tcp", "127.0.0.1:0")
   752  	c.Assert(err, check.IsNil)
   753  
   754  	snapL, err := net.Listen("tcp", "127.0.0.1:0")
   755  	c.Assert(err, check.IsNil)
   756  
   757  	snapdAccept := make(chan struct{})
   758  	snapdClosed := make(chan struct{})
   759  	d.snapdListener = &witnessAcceptListener{Listener: snapdL, accept: snapdAccept, closed: snapdClosed}
   760  
   761  	snapAccept := make(chan struct{})
   762  	d.snapListener = &witnessAcceptListener{Listener: snapL, accept: snapAccept}
   763  
   764  	c.Assert(d.Start(), check.IsNil)
   765  
   766  	snapdAccepting := make(chan struct{})
   767  	go func() {
   768  		select {
   769  		case <-snapdAccept:
   770  		case <-time.After(2 * time.Second):
   771  			c.Fatal("snapd accept was not called")
   772  		}
   773  		close(snapdAccepting)
   774  	}()
   775  
   776  	snapAccepting := make(chan struct{})
   777  	go func() {
   778  		select {
   779  		case <-snapAccept:
   780  		case <-time.After(2 * time.Second):
   781  			c.Fatal("snapd accept was not called")
   782  		}
   783  		close(snapAccepting)
   784  	}()
   785  
   786  	<-snapdAccepting
   787  	<-snapAccepting
   788  
   789  	clientErr := make(chan error)
   790  
   791  	go func() {
   792  		_, err := http.Get(fmt.Sprintf("http://%s/endp", snapdL.Addr()))
   793  		c.Assert(err, check.NotNil)
   794  		clientErr <- err
   795  		close(clientErr)
   796  	}()
   797  	go func() {
   798  		<-snapdClosed
   799  		time.Sleep(200 * time.Millisecond)
   800  		doRespond <- true
   801  	}()
   802  
   803  	<-responding
   804  	err = d.Stop(nil)
   805  	doRespond <- false
   806  	c.Check(err, check.IsNil)
   807  
   808  	select {
   809  	case cErr := <-clientErr:
   810  		c.Check(cErr, check.ErrorMatches, ".*: EOF")
   811  	case <-time.After(5 * time.Second):
   812  		c.Fatal("never got proper response")
   813  	}
   814  }
   815  
   816  func (s *daemonSuite) testRestartSystemWiring(c *check.C, restartKind state.RestartType) {
   817  	d := newTestDaemon(c)
   818  	// mark as already seeded
   819  	s.markSeeded(d)
   820  
   821  	l, err := net.Listen("tcp", "127.0.0.1:0")
   822  	c.Assert(err, check.IsNil)
   823  
   824  	snapdAccept := make(chan struct{})
   825  	d.snapdListener = &witnessAcceptListener{Listener: l, accept: snapdAccept}
   826  
   827  	snapAccept := make(chan struct{})
   828  	d.snapListener = &witnessAcceptListener{Listener: l, accept: snapAccept}
   829  
   830  	c.Assert(d.Start(), check.IsNil)
   831  	defer d.Stop(nil)
   832  
   833  	st := d.overlord.State()
   834  
   835  	snapdDone := make(chan struct{})
   836  	go func() {
   837  		select {
   838  		case <-snapdAccept:
   839  		case <-time.After(2 * time.Second):
   840  			c.Fatal("snapd accept was not called")
   841  		}
   842  		close(snapdDone)
   843  	}()
   844  
   845  	snapDone := make(chan struct{})
   846  	go func() {
   847  		select {
   848  		case <-snapAccept:
   849  		case <-time.After(2 * time.Second):
   850  			c.Fatal("snap accept was not called")
   851  		}
   852  		close(snapDone)
   853  	}()
   854  
   855  	<-snapdDone
   856  	<-snapDone
   857  
   858  	oldRebootNoticeWait := rebootNoticeWait
   859  	oldRebootWaitTimeout := rebootWaitTimeout
   860  	defer func() {
   861  		reboot = rebootImpl
   862  		rebootNoticeWait = oldRebootNoticeWait
   863  		rebootWaitTimeout = oldRebootWaitTimeout
   864  	}()
   865  	rebootWaitTimeout = 100 * time.Millisecond
   866  	rebootNoticeWait = 150 * time.Millisecond
   867  
   868  	var delays []time.Duration
   869  	reboot = func(d time.Duration) error {
   870  		delays = append(delays, d)
   871  		return nil
   872  	}
   873  
   874  	st.Lock()
   875  	st.RequestRestart(restartKind)
   876  	st.Unlock()
   877  
   878  	defer func() {
   879  		d.mu.Lock()
   880  		d.restartSystem = state.RestartUnset
   881  		d.mu.Unlock()
   882  	}()
   883  
   884  	select {
   885  	case <-d.Dying():
   886  	case <-time.After(2 * time.Second):
   887  		c.Fatal("RequestRestart -> overlord -> Kill chain didn't work")
   888  	}
   889  
   890  	d.mu.Lock()
   891  	rs := d.restartSystem
   892  	d.mu.Unlock()
   893  
   894  	c.Check(rs, check.Equals, restartKind)
   895  
   896  	c.Check(delays, check.HasLen, 1)
   897  	c.Check(delays[0], check.DeepEquals, rebootWaitTimeout)
   898  
   899  	now := time.Now()
   900  
   901  	err = d.Stop(nil)
   902  
   903  	c.Check(err, check.ErrorMatches, "expected reboot did not happen")
   904  
   905  	c.Check(delays, check.HasLen, 2)
   906  	if restartKind == state.RestartSystem {
   907  		c.Check(delays[1], check.DeepEquals, 1*time.Minute)
   908  	} else if restartKind == state.RestartSystemNow {
   909  		c.Check(delays[1], check.DeepEquals, time.Duration(0))
   910  	}
   911  
   912  	// we are not stopping, we wait for the reboot instead
   913  	c.Check(s.notified, check.DeepEquals, []string{"EXTEND_TIMEOUT_USEC=30000000", "READY=1"})
   914  
   915  	st.Lock()
   916  	defer st.Unlock()
   917  	var rebootAt time.Time
   918  	err = st.Get("daemon-system-restart-at", &rebootAt)
   919  	c.Assert(err, check.IsNil)
   920  	if restartKind == state.RestartSystem {
   921  		approxAt := now.Add(time.Minute)
   922  		c.Check(rebootAt.After(approxAt) || rebootAt.Equal(approxAt), check.Equals, true)
   923  	} else if restartKind == state.RestartSystemNow {
   924  		// should be good enough
   925  		c.Check(rebootAt.Before(now.Add(10*time.Second)), check.Equals, true)
   926  	}
   927  }
   928  
   929  func (s *daemonSuite) TestRestartSystemGracefulWiring(c *check.C) {
   930  	s.testRestartSystemWiring(c, state.RestartSystem)
   931  }
   932  
   933  func (s *daemonSuite) TestRestartSystemImmediateWiring(c *check.C) {
   934  	s.testRestartSystemWiring(c, state.RestartSystemNow)
   935  }
   936  
   937  func (s *daemonSuite) TestRebootHelper(c *check.C) {
   938  	cmd := testutil.MockCommand(c, "shutdown", "")
   939  	defer cmd.Restore()
   940  
   941  	tests := []struct {
   942  		delay    time.Duration
   943  		delayArg string
   944  	}{
   945  		{-1, "+0"},
   946  		{0, "+0"},
   947  		{time.Minute, "+1"},
   948  		{10 * time.Minute, "+10"},
   949  		{30 * time.Second, "+0"},
   950  	}
   951  
   952  	for _, t := range tests {
   953  		err := reboot(t.delay)
   954  		c.Assert(err, check.IsNil)
   955  		c.Check(cmd.Calls(), check.DeepEquals, [][]string{
   956  			{"shutdown", "-r", t.delayArg, "reboot scheduled to update the system"},
   957  		})
   958  
   959  		cmd.ForgetCalls()
   960  	}
   961  }
   962  
   963  func makeDaemonListeners(c *check.C, d *Daemon) {
   964  	snapdL, err := net.Listen("tcp", "127.0.0.1:0")
   965  	c.Assert(err, check.IsNil)
   966  
   967  	snapL, err := net.Listen("tcp", "127.0.0.1:0")
   968  	c.Assert(err, check.IsNil)
   969  
   970  	snapdAccept := make(chan struct{})
   971  	snapdClosed := make(chan struct{})
   972  	d.snapdListener = &witnessAcceptListener{Listener: snapdL, accept: snapdAccept, closed: snapdClosed}
   973  
   974  	snapAccept := make(chan struct{})
   975  	d.snapListener = &witnessAcceptListener{Listener: snapL, accept: snapAccept}
   976  }
   977  
   978  // This test tests that when the snapd calls a restart of the system
   979  // a sigterm (from e.g. systemd) is handled when it arrives before
   980  // stop is fully done.
   981  func (s *daemonSuite) TestRestartShutdownWithSigtermInBetween(c *check.C) {
   982  	oldRebootNoticeWait := rebootNoticeWait
   983  	defer func() {
   984  		rebootNoticeWait = oldRebootNoticeWait
   985  	}()
   986  	rebootNoticeWait = 150 * time.Millisecond
   987  
   988  	cmd := testutil.MockCommand(c, "shutdown", "")
   989  	defer cmd.Restore()
   990  
   991  	d := newTestDaemon(c)
   992  	makeDaemonListeners(c, d)
   993  	s.markSeeded(d)
   994  
   995  	c.Assert(d.Start(), check.IsNil)
   996  	st := d.overlord.State()
   997  
   998  	st.Lock()
   999  	st.RequestRestart(state.RestartSystem)
  1000  	st.Unlock()
  1001  
  1002  	ch := make(chan os.Signal, 2)
  1003  	ch <- syscall.SIGTERM
  1004  	// stop will check if we got a sigterm in between (which we did)
  1005  	err := d.Stop(ch)
  1006  	c.Assert(err, check.IsNil)
  1007  }
  1008  
  1009  // This test tests that when there is a shutdown we close the sigterm
  1010  // handler so that systemd can kill snapd.
  1011  func (s *daemonSuite) TestRestartShutdown(c *check.C) {
  1012  	oldRebootNoticeWait := rebootNoticeWait
  1013  	oldRebootWaitTimeout := rebootWaitTimeout
  1014  	defer func() {
  1015  		rebootNoticeWait = oldRebootNoticeWait
  1016  		rebootWaitTimeout = oldRebootWaitTimeout
  1017  	}()
  1018  	rebootWaitTimeout = 100 * time.Millisecond
  1019  	rebootNoticeWait = 150 * time.Millisecond
  1020  
  1021  	cmd := testutil.MockCommand(c, "shutdown", "")
  1022  	defer cmd.Restore()
  1023  
  1024  	d := newTestDaemon(c)
  1025  	makeDaemonListeners(c, d)
  1026  	s.markSeeded(d)
  1027  
  1028  	c.Assert(d.Start(), check.IsNil)
  1029  	st := d.overlord.State()
  1030  
  1031  	st.Lock()
  1032  	st.RequestRestart(state.RestartSystem)
  1033  	st.Unlock()
  1034  
  1035  	sigCh := make(chan os.Signal, 2)
  1036  	// stop (this will timeout but that's not relevant for this test)
  1037  	d.Stop(sigCh)
  1038  
  1039  	// ensure that the sigCh got closed as part of the stop
  1040  	_, chOpen := <-sigCh
  1041  	c.Assert(chOpen, check.Equals, false)
  1042  }
  1043  
  1044  func (s *daemonSuite) TestRestartExpectedRebootDidNotHappen(c *check.C) {
  1045  	curBootID, err := osutil.BootID()
  1046  	c.Assert(err, check.IsNil)
  1047  
  1048  	fakeState := []byte(fmt.Sprintf(`{"data":{"patch-level":%d,"patch-sublevel":%d,"some":"data","refresh-privacy-key":"0123456789ABCDEF","system-restart-from-boot-id":%q,"daemon-system-restart-at":"%s"},"changes":null,"tasks":null,"last-change-id":0,"last-task-id":0,"last-lane-id":0}`, patch.Level, patch.Sublevel, curBootID, time.Now().UTC().Format(time.RFC3339)))
  1049  	err = ioutil.WriteFile(dirs.SnapStateFile, fakeState, 0600)
  1050  	c.Assert(err, check.IsNil)
  1051  
  1052  	oldRebootNoticeWait := rebootNoticeWait
  1053  	oldRebootRetryWaitTimeout := rebootRetryWaitTimeout
  1054  	defer func() {
  1055  		rebootNoticeWait = oldRebootNoticeWait
  1056  		rebootRetryWaitTimeout = oldRebootRetryWaitTimeout
  1057  	}()
  1058  	rebootRetryWaitTimeout = 100 * time.Millisecond
  1059  	rebootNoticeWait = 150 * time.Millisecond
  1060  
  1061  	cmd := testutil.MockCommand(c, "shutdown", "")
  1062  	defer cmd.Restore()
  1063  
  1064  	d := newTestDaemon(c)
  1065  	c.Check(d.overlord, check.IsNil)
  1066  	c.Check(d.expectedRebootDidNotHappen, check.Equals, true)
  1067  
  1068  	var n int
  1069  	d.state.Lock()
  1070  	err = d.state.Get("daemon-system-restart-tentative", &n)
  1071  	d.state.Unlock()
  1072  	c.Check(err, check.IsNil)
  1073  	c.Check(n, check.Equals, 1)
  1074  
  1075  	c.Assert(d.Start(), check.IsNil)
  1076  
  1077  	c.Check(s.notified, check.DeepEquals, []string{"READY=1"})
  1078  
  1079  	select {
  1080  	case <-d.Dying():
  1081  	case <-time.After(2 * time.Second):
  1082  		c.Fatal("expected reboot not happening should proceed to try to shutdown again")
  1083  	}
  1084  
  1085  	sigCh := make(chan os.Signal, 2)
  1086  	// stop (this will timeout but thats not relevant for this test)
  1087  	d.Stop(sigCh)
  1088  
  1089  	// an immediate shutdown was scheduled again
  1090  	c.Check(cmd.Calls(), check.DeepEquals, [][]string{
  1091  		{"shutdown", "-r", "+0", "reboot scheduled to update the system"},
  1092  	})
  1093  }
  1094  
  1095  func (s *daemonSuite) TestRestartExpectedRebootOK(c *check.C) {
  1096  	fakeState := []byte(fmt.Sprintf(`{"data":{"patch-level":%d,"patch-sublevel":%d,"some":"data","refresh-privacy-key":"0123456789ABCDEF","system-restart-from-boot-id":%q,"daemon-system-restart-at":"%s"},"changes":null,"tasks":null,"last-change-id":0,"last-task-id":0,"last-lane-id":0}`, patch.Level, patch.Sublevel, "boot-id-0", time.Now().UTC().Format(time.RFC3339)))
  1097  	err := ioutil.WriteFile(dirs.SnapStateFile, fakeState, 0600)
  1098  	c.Assert(err, check.IsNil)
  1099  
  1100  	cmd := testutil.MockCommand(c, "shutdown", "")
  1101  	defer cmd.Restore()
  1102  
  1103  	d := newTestDaemon(c)
  1104  	c.Assert(d.overlord, check.NotNil)
  1105  
  1106  	st := d.overlord.State()
  1107  	st.Lock()
  1108  	defer st.Unlock()
  1109  	var v interface{}
  1110  	// these were cleared
  1111  	c.Check(st.Get("daemon-system-restart-at", &v), check.Equals, state.ErrNoState)
  1112  	c.Check(st.Get("system-restart-from-boot-id", &v), check.Equals, state.ErrNoState)
  1113  }
  1114  
  1115  func (s *daemonSuite) TestRestartExpectedRebootGiveUp(c *check.C) {
  1116  	// we give up trying to restart the system after 3 retry tentatives
  1117  	curBootID, err := osutil.BootID()
  1118  	c.Assert(err, check.IsNil)
  1119  
  1120  	fakeState := []byte(fmt.Sprintf(`{"data":{"patch-level":%d,"patch-sublevel":%d,"some":"data","refresh-privacy-key":"0123456789ABCDEF","system-restart-from-boot-id":%q,"daemon-system-restart-at":"%s","daemon-system-restart-tentative":3},"changes":null,"tasks":null,"last-change-id":0,"last-task-id":0,"last-lane-id":0}`, patch.Level, patch.Sublevel, curBootID, time.Now().UTC().Format(time.RFC3339)))
  1121  	err = ioutil.WriteFile(dirs.SnapStateFile, fakeState, 0600)
  1122  	c.Assert(err, check.IsNil)
  1123  
  1124  	cmd := testutil.MockCommand(c, "shutdown", "")
  1125  	defer cmd.Restore()
  1126  
  1127  	d := newTestDaemon(c)
  1128  	c.Assert(d.overlord, check.NotNil)
  1129  
  1130  	st := d.overlord.State()
  1131  	st.Lock()
  1132  	defer st.Unlock()
  1133  	var v interface{}
  1134  	// these were cleared
  1135  	c.Check(st.Get("daemon-system-restart-at", &v), check.Equals, state.ErrNoState)
  1136  	c.Check(st.Get("system-restart-from-boot-id", &v), check.Equals, state.ErrNoState)
  1137  	c.Check(st.Get("daemon-system-restart-tentative", &v), check.Equals, state.ErrNoState)
  1138  }
  1139  
  1140  func (s *daemonSuite) TestRestartIntoSocketModeNoNewChanges(c *check.C) {
  1141  	restore := standby.MockStandbyWait(5 * time.Millisecond)
  1142  	defer restore()
  1143  
  1144  	d := newTestDaemon(c)
  1145  	makeDaemonListeners(c, d)
  1146  
  1147  	// mark as already seeded, we also have no snaps so this will
  1148  	// go into socket activation mode
  1149  	s.markSeeded(d)
  1150  
  1151  	c.Assert(d.Start(), check.IsNil)
  1152  	// pretend some ensure happened
  1153  	for i := 0; i < 5; i++ {
  1154  		c.Check(d.overlord.StateEngine().Ensure(), check.IsNil)
  1155  		time.Sleep(5 * time.Millisecond)
  1156  	}
  1157  
  1158  	select {
  1159  	case <-d.Dying():
  1160  		// exit the loop
  1161  	case <-time.After(15 * time.Second):
  1162  		c.Errorf("daemon did not stop after 15s")
  1163  	}
  1164  	err := d.Stop(nil)
  1165  	c.Check(err, check.Equals, ErrRestartSocket)
  1166  	c.Check(d.restartSocket, check.Equals, true)
  1167  }
  1168  
  1169  func (s *daemonSuite) TestRestartIntoSocketModePendingChanges(c *check.C) {
  1170  	restore := standby.MockStandbyWait(5 * time.Millisecond)
  1171  	defer restore()
  1172  
  1173  	d := newTestDaemon(c)
  1174  	makeDaemonListeners(c, d)
  1175  
  1176  	// mark as already seeded, we also have no snaps so this will
  1177  	// go into socket activation mode
  1178  	s.markSeeded(d)
  1179  	st := d.overlord.State()
  1180  
  1181  	c.Assert(d.Start(), check.IsNil)
  1182  	// pretend some ensure happened
  1183  	for i := 0; i < 5; i++ {
  1184  		c.Check(d.overlord.StateEngine().Ensure(), check.IsNil)
  1185  		time.Sleep(5 * time.Millisecond)
  1186  	}
  1187  
  1188  	select {
  1189  	case <-d.Dying():
  1190  		// Pretend we got change while shutting down, this can
  1191  		// happen when e.g. the user requested a `snap install
  1192  		// foo` at the same time as the code in the overlord
  1193  		// checked that it can go into socket activated
  1194  		// mode. I.e. the daemon was processing the request
  1195  		// but no change was generated at the time yet.
  1196  		st.Lock()
  1197  		chg := st.NewChange("fake-install", "fake install some snap")
  1198  		chg.AddTask(st.NewTask("fake-install-task", "fake install task"))
  1199  		chgStatus := chg.Status()
  1200  		st.Unlock()
  1201  		// ensure our change is valid and ready
  1202  		c.Check(chgStatus, check.Equals, state.DoStatus)
  1203  	case <-time.After(5 * time.Second):
  1204  		c.Errorf("daemon did not stop after 5s")
  1205  	}
  1206  	// when the daemon got a pending change it just restarts
  1207  	err := d.Stop(nil)
  1208  	c.Check(err, check.IsNil)
  1209  	c.Check(d.restartSocket, check.Equals, false)
  1210  }
  1211  
  1212  func (s *daemonSuite) TestConnTrackerCanShutdown(c *check.C) {
  1213  	ct := &connTracker{conns: make(map[net.Conn]struct{})}
  1214  	c.Check(ct.CanStandby(), check.Equals, true)
  1215  
  1216  	con := &net.IPConn{}
  1217  	ct.trackConn(con, http.StateActive)
  1218  	c.Check(ct.CanStandby(), check.Equals, false)
  1219  
  1220  	ct.trackConn(con, http.StateIdle)
  1221  	c.Check(ct.CanStandby(), check.Equals, true)
  1222  }
  1223  
  1224  func doTestReq(c *check.C, cmd *Command, mth string) *httptest.ResponseRecorder {
  1225  	req, err := http.NewRequest(mth, "", nil)
  1226  	c.Assert(err, check.IsNil)
  1227  	req.RemoteAddr = "pid=100;uid=0;socket=;"
  1228  	rec := httptest.NewRecorder()
  1229  	cmd.ServeHTTP(rec, req)
  1230  	return rec
  1231  }
  1232  
  1233  func (s *daemonSuite) TestDegradedModeReply(c *check.C) {
  1234  	d := newTestDaemon(c)
  1235  	cmd := &Command{d: d}
  1236  	cmd.GET = func(*Command, *http.Request, *auth.UserState) Response {
  1237  		return SyncResponse(nil, nil)
  1238  	}
  1239  	cmd.POST = func(*Command, *http.Request, *auth.UserState) Response {
  1240  		return SyncResponse(nil, nil)
  1241  	}
  1242  
  1243  	// pretend we are in degraded mode
  1244  	d.SetDegradedMode(fmt.Errorf("foo error"))
  1245  
  1246  	// GET is ok even in degraded mode
  1247  	rec := doTestReq(c, cmd, "GET")
  1248  	c.Check(rec.Code, check.Equals, 200)
  1249  	// POST is not allowed
  1250  	rec = doTestReq(c, cmd, "POST")
  1251  	c.Check(rec.Code, check.Equals, 500)
  1252  	// verify we get the error
  1253  	var v struct{ Result errorResult }
  1254  	c.Assert(json.NewDecoder(rec.Body).Decode(&v), check.IsNil)
  1255  	c.Check(v.Result.Message, check.Equals, "foo error")
  1256  
  1257  	// clean degraded mode
  1258  	d.SetDegradedMode(nil)
  1259  	rec = doTestReq(c, cmd, "POST")
  1260  	c.Check(rec.Code, check.Equals, 200)
  1261  }