github.com/neonyo/sys@v0.0.0-20230720094341-b1ee14be3ce8/unix/creds_test.go (about) 1 // Copyright 2012 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 //go:build linux 6 // +build linux 7 8 package unix_test 9 10 import ( 11 "bytes" 12 "errors" 13 "net" 14 "os" 15 "testing" 16 "time" 17 18 "golang.org/x/sys/unix" 19 ) 20 21 // TestSCMCredentials tests the sending and receiving of credentials 22 // (PID, UID, GID) in an ancillary message between two UNIX 23 // sockets. The SO_PASSCRED socket option is enabled on the sending 24 // socket for this to work. 25 func TestSCMCredentials(t *testing.T) { 26 socketTypeTests := []struct { 27 socketType int 28 dataLen int 29 }{ 30 { 31 unix.SOCK_STREAM, 32 1, 33 }, { 34 unix.SOCK_DGRAM, 35 0, 36 }, 37 } 38 39 for _, tt := range socketTypeTests { 40 fds, err := unix.Socketpair(unix.AF_LOCAL, tt.socketType, 0) 41 if err != nil { 42 t.Fatalf("Socketpair: %v", err) 43 } 44 45 err = unix.SetsockoptInt(fds[0], unix.SOL_SOCKET, unix.SO_PASSCRED, 1) 46 if err != nil { 47 unix.Close(fds[0]) 48 unix.Close(fds[1]) 49 t.Fatalf("SetsockoptInt: %v", err) 50 } 51 52 srvFile := os.NewFile(uintptr(fds[0]), "server") 53 cliFile := os.NewFile(uintptr(fds[1]), "client") 54 defer srvFile.Close() 55 defer cliFile.Close() 56 57 srv, err := net.FileConn(srvFile) 58 if err != nil { 59 t.Errorf("FileConn: %v", err) 60 return 61 } 62 defer srv.Close() 63 64 cli, err := net.FileConn(cliFile) 65 if err != nil { 66 t.Errorf("FileConn: %v", err) 67 return 68 } 69 defer cli.Close() 70 71 var ucred unix.Ucred 72 ucred.Pid = int32(os.Getpid()) 73 ucred.Uid = uint32(os.Getuid()) 74 ucred.Gid = uint32(os.Getgid()) 75 oob := unix.UnixCredentials(&ucred) 76 77 // On SOCK_STREAM, this is internally going to send a dummy byte 78 n, oobn, err := cli.(*net.UnixConn).WriteMsgUnix(nil, oob, nil) 79 if err != nil { 80 t.Fatalf("WriteMsgUnix: %v", err) 81 } 82 if n != 0 { 83 t.Fatalf("WriteMsgUnix n = %d, want 0", n) 84 } 85 if oobn != len(oob) { 86 t.Fatalf("WriteMsgUnix oobn = %d, want %d", oobn, len(oob)) 87 } 88 89 oob2 := make([]byte, 10*len(oob)) 90 n, oobn2, flags, _, err := srv.(*net.UnixConn).ReadMsgUnix(nil, oob2) 91 if err != nil { 92 t.Fatalf("ReadMsgUnix: %v", err) 93 } 94 if flags != 0 && flags != unix.MSG_CMSG_CLOEXEC { 95 t.Fatalf("ReadMsgUnix flags = %#x, want 0 or %#x (MSG_CMSG_CLOEXEC)", flags, unix.MSG_CMSG_CLOEXEC) 96 } 97 if n != tt.dataLen { 98 t.Fatalf("ReadMsgUnix n = %d, want %d", n, tt.dataLen) 99 } 100 if oobn2 != oobn { 101 // without SO_PASSCRED set on the socket, ReadMsgUnix will 102 // return zero oob bytes 103 t.Fatalf("ReadMsgUnix oobn = %d, want %d", oobn2, oobn) 104 } 105 oob2 = oob2[:oobn2] 106 if !bytes.Equal(oob, oob2) { 107 t.Fatal("ReadMsgUnix oob bytes don't match") 108 } 109 110 scm, err := unix.ParseSocketControlMessage(oob2) 111 if err != nil { 112 t.Fatalf("ParseSocketControlMessage: %v", err) 113 } 114 newUcred, err := unix.ParseUnixCredentials(&scm[0]) 115 if err != nil { 116 t.Fatalf("ParseUnixCredentials: %v", err) 117 } 118 if *newUcred != ucred { 119 t.Fatalf("ParseUnixCredentials = %+v, want %+v", newUcred, ucred) 120 } 121 } 122 } 123 124 func TestPktInfo(t *testing.T) { 125 testcases := []struct { 126 network string 127 address *net.UDPAddr 128 }{ 129 {"udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")}}, 130 {"udp6", &net.UDPAddr{IP: net.ParseIP("::1")}}, 131 } 132 for _, test := range testcases { 133 t.Run(test.network, func(t *testing.T) { 134 conn, err := net.ListenUDP(test.network, test.address) 135 if errors.Is(err, unix.EADDRNOTAVAIL) || errors.Is(err, unix.EAFNOSUPPORT) { 136 t.Skipf("%v is not available", test.address) 137 } 138 if err != nil { 139 t.Fatal("Listen:", err) 140 } 141 defer conn.Close() 142 143 var pktInfo []byte 144 var src net.IP 145 switch test.network { 146 case "udp4": 147 var info4 unix.Inet4Pktinfo 148 src = net.ParseIP("127.0.0.2").To4() 149 copy(info4.Spec_dst[:], src) 150 pktInfo = unix.PktInfo4(&info4) 151 152 case "udp6": 153 var info6 unix.Inet6Pktinfo 154 src = net.ParseIP("2001:0DB8::1") 155 copy(info6.Addr[:], src) 156 pktInfo = unix.PktInfo6(&info6) 157 158 raw, err := conn.SyscallConn() 159 if err != nil { 160 t.Fatal("SyscallConn:", err) 161 } 162 var opErr error 163 err = raw.Control(func(fd uintptr) { 164 opErr = unix.SetsockoptInt(int(fd), unix.SOL_IPV6, unix.IPV6_FREEBIND, 1) 165 }) 166 if err != nil { 167 t.Fatal("Control:", err) 168 } 169 if errors.Is(opErr, unix.ENOPROTOOPT) { 170 // Happens on android-amd64-emu, maybe Android has disabled 171 // IPV6_FREEBIND? 172 t.Skip("IPV6_FREEBIND not supported") 173 } 174 if opErr != nil { 175 t.Fatal("Can't enable IPV6_FREEBIND:", opErr) 176 } 177 } 178 179 msg := []byte{1} 180 addr := conn.LocalAddr().(*net.UDPAddr) 181 _, _, err = conn.WriteMsgUDP(msg, pktInfo, addr) 182 if err != nil { 183 t.Fatal("WriteMsgUDP:", err) 184 } 185 186 conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) 187 _, _, _, remote, err := conn.ReadMsgUDP(msg, nil) 188 if err != nil { 189 t.Fatal("ReadMsgUDP:", err) 190 } 191 192 if !remote.IP.Equal(src) { 193 t.Errorf("Got packet from %v, want %v", remote.IP, src) 194 } 195 }) 196 } 197 } 198 199 func TestParseOrigDstAddr(t *testing.T) { 200 testcases := []struct { 201 network string 202 address *net.UDPAddr 203 }{ 204 {"udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}}, 205 {"udp6", &net.UDPAddr{IP: net.IPv6loopback}}, 206 } 207 208 for _, test := range testcases { 209 t.Run(test.network, func(t *testing.T) { 210 conn, err := net.ListenUDP(test.network, test.address) 211 if errors.Is(err, unix.EADDRNOTAVAIL) || errors.Is(err, unix.EAFNOSUPPORT) { 212 t.Skipf("%v is not available", test.address) 213 } 214 if err != nil { 215 t.Fatal("Listen:", err) 216 } 217 defer conn.Close() 218 219 raw, err := conn.SyscallConn() 220 if err != nil { 221 t.Fatal("SyscallConn:", err) 222 } 223 224 var opErr error 225 err = raw.Control(func(fd uintptr) { 226 switch test.network { 227 case "udp4": 228 opErr = unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_RECVORIGDSTADDR, 1) 229 case "udp6": 230 opErr = unix.SetsockoptInt(int(fd), unix.SOL_IPV6, unix.IPV6_RECVORIGDSTADDR, 1) 231 } 232 }) 233 if err != nil { 234 t.Fatal("Control:", err) 235 } 236 if opErr != nil { 237 t.Fatal("Can't enable RECVORIGDSTADDR:", err) 238 } 239 240 msg := []byte{1} 241 addr := conn.LocalAddr().(*net.UDPAddr) 242 _, err = conn.WriteToUDP(msg, addr) 243 if err != nil { 244 t.Fatal("WriteToUDP:", err) 245 } 246 247 conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) 248 oob := make([]byte, unix.CmsgSpace(unix.SizeofSockaddrInet6)) 249 _, oobn, _, _, err := conn.ReadMsgUDP(msg, oob) 250 if err != nil { 251 t.Fatal("ReadMsgUDP:", err) 252 } 253 254 scms, err := unix.ParseSocketControlMessage(oob[:oobn]) 255 if err != nil { 256 t.Fatal("ParseSocketControlMessage:", err) 257 } 258 259 sa, err := unix.ParseOrigDstAddr(&scms[0]) 260 if err != nil { 261 t.Fatal("ParseOrigDstAddr:", err) 262 } 263 264 switch test.network { 265 case "udp4": 266 sa4, ok := sa.(*unix.SockaddrInet4) 267 if !ok { 268 t.Fatalf("Got %T not *SockaddrInet4", sa) 269 } 270 271 lo := net.IPv4(127, 0, 0, 1) 272 if addr := net.IP(sa4.Addr[:]); !lo.Equal(addr) { 273 t.Errorf("Got address %v, want %v", addr, lo) 274 } 275 276 if sa4.Port != addr.Port { 277 t.Errorf("Got port %d, want %d", sa4.Port, addr.Port) 278 } 279 280 case "udp6": 281 sa6, ok := sa.(*unix.SockaddrInet6) 282 if !ok { 283 t.Fatalf("Got %T, want *SockaddrInet6", sa) 284 } 285 286 if addr := net.IP(sa6.Addr[:]); !net.IPv6loopback.Equal(addr) { 287 t.Errorf("Got address %v, want %v", addr, net.IPv6loopback) 288 } 289 290 if sa6.Port != addr.Port { 291 t.Errorf("Got port %d, want %d", sa6.Port, addr.Port) 292 } 293 } 294 }) 295 } 296 }