github.com/sagernet/netlink@v0.0.0-20240612041022-b9a21c07ac6a/handle_test.go (about) 1 //go:build linux 2 // +build linux 3 4 package netlink 5 6 import ( 7 "crypto/rand" 8 "encoding/hex" 9 "fmt" 10 "io" 11 "net" 12 "sync" 13 "sync/atomic" 14 "testing" 15 "time" 16 "unsafe" 17 18 "github.com/sagernet/netlink/nl" 19 "github.com/vishvananda/netns" 20 "golang.org/x/sys/unix" 21 ) 22 23 func TestHandleCreateClose(t *testing.T) { 24 h, err := NewHandle() 25 if err != nil { 26 t.Fatal(err) 27 } 28 for _, f := range nl.SupportedNlFamilies { 29 sh, ok := h.sockets[f] 30 if !ok { 31 t.Fatalf("Handle socket(s) for family %d was not created", f) 32 } 33 if sh.Socket == nil { 34 t.Fatalf("Socket for family %d was not created", f) 35 } 36 } 37 38 h.Close() 39 if h.sockets != nil { 40 t.Fatalf("Handle socket(s) were not closed") 41 } 42 } 43 44 func TestHandleCreateNetns(t *testing.T) { 45 skipUnlessRoot(t) 46 47 id := make([]byte, 4) 48 if _, err := io.ReadFull(rand.Reader, id); err != nil { 49 t.Fatal(err) 50 } 51 ifName := "dummy-" + hex.EncodeToString(id) 52 53 // Create an handle on the current netns 54 curNs, err := netns.Get() 55 if err != nil { 56 t.Fatal(err) 57 } 58 defer curNs.Close() 59 60 ch, err := NewHandleAt(curNs) 61 if err != nil { 62 t.Fatal(err) 63 } 64 defer ch.Close() 65 66 // Create an handle on a custom netns 67 newNs, err := netns.New() 68 if err != nil { 69 t.Fatal(err) 70 } 71 defer newNs.Close() 72 73 nh, err := NewHandleAt(newNs) 74 if err != nil { 75 t.Fatal(err) 76 } 77 defer nh.Close() 78 79 // Create an interface using the current handle 80 err = ch.LinkAdd(&Dummy{LinkAttrs{Name: ifName}}) 81 if err != nil { 82 t.Fatal(err) 83 } 84 l, err := ch.LinkByName(ifName) 85 if err != nil { 86 t.Fatal(err) 87 } 88 if l.Type() != "dummy" { 89 t.Fatalf("Unexpected link type: %s", l.Type()) 90 } 91 92 // Verify the new handle cannot find the interface 93 ll, err := nh.LinkByName(ifName) 94 if err == nil { 95 t.Fatalf("Unexpected link found on netns %s: %v", newNs, ll) 96 } 97 98 // Move the interface to the new netns 99 err = ch.LinkSetNsFd(l, int(newNs)) 100 if err != nil { 101 t.Fatal(err) 102 } 103 104 // Verify new netns handle can find the interface while current cannot 105 ll, err = nh.LinkByName(ifName) 106 if err != nil { 107 t.Fatal(err) 108 } 109 if ll.Type() != "dummy" { 110 t.Fatalf("Unexpected link type: %s", ll.Type()) 111 } 112 ll, err = ch.LinkByName(ifName) 113 if err == nil { 114 t.Fatalf("Unexpected link found on netns %s: %v", curNs, ll) 115 } 116 } 117 118 func TestHandleTimeout(t *testing.T) { 119 h, err := NewHandle() 120 if err != nil { 121 t.Fatal(err) 122 } 123 defer h.Close() 124 125 for _, sh := range h.sockets { 126 verifySockTimeVal(t, sh.Socket.GetFd(), unix.Timeval{Sec: 0, Usec: 0}) 127 } 128 129 h.SetSocketTimeout(2*time.Second + 8*time.Millisecond) 130 131 for _, sh := range h.sockets { 132 verifySockTimeVal(t, sh.Socket.GetFd(), unix.Timeval{Sec: 2, Usec: 8000}) 133 } 134 } 135 136 func TestHandleReceiveBuffer(t *testing.T) { 137 h, err := NewHandle() 138 if err != nil { 139 t.Fatal(err) 140 } 141 defer h.Close() 142 if err := h.SetSocketReceiveBufferSize(65536, false); err != nil { 143 t.Fatal(err) 144 } 145 sizes, err := h.GetSocketReceiveBufferSize() 146 if err != nil { 147 t.Fatal(err) 148 } 149 if len(sizes) != len(h.sockets) { 150 t.Fatalf("Unexpected number of socket buffer sizes: %d (expected %d)", 151 len(sizes), len(h.sockets)) 152 } 153 for _, s := range sizes { 154 if s < 65536 || s > 2*65536 { 155 t.Fatalf("Unexpected socket receive buffer size: %d (expected around %d)", 156 s, 65536) 157 } 158 } 159 } 160 161 func verifySockTimeVal(t *testing.T, fd int, tv unix.Timeval) { 162 var ( 163 tr unix.Timeval 164 v = uint32(0x10) 165 ) 166 _, _, errno := unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(fd), unix.SOL_SOCKET, unix.SO_SNDTIMEO, uintptr(unsafe.Pointer(&tr)), uintptr(unsafe.Pointer(&v)), 0) 167 if errno != 0 { 168 t.Fatal(errno) 169 } 170 171 if tr.Sec != tv.Sec || tr.Usec != tv.Usec { 172 t.Fatalf("Unexpected timeout value read: %v. Expected: %v", tr, tv) 173 } 174 175 _, _, errno = unix.Syscall6(unix.SYS_GETSOCKOPT, uintptr(fd), unix.SOL_SOCKET, unix.SO_RCVTIMEO, uintptr(unsafe.Pointer(&tr)), uintptr(unsafe.Pointer(&v)), 0) 176 if errno != 0 { 177 t.Fatal(errno) 178 } 179 180 if tr.Sec != tv.Sec || tr.Usec != tv.Usec { 181 t.Fatalf("Unexpected timeout value read: %v. Expected: %v", tr, tv) 182 } 183 } 184 185 var ( 186 iter = 10 187 numThread = uint32(4) 188 prefix = "iface" 189 handle1 *Handle 190 handle2 *Handle 191 ns1 netns.NsHandle 192 ns2 netns.NsHandle 193 done uint32 194 initError error 195 once sync.Once 196 ) 197 198 func getXfrmState(thread int) *XfrmState { 199 return &XfrmState{ 200 Src: net.IPv4(byte(192), byte(168), 1, byte(1+thread)), 201 Dst: net.IPv4(byte(192), byte(168), 2, byte(1+thread)), 202 Proto: XFRM_PROTO_AH, 203 Mode: XFRM_MODE_TUNNEL, 204 Spi: thread, 205 Auth: &XfrmStateAlgo{ 206 Name: "hmac(sha256)", 207 Key: []byte("abcdefghijklmnopqrstuvwzyzABCDEF"), 208 }, 209 } 210 } 211 212 func getXfrmPolicy(thread int) *XfrmPolicy { 213 return &XfrmPolicy{ 214 Src: &net.IPNet{IP: net.IPv4(byte(10), byte(10), byte(thread), 0), Mask: []byte{255, 255, 255, 0}}, 215 Dst: &net.IPNet{IP: net.IPv4(byte(10), byte(10), byte(thread), 0), Mask: []byte{255, 255, 255, 0}}, 216 Proto: 17, 217 DstPort: 1234, 218 SrcPort: 5678, 219 Dir: XFRM_DIR_OUT, 220 Tmpls: []XfrmPolicyTmpl{ 221 { 222 Src: net.IPv4(byte(192), byte(168), 1, byte(thread)), 223 Dst: net.IPv4(byte(192), byte(168), 2, byte(thread)), 224 Proto: XFRM_PROTO_ESP, 225 Mode: XFRM_MODE_TUNNEL, 226 }, 227 }, 228 } 229 } 230 func initParallel() { 231 ns1, initError = netns.New() 232 if initError != nil { 233 return 234 } 235 handle1, initError = NewHandleAt(ns1) 236 if initError != nil { 237 return 238 } 239 ns2, initError = netns.New() 240 if initError != nil { 241 return 242 } 243 handle2, initError = NewHandleAt(ns2) 244 if initError != nil { 245 return 246 } 247 } 248 249 func parallelDone() { 250 atomic.AddUint32(&done, 1) 251 if done == numThread { 252 if ns1.IsOpen() { 253 ns1.Close() 254 } 255 if ns2.IsOpen() { 256 ns2.Close() 257 } 258 if handle1 != nil { 259 handle1.Close() 260 } 261 if handle2 != nil { 262 handle2.Close() 263 } 264 } 265 } 266 267 // Do few route and xfrm operation on the two handles in parallel 268 func runParallelTests(t *testing.T, thread int) { 269 skipUnlessRoot(t) 270 defer parallelDone() 271 272 t.Parallel() 273 274 once.Do(initParallel) 275 if initError != nil { 276 t.Fatal(initError) 277 } 278 279 state := getXfrmState(thread) 280 policy := getXfrmPolicy(thread) 281 for i := 0; i < iter; i++ { 282 ifName := fmt.Sprintf("%s_%d_%d", prefix, thread, i) 283 link := &Dummy{LinkAttrs{Name: ifName}} 284 err := handle1.LinkAdd(link) 285 if err != nil { 286 t.Fatal(err) 287 } 288 l, err := handle1.LinkByName(ifName) 289 if err != nil { 290 t.Fatal(err) 291 } 292 err = handle1.LinkSetUp(l) 293 if err != nil { 294 t.Fatal(err) 295 } 296 handle1.LinkSetNsFd(l, int(ns2)) 297 if err != nil { 298 t.Fatal(err) 299 } 300 err = handle1.XfrmStateAdd(state) 301 if err != nil { 302 t.Fatal(err) 303 } 304 err = handle1.XfrmPolicyAdd(policy) 305 if err != nil { 306 t.Fatal(err) 307 } 308 err = handle2.LinkSetDown(l) 309 if err != nil { 310 t.Fatal(err) 311 } 312 err = handle2.XfrmStateAdd(state) 313 if err != nil { 314 t.Fatal(err) 315 } 316 err = handle2.XfrmPolicyAdd(policy) 317 if err != nil { 318 t.Fatal(err) 319 } 320 _, err = handle2.LinkByName(ifName) 321 if err != nil { 322 t.Fatal(err) 323 } 324 handle2.LinkSetNsFd(l, int(ns1)) 325 if err != nil { 326 t.Fatal(err) 327 } 328 err = handle1.LinkSetUp(l) 329 if err != nil { 330 t.Fatal(err) 331 } 332 _, err = handle1.LinkByName(ifName) 333 if err != nil { 334 t.Fatal(err) 335 } 336 err = handle1.XfrmPolicyDel(policy) 337 if err != nil { 338 t.Fatal(err) 339 } 340 err = handle2.XfrmPolicyDel(policy) 341 if err != nil { 342 t.Fatal(err) 343 } 344 err = handle1.XfrmStateDel(state) 345 if err != nil { 346 t.Fatal(err) 347 } 348 err = handle2.XfrmStateDel(state) 349 if err != nil { 350 t.Fatal(err) 351 } 352 } 353 } 354 355 func TestHandleParallel1(t *testing.T) { 356 runParallelTests(t, 1) 357 } 358 359 func TestHandleParallel2(t *testing.T) { 360 runParallelTests(t, 2) 361 } 362 363 func TestHandleParallel3(t *testing.T) { 364 runParallelTests(t, 3) 365 } 366 367 func TestHandleParallel4(t *testing.T) { 368 runParallelTests(t, 4) 369 }