github.com/kubiko/snapd@v0.0.0-20201013125620-d4f3094d9ddf/usersession/agent/session_agent_test.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2019 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package agent_test
    21  
    22  import (
    23  	"encoding/json"
    24  	"fmt"
    25  	"net"
    26  	"net/http"
    27  	"os"
    28  	"syscall"
    29  	"testing"
    30  	"time"
    31  
    32  	. "gopkg.in/check.v1"
    33  
    34  	"github.com/snapcore/snapd/dirs"
    35  	"github.com/snapcore/snapd/logger"
    36  	"github.com/snapcore/snapd/osutil/sys"
    37  	"github.com/snapcore/snapd/testutil"
    38  	"github.com/snapcore/snapd/usersession/agent"
    39  )
    40  
    41  func Test(t *testing.T) { TestingT(t) }
    42  
    43  type sessionAgentSuite struct {
    44  	socketPath string
    45  	client     *http.Client
    46  }
    47  
    48  var _ = Suite(&sessionAgentSuite{})
    49  
    50  func (s *sessionAgentSuite) SetUpTest(c *C) {
    51  	dirs.SetRootDir(c.MkDir())
    52  	xdgRuntimeDir := fmt.Sprintf("%s/%d", dirs.XdgRuntimeDirBase, os.Getuid())
    53  	c.Assert(os.MkdirAll(xdgRuntimeDir, 0700), IsNil)
    54  	s.socketPath = fmt.Sprintf("%s/snapd-session-agent.socket", xdgRuntimeDir)
    55  
    56  	transport := &http.Transport{
    57  		Dial: func(_, _ string) (net.Conn, error) {
    58  			return net.Dial("unix", s.socketPath)
    59  		},
    60  		DisableKeepAlives: true,
    61  	}
    62  	s.client = &http.Client{Transport: transport}
    63  }
    64  
    65  func (s *sessionAgentSuite) TearDownTest(c *C) {
    66  	dirs.SetRootDir("")
    67  	logger.SetLogger(logger.NullLogger)
    68  }
    69  
    70  func (s *sessionAgentSuite) TestStartStop(c *C) {
    71  	agent, err := agent.New()
    72  	c.Assert(err, IsNil)
    73  	agent.Version = "42"
    74  	agent.Start()
    75  	defer func() { c.Check(agent.Stop(), IsNil) }()
    76  
    77  	response, err := s.client.Get("http://localhost/v1/session-info")
    78  	c.Assert(err, IsNil)
    79  	defer response.Body.Close()
    80  	c.Check(response.StatusCode, Equals, 200)
    81  
    82  	var rst struct {
    83  		Result struct {
    84  			Version string `json:"version"`
    85  		} `json:"result"`
    86  	}
    87  	c.Assert(json.NewDecoder(response.Body).Decode(&rst), IsNil)
    88  	c.Check(rst.Result.Version, Equals, "42")
    89  	response.Body.Close()
    90  
    91  	c.Check(agent.Stop(), IsNil)
    92  }
    93  
    94  func (s *sessionAgentSuite) TestDying(c *C) {
    95  	agent, err := agent.New()
    96  	c.Assert(err, IsNil)
    97  	agent.Start()
    98  	select {
    99  	case <-agent.Dying():
   100  		c.Error("agent.Dying() channel closed prematurely")
   101  	default:
   102  	}
   103  	go func() {
   104  		time.Sleep(5 * time.Millisecond)
   105  		c.Check(agent.Stop(), IsNil)
   106  	}()
   107  	select {
   108  	case <-agent.Dying():
   109  	case <-time.After(2 * time.Second):
   110  		c.Error("agent.Dying() channel was not closed when agent stopped")
   111  	}
   112  }
   113  
   114  func (s *sessionAgentSuite) TestExitOnIdle(c *C) {
   115  	agent, err := agent.New()
   116  	c.Assert(err, IsNil)
   117  	agent.IdleTimeout = 150 * time.Millisecond
   118  	startTime := time.Now()
   119  	agent.Start()
   120  	defer agent.Stop()
   121  
   122  	makeRequest := func() {
   123  		response, err := s.client.Get("http://localhost/v1/session-info")
   124  		c.Assert(err, IsNil)
   125  		defer response.Body.Close()
   126  		c.Check(response.StatusCode, Equals, 200)
   127  	}
   128  	makeRequest()
   129  	time.Sleep(25 * time.Millisecond)
   130  	makeRequest()
   131  
   132  	select {
   133  	case <-agent.Dying():
   134  	case <-time.After(2 * time.Second):
   135  		c.Fatal("agent did not exit after idle timeout expired")
   136  	}
   137  	elapsed := time.Since(startTime)
   138  	if elapsed < 175*time.Millisecond || elapsed > 450*time.Millisecond {
   139  		// The idle timeout should have been extended when we
   140  		// issued a second request after 25ms.
   141  		c.Errorf("Expected ellaped time close to 175 ms, but got %v", elapsed)
   142  	}
   143  }
   144  
   145  func (s *sessionAgentSuite) TestConnectFromOtherUser(c *C) {
   146  	logbuf, restore := logger.MockLogger()
   147  	defer restore()
   148  
   149  	// Mock connections to appear to come from a different user ID
   150  	uid := uint32(sys.Geteuid())
   151  	restore = agent.MockUcred(&syscall.Ucred{Uid: uid + 1}, nil)
   152  	defer restore()
   153  
   154  	sa, err := agent.New()
   155  	c.Assert(err, IsNil)
   156  	sa.Start()
   157  	defer sa.Stop()
   158  
   159  	_, err = s.client.Get("http://localhost/v1/session-info")
   160  	// This could be an EOF error or a failed read, depending on timing
   161  	c.Assert(err, ErrorMatches, "Get \"?http://localhost/v1/session-info\"?: .*")
   162  	logger.WithLoggerLock(func() {
   163  		c.Check(logbuf.String(), testutil.Contains, "Blocking request from user ID")
   164  	})
   165  }
   166  
   167  func (s *sessionAgentSuite) TestConnectFromRoot(c *C) {
   168  	logbuf, restore := logger.MockLogger()
   169  	defer restore()
   170  
   171  	// Mock connections to appear to come from root
   172  	restore = agent.MockUcred(&syscall.Ucred{Uid: 0}, nil)
   173  	defer restore()
   174  
   175  	sa, err := agent.New()
   176  	c.Assert(err, IsNil)
   177  	sa.Start()
   178  	defer sa.Stop()
   179  
   180  	response, err := s.client.Get("http://localhost/v1/session-info")
   181  	c.Assert(err, IsNil)
   182  	defer response.Body.Close()
   183  	c.Check(response.StatusCode, Equals, 200)
   184  	logger.WithLoggerLock(func() {
   185  		c.Check(logbuf.String(), Equals, "")
   186  	})
   187  }
   188  
   189  func (s *sessionAgentSuite) TestConnectWithFailedPeerCredentials(c *C) {
   190  	logbuf, restore := logger.MockLogger()
   191  	defer restore()
   192  
   193  	// Connections are dropped if peer credential lookup fails.
   194  	restore = agent.MockUcred(nil, fmt.Errorf("SO_PEERCRED failed"))
   195  	defer restore()
   196  
   197  	sa, err := agent.New()
   198  	c.Assert(err, IsNil)
   199  	sa.Start()
   200  	defer sa.Stop()
   201  
   202  	_, err = s.client.Get("http://localhost/v1/session-info")
   203  	c.Assert(err, ErrorMatches, "Get \"?http://localhost/v1/session-info\"?: .*")
   204  	logger.WithLoggerLock(func() {
   205  		c.Check(logbuf.String(), testutil.Contains, "Failed to retrieve peer credentials: SO_PEERCRED failed")
   206  	})
   207  }