github.com/david-imola/snapd@v0.0.0-20210611180407-2de8ddeece6d/daemon/ucrednet_test.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2015 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package daemon
    21  
    22  import (
    23  	"errors"
    24  	"net"
    25  	"path/filepath"
    26  	sys "syscall"
    27  
    28  	"gopkg.in/check.v1"
    29  )
    30  
    31  type ucrednetSuite struct {
    32  	ucred *sys.Ucred
    33  	err   error
    34  }
    35  
    36  var _ = check.Suite(&ucrednetSuite{})
    37  
    38  func (s *ucrednetSuite) getUcred(fd, level, opt int) (*sys.Ucred, error) {
    39  	return s.ucred, s.err
    40  }
    41  
    42  func (s *ucrednetSuite) SetUpSuite(c *check.C) {
    43  	getUcred = s.getUcred
    44  }
    45  
    46  func (s *ucrednetSuite) TearDownTest(c *check.C) {
    47  	s.ucred = nil
    48  	s.err = nil
    49  }
    50  func (s *ucrednetSuite) TearDownSuite(c *check.C) {
    51  	getUcred = sys.GetsockoptUcred
    52  }
    53  
    54  func (s *ucrednetSuite) TestAcceptConnRemoteAddrString(c *check.C) {
    55  	s.ucred = &sys.Ucred{Pid: 100, Uid: 42}
    56  	d := c.MkDir()
    57  	sock := filepath.Join(d, "sock")
    58  
    59  	l, err := net.Listen("unix", sock)
    60  	c.Assert(err, check.IsNil)
    61  	wl := &ucrednetListener{Listener: l}
    62  
    63  	defer wl.Close()
    64  
    65  	go func() {
    66  		cli, err := net.Dial("unix", sock)
    67  		c.Assert(err, check.IsNil)
    68  		cli.Close()
    69  	}()
    70  
    71  	conn, err := wl.Accept()
    72  	c.Assert(err, check.IsNil)
    73  	defer conn.Close()
    74  
    75  	remoteAddr := conn.RemoteAddr().String()
    76  	c.Check(remoteAddr, check.Matches, "pid=100;uid=42;.*")
    77  	u, err := ucrednetGet(remoteAddr)
    78  	c.Assert(err, check.IsNil)
    79  	c.Check(u.Pid, check.Equals, int32(100))
    80  	c.Check(u.Uid, check.Equals, uint32(42))
    81  }
    82  
    83  func (s *ucrednetSuite) TestNonUnix(c *check.C) {
    84  	l, err := net.Listen("tcp", "localhost:0")
    85  	c.Assert(err, check.IsNil)
    86  
    87  	wl := &ucrednetListener{Listener: l}
    88  	defer wl.Close()
    89  
    90  	addr := l.Addr().String()
    91  
    92  	go func() {
    93  		cli, err := net.Dial("tcp", addr)
    94  		c.Assert(err, check.IsNil)
    95  		cli.Close()
    96  	}()
    97  
    98  	conn, err := wl.Accept()
    99  	c.Assert(err, check.IsNil)
   100  	defer conn.Close()
   101  
   102  	remoteAddr := conn.RemoteAddr().String()
   103  	c.Check(remoteAddr, check.Matches, "pid=;uid=;.*")
   104  	u, err := ucrednetGet(remoteAddr)
   105  	c.Check(u, check.IsNil)
   106  	c.Check(err, check.Equals, errNoID)
   107  }
   108  
   109  func (s *ucrednetSuite) TestAcceptErrors(c *check.C) {
   110  	s.ucred = &sys.Ucred{Pid: 100, Uid: 42}
   111  	d := c.MkDir()
   112  	sock := filepath.Join(d, "sock")
   113  
   114  	l, err := net.Listen("unix", sock)
   115  	c.Assert(err, check.IsNil)
   116  	c.Assert(l.Close(), check.IsNil)
   117  
   118  	wl := &ucrednetListener{Listener: l}
   119  
   120  	_, err = wl.Accept()
   121  	c.Assert(err, check.NotNil)
   122  }
   123  
   124  func (s *ucrednetSuite) TestUcredErrors(c *check.C) {
   125  	s.err = errors.New("oopsie")
   126  	d := c.MkDir()
   127  	sock := filepath.Join(d, "sock")
   128  
   129  	l, err := net.Listen("unix", sock)
   130  	c.Assert(err, check.IsNil)
   131  
   132  	wl := &ucrednetListener{Listener: l}
   133  	defer wl.Close()
   134  
   135  	go func() {
   136  		cli, err := net.Dial("unix", sock)
   137  		c.Assert(err, check.IsNil)
   138  		cli.Close()
   139  	}()
   140  
   141  	_, err = wl.Accept()
   142  	c.Assert(err, check.Equals, s.err)
   143  }
   144  
   145  func (s *ucrednetSuite) TestIdempotentClose(c *check.C) {
   146  	s.ucred = &sys.Ucred{Pid: 100, Uid: 42}
   147  	d := c.MkDir()
   148  	sock := filepath.Join(d, "sock")
   149  
   150  	l, err := net.Listen("unix", sock)
   151  	c.Assert(err, check.IsNil)
   152  	wl := &ucrednetListener{Listener: l}
   153  
   154  	c.Assert(wl.Close(), check.IsNil)
   155  	c.Assert(wl.Close(), check.IsNil)
   156  }
   157  
   158  func (s *ucrednetSuite) TestGetNoUid(c *check.C) {
   159  	u, err := ucrednetGet("pid=100;uid=;socket=;")
   160  	c.Check(err, check.Equals, errNoID)
   161  	c.Check(u, check.IsNil)
   162  }
   163  
   164  func (s *ucrednetSuite) TestGetBadUid(c *check.C) {
   165  	u, err := ucrednetGet("pid=100;uid=4294967296;socket=;")
   166  	c.Check(err, check.Equals, errNoID)
   167  	c.Check(u, check.IsNil)
   168  }
   169  
   170  func (s *ucrednetSuite) TestGetNonUcrednet(c *check.C) {
   171  	u, err := ucrednetGet("hello")
   172  	c.Check(err, check.Equals, errNoID)
   173  	c.Check(u, check.IsNil)
   174  }
   175  
   176  func (s *ucrednetSuite) TestGetNothing(c *check.C) {
   177  	u, err := ucrednetGet("")
   178  	c.Check(err, check.Equals, errNoID)
   179  	c.Check(u, check.IsNil)
   180  }
   181  
   182  func (s *ucrednetSuite) TestGet(c *check.C) {
   183  	u, err := ucrednetGet("pid=100;uid=42;socket=/run/snap.socket;")
   184  	c.Assert(err, check.IsNil)
   185  	c.Check(u.Pid, check.Equals, int32(100))
   186  	c.Check(u.Uid, check.Equals, uint32(42))
   187  	c.Check(u.Socket, check.Equals, "/run/snap.socket")
   188  }
   189  
   190  func (s *ucrednetSuite) TestGetSneak(c *check.C) {
   191  	u, err := ucrednetGet("pid=100;uid=42;socket=/run/snap.socket;pid=0;uid=0;socket=/tmp/my.socket")
   192  	c.Check(err, check.Equals, errNoID)
   193  	c.Check(u, check.IsNil)
   194  }