github.com/rigado/snapd@v2.42.5-go-mod+incompatible/daemon/ucrednet_test.go (about) 1 // -*- Mode: Go; indent-tabs-mode: t -*- 2 3 /* 4 * Copyright (C) 2015 Canonical Ltd 5 * 6 * This program is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU General Public License version 3 as 8 * published by the Free Software Foundation. 9 * 10 * This program is distributed in the hope that it will be useful, 11 * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 * GNU General Public License for more details. 14 * 15 * You should have received a copy of the GNU General Public License 16 * along with this program. If not, see <http://www.gnu.org/licenses/>. 17 * 18 */ 19 20 package daemon 21 22 import ( 23 "errors" 24 "net" 25 "path/filepath" 26 sys "syscall" 27 28 "gopkg.in/check.v1" 29 ) 30 31 type ucrednetSuite struct { 32 ucred *sys.Ucred 33 err error 34 } 35 36 var _ = check.Suite(&ucrednetSuite{}) 37 38 func (s *ucrednetSuite) getUcred(fd, level, opt int) (*sys.Ucred, error) { 39 return s.ucred, s.err 40 } 41 42 func (s *ucrednetSuite) SetUpSuite(c *check.C) { 43 getUcred = s.getUcred 44 } 45 46 func (s *ucrednetSuite) TearDownTest(c *check.C) { 47 s.ucred = nil 48 s.err = nil 49 } 50 func (s *ucrednetSuite) TearDownSuite(c *check.C) { 51 getUcred = sys.GetsockoptUcred 52 } 53 54 func (s *ucrednetSuite) TestAcceptConnRemoteAddrString(c *check.C) { 55 s.ucred = &sys.Ucred{Pid: 100, Uid: 42} 56 d := c.MkDir() 57 sock := filepath.Join(d, "sock") 58 59 l, err := net.Listen("unix", sock) 60 c.Assert(err, check.IsNil) 61 wl := &ucrednetListener{Listener: l} 62 63 defer wl.Close() 64 65 go func() { 66 cli, err := net.Dial("unix", sock) 67 c.Assert(err, check.IsNil) 68 cli.Close() 69 }() 70 71 conn, err := wl.Accept() 72 c.Assert(err, check.IsNil) 73 defer conn.Close() 74 75 remoteAddr := conn.RemoteAddr().String() 76 c.Check(remoteAddr, check.Matches, "pid=100;uid=42;.*") 77 pid, uid, _, err := ucrednetGet(remoteAddr) 78 c.Check(pid, check.Equals, int32(100)) 79 c.Check(uid, check.Equals, uint32(42)) 80 c.Check(err, check.IsNil) 81 } 82 83 func (s *ucrednetSuite) TestNonUnix(c *check.C) { 84 l, err := net.Listen("tcp", "localhost:0") 85 c.Assert(err, check.IsNil) 86 87 wl := &ucrednetListener{Listener: l} 88 defer wl.Close() 89 90 addr := l.Addr().String() 91 92 go func() { 93 cli, err := net.Dial("tcp", addr) 94 c.Assert(err, check.IsNil) 95 cli.Close() 96 }() 97 98 conn, err := wl.Accept() 99 c.Assert(err, check.IsNil) 100 defer conn.Close() 101 102 remoteAddr := conn.RemoteAddr().String() 103 c.Check(remoteAddr, check.Matches, "pid=;uid=;.*") 104 pid, uid, _, err := ucrednetGet(remoteAddr) 105 c.Check(pid, check.Equals, ucrednetNoProcess) 106 c.Check(uid, check.Equals, ucrednetNobody) 107 c.Check(err, check.Equals, errNoID) 108 } 109 110 func (s *ucrednetSuite) TestAcceptErrors(c *check.C) { 111 s.ucred = &sys.Ucred{Pid: 100, Uid: 42} 112 d := c.MkDir() 113 sock := filepath.Join(d, "sock") 114 115 l, err := net.Listen("unix", sock) 116 c.Assert(err, check.IsNil) 117 c.Assert(l.Close(), check.IsNil) 118 119 wl := &ucrednetListener{Listener: l} 120 121 _, err = wl.Accept() 122 c.Assert(err, check.NotNil) 123 } 124 125 func (s *ucrednetSuite) TestUcredErrors(c *check.C) { 126 s.err = errors.New("oopsie") 127 d := c.MkDir() 128 sock := filepath.Join(d, "sock") 129 130 l, err := net.Listen("unix", sock) 131 c.Assert(err, check.IsNil) 132 133 wl := &ucrednetListener{Listener: l} 134 defer wl.Close() 135 136 go func() { 137 cli, err := net.Dial("unix", sock) 138 c.Assert(err, check.IsNil) 139 cli.Close() 140 }() 141 142 _, err = wl.Accept() 143 c.Assert(err, check.Equals, s.err) 144 } 145 146 func (s *ucrednetSuite) TestIdempotentClose(c *check.C) { 147 s.ucred = &sys.Ucred{Pid: 100, Uid: 42} 148 d := c.MkDir() 149 sock := filepath.Join(d, "sock") 150 151 l, err := net.Listen("unix", sock) 152 c.Assert(err, check.IsNil) 153 wl := &ucrednetListener{Listener: l} 154 155 c.Assert(wl.Close(), check.IsNil) 156 c.Assert(wl.Close(), check.IsNil) 157 } 158 159 func (s *ucrednetSuite) TestGetNoUid(c *check.C) { 160 pid, uid, _, err := ucrednetGet("pid=100;uid=;socket=;") 161 c.Check(err, check.Equals, errNoID) 162 c.Check(pid, check.Equals, ucrednetNoProcess) 163 c.Check(uid, check.Equals, ucrednetNobody) 164 } 165 166 func (s *ucrednetSuite) TestGetBadUid(c *check.C) { 167 pid, uid, _, err := ucrednetGet("pid=100;uid=4294967296;socket=;") 168 c.Check(err, check.NotNil) 169 c.Check(pid, check.Equals, int32(100)) 170 c.Check(uid, check.Equals, ucrednetNobody) 171 } 172 173 func (s *ucrednetSuite) TestGetNonUcrednet(c *check.C) { 174 pid, uid, _, err := ucrednetGet("hello") 175 c.Check(err, check.Equals, errNoID) 176 c.Check(pid, check.Equals, ucrednetNoProcess) 177 c.Check(uid, check.Equals, ucrednetNobody) 178 } 179 180 func (s *ucrednetSuite) TestGetNothing(c *check.C) { 181 pid, uid, _, err := ucrednetGet("") 182 c.Check(err, check.Equals, errNoID) 183 c.Check(pid, check.Equals, ucrednetNoProcess) 184 c.Check(uid, check.Equals, ucrednetNobody) 185 } 186 187 func (s *ucrednetSuite) TestGet(c *check.C) { 188 pid, uid, socket, err := ucrednetGet("pid=100;uid=42;socket=/run/snap.socket;") 189 c.Check(err, check.IsNil) 190 c.Check(pid, check.Equals, int32(100)) 191 c.Check(uid, check.Equals, uint32(42)) 192 c.Check(socket, check.Equals, "/run/snap.socket") 193 } 194 195 func (s *ucrednetSuite) TestGetSneak(c *check.C) { 196 pid, uid, socket, err := ucrednetGet("pid=100;uid=42;socket=/run/snap.socket;pid=0;uid=0;socket=/tmp/my.socket") 197 c.Check(err, check.Equals, errNoID) 198 c.Check(pid, check.Equals, ucrednetNoProcess) 199 c.Check(uid, check.Equals, ucrednetNobody) 200 c.Check(socket, check.Equals, "") 201 }