github.com/kubiko/snapd@v0.0.0-20201013125620-d4f3094d9ddf/cmd/snap/cmd_userd_test.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  // +build !darwin
     3  
     4  /*
     5   * Copyright (C) 2016-2019 Canonical Ltd
     6   *
     7   * This program is free software: you can redistribute it and/or modify
     8   * it under the terms of the GNU General Public License version 3 as
     9   * published by the Free Software Foundation.
    10   *
    11   * This program is distributed in the hope that it will be useful,
    12   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    13   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    14   * GNU General Public License for more details.
    15   *
    16   * You should have received a copy of the GNU General Public License
    17   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    18   *
    19   */
    20  
    21  package main_test
    22  
    23  import (
    24  	"fmt"
    25  	"net"
    26  	"net/http"
    27  	"os"
    28  	"strings"
    29  	"syscall"
    30  	"time"
    31  
    32  	. "gopkg.in/check.v1"
    33  
    34  	snap "github.com/snapcore/snapd/cmd/snap"
    35  	"github.com/snapcore/snapd/dirs"
    36  	"github.com/snapcore/snapd/logger"
    37  	"github.com/snapcore/snapd/osutil"
    38  	"github.com/snapcore/snapd/testutil"
    39  )
    40  
    41  type userdSuite struct {
    42  	BaseSnapSuite
    43  	testutil.DBusTest
    44  
    45  	agentSocketPath string
    46  }
    47  
    48  var _ = Suite(&userdSuite{})
    49  
    50  func (s *userdSuite) SetUpTest(c *C) {
    51  	s.BaseSnapSuite.SetUpTest(c)
    52  	s.DBusTest.SetUpTest(c)
    53  
    54  	_, restore := logger.MockLogger()
    55  	s.AddCleanup(restore)
    56  
    57  	xdgRuntimeDir := fmt.Sprintf("%s/%d", dirs.XdgRuntimeDirBase, os.Getuid())
    58  	c.Assert(os.MkdirAll(xdgRuntimeDir, 0700), IsNil)
    59  	s.agentSocketPath = fmt.Sprintf("%s/snapd-session-agent.socket", xdgRuntimeDir)
    60  }
    61  
    62  func (s *userdSuite) TearDownTest(c *C) {
    63  	s.BaseSnapSuite.TearDownTest(c)
    64  	s.DBusTest.TearDownTest(c)
    65  }
    66  
    67  func (s *userdSuite) TestUserdBadCommandline(c *C) {
    68  	_, err := snap.Parser(snap.Client()).ParseArgs([]string{"userd", "extra-arg"})
    69  	c.Assert(err, ErrorMatches, "too many arguments for command")
    70  }
    71  
    72  type mockSignal struct{}
    73  
    74  func (m *mockSignal) String() string {
    75  	return "<test signal>"
    76  }
    77  
    78  func (m *mockSignal) Signal() {}
    79  
    80  func (s *userdSuite) TestUserdDBus(c *C) {
    81  	sigCh := make(chan os.Signal, 1)
    82  	sigStopCalls := 0
    83  
    84  	restore := snap.MockSignalNotify(func(sig ...os.Signal) (chan os.Signal, func()) {
    85  		c.Assert(sig, DeepEquals, []os.Signal{syscall.SIGINT, syscall.SIGTERM})
    86  		return sigCh, func() { sigStopCalls++ }
    87  	})
    88  	defer restore()
    89  
    90  	go func() {
    91  		myPid := os.Getpid()
    92  
    93  		defer func() {
    94  			sigCh <- &mockSignal{}
    95  		}()
    96  
    97  		names := map[string]bool{
    98  			"io.snapcraft.Launcher": false,
    99  			"io.snapcraft.Settings": false,
   100  		}
   101  		for i := 0; i < 1000; i++ {
   102  			seenCount := 0
   103  			for name, seen := range names {
   104  				if seen {
   105  					seenCount++
   106  					continue
   107  				}
   108  				pid, err := testutil.DBusGetConnectionUnixProcessID(s.SessionBus, name)
   109  				c.Logf("name: %v pid: %v err: %v", name, pid, err)
   110  				if pid == myPid {
   111  					names[name] = true
   112  					seenCount++
   113  				}
   114  			}
   115  			if seenCount == len(names) {
   116  				return
   117  			}
   118  			time.Sleep(10 * time.Millisecond)
   119  		}
   120  		c.Fatalf("not all names have appeared on the bus: %v", names)
   121  	}()
   122  
   123  	rest, err := snap.Parser(snap.Client()).ParseArgs([]string{"userd"})
   124  	c.Assert(err, IsNil)
   125  	c.Check(rest, DeepEquals, []string{})
   126  	c.Check(strings.ToLower(s.Stdout()), Equals, "exiting on <test signal>.\n")
   127  	c.Check(sigStopCalls, Equals, 1)
   128  }
   129  
   130  func (s *userdSuite) makeAgentClient() *http.Client {
   131  	transport := &http.Transport{
   132  		Dial: func(_, _ string) (net.Conn, error) {
   133  			return net.Dial("unix", s.agentSocketPath)
   134  		},
   135  		DisableKeepAlives: true,
   136  	}
   137  	return &http.Client{Transport: transport}
   138  }
   139  
   140  func (s *userdSuite) TestSessionAgentSocket(c *C) {
   141  	sigCh := make(chan os.Signal, 1)
   142  	sigStopCalls := 0
   143  
   144  	restore := snap.MockSignalNotify(func(sig ...os.Signal) (chan os.Signal, func()) {
   145  		c.Assert(sig, DeepEquals, []os.Signal{syscall.SIGINT, syscall.SIGTERM})
   146  		return sigCh, func() { sigStopCalls++ }
   147  	})
   148  	defer restore()
   149  
   150  	go func() {
   151  		defer func() {
   152  			sigCh <- &mockSignal{}
   153  		}()
   154  
   155  		// Wait for command to create socket file
   156  		for i := 0; i < 1000; i++ {
   157  			if osutil.FileExists(s.agentSocketPath) {
   158  				break
   159  			}
   160  			time.Sleep(10 * time.Millisecond)
   161  		}
   162  
   163  		// Check that agent functions
   164  		client := s.makeAgentClient()
   165  		response, err := client.Get("http://localhost/v1/session-info")
   166  		c.Assert(err, IsNil)
   167  		defer response.Body.Close()
   168  		c.Check(response.StatusCode, Equals, 200)
   169  	}()
   170  
   171  	rest, err := snap.Parser(snap.Client()).ParseArgs([]string{"userd", "--agent"})
   172  	c.Assert(err, IsNil)
   173  	c.Check(rest, DeepEquals, []string{})
   174  	c.Check(strings.ToLower(s.Stdout()), Equals, "exiting on <test signal>.\n")
   175  	c.Check(sigStopCalls, Equals, 1)
   176  }
   177  
   178  func (s *userdSuite) TestSignalNotify(c *C) {
   179  	ch, stop := snap.SignalNotify(syscall.SIGUSR1)
   180  	defer stop()
   181  	go func() {
   182  		myPid := os.Getpid()
   183  		me, err := os.FindProcess(myPid)
   184  		c.Assert(err, IsNil)
   185  		err = me.Signal(syscall.SIGUSR1)
   186  		c.Assert(err, IsNil)
   187  	}()
   188  	select {
   189  	case sig := <-ch:
   190  		c.Assert(sig, Equals, syscall.SIGUSR1)
   191  	case <-time.After(5 * time.Second):
   192  		c.Fatal("signal not received within 5s")
   193  	}
   194  }