github.com/hugh712/snapd@v0.0.0-20200910133618-1a99902bd583/daemon/ucrednet_test.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2015 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package daemon
    21  
    22  import (
    23  	"errors"
    24  	"net"
    25  	"path/filepath"
    26  	sys "syscall"
    27  
    28  	"gopkg.in/check.v1"
    29  )
    30  
    31  type ucrednetSuite struct {
    32  	ucred *sys.Ucred
    33  	err   error
    34  }
    35  
    36  var _ = check.Suite(&ucrednetSuite{})
    37  
    38  func (s *ucrednetSuite) getUcred(fd, level, opt int) (*sys.Ucred, error) {
    39  	return s.ucred, s.err
    40  }
    41  
    42  func (s *ucrednetSuite) SetUpSuite(c *check.C) {
    43  	getUcred = s.getUcred
    44  }
    45  
    46  func (s *ucrednetSuite) TearDownTest(c *check.C) {
    47  	s.ucred = nil
    48  	s.err = nil
    49  }
    50  func (s *ucrednetSuite) TearDownSuite(c *check.C) {
    51  	getUcred = sys.GetsockoptUcred
    52  }
    53  
    54  func (s *ucrednetSuite) TestAcceptConnRemoteAddrString(c *check.C) {
    55  	s.ucred = &sys.Ucred{Pid: 100, Uid: 42}
    56  	d := c.MkDir()
    57  	sock := filepath.Join(d, "sock")
    58  
    59  	l, err := net.Listen("unix", sock)
    60  	c.Assert(err, check.IsNil)
    61  	wl := &ucrednetListener{Listener: l}
    62  
    63  	defer wl.Close()
    64  
    65  	go func() {
    66  		cli, err := net.Dial("unix", sock)
    67  		c.Assert(err, check.IsNil)
    68  		cli.Close()
    69  	}()
    70  
    71  	conn, err := wl.Accept()
    72  	c.Assert(err, check.IsNil)
    73  	defer conn.Close()
    74  
    75  	remoteAddr := conn.RemoteAddr().String()
    76  	c.Check(remoteAddr, check.Matches, "pid=100;uid=42;.*")
    77  	pid, uid, _, err := ucrednetGet(remoteAddr)
    78  	c.Check(pid, check.Equals, int32(100))
    79  	c.Check(uid, check.Equals, uint32(42))
    80  	c.Check(err, check.IsNil)
    81  }
    82  
    83  func (s *ucrednetSuite) TestNonUnix(c *check.C) {
    84  	l, err := net.Listen("tcp", "localhost:0")
    85  	c.Assert(err, check.IsNil)
    86  
    87  	wl := &ucrednetListener{Listener: l}
    88  	defer wl.Close()
    89  
    90  	addr := l.Addr().String()
    91  
    92  	go func() {
    93  		cli, err := net.Dial("tcp", addr)
    94  		c.Assert(err, check.IsNil)
    95  		cli.Close()
    96  	}()
    97  
    98  	conn, err := wl.Accept()
    99  	c.Assert(err, check.IsNil)
   100  	defer conn.Close()
   101  
   102  	remoteAddr := conn.RemoteAddr().String()
   103  	c.Check(remoteAddr, check.Matches, "pid=;uid=;.*")
   104  	pid, uid, _, err := ucrednetGet(remoteAddr)
   105  	c.Check(pid, check.Equals, ucrednetNoProcess)
   106  	c.Check(uid, check.Equals, ucrednetNobody)
   107  	c.Check(err, check.Equals, errNoID)
   108  }
   109  
   110  func (s *ucrednetSuite) TestAcceptErrors(c *check.C) {
   111  	s.ucred = &sys.Ucred{Pid: 100, Uid: 42}
   112  	d := c.MkDir()
   113  	sock := filepath.Join(d, "sock")
   114  
   115  	l, err := net.Listen("unix", sock)
   116  	c.Assert(err, check.IsNil)
   117  	c.Assert(l.Close(), check.IsNil)
   118  
   119  	wl := &ucrednetListener{Listener: l}
   120  
   121  	_, err = wl.Accept()
   122  	c.Assert(err, check.NotNil)
   123  }
   124  
   125  func (s *ucrednetSuite) TestUcredErrors(c *check.C) {
   126  	s.err = errors.New("oopsie")
   127  	d := c.MkDir()
   128  	sock := filepath.Join(d, "sock")
   129  
   130  	l, err := net.Listen("unix", sock)
   131  	c.Assert(err, check.IsNil)
   132  
   133  	wl := &ucrednetListener{Listener: l}
   134  	defer wl.Close()
   135  
   136  	go func() {
   137  		cli, err := net.Dial("unix", sock)
   138  		c.Assert(err, check.IsNil)
   139  		cli.Close()
   140  	}()
   141  
   142  	_, err = wl.Accept()
   143  	c.Assert(err, check.Equals, s.err)
   144  }
   145  
   146  func (s *ucrednetSuite) TestIdempotentClose(c *check.C) {
   147  	s.ucred = &sys.Ucred{Pid: 100, Uid: 42}
   148  	d := c.MkDir()
   149  	sock := filepath.Join(d, "sock")
   150  
   151  	l, err := net.Listen("unix", sock)
   152  	c.Assert(err, check.IsNil)
   153  	wl := &ucrednetListener{Listener: l}
   154  
   155  	c.Assert(wl.Close(), check.IsNil)
   156  	c.Assert(wl.Close(), check.IsNil)
   157  }
   158  
   159  func (s *ucrednetSuite) TestGetNoUid(c *check.C) {
   160  	pid, uid, _, err := ucrednetGet("pid=100;uid=;socket=;")
   161  	c.Check(err, check.Equals, errNoID)
   162  	c.Check(pid, check.Equals, ucrednetNoProcess)
   163  	c.Check(uid, check.Equals, ucrednetNobody)
   164  }
   165  
   166  func (s *ucrednetSuite) TestGetBadUid(c *check.C) {
   167  	pid, uid, _, err := ucrednetGet("pid=100;uid=4294967296;socket=;")
   168  	c.Check(err, check.NotNil)
   169  	c.Check(pid, check.Equals, int32(100))
   170  	c.Check(uid, check.Equals, ucrednetNobody)
   171  }
   172  
   173  func (s *ucrednetSuite) TestGetNonUcrednet(c *check.C) {
   174  	pid, uid, _, err := ucrednetGet("hello")
   175  	c.Check(err, check.Equals, errNoID)
   176  	c.Check(pid, check.Equals, ucrednetNoProcess)
   177  	c.Check(uid, check.Equals, ucrednetNobody)
   178  }
   179  
   180  func (s *ucrednetSuite) TestGetNothing(c *check.C) {
   181  	pid, uid, _, err := ucrednetGet("")
   182  	c.Check(err, check.Equals, errNoID)
   183  	c.Check(pid, check.Equals, ucrednetNoProcess)
   184  	c.Check(uid, check.Equals, ucrednetNobody)
   185  }
   186  
   187  func (s *ucrednetSuite) TestGet(c *check.C) {
   188  	pid, uid, socket, err := ucrednetGet("pid=100;uid=42;socket=/run/snap.socket;")
   189  	c.Check(err, check.IsNil)
   190  	c.Check(pid, check.Equals, int32(100))
   191  	c.Check(uid, check.Equals, uint32(42))
   192  	c.Check(socket, check.Equals, "/run/snap.socket")
   193  }
   194  
   195  func (s *ucrednetSuite) TestGetSneak(c *check.C) {
   196  	pid, uid, socket, err := ucrednetGet("pid=100;uid=42;socket=/run/snap.socket;pid=0;uid=0;socket=/tmp/my.socket")
   197  	c.Check(err, check.Equals, errNoID)
   198  	c.Check(pid, check.Equals, ucrednetNoProcess)
   199  	c.Check(uid, check.Equals, ucrednetNobody)
   200  	c.Check(socket, check.Equals, "")
   201  }