github.com/rigado/snapd@v2.42.5-go-mod+incompatible/cmd/snap/cmd_userd_test.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2016-2019 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package main_test
    21  
    22  import (
    23  	"fmt"
    24  	"net"
    25  	"net/http"
    26  	"os"
    27  	"strings"
    28  	"syscall"
    29  	"time"
    30  
    31  	. "gopkg.in/check.v1"
    32  
    33  	snap "github.com/snapcore/snapd/cmd/snap"
    34  	"github.com/snapcore/snapd/dirs"
    35  	"github.com/snapcore/snapd/logger"
    36  	"github.com/snapcore/snapd/osutil"
    37  	"github.com/snapcore/snapd/testutil"
    38  )
    39  
    40  type userdSuite struct {
    41  	BaseSnapSuite
    42  	testutil.DBusTest
    43  
    44  	agentSocketPath string
    45  }
    46  
    47  var _ = Suite(&userdSuite{})
    48  
    49  func (s *userdSuite) SetUpTest(c *C) {
    50  	s.BaseSnapSuite.SetUpTest(c)
    51  	s.DBusTest.SetUpTest(c)
    52  
    53  	_, restore := logger.MockLogger()
    54  	s.AddCleanup(restore)
    55  
    56  	xdgRuntimeDir := fmt.Sprintf("%s/%d", dirs.XdgRuntimeDirBase, os.Getuid())
    57  	c.Assert(os.MkdirAll(xdgRuntimeDir, 0700), IsNil)
    58  	s.agentSocketPath = fmt.Sprintf("%s/snapd-session-agent.socket", xdgRuntimeDir)
    59  }
    60  
    61  func (s *userdSuite) TearDownTest(c *C) {
    62  	s.BaseSnapSuite.TearDownTest(c)
    63  	s.DBusTest.TearDownTest(c)
    64  }
    65  
    66  func (s *userdSuite) TestUserdBadCommandline(c *C) {
    67  	_, err := snap.Parser(snap.Client()).ParseArgs([]string{"userd", "extra-arg"})
    68  	c.Assert(err, ErrorMatches, "too many arguments for command")
    69  }
    70  
    71  type mockSignal struct{}
    72  
    73  func (m *mockSignal) String() string {
    74  	return "<test signal>"
    75  }
    76  
    77  func (m *mockSignal) Signal() {}
    78  
    79  func (s *userdSuite) TestUserdDBus(c *C) {
    80  	sigCh := make(chan os.Signal, 1)
    81  	sigStopCalls := 0
    82  
    83  	restore := snap.MockSignalNotify(func(sig ...os.Signal) (chan os.Signal, func()) {
    84  		c.Assert(sig, DeepEquals, []os.Signal{syscall.SIGINT, syscall.SIGTERM})
    85  		return sigCh, func() { sigStopCalls++ }
    86  	})
    87  	defer restore()
    88  
    89  	go func() {
    90  		myPid := os.Getpid()
    91  
    92  		defer func() {
    93  			sigCh <- &mockSignal{}
    94  		}()
    95  
    96  		names := map[string]bool{
    97  			"io.snapcraft.Launcher": false,
    98  			"io.snapcraft.Settings": false,
    99  		}
   100  		for i := 0; i < 1000; i++ {
   101  			seenCount := 0
   102  			for name, seen := range names {
   103  				if seen {
   104  					seenCount++
   105  					continue
   106  				}
   107  				pid, err := testutil.DBusGetConnectionUnixProcessID(s.SessionBus, name)
   108  				c.Logf("name: %v pid: %v err: %v", name, pid, err)
   109  				if pid == myPid {
   110  					names[name] = true
   111  					seenCount++
   112  				}
   113  			}
   114  			if seenCount == len(names) {
   115  				return
   116  			}
   117  			time.Sleep(10 * time.Millisecond)
   118  		}
   119  		c.Fatalf("not all names have appeared on the bus: %v", names)
   120  	}()
   121  
   122  	rest, err := snap.Parser(snap.Client()).ParseArgs([]string{"userd"})
   123  	c.Assert(err, IsNil)
   124  	c.Check(rest, DeepEquals, []string{})
   125  	c.Check(strings.ToLower(s.Stdout()), Equals, "exiting on <test signal>.\n")
   126  	c.Check(sigStopCalls, Equals, 1)
   127  }
   128  
   129  func (s *userdSuite) makeAgentClient() *http.Client {
   130  	transport := &http.Transport{
   131  		Dial: func(_, _ string) (net.Conn, error) {
   132  			return net.Dial("unix", s.agentSocketPath)
   133  		},
   134  		DisableKeepAlives: true,
   135  	}
   136  	return &http.Client{Transport: transport}
   137  }
   138  
   139  func (s *userdSuite) TestSessionAgentSocket(c *C) {
   140  	sigCh := make(chan os.Signal, 1)
   141  	sigStopCalls := 0
   142  
   143  	restore := snap.MockSignalNotify(func(sig ...os.Signal) (chan os.Signal, func()) {
   144  		c.Assert(sig, DeepEquals, []os.Signal{syscall.SIGINT, syscall.SIGTERM})
   145  		return sigCh, func() { sigStopCalls++ }
   146  	})
   147  	defer restore()
   148  
   149  	go func() {
   150  		defer func() {
   151  			sigCh <- &mockSignal{}
   152  		}()
   153  
   154  		// Wait for command to create socket file
   155  		for i := 0; i < 1000; i++ {
   156  			if osutil.FileExists(s.agentSocketPath) {
   157  				break
   158  			}
   159  			time.Sleep(10 * time.Millisecond)
   160  		}
   161  
   162  		// Check that agent functions
   163  		client := s.makeAgentClient()
   164  		response, err := client.Get("http://localhost/v1/session-info")
   165  		c.Assert(err, IsNil)
   166  		defer response.Body.Close()
   167  		c.Check(response.StatusCode, Equals, 200)
   168  	}()
   169  
   170  	rest, err := snap.Parser(snap.Client()).ParseArgs([]string{"userd", "--agent"})
   171  	c.Assert(err, IsNil)
   172  	c.Check(rest, DeepEquals, []string{})
   173  	c.Check(strings.ToLower(s.Stdout()), Equals, "exiting on <test signal>.\n")
   174  	c.Check(sigStopCalls, Equals, 1)
   175  }
   176  
   177  func (s *userdSuite) TestSignalNotify(c *C) {
   178  	ch, stop := snap.SignalNotify(syscall.SIGUSR1)
   179  	defer stop()
   180  	go func() {
   181  		myPid := os.Getpid()
   182  		me, err := os.FindProcess(myPid)
   183  		c.Assert(err, IsNil)
   184  		err = me.Signal(syscall.SIGUSR1)
   185  		c.Assert(err, IsNil)
   186  	}()
   187  	select {
   188  	case sig := <-ch:
   189  		c.Assert(sig, Equals, syscall.SIGUSR1)
   190  	case <-time.After(5 * time.Second):
   191  		c.Fatal("signal not received within 5s")
   192  	}
   193  }