golang.org/x/sys@v0.20.1-0.20240517151509-673e0f94c16d/unix/creds_test.go (about) 1 // Copyright 2012 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 //go:build linux 6 7 package unix_test 8 9 import ( 10 "bytes" 11 "errors" 12 "net" 13 "os" 14 "testing" 15 "time" 16 17 "golang.org/x/sys/unix" 18 ) 19 20 // TestSCMCredentials tests the sending and receiving of credentials 21 // (PID, UID, GID) in an ancillary message between two UNIX 22 // sockets. The SO_PASSCRED socket option is enabled on the sending 23 // socket for this to work. 24 func TestSCMCredentials(t *testing.T) { 25 socketTypeTests := []struct { 26 socketType int 27 dataLen int 28 }{ 29 { 30 unix.SOCK_STREAM, 31 1, 32 }, { 33 unix.SOCK_DGRAM, 34 0, 35 }, 36 } 37 38 for _, tt := range socketTypeTests { 39 fds, err := unix.Socketpair(unix.AF_LOCAL, tt.socketType, 0) 40 if err != nil { 41 t.Fatalf("Socketpair: %v", err) 42 } 43 44 err = unix.SetsockoptInt(fds[0], unix.SOL_SOCKET, unix.SO_PASSCRED, 1) 45 if err != nil { 46 unix.Close(fds[0]) 47 unix.Close(fds[1]) 48 t.Fatalf("SetsockoptInt: %v", err) 49 } 50 51 srvFile := os.NewFile(uintptr(fds[0]), "server") 52 cliFile := os.NewFile(uintptr(fds[1]), "client") 53 defer srvFile.Close() 54 defer cliFile.Close() 55 56 srv, err := net.FileConn(srvFile) 57 if err != nil { 58 t.Errorf("FileConn: %v", err) 59 return 60 } 61 defer srv.Close() 62 63 cli, err := net.FileConn(cliFile) 64 if err != nil { 65 t.Errorf("FileConn: %v", err) 66 return 67 } 68 defer cli.Close() 69 70 var ucred unix.Ucred 71 ucred.Pid = int32(os.Getpid()) 72 ucred.Uid = uint32(os.Getuid()) 73 ucred.Gid = uint32(os.Getgid()) 74 oob := unix.UnixCredentials(&ucred) 75 76 // On SOCK_STREAM, this is internally going to send a dummy byte 77 n, oobn, err := cli.(*net.UnixConn).WriteMsgUnix(nil, oob, nil) 78 if err != nil { 79 t.Fatalf("WriteMsgUnix: %v", err) 80 } 81 if n != 0 { 82 t.Fatalf("WriteMsgUnix n = %d, want 0", n) 83 } 84 if oobn != len(oob) { 85 t.Fatalf("WriteMsgUnix oobn = %d, want %d", oobn, len(oob)) 86 } 87 88 oob2 := make([]byte, 10*len(oob)) 89 n, oobn2, flags, _, err := srv.(*net.UnixConn).ReadMsgUnix(nil, oob2) 90 if err != nil { 91 t.Fatalf("ReadMsgUnix: %v", err) 92 } 93 if flags != 0 && flags != unix.MSG_CMSG_CLOEXEC { 94 t.Fatalf("ReadMsgUnix flags = %#x, want 0 or %#x (MSG_CMSG_CLOEXEC)", flags, unix.MSG_CMSG_CLOEXEC) 95 } 96 if n != tt.dataLen { 97 t.Fatalf("ReadMsgUnix n = %d, want %d", n, tt.dataLen) 98 } 99 if oobn2 != oobn { 100 // without SO_PASSCRED set on the socket, ReadMsgUnix will 101 // return zero oob bytes 102 t.Fatalf("ReadMsgUnix oobn = %d, want %d", oobn2, oobn) 103 } 104 oob2 = oob2[:oobn2] 105 if !bytes.Equal(oob, oob2) { 106 t.Fatal("ReadMsgUnix oob bytes don't match") 107 } 108 109 scm, err := unix.ParseSocketControlMessage(oob2) 110 if err != nil { 111 t.Fatalf("ParseSocketControlMessage: %v", err) 112 } 113 newUcred, err := unix.ParseUnixCredentials(&scm[0]) 114 if err != nil { 115 t.Fatalf("ParseUnixCredentials: %v", err) 116 } 117 if *newUcred != ucred { 118 t.Fatalf("ParseUnixCredentials = %+v, want %+v", newUcred, ucred) 119 } 120 } 121 } 122 123 func TestPktInfo(t *testing.T) { 124 testcases := []struct { 125 network string 126 address *net.UDPAddr 127 }{ 128 {"udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")}}, 129 {"udp6", &net.UDPAddr{IP: net.ParseIP("::1")}}, 130 } 131 for _, test := range testcases { 132 t.Run(test.network, func(t *testing.T) { 133 conn, err := net.ListenUDP(test.network, test.address) 134 if errors.Is(err, unix.EADDRNOTAVAIL) || errors.Is(err, unix.EAFNOSUPPORT) { 135 t.Skipf("%v is not available", test.address) 136 } 137 if err != nil { 138 t.Fatal("Listen:", err) 139 } 140 defer conn.Close() 141 142 var pktInfo []byte 143 var src net.IP 144 switch test.network { 145 case "udp4": 146 var info4 unix.Inet4Pktinfo 147 src = net.ParseIP("127.0.0.2").To4() 148 copy(info4.Spec_dst[:], src) 149 pktInfo = unix.PktInfo4(&info4) 150 151 case "udp6": 152 var info6 unix.Inet6Pktinfo 153 src = net.ParseIP("2001:0DB8::1") 154 copy(info6.Addr[:], src) 155 pktInfo = unix.PktInfo6(&info6) 156 157 raw, err := conn.SyscallConn() 158 if err != nil { 159 t.Fatal("SyscallConn:", err) 160 } 161 var opErr error 162 err = raw.Control(func(fd uintptr) { 163 opErr = unix.SetsockoptInt(int(fd), unix.SOL_IPV6, unix.IPV6_FREEBIND, 1) 164 }) 165 if err != nil { 166 t.Fatal("Control:", err) 167 } 168 if errors.Is(opErr, unix.ENOPROTOOPT) { 169 // Happens on android-amd64-emu, maybe Android has disabled 170 // IPV6_FREEBIND? 171 t.Skip("IPV6_FREEBIND not supported") 172 } 173 if opErr != nil { 174 t.Fatal("Can't enable IPV6_FREEBIND:", opErr) 175 } 176 } 177 178 msg := []byte{1} 179 addr := conn.LocalAddr().(*net.UDPAddr) 180 _, _, err = conn.WriteMsgUDP(msg, pktInfo, addr) 181 if err != nil { 182 t.Fatal("WriteMsgUDP:", err) 183 } 184 185 conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) 186 _, _, _, remote, err := conn.ReadMsgUDP(msg, nil) 187 if err != nil { 188 t.Fatal("ReadMsgUDP:", err) 189 } 190 191 if !remote.IP.Equal(src) { 192 t.Errorf("Got packet from %v, want %v", remote.IP, src) 193 } 194 }) 195 } 196 } 197 198 func TestParseOrigDstAddr(t *testing.T) { 199 testcases := []struct { 200 network string 201 address *net.UDPAddr 202 }{ 203 {"udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}}, 204 {"udp6", &net.UDPAddr{IP: net.IPv6loopback}}, 205 } 206 207 for _, test := range testcases { 208 t.Run(test.network, func(t *testing.T) { 209 conn, err := net.ListenUDP(test.network, test.address) 210 if errors.Is(err, unix.EADDRNOTAVAIL) || errors.Is(err, unix.EAFNOSUPPORT) { 211 t.Skipf("%v is not available", test.address) 212 } 213 if err != nil { 214 t.Fatal("Listen:", err) 215 } 216 defer conn.Close() 217 218 raw, err := conn.SyscallConn() 219 if err != nil { 220 t.Fatal("SyscallConn:", err) 221 } 222 223 var opErr error 224 err = raw.Control(func(fd uintptr) { 225 switch test.network { 226 case "udp4": 227 opErr = unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_RECVORIGDSTADDR, 1) 228 case "udp6": 229 opErr = unix.SetsockoptInt(int(fd), unix.SOL_IPV6, unix.IPV6_RECVORIGDSTADDR, 1) 230 } 231 }) 232 if err != nil { 233 t.Fatal("Control:", err) 234 } 235 if opErr != nil { 236 t.Fatal("Can't enable RECVORIGDSTADDR:", err) 237 } 238 239 msg := []byte{1} 240 addr := conn.LocalAddr().(*net.UDPAddr) 241 _, err = conn.WriteToUDP(msg, addr) 242 if err != nil { 243 t.Fatal("WriteToUDP:", err) 244 } 245 246 conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) 247 oob := make([]byte, unix.CmsgSpace(unix.SizeofSockaddrInet6)) 248 _, oobn, _, _, err := conn.ReadMsgUDP(msg, oob) 249 if err != nil { 250 t.Fatal("ReadMsgUDP:", err) 251 } 252 253 scms, err := unix.ParseSocketControlMessage(oob[:oobn]) 254 if err != nil { 255 t.Fatal("ParseSocketControlMessage:", err) 256 } 257 258 sa, err := unix.ParseOrigDstAddr(&scms[0]) 259 if err != nil { 260 t.Fatal("ParseOrigDstAddr:", err) 261 } 262 263 switch test.network { 264 case "udp4": 265 sa4, ok := sa.(*unix.SockaddrInet4) 266 if !ok { 267 t.Fatalf("Got %T not *SockaddrInet4", sa) 268 } 269 270 lo := net.IPv4(127, 0, 0, 1) 271 if addr := net.IP(sa4.Addr[:]); !lo.Equal(addr) { 272 t.Errorf("Got address %v, want %v", addr, lo) 273 } 274 275 if sa4.Port != addr.Port { 276 t.Errorf("Got port %d, want %d", sa4.Port, addr.Port) 277 } 278 279 case "udp6": 280 sa6, ok := sa.(*unix.SockaddrInet6) 281 if !ok { 282 t.Fatalf("Got %T, want *SockaddrInet6", sa) 283 } 284 285 if addr := net.IP(sa6.Addr[:]); !net.IPv6loopback.Equal(addr) { 286 t.Errorf("Got address %v, want %v", addr, net.IPv6loopback) 287 } 288 289 if sa6.Port != addr.Port { 290 t.Errorf("Got port %d, want %d", sa6.Port, addr.Port) 291 } 292 } 293 }) 294 } 295 }