golang.org/x/sys@v0.20.1-0.20240517151509-673e0f94c16d/unix/creds_test.go (about)

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build linux
     6  
     7  package unix_test
     8  
     9  import (
    10  	"bytes"
    11  	"errors"
    12  	"net"
    13  	"os"
    14  	"testing"
    15  	"time"
    16  
    17  	"golang.org/x/sys/unix"
    18  )
    19  
    20  // TestSCMCredentials tests the sending and receiving of credentials
    21  // (PID, UID, GID) in an ancillary message between two UNIX
    22  // sockets. The SO_PASSCRED socket option is enabled on the sending
    23  // socket for this to work.
    24  func TestSCMCredentials(t *testing.T) {
    25  	socketTypeTests := []struct {
    26  		socketType int
    27  		dataLen    int
    28  	}{
    29  		{
    30  			unix.SOCK_STREAM,
    31  			1,
    32  		}, {
    33  			unix.SOCK_DGRAM,
    34  			0,
    35  		},
    36  	}
    37  
    38  	for _, tt := range socketTypeTests {
    39  		fds, err := unix.Socketpair(unix.AF_LOCAL, tt.socketType, 0)
    40  		if err != nil {
    41  			t.Fatalf("Socketpair: %v", err)
    42  		}
    43  
    44  		err = unix.SetsockoptInt(fds[0], unix.SOL_SOCKET, unix.SO_PASSCRED, 1)
    45  		if err != nil {
    46  			unix.Close(fds[0])
    47  			unix.Close(fds[1])
    48  			t.Fatalf("SetsockoptInt: %v", err)
    49  		}
    50  
    51  		srvFile := os.NewFile(uintptr(fds[0]), "server")
    52  		cliFile := os.NewFile(uintptr(fds[1]), "client")
    53  		defer srvFile.Close()
    54  		defer cliFile.Close()
    55  
    56  		srv, err := net.FileConn(srvFile)
    57  		if err != nil {
    58  			t.Errorf("FileConn: %v", err)
    59  			return
    60  		}
    61  		defer srv.Close()
    62  
    63  		cli, err := net.FileConn(cliFile)
    64  		if err != nil {
    65  			t.Errorf("FileConn: %v", err)
    66  			return
    67  		}
    68  		defer cli.Close()
    69  
    70  		var ucred unix.Ucred
    71  		ucred.Pid = int32(os.Getpid())
    72  		ucred.Uid = uint32(os.Getuid())
    73  		ucred.Gid = uint32(os.Getgid())
    74  		oob := unix.UnixCredentials(&ucred)
    75  
    76  		// On SOCK_STREAM, this is internally going to send a dummy byte
    77  		n, oobn, err := cli.(*net.UnixConn).WriteMsgUnix(nil, oob, nil)
    78  		if err != nil {
    79  			t.Fatalf("WriteMsgUnix: %v", err)
    80  		}
    81  		if n != 0 {
    82  			t.Fatalf("WriteMsgUnix n = %d, want 0", n)
    83  		}
    84  		if oobn != len(oob) {
    85  			t.Fatalf("WriteMsgUnix oobn = %d, want %d", oobn, len(oob))
    86  		}
    87  
    88  		oob2 := make([]byte, 10*len(oob))
    89  		n, oobn2, flags, _, err := srv.(*net.UnixConn).ReadMsgUnix(nil, oob2)
    90  		if err != nil {
    91  			t.Fatalf("ReadMsgUnix: %v", err)
    92  		}
    93  		if flags != 0 && flags != unix.MSG_CMSG_CLOEXEC {
    94  			t.Fatalf("ReadMsgUnix flags = %#x, want 0 or %#x (MSG_CMSG_CLOEXEC)", flags, unix.MSG_CMSG_CLOEXEC)
    95  		}
    96  		if n != tt.dataLen {
    97  			t.Fatalf("ReadMsgUnix n = %d, want %d", n, tt.dataLen)
    98  		}
    99  		if oobn2 != oobn {
   100  			// without SO_PASSCRED set on the socket, ReadMsgUnix will
   101  			// return zero oob bytes
   102  			t.Fatalf("ReadMsgUnix oobn = %d, want %d", oobn2, oobn)
   103  		}
   104  		oob2 = oob2[:oobn2]
   105  		if !bytes.Equal(oob, oob2) {
   106  			t.Fatal("ReadMsgUnix oob bytes don't match")
   107  		}
   108  
   109  		scm, err := unix.ParseSocketControlMessage(oob2)
   110  		if err != nil {
   111  			t.Fatalf("ParseSocketControlMessage: %v", err)
   112  		}
   113  		newUcred, err := unix.ParseUnixCredentials(&scm[0])
   114  		if err != nil {
   115  			t.Fatalf("ParseUnixCredentials: %v", err)
   116  		}
   117  		if *newUcred != ucred {
   118  			t.Fatalf("ParseUnixCredentials = %+v, want %+v", newUcred, ucred)
   119  		}
   120  	}
   121  }
   122  
   123  func TestPktInfo(t *testing.T) {
   124  	testcases := []struct {
   125  		network string
   126  		address *net.UDPAddr
   127  	}{
   128  		{"udp4", &net.UDPAddr{IP: net.ParseIP("127.0.0.1")}},
   129  		{"udp6", &net.UDPAddr{IP: net.ParseIP("::1")}},
   130  	}
   131  	for _, test := range testcases {
   132  		t.Run(test.network, func(t *testing.T) {
   133  			conn, err := net.ListenUDP(test.network, test.address)
   134  			if errors.Is(err, unix.EADDRNOTAVAIL) || errors.Is(err, unix.EAFNOSUPPORT) {
   135  				t.Skipf("%v is not available", test.address)
   136  			}
   137  			if err != nil {
   138  				t.Fatal("Listen:", err)
   139  			}
   140  			defer conn.Close()
   141  
   142  			var pktInfo []byte
   143  			var src net.IP
   144  			switch test.network {
   145  			case "udp4":
   146  				var info4 unix.Inet4Pktinfo
   147  				src = net.ParseIP("127.0.0.2").To4()
   148  				copy(info4.Spec_dst[:], src)
   149  				pktInfo = unix.PktInfo4(&info4)
   150  
   151  			case "udp6":
   152  				var info6 unix.Inet6Pktinfo
   153  				src = net.ParseIP("2001:0DB8::1")
   154  				copy(info6.Addr[:], src)
   155  				pktInfo = unix.PktInfo6(&info6)
   156  
   157  				raw, err := conn.SyscallConn()
   158  				if err != nil {
   159  					t.Fatal("SyscallConn:", err)
   160  				}
   161  				var opErr error
   162  				err = raw.Control(func(fd uintptr) {
   163  					opErr = unix.SetsockoptInt(int(fd), unix.SOL_IPV6, unix.IPV6_FREEBIND, 1)
   164  				})
   165  				if err != nil {
   166  					t.Fatal("Control:", err)
   167  				}
   168  				if errors.Is(opErr, unix.ENOPROTOOPT) {
   169  					// Happens on android-amd64-emu, maybe Android has disabled
   170  					// IPV6_FREEBIND?
   171  					t.Skip("IPV6_FREEBIND not supported")
   172  				}
   173  				if opErr != nil {
   174  					t.Fatal("Can't enable IPV6_FREEBIND:", opErr)
   175  				}
   176  			}
   177  
   178  			msg := []byte{1}
   179  			addr := conn.LocalAddr().(*net.UDPAddr)
   180  			_, _, err = conn.WriteMsgUDP(msg, pktInfo, addr)
   181  			if err != nil {
   182  				t.Fatal("WriteMsgUDP:", err)
   183  			}
   184  
   185  			conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
   186  			_, _, _, remote, err := conn.ReadMsgUDP(msg, nil)
   187  			if err != nil {
   188  				t.Fatal("ReadMsgUDP:", err)
   189  			}
   190  
   191  			if !remote.IP.Equal(src) {
   192  				t.Errorf("Got packet from %v, want %v", remote.IP, src)
   193  			}
   194  		})
   195  	}
   196  }
   197  
   198  func TestParseOrigDstAddr(t *testing.T) {
   199  	testcases := []struct {
   200  		network string
   201  		address *net.UDPAddr
   202  	}{
   203  		{"udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)}},
   204  		{"udp6", &net.UDPAddr{IP: net.IPv6loopback}},
   205  	}
   206  
   207  	for _, test := range testcases {
   208  		t.Run(test.network, func(t *testing.T) {
   209  			conn, err := net.ListenUDP(test.network, test.address)
   210  			if errors.Is(err, unix.EADDRNOTAVAIL) || errors.Is(err, unix.EAFNOSUPPORT) {
   211  				t.Skipf("%v is not available", test.address)
   212  			}
   213  			if err != nil {
   214  				t.Fatal("Listen:", err)
   215  			}
   216  			defer conn.Close()
   217  
   218  			raw, err := conn.SyscallConn()
   219  			if err != nil {
   220  				t.Fatal("SyscallConn:", err)
   221  			}
   222  
   223  			var opErr error
   224  			err = raw.Control(func(fd uintptr) {
   225  				switch test.network {
   226  				case "udp4":
   227  					opErr = unix.SetsockoptInt(int(fd), unix.SOL_IP, unix.IP_RECVORIGDSTADDR, 1)
   228  				case "udp6":
   229  					opErr = unix.SetsockoptInt(int(fd), unix.SOL_IPV6, unix.IPV6_RECVORIGDSTADDR, 1)
   230  				}
   231  			})
   232  			if err != nil {
   233  				t.Fatal("Control:", err)
   234  			}
   235  			if opErr != nil {
   236  				t.Fatal("Can't enable RECVORIGDSTADDR:", err)
   237  			}
   238  
   239  			msg := []byte{1}
   240  			addr := conn.LocalAddr().(*net.UDPAddr)
   241  			_, err = conn.WriteToUDP(msg, addr)
   242  			if err != nil {
   243  				t.Fatal("WriteToUDP:", err)
   244  			}
   245  
   246  			conn.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
   247  			oob := make([]byte, unix.CmsgSpace(unix.SizeofSockaddrInet6))
   248  			_, oobn, _, _, err := conn.ReadMsgUDP(msg, oob)
   249  			if err != nil {
   250  				t.Fatal("ReadMsgUDP:", err)
   251  			}
   252  
   253  			scms, err := unix.ParseSocketControlMessage(oob[:oobn])
   254  			if err != nil {
   255  				t.Fatal("ParseSocketControlMessage:", err)
   256  			}
   257  
   258  			sa, err := unix.ParseOrigDstAddr(&scms[0])
   259  			if err != nil {
   260  				t.Fatal("ParseOrigDstAddr:", err)
   261  			}
   262  
   263  			switch test.network {
   264  			case "udp4":
   265  				sa4, ok := sa.(*unix.SockaddrInet4)
   266  				if !ok {
   267  					t.Fatalf("Got %T not *SockaddrInet4", sa)
   268  				}
   269  
   270  				lo := net.IPv4(127, 0, 0, 1)
   271  				if addr := net.IP(sa4.Addr[:]); !lo.Equal(addr) {
   272  					t.Errorf("Got address %v, want %v", addr, lo)
   273  				}
   274  
   275  				if sa4.Port != addr.Port {
   276  					t.Errorf("Got port %d, want %d", sa4.Port, addr.Port)
   277  				}
   278  
   279  			case "udp6":
   280  				sa6, ok := sa.(*unix.SockaddrInet6)
   281  				if !ok {
   282  					t.Fatalf("Got %T, want *SockaddrInet6", sa)
   283  				}
   284  
   285  				if addr := net.IP(sa6.Addr[:]); !net.IPv6loopback.Equal(addr) {
   286  					t.Errorf("Got address %v, want %v", addr, net.IPv6loopback)
   287  				}
   288  
   289  				if sa6.Port != addr.Port {
   290  					t.Errorf("Got port %d, want %d", sa6.Port, addr.Port)
   291  				}
   292  			}
   293  		})
   294  	}
   295  }