golang.org/x/net@v0.25.1-0.20240516223405-c87a5b62e243/internal/socket/socket_test.go (about) 1 // Copyright 2017 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 //go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || windows || zos 6 7 package socket_test 8 9 import ( 10 "bytes" 11 "fmt" 12 "io/ioutil" 13 "net" 14 "os" 15 "os/exec" 16 "path/filepath" 17 "runtime" 18 "strings" 19 "syscall" 20 "testing" 21 22 "golang.org/x/net/internal/socket" 23 "golang.org/x/net/nettest" 24 ) 25 26 func TestSocket(t *testing.T) { 27 t.Run("Option", func(t *testing.T) { 28 testSocketOption(t, &socket.Option{Level: syscall.SOL_SOCKET, Name: syscall.SO_RCVBUF, Len: 4}) 29 }) 30 } 31 32 func testSocketOption(t *testing.T, so *socket.Option) { 33 c, err := nettest.NewLocalPacketListener("udp") 34 if err != nil { 35 t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err) 36 } 37 defer c.Close() 38 cc, err := socket.NewConn(c.(net.Conn)) 39 if err != nil { 40 t.Fatal(err) 41 } 42 const N = 2048 43 if err := so.SetInt(cc, N); err != nil { 44 t.Fatal(err) 45 } 46 n, err := so.GetInt(cc) 47 if err != nil { 48 t.Fatal(err) 49 } 50 if n < N { 51 t.Fatalf("got %d; want greater than or equal to %d", n, N) 52 } 53 } 54 55 type mockControl struct { 56 Level int 57 Type int 58 Data []byte 59 } 60 61 func TestControlMessage(t *testing.T) { 62 switch runtime.GOOS { 63 case "windows": 64 t.Skipf("not supported on %s", runtime.GOOS) 65 } 66 67 for _, tt := range []struct { 68 cs []mockControl 69 }{ 70 { 71 []mockControl{ 72 {Level: 1, Type: 1}, 73 }, 74 }, 75 { 76 []mockControl{ 77 {Level: 2, Type: 2, Data: []byte{0xfe}}, 78 }, 79 }, 80 { 81 []mockControl{ 82 {Level: 3, Type: 3, Data: []byte{0xfe, 0xff, 0xff, 0xfe}}, 83 }, 84 }, 85 { 86 []mockControl{ 87 {Level: 4, Type: 4, Data: []byte{0xfe, 0xff, 0xff, 0xfe, 0xfe, 0xff, 0xff, 0xfe}}, 88 }, 89 }, 90 { 91 []mockControl{ 92 {Level: 4, Type: 4, Data: []byte{0xfe, 0xff, 0xff, 0xfe, 0xfe, 0xff, 0xff, 0xfe}}, 93 {Level: 2, Type: 2, Data: []byte{0xfe}}, 94 }, 95 }, 96 } { 97 var w []byte 98 var tailPadLen int 99 mm := socket.NewControlMessage([]int{0}) 100 for i, c := range tt.cs { 101 m := socket.NewControlMessage([]int{len(c.Data)}) 102 l := len(m) - len(mm) 103 if i == len(tt.cs)-1 && l > len(c.Data) { 104 tailPadLen = l - len(c.Data) 105 } 106 w = append(w, m...) 107 } 108 109 var err error 110 ww := make([]byte, len(w)) 111 copy(ww, w) 112 m := socket.ControlMessage(ww) 113 for _, c := range tt.cs { 114 if err = m.MarshalHeader(c.Level, c.Type, len(c.Data)); err != nil { 115 t.Fatalf("(%v).MarshalHeader() = %v", tt.cs, err) 116 } 117 copy(m.Data(len(c.Data)), c.Data) 118 m = m.Next(len(c.Data)) 119 } 120 m = socket.ControlMessage(w) 121 for _, c := range tt.cs { 122 m, err = m.Marshal(c.Level, c.Type, c.Data) 123 if err != nil { 124 t.Fatalf("(%v).Marshal() = %v", tt.cs, err) 125 } 126 } 127 if !bytes.Equal(ww, w) { 128 t.Fatalf("got %#v; want %#v", ww, w) 129 } 130 131 ws := [][]byte{w} 132 if tailPadLen > 0 { 133 // Test a message with no tail padding. 134 nopad := w[:len(w)-tailPadLen] 135 ws = append(ws, [][]byte{nopad}...) 136 } 137 for _, w := range ws { 138 ms, err := socket.ControlMessage(w).Parse() 139 if err != nil { 140 t.Fatalf("(%v).Parse() = %v", tt.cs, err) 141 } 142 for i, m := range ms { 143 lvl, typ, dataLen, err := m.ParseHeader() 144 if err != nil { 145 t.Fatalf("(%v).ParseHeader() = %v", tt.cs, err) 146 } 147 if lvl != tt.cs[i].Level || typ != tt.cs[i].Type || dataLen != len(tt.cs[i].Data) { 148 t.Fatalf("%v: got %d, %d, %d; want %d, %d, %d", tt.cs[i], lvl, typ, dataLen, tt.cs[i].Level, tt.cs[i].Type, len(tt.cs[i].Data)) 149 } 150 } 151 } 152 } 153 } 154 155 func TestUDP(t *testing.T) { 156 switch runtime.GOOS { 157 case "windows": 158 t.Skipf("not supported on %s", runtime.GOOS) 159 } 160 161 c, err := nettest.NewLocalPacketListener("udp") 162 if err != nil { 163 t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err) 164 } 165 defer c.Close() 166 // test that wrapped connections work with NewConn too 167 type wrappedConn struct{ *net.UDPConn } 168 cc, err := socket.NewConn(&wrappedConn{c.(*net.UDPConn)}) 169 if err != nil { 170 t.Fatal(err) 171 } 172 173 // create a dialed connection talking (only) to c/cc 174 cDialed, err := net.Dial("udp", c.LocalAddr().String()) 175 if err != nil { 176 t.Fatal(err) 177 } 178 ccDialed, err := socket.NewConn(cDialed) 179 if err != nil { 180 t.Fatal(err) 181 } 182 183 const data = "HELLO-R-U-THERE" 184 messageTests := []struct { 185 name string 186 conn *socket.Conn 187 dest net.Addr 188 }{ 189 { 190 name: "Message", 191 conn: cc, 192 dest: c.LocalAddr(), 193 }, 194 { 195 name: "Message-dialed", 196 conn: ccDialed, 197 dest: nil, 198 }, 199 } 200 for _, tt := range messageTests { 201 t.Run(tt.name, func(t *testing.T) { 202 wm := socket.Message{ 203 Buffers: bytes.SplitAfter([]byte(data), []byte("-")), 204 Addr: tt.dest, 205 } 206 if err := tt.conn.SendMsg(&wm, 0); err != nil { 207 t.Fatal(err) 208 } 209 b := make([]byte, 32) 210 rm := socket.Message{ 211 Buffers: [][]byte{b[:1], b[1:3], b[3:7], b[7:11], b[11:]}, 212 } 213 if err := cc.RecvMsg(&rm, 0); err != nil { 214 t.Fatal(err) 215 } 216 received := string(b[:rm.N]) 217 if received != data { 218 t.Fatalf("Roundtrip SendMsg/RecvMsg got %q; want %q", received, data) 219 } 220 }) 221 } 222 223 switch runtime.GOOS { 224 case "android", "linux": 225 messagesTests := []struct { 226 name string 227 conn *socket.Conn 228 dest net.Addr 229 }{ 230 { 231 name: "Messages", 232 conn: cc, 233 dest: c.LocalAddr(), 234 }, 235 { 236 name: "Messages-dialed", 237 conn: ccDialed, 238 dest: nil, 239 }, 240 } 241 for _, tt := range messagesTests { 242 t.Run(tt.name, func(t *testing.T) { 243 wmbs := bytes.SplitAfter([]byte(data), []byte("-")) 244 wms := []socket.Message{ 245 {Buffers: wmbs[:1], Addr: tt.dest}, 246 {Buffers: wmbs[1:], Addr: tt.dest}, 247 } 248 n, err := tt.conn.SendMsgs(wms, 0) 249 if err != nil { 250 t.Fatal(err) 251 } 252 if n != len(wms) { 253 t.Fatalf("SendMsgs(%#v) != %d; want %d", wms, n, len(wms)) 254 } 255 rmbs := [][]byte{make([]byte, 32), make([]byte, 32)} 256 rms := []socket.Message{ 257 {Buffers: [][]byte{rmbs[0]}}, 258 {Buffers: [][]byte{rmbs[1][:1], rmbs[1][1:3], rmbs[1][3:7], rmbs[1][7:11], rmbs[1][11:]}}, 259 } 260 nrecv := 0 261 for nrecv < len(rms) { 262 n, err := cc.RecvMsgs(rms[nrecv:], 0) 263 if err != nil { 264 t.Fatal(err) 265 } 266 nrecv += n 267 } 268 received0, received1 := string(rmbs[0][:rms[0].N]), string(rmbs[1][:rms[1].N]) 269 assembled := received0 + received1 270 assembledReordered := received1 + received0 271 if assembled != data && assembledReordered != data { 272 t.Fatalf("Roundtrip SendMsgs/RecvMsgs got %q / %q; want %q", assembled, assembledReordered, data) 273 } 274 }) 275 } 276 t.Run("Messages-undialed-no-dst", func(t *testing.T) { 277 // sending without destination address should fail. 278 // This checks that the internally recycled buffers are reset correctly. 279 data := []byte("HELLO-R-U-THERE") 280 wmbs := bytes.SplitAfter(data, []byte("-")) 281 wms := []socket.Message{ 282 {Buffers: wmbs[:1], Addr: nil}, 283 {Buffers: wmbs[1:], Addr: nil}, 284 } 285 n, err := cc.SendMsgs(wms, 0) 286 if n != 0 && err == nil { 287 t.Fatal("expected error, destination address required") 288 } 289 }) 290 } 291 292 // The behavior of transmission for zero byte paylaod depends 293 // on each platform implementation. Some may transmit only 294 // protocol header and options, other may transmit nothing. 295 // We test only that SendMsg and SendMsgs will not crash with 296 // empty buffers. 297 wm := socket.Message{ 298 Buffers: [][]byte{{}}, 299 Addr: c.LocalAddr(), 300 } 301 cc.SendMsg(&wm, 0) 302 wms := []socket.Message{ 303 {Buffers: [][]byte{{}}, Addr: c.LocalAddr()}, 304 } 305 cc.SendMsgs(wms, 0) 306 } 307 308 func BenchmarkUDP(b *testing.B) { 309 c, err := nettest.NewLocalPacketListener("udp") 310 if err != nil { 311 b.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err) 312 } 313 defer c.Close() 314 cc, err := socket.NewConn(c.(net.Conn)) 315 if err != nil { 316 b.Fatal(err) 317 } 318 data := []byte("HELLO-R-U-THERE") 319 wm := socket.Message{ 320 Buffers: [][]byte{data}, 321 Addr: c.LocalAddr(), 322 } 323 rm := socket.Message{ 324 Buffers: [][]byte{make([]byte, 128)}, 325 OOB: make([]byte, 128), 326 } 327 328 for M := 1; M <= 1<<9; M = M << 1 { 329 b.Run(fmt.Sprintf("Iter-%d", M), func(b *testing.B) { 330 for i := 0; i < b.N; i++ { 331 for j := 0; j < M; j++ { 332 if err := cc.SendMsg(&wm, 0); err != nil { 333 b.Fatal(err) 334 } 335 if err := cc.RecvMsg(&rm, 0); err != nil { 336 b.Fatal(err) 337 } 338 } 339 } 340 }) 341 switch runtime.GOOS { 342 case "android", "linux": 343 wms := make([]socket.Message, M) 344 for i := range wms { 345 wms[i].Buffers = [][]byte{data} 346 wms[i].Addr = c.LocalAddr() 347 } 348 rms := make([]socket.Message, M) 349 for i := range rms { 350 rms[i].Buffers = [][]byte{make([]byte, 128)} 351 rms[i].OOB = make([]byte, 128) 352 } 353 b.Run(fmt.Sprintf("Batch-%d", M), func(b *testing.B) { 354 for i := 0; i < b.N; i++ { 355 if _, err := cc.SendMsgs(wms, 0); err != nil { 356 b.Fatal(err) 357 } 358 if _, err := cc.RecvMsgs(rms, 0); err != nil { 359 b.Fatal(err) 360 } 361 } 362 }) 363 } 364 } 365 } 366 367 func TestRace(t *testing.T) { 368 tests := []string{ 369 ` 370 package main 371 import ( 372 "log" 373 "net" 374 375 "golang.org/x/net/ipv4" 376 ) 377 378 var g byte 379 380 func main() { 381 c, err := net.ListenPacket("udp", "127.0.0.1:0") 382 if err != nil { 383 log.Fatalf("ListenPacket: %v", err) 384 } 385 cc := ipv4.NewPacketConn(c) 386 sync := make(chan bool) 387 src := make([]byte, 100) 388 dst := make([]byte, 100) 389 go func() { 390 if _, err := cc.WriteTo(src, nil, c.LocalAddr()); err != nil { 391 log.Fatalf("WriteTo: %v", err) 392 } 393 }() 394 go func() { 395 if _, _, _, err := cc.ReadFrom(dst); err != nil { 396 log.Fatalf("ReadFrom: %v", err) 397 } 398 sync <- true 399 }() 400 g = dst[0] 401 <-sync 402 } 403 `, 404 ` 405 package main 406 import ( 407 "log" 408 "net" 409 410 "golang.org/x/net/ipv4" 411 ) 412 413 func main() { 414 c, err := net.ListenPacket("udp", "127.0.0.1:0") 415 if err != nil { 416 log.Fatalf("ListenPacket: %v", err) 417 } 418 cc := ipv4.NewPacketConn(c) 419 sync := make(chan bool) 420 src := make([]byte, 100) 421 dst := make([]byte, 100) 422 go func() { 423 if _, err := cc.WriteTo(src, nil, c.LocalAddr()); err != nil { 424 log.Fatalf("WriteTo: %v", err) 425 } 426 sync <- true 427 }() 428 src[0] = 0 429 go func() { 430 if _, _, _, err := cc.ReadFrom(dst); err != nil { 431 log.Fatalf("ReadFrom: %v", err) 432 } 433 }() 434 <-sync 435 } 436 `, 437 } 438 platforms := map[string]bool{ 439 "linux/amd64": true, 440 "linux/ppc64le": true, 441 "linux/arm64": true, 442 } 443 if !platforms[runtime.GOOS+"/"+runtime.GOARCH] { 444 t.Skip("skipping test on non-race-enabled host.") 445 } 446 if runtime.Compiler == "gccgo" { 447 t.Skip("skipping race test when built with gccgo") 448 } 449 dir, err := ioutil.TempDir("", "testrace") 450 if err != nil { 451 t.Fatalf("failed to create temp directory: %v", err) 452 } 453 defer os.RemoveAll(dir) 454 goBinary := filepath.Join(runtime.GOROOT(), "bin", "go") 455 t.Logf("%s version", goBinary) 456 got, err := exec.Command(goBinary, "version").CombinedOutput() 457 if len(got) > 0 { 458 t.Logf("%s", got) 459 } 460 if err != nil { 461 t.Fatalf("go version failed: %v", err) 462 } 463 for i, test := range tests { 464 t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) { 465 src := filepath.Join(dir, fmt.Sprintf("test%d.go", i)) 466 if err := ioutil.WriteFile(src, []byte(test), 0644); err != nil { 467 t.Fatalf("failed to write file: %v", err) 468 } 469 t.Logf("%s run -race %s", goBinary, src) 470 got, err := exec.Command(goBinary, "run", "-race", src).CombinedOutput() 471 if len(got) > 0 { 472 t.Logf("%s", got) 473 } 474 if strings.Contains(string(got), "-race requires cgo") { 475 t.Log("CGO is not enabled so can't use -race") 476 } else if !strings.Contains(string(got), "WARNING: DATA RACE") { 477 t.Errorf("race not detected for test %d: err:%v", i, err) 478 } 479 }) 480 } 481 }