github.com/anonymouse64/snapd@v0.0.0-20210824153203-04c4c42d842d/daemon/ucrednet_test.go (about) 1 // -*- Mode: Go; indent-tabs-mode: t -*- 2 3 /* 4 * Copyright (C) 2015 Canonical Ltd 5 * 6 * This program is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU General Public License version 3 as 8 * published by the Free Software Foundation. 9 * 10 * This program is distributed in the hope that it will be useful, 11 * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 * GNU General Public License for more details. 14 * 15 * You should have received a copy of the GNU General Public License 16 * along with this program. If not, see <http://www.gnu.org/licenses/>. 17 * 18 */ 19 20 package daemon 21 22 import ( 23 "errors" 24 "net" 25 "path/filepath" 26 sys "syscall" 27 28 "gopkg.in/check.v1" 29 ) 30 31 type ucrednetSuite struct { 32 ucred *sys.Ucred 33 err error 34 } 35 36 var _ = check.Suite(&ucrednetSuite{}) 37 38 func (s *ucrednetSuite) getUcred(fd, level, opt int) (*sys.Ucred, error) { 39 return s.ucred, s.err 40 } 41 42 func (s *ucrednetSuite) SetUpSuite(c *check.C) { 43 getUcred = s.getUcred 44 } 45 46 func (s *ucrednetSuite) TearDownTest(c *check.C) { 47 s.ucred = nil 48 s.err = nil 49 } 50 func (s *ucrednetSuite) TearDownSuite(c *check.C) { 51 getUcred = sys.GetsockoptUcred 52 } 53 54 func (s *ucrednetSuite) TestAcceptConnRemoteAddrString(c *check.C) { 55 s.ucred = &sys.Ucred{Pid: 100, Uid: 42} 56 d := c.MkDir() 57 sock := filepath.Join(d, "sock") 58 59 l, err := net.Listen("unix", sock) 60 c.Assert(err, check.IsNil) 61 wl := &ucrednetListener{Listener: l} 62 63 defer wl.Close() 64 65 go func() { 66 cli, err := net.Dial("unix", sock) 67 c.Assert(err, check.IsNil) 68 cli.Close() 69 }() 70 71 conn, err := wl.Accept() 72 c.Assert(err, check.IsNil) 73 defer conn.Close() 74 75 remoteAddr := conn.RemoteAddr().String() 76 c.Check(remoteAddr, check.Matches, "pid=100;uid=42;.*") 77 u, err := ucrednetGet(remoteAddr) 78 c.Assert(err, check.IsNil) 79 c.Check(u.Pid, check.Equals, int32(100)) 80 c.Check(u.Uid, check.Equals, uint32(42)) 81 } 82 83 func (s *ucrednetSuite) TestNonUnix(c *check.C) { 84 l, err := net.Listen("tcp", "localhost:0") 85 c.Assert(err, check.IsNil) 86 87 wl := &ucrednetListener{Listener: l} 88 defer wl.Close() 89 90 addr := l.Addr().String() 91 92 go func() { 93 cli, err := net.Dial("tcp", addr) 94 c.Assert(err, check.IsNil) 95 cli.Close() 96 }() 97 98 conn, err := wl.Accept() 99 c.Assert(err, check.IsNil) 100 defer conn.Close() 101 102 remoteAddr := conn.RemoteAddr().String() 103 c.Check(remoteAddr, check.Matches, "pid=;uid=;.*") 104 u, err := ucrednetGet(remoteAddr) 105 c.Check(u, check.IsNil) 106 c.Check(err, check.Equals, errNoID) 107 } 108 109 func (s *ucrednetSuite) TestAcceptErrors(c *check.C) { 110 s.ucred = &sys.Ucred{Pid: 100, Uid: 42} 111 d := c.MkDir() 112 sock := filepath.Join(d, "sock") 113 114 l, err := net.Listen("unix", sock) 115 c.Assert(err, check.IsNil) 116 c.Assert(l.Close(), check.IsNil) 117 118 wl := &ucrednetListener{Listener: l} 119 120 _, err = wl.Accept() 121 c.Assert(err, check.NotNil) 122 } 123 124 func (s *ucrednetSuite) TestUcredErrors(c *check.C) { 125 s.err = errors.New("oopsie") 126 d := c.MkDir() 127 sock := filepath.Join(d, "sock") 128 129 l, err := net.Listen("unix", sock) 130 c.Assert(err, check.IsNil) 131 132 wl := &ucrednetListener{Listener: l} 133 defer wl.Close() 134 135 go func() { 136 cli, err := net.Dial("unix", sock) 137 c.Assert(err, check.IsNil) 138 cli.Close() 139 }() 140 141 _, err = wl.Accept() 142 c.Assert(err, check.Equals, s.err) 143 } 144 145 func (s *ucrednetSuite) TestIdempotentClose(c *check.C) { 146 s.ucred = &sys.Ucred{Pid: 100, Uid: 42} 147 d := c.MkDir() 148 sock := filepath.Join(d, "sock") 149 150 l, err := net.Listen("unix", sock) 151 c.Assert(err, check.IsNil) 152 wl := &ucrednetListener{Listener: l} 153 154 c.Assert(wl.Close(), check.IsNil) 155 c.Assert(wl.Close(), check.IsNil) 156 } 157 158 func (s *ucrednetSuite) TestGetNoUid(c *check.C) { 159 u, err := ucrednetGet("pid=100;uid=;socket=;") 160 c.Check(err, check.Equals, errNoID) 161 c.Check(u, check.IsNil) 162 } 163 164 func (s *ucrednetSuite) TestGetBadUid(c *check.C) { 165 u, err := ucrednetGet("pid=100;uid=4294967296;socket=;") 166 c.Check(err, check.Equals, errNoID) 167 c.Check(u, check.IsNil) 168 } 169 170 func (s *ucrednetSuite) TestGetNonUcrednet(c *check.C) { 171 u, err := ucrednetGet("hello") 172 c.Check(err, check.Equals, errNoID) 173 c.Check(u, check.IsNil) 174 } 175 176 func (s *ucrednetSuite) TestGetNothing(c *check.C) { 177 u, err := ucrednetGet("") 178 c.Check(err, check.Equals, errNoID) 179 c.Check(u, check.IsNil) 180 } 181 182 func (s *ucrednetSuite) TestGet(c *check.C) { 183 u, err := ucrednetGet("pid=100;uid=42;socket=/run/snap.socket;") 184 c.Assert(err, check.IsNil) 185 c.Check(u.Pid, check.Equals, int32(100)) 186 c.Check(u.Uid, check.Equals, uint32(42)) 187 c.Check(u.Socket, check.Equals, "/run/snap.socket") 188 } 189 190 func (s *ucrednetSuite) TestGetSneak(c *check.C) { 191 u, err := ucrednetGet("pid=100;uid=42;socket=/run/snap.socket;pid=0;uid=0;socket=/tmp/my.socket") 192 c.Check(err, check.Equals, errNoID) 193 c.Check(u, check.IsNil) 194 }