github.com/GFW-knocker/wireguard@v1.0.1/conn/sticky_linux_test.go (about) 1 //go:build linux && !android 2 3 /* SPDX-License-Identifier: MIT 4 * 5 * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. 6 */ 7 8 package conn 9 10 import ( 11 "context" 12 "net" 13 "net/netip" 14 "runtime" 15 "testing" 16 "unsafe" 17 18 "golang.org/x/sys/unix" 19 ) 20 21 func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) { 22 var buf []byte 23 if addr.Is4() { 24 buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) 25 hdr := unix.Cmsghdr{ 26 Level: unix.IPPROTO_IP, 27 Type: unix.IP_PKTINFO, 28 } 29 hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo)) 30 copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr)))) 31 32 info := unix.Inet4Pktinfo{ 33 Ifindex: ifidx, 34 Spec_dst: addr.As4(), 35 } 36 copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo)) 37 } else { 38 buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo)) 39 hdr := unix.Cmsghdr{ 40 Level: unix.IPPROTO_IPV6, 41 Type: unix.IPV6_PKTINFO, 42 } 43 hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo)) 44 copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr)))) 45 46 info := unix.Inet6Pktinfo{ 47 Ifindex: uint32(ifidx), 48 Addr: addr.As16(), 49 } 50 copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo)) 51 } 52 53 ep.src = buf 54 } 55 56 func Test_setSrcControl(t *testing.T) { 57 t.Run("IPv4", func(t *testing.T) { 58 ep := &StdNetEndpoint{ 59 AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"), 60 } 61 setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5) 62 63 control := make([]byte, stickyControlSize) 64 65 setSrcControl(&control, ep) 66 67 hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) 68 if hdr.Level != unix.IPPROTO_IP { 69 t.Errorf("unexpected level: %d", hdr.Level) 70 } 71 if hdr.Type != unix.IP_PKTINFO { 72 t.Errorf("unexpected type: %d", hdr.Type) 73 } 74 if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) { 75 t.Errorf("unexpected length: %d", hdr.Len) 76 } 77 info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) 78 if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 { 79 t.Errorf("unexpected address: %v", info.Spec_dst) 80 } 81 if info.Ifindex != 5 { 82 t.Errorf("unexpected ifindex: %d", info.Ifindex) 83 } 84 }) 85 86 t.Run("IPv6", func(t *testing.T) { 87 ep := &StdNetEndpoint{ 88 AddrPort: netip.MustParseAddrPort("[::1]:1234"), 89 } 90 setSrc(ep, netip.MustParseAddr("::1"), 5) 91 92 control := make([]byte, stickyControlSize) 93 94 setSrcControl(&control, ep) 95 96 hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) 97 if hdr.Level != unix.IPPROTO_IPV6 { 98 t.Errorf("unexpected level: %d", hdr.Level) 99 } 100 if hdr.Type != unix.IPV6_PKTINFO { 101 t.Errorf("unexpected type: %d", hdr.Type) 102 } 103 if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) { 104 t.Errorf("unexpected length: %d", hdr.Len) 105 } 106 info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) 107 if info.Addr != ep.SrcIP().As16() { 108 t.Errorf("unexpected address: %v", info.Addr) 109 } 110 if info.Ifindex != 5 { 111 t.Errorf("unexpected ifindex: %d", info.Ifindex) 112 } 113 }) 114 115 t.Run("ClearOnNoSrc", func(t *testing.T) { 116 control := make([]byte, stickyControlSize) 117 hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) 118 hdr.Level = 1 119 hdr.Type = 2 120 hdr.Len = 3 121 122 setSrcControl(&control, &StdNetEndpoint{}) 123 124 if len(control) != 0 { 125 t.Errorf("unexpected control: %v", control) 126 } 127 }) 128 } 129 130 func Test_getSrcFromControl(t *testing.T) { 131 t.Run("IPv4", func(t *testing.T) { 132 control := make([]byte, stickyControlSize) 133 hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) 134 hdr.Level = unix.IPPROTO_IP 135 hdr.Type = unix.IP_PKTINFO 136 hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) 137 info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) 138 info.Spec_dst = [4]byte{127, 0, 0, 1} 139 info.Ifindex = 5 140 141 ep := &StdNetEndpoint{} 142 getSrcFromControl(control, ep) 143 144 if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") { 145 t.Errorf("unexpected address: %v", ep.SrcIP()) 146 } 147 if ep.SrcIfidx() != 5 { 148 t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) 149 } 150 }) 151 t.Run("IPv6", func(t *testing.T) { 152 control := make([]byte, stickyControlSize) 153 hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) 154 hdr.Level = unix.IPPROTO_IPV6 155 hdr.Type = unix.IPV6_PKTINFO 156 hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) 157 info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) 158 info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1} 159 info.Ifindex = 5 160 161 ep := &StdNetEndpoint{} 162 getSrcFromControl(control, ep) 163 164 if ep.SrcIP() != netip.MustParseAddr("::1") { 165 t.Errorf("unexpected address: %v", ep.SrcIP()) 166 } 167 if ep.SrcIfidx() != 5 { 168 t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) 169 } 170 }) 171 t.Run("ClearOnEmpty", func(t *testing.T) { 172 var control []byte 173 ep := &StdNetEndpoint{} 174 setSrc(ep, netip.MustParseAddr("::1"), 5) 175 176 getSrcFromControl(control, ep) 177 if ep.SrcIP().IsValid() { 178 t.Errorf("unexpected address: %v", ep.SrcIP()) 179 } 180 if ep.SrcIfidx() != 0 { 181 t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) 182 } 183 }) 184 t.Run("Multiple", func(t *testing.T) { 185 zeroControl := make([]byte, unix.CmsgSpace(0)) 186 zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0])) 187 zeroHdr.SetLen(unix.CmsgLen(0)) 188 189 control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo)) 190 hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0])) 191 hdr.Level = unix.IPPROTO_IP 192 hdr.Type = unix.IP_PKTINFO 193 hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) 194 info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)])) 195 info.Spec_dst = [4]byte{127, 0, 0, 1} 196 info.Ifindex = 5 197 198 combined := make([]byte, 0) 199 combined = append(combined, zeroControl...) 200 combined = append(combined, control...) 201 202 ep := &StdNetEndpoint{} 203 getSrcFromControl(combined, ep) 204 205 if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") { 206 t.Errorf("unexpected address: %v", ep.SrcIP()) 207 } 208 if ep.SrcIfidx() != 5 { 209 t.Errorf("unexpected ifindex: %d", ep.SrcIfidx()) 210 } 211 }) 212 } 213 214 func Test_listenConfig(t *testing.T) { 215 t.Run("IPv4", func(t *testing.T) { 216 conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0") 217 if err != nil { 218 t.Fatal(err) 219 } 220 defer conn.Close() 221 sc, err := conn.(*net.UDPConn).SyscallConn() 222 if err != nil { 223 t.Fatal(err) 224 } 225 226 if runtime.GOOS == "linux" { 227 var i int 228 sc.Control(func(fd uintptr) { 229 i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO) 230 }) 231 if err != nil { 232 t.Fatal(err) 233 } 234 if i != 1 { 235 t.Error("IP_PKTINFO not set!") 236 } 237 } else { 238 t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS) 239 } 240 }) 241 t.Run("IPv6", func(t *testing.T) { 242 conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0") 243 if err != nil { 244 t.Fatal(err) 245 } 246 sc, err := conn.(*net.UDPConn).SyscallConn() 247 if err != nil { 248 t.Fatal(err) 249 } 250 251 if runtime.GOOS == "linux" { 252 var i int 253 sc.Control(func(fd uintptr) { 254 i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO) 255 }) 256 if err != nil { 257 t.Fatal(err) 258 } 259 if i != 1 { 260 t.Error("IPV6_PKTINFO not set!") 261 } 262 } else { 263 t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS) 264 } 265 }) 266 }