golang.org/x/net@v0.25.1-0.20240516223405-c87a5b62e243/internal/socket/socket_test.go (about)

     1  // Copyright 2017 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || windows || zos
     6  
     7  package socket_test
     8  
     9  import (
    10  	"bytes"
    11  	"fmt"
    12  	"io/ioutil"
    13  	"net"
    14  	"os"
    15  	"os/exec"
    16  	"path/filepath"
    17  	"runtime"
    18  	"strings"
    19  	"syscall"
    20  	"testing"
    21  
    22  	"golang.org/x/net/internal/socket"
    23  	"golang.org/x/net/nettest"
    24  )
    25  
    26  func TestSocket(t *testing.T) {
    27  	t.Run("Option", func(t *testing.T) {
    28  		testSocketOption(t, &socket.Option{Level: syscall.SOL_SOCKET, Name: syscall.SO_RCVBUF, Len: 4})
    29  	})
    30  }
    31  
    32  func testSocketOption(t *testing.T, so *socket.Option) {
    33  	c, err := nettest.NewLocalPacketListener("udp")
    34  	if err != nil {
    35  		t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
    36  	}
    37  	defer c.Close()
    38  	cc, err := socket.NewConn(c.(net.Conn))
    39  	if err != nil {
    40  		t.Fatal(err)
    41  	}
    42  	const N = 2048
    43  	if err := so.SetInt(cc, N); err != nil {
    44  		t.Fatal(err)
    45  	}
    46  	n, err := so.GetInt(cc)
    47  	if err != nil {
    48  		t.Fatal(err)
    49  	}
    50  	if n < N {
    51  		t.Fatalf("got %d; want greater than or equal to %d", n, N)
    52  	}
    53  }
    54  
    55  type mockControl struct {
    56  	Level int
    57  	Type  int
    58  	Data  []byte
    59  }
    60  
    61  func TestControlMessage(t *testing.T) {
    62  	switch runtime.GOOS {
    63  	case "windows":
    64  		t.Skipf("not supported on %s", runtime.GOOS)
    65  	}
    66  
    67  	for _, tt := range []struct {
    68  		cs []mockControl
    69  	}{
    70  		{
    71  			[]mockControl{
    72  				{Level: 1, Type: 1},
    73  			},
    74  		},
    75  		{
    76  			[]mockControl{
    77  				{Level: 2, Type: 2, Data: []byte{0xfe}},
    78  			},
    79  		},
    80  		{
    81  			[]mockControl{
    82  				{Level: 3, Type: 3, Data: []byte{0xfe, 0xff, 0xff, 0xfe}},
    83  			},
    84  		},
    85  		{
    86  			[]mockControl{
    87  				{Level: 4, Type: 4, Data: []byte{0xfe, 0xff, 0xff, 0xfe, 0xfe, 0xff, 0xff, 0xfe}},
    88  			},
    89  		},
    90  		{
    91  			[]mockControl{
    92  				{Level: 4, Type: 4, Data: []byte{0xfe, 0xff, 0xff, 0xfe, 0xfe, 0xff, 0xff, 0xfe}},
    93  				{Level: 2, Type: 2, Data: []byte{0xfe}},
    94  			},
    95  		},
    96  	} {
    97  		var w []byte
    98  		var tailPadLen int
    99  		mm := socket.NewControlMessage([]int{0})
   100  		for i, c := range tt.cs {
   101  			m := socket.NewControlMessage([]int{len(c.Data)})
   102  			l := len(m) - len(mm)
   103  			if i == len(tt.cs)-1 && l > len(c.Data) {
   104  				tailPadLen = l - len(c.Data)
   105  			}
   106  			w = append(w, m...)
   107  		}
   108  
   109  		var err error
   110  		ww := make([]byte, len(w))
   111  		copy(ww, w)
   112  		m := socket.ControlMessage(ww)
   113  		for _, c := range tt.cs {
   114  			if err = m.MarshalHeader(c.Level, c.Type, len(c.Data)); err != nil {
   115  				t.Fatalf("(%v).MarshalHeader() = %v", tt.cs, err)
   116  			}
   117  			copy(m.Data(len(c.Data)), c.Data)
   118  			m = m.Next(len(c.Data))
   119  		}
   120  		m = socket.ControlMessage(w)
   121  		for _, c := range tt.cs {
   122  			m, err = m.Marshal(c.Level, c.Type, c.Data)
   123  			if err != nil {
   124  				t.Fatalf("(%v).Marshal() = %v", tt.cs, err)
   125  			}
   126  		}
   127  		if !bytes.Equal(ww, w) {
   128  			t.Fatalf("got %#v; want %#v", ww, w)
   129  		}
   130  
   131  		ws := [][]byte{w}
   132  		if tailPadLen > 0 {
   133  			// Test a message with no tail padding.
   134  			nopad := w[:len(w)-tailPadLen]
   135  			ws = append(ws, [][]byte{nopad}...)
   136  		}
   137  		for _, w := range ws {
   138  			ms, err := socket.ControlMessage(w).Parse()
   139  			if err != nil {
   140  				t.Fatalf("(%v).Parse() = %v", tt.cs, err)
   141  			}
   142  			for i, m := range ms {
   143  				lvl, typ, dataLen, err := m.ParseHeader()
   144  				if err != nil {
   145  					t.Fatalf("(%v).ParseHeader() = %v", tt.cs, err)
   146  				}
   147  				if lvl != tt.cs[i].Level || typ != tt.cs[i].Type || dataLen != len(tt.cs[i].Data) {
   148  					t.Fatalf("%v: got %d, %d, %d; want %d, %d, %d", tt.cs[i], lvl, typ, dataLen, tt.cs[i].Level, tt.cs[i].Type, len(tt.cs[i].Data))
   149  				}
   150  			}
   151  		}
   152  	}
   153  }
   154  
   155  func TestUDP(t *testing.T) {
   156  	switch runtime.GOOS {
   157  	case "windows":
   158  		t.Skipf("not supported on %s", runtime.GOOS)
   159  	}
   160  
   161  	c, err := nettest.NewLocalPacketListener("udp")
   162  	if err != nil {
   163  		t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
   164  	}
   165  	defer c.Close()
   166  	// test that wrapped connections work with NewConn too
   167  	type wrappedConn struct{ *net.UDPConn }
   168  	cc, err := socket.NewConn(&wrappedConn{c.(*net.UDPConn)})
   169  	if err != nil {
   170  		t.Fatal(err)
   171  	}
   172  
   173  	// create a dialed connection talking (only) to c/cc
   174  	cDialed, err := net.Dial("udp", c.LocalAddr().String())
   175  	if err != nil {
   176  		t.Fatal(err)
   177  	}
   178  	ccDialed, err := socket.NewConn(cDialed)
   179  	if err != nil {
   180  		t.Fatal(err)
   181  	}
   182  
   183  	const data = "HELLO-R-U-THERE"
   184  	messageTests := []struct {
   185  		name string
   186  		conn *socket.Conn
   187  		dest net.Addr
   188  	}{
   189  		{
   190  			name: "Message",
   191  			conn: cc,
   192  			dest: c.LocalAddr(),
   193  		},
   194  		{
   195  			name: "Message-dialed",
   196  			conn: ccDialed,
   197  			dest: nil,
   198  		},
   199  	}
   200  	for _, tt := range messageTests {
   201  		t.Run(tt.name, func(t *testing.T) {
   202  			wm := socket.Message{
   203  				Buffers: bytes.SplitAfter([]byte(data), []byte("-")),
   204  				Addr:    tt.dest,
   205  			}
   206  			if err := tt.conn.SendMsg(&wm, 0); err != nil {
   207  				t.Fatal(err)
   208  			}
   209  			b := make([]byte, 32)
   210  			rm := socket.Message{
   211  				Buffers: [][]byte{b[:1], b[1:3], b[3:7], b[7:11], b[11:]},
   212  			}
   213  			if err := cc.RecvMsg(&rm, 0); err != nil {
   214  				t.Fatal(err)
   215  			}
   216  			received := string(b[:rm.N])
   217  			if received != data {
   218  				t.Fatalf("Roundtrip SendMsg/RecvMsg got %q; want %q", received, data)
   219  			}
   220  		})
   221  	}
   222  
   223  	switch runtime.GOOS {
   224  	case "android", "linux":
   225  		messagesTests := []struct {
   226  			name string
   227  			conn *socket.Conn
   228  			dest net.Addr
   229  		}{
   230  			{
   231  				name: "Messages",
   232  				conn: cc,
   233  				dest: c.LocalAddr(),
   234  			},
   235  			{
   236  				name: "Messages-dialed",
   237  				conn: ccDialed,
   238  				dest: nil,
   239  			},
   240  		}
   241  		for _, tt := range messagesTests {
   242  			t.Run(tt.name, func(t *testing.T) {
   243  				wmbs := bytes.SplitAfter([]byte(data), []byte("-"))
   244  				wms := []socket.Message{
   245  					{Buffers: wmbs[:1], Addr: tt.dest},
   246  					{Buffers: wmbs[1:], Addr: tt.dest},
   247  				}
   248  				n, err := tt.conn.SendMsgs(wms, 0)
   249  				if err != nil {
   250  					t.Fatal(err)
   251  				}
   252  				if n != len(wms) {
   253  					t.Fatalf("SendMsgs(%#v) != %d; want %d", wms, n, len(wms))
   254  				}
   255  				rmbs := [][]byte{make([]byte, 32), make([]byte, 32)}
   256  				rms := []socket.Message{
   257  					{Buffers: [][]byte{rmbs[0]}},
   258  					{Buffers: [][]byte{rmbs[1][:1], rmbs[1][1:3], rmbs[1][3:7], rmbs[1][7:11], rmbs[1][11:]}},
   259  				}
   260  				nrecv := 0
   261  				for nrecv < len(rms) {
   262  					n, err := cc.RecvMsgs(rms[nrecv:], 0)
   263  					if err != nil {
   264  						t.Fatal(err)
   265  					}
   266  					nrecv += n
   267  				}
   268  				received0, received1 := string(rmbs[0][:rms[0].N]), string(rmbs[1][:rms[1].N])
   269  				assembled := received0 + received1
   270  				assembledReordered := received1 + received0
   271  				if assembled != data && assembledReordered != data {
   272  					t.Fatalf("Roundtrip SendMsgs/RecvMsgs got %q / %q; want %q", assembled, assembledReordered, data)
   273  				}
   274  			})
   275  		}
   276  		t.Run("Messages-undialed-no-dst", func(t *testing.T) {
   277  			// sending without destination address should fail.
   278  			// This checks that the internally recycled buffers are reset correctly.
   279  			data := []byte("HELLO-R-U-THERE")
   280  			wmbs := bytes.SplitAfter(data, []byte("-"))
   281  			wms := []socket.Message{
   282  				{Buffers: wmbs[:1], Addr: nil},
   283  				{Buffers: wmbs[1:], Addr: nil},
   284  			}
   285  			n, err := cc.SendMsgs(wms, 0)
   286  			if n != 0 && err == nil {
   287  				t.Fatal("expected error, destination address required")
   288  			}
   289  		})
   290  	}
   291  
   292  	// The behavior of transmission for zero byte paylaod depends
   293  	// on each platform implementation. Some may transmit only
   294  	// protocol header and options, other may transmit nothing.
   295  	// We test only that SendMsg and SendMsgs will not crash with
   296  	// empty buffers.
   297  	wm := socket.Message{
   298  		Buffers: [][]byte{{}},
   299  		Addr:    c.LocalAddr(),
   300  	}
   301  	cc.SendMsg(&wm, 0)
   302  	wms := []socket.Message{
   303  		{Buffers: [][]byte{{}}, Addr: c.LocalAddr()},
   304  	}
   305  	cc.SendMsgs(wms, 0)
   306  }
   307  
   308  func BenchmarkUDP(b *testing.B) {
   309  	c, err := nettest.NewLocalPacketListener("udp")
   310  	if err != nil {
   311  		b.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
   312  	}
   313  	defer c.Close()
   314  	cc, err := socket.NewConn(c.(net.Conn))
   315  	if err != nil {
   316  		b.Fatal(err)
   317  	}
   318  	data := []byte("HELLO-R-U-THERE")
   319  	wm := socket.Message{
   320  		Buffers: [][]byte{data},
   321  		Addr:    c.LocalAddr(),
   322  	}
   323  	rm := socket.Message{
   324  		Buffers: [][]byte{make([]byte, 128)},
   325  		OOB:     make([]byte, 128),
   326  	}
   327  
   328  	for M := 1; M <= 1<<9; M = M << 1 {
   329  		b.Run(fmt.Sprintf("Iter-%d", M), func(b *testing.B) {
   330  			for i := 0; i < b.N; i++ {
   331  				for j := 0; j < M; j++ {
   332  					if err := cc.SendMsg(&wm, 0); err != nil {
   333  						b.Fatal(err)
   334  					}
   335  					if err := cc.RecvMsg(&rm, 0); err != nil {
   336  						b.Fatal(err)
   337  					}
   338  				}
   339  			}
   340  		})
   341  		switch runtime.GOOS {
   342  		case "android", "linux":
   343  			wms := make([]socket.Message, M)
   344  			for i := range wms {
   345  				wms[i].Buffers = [][]byte{data}
   346  				wms[i].Addr = c.LocalAddr()
   347  			}
   348  			rms := make([]socket.Message, M)
   349  			for i := range rms {
   350  				rms[i].Buffers = [][]byte{make([]byte, 128)}
   351  				rms[i].OOB = make([]byte, 128)
   352  			}
   353  			b.Run(fmt.Sprintf("Batch-%d", M), func(b *testing.B) {
   354  				for i := 0; i < b.N; i++ {
   355  					if _, err := cc.SendMsgs(wms, 0); err != nil {
   356  						b.Fatal(err)
   357  					}
   358  					if _, err := cc.RecvMsgs(rms, 0); err != nil {
   359  						b.Fatal(err)
   360  					}
   361  				}
   362  			})
   363  		}
   364  	}
   365  }
   366  
   367  func TestRace(t *testing.T) {
   368  	tests := []string{
   369  		`
   370  package main
   371  import (
   372  	"log"
   373  	"net"
   374  
   375  	"golang.org/x/net/ipv4"
   376  )
   377  
   378  var g byte
   379  
   380  func main() {
   381  	c, err := net.ListenPacket("udp", "127.0.0.1:0")
   382  	if err != nil {
   383  		log.Fatalf("ListenPacket: %v", err)
   384  	}
   385  	cc := ipv4.NewPacketConn(c)
   386  	sync := make(chan bool)
   387  	src := make([]byte, 100)
   388  	dst := make([]byte, 100)
   389  	go func() {
   390  		if _, err := cc.WriteTo(src, nil, c.LocalAddr()); err != nil {
   391  			log.Fatalf("WriteTo: %v", err)
   392  		}
   393  	}()
   394  	go func() {
   395  		if _, _, _, err := cc.ReadFrom(dst); err != nil {
   396  			log.Fatalf("ReadFrom: %v", err)
   397  		}
   398  		sync <- true
   399  	}()
   400  	g = dst[0]
   401  	<-sync
   402  }
   403  `,
   404  		`
   405  package main
   406  import (
   407  	"log"
   408  	"net"
   409  
   410  	"golang.org/x/net/ipv4"
   411  )
   412  
   413  func main() {
   414  	c, err := net.ListenPacket("udp", "127.0.0.1:0")
   415  	if err != nil {
   416  		log.Fatalf("ListenPacket: %v", err)
   417  	}
   418  	cc := ipv4.NewPacketConn(c)
   419  	sync := make(chan bool)
   420  	src := make([]byte, 100)
   421  	dst := make([]byte, 100)
   422  	go func() {
   423  		if _, err := cc.WriteTo(src, nil, c.LocalAddr()); err != nil {
   424  			log.Fatalf("WriteTo: %v", err)
   425  		}
   426  		sync <- true
   427  	}()
   428  	src[0] = 0
   429  	go func() {
   430  		if _, _, _, err := cc.ReadFrom(dst); err != nil {
   431  			log.Fatalf("ReadFrom: %v", err)
   432  		}
   433  	}()
   434  	<-sync
   435  }
   436  `,
   437  	}
   438  	platforms := map[string]bool{
   439  		"linux/amd64":   true,
   440  		"linux/ppc64le": true,
   441  		"linux/arm64":   true,
   442  	}
   443  	if !platforms[runtime.GOOS+"/"+runtime.GOARCH] {
   444  		t.Skip("skipping test on non-race-enabled host.")
   445  	}
   446  	if runtime.Compiler == "gccgo" {
   447  		t.Skip("skipping race test when built with gccgo")
   448  	}
   449  	dir, err := ioutil.TempDir("", "testrace")
   450  	if err != nil {
   451  		t.Fatalf("failed to create temp directory: %v", err)
   452  	}
   453  	defer os.RemoveAll(dir)
   454  	goBinary := filepath.Join(runtime.GOROOT(), "bin", "go")
   455  	t.Logf("%s version", goBinary)
   456  	got, err := exec.Command(goBinary, "version").CombinedOutput()
   457  	if len(got) > 0 {
   458  		t.Logf("%s", got)
   459  	}
   460  	if err != nil {
   461  		t.Fatalf("go version failed: %v", err)
   462  	}
   463  	for i, test := range tests {
   464  		t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) {
   465  			src := filepath.Join(dir, fmt.Sprintf("test%d.go", i))
   466  			if err := ioutil.WriteFile(src, []byte(test), 0644); err != nil {
   467  				t.Fatalf("failed to write file: %v", err)
   468  			}
   469  			t.Logf("%s run -race %s", goBinary, src)
   470  			got, err := exec.Command(goBinary, "run", "-race", src).CombinedOutput()
   471  			if len(got) > 0 {
   472  				t.Logf("%s", got)
   473  			}
   474  			if strings.Contains(string(got), "-race requires cgo") {
   475  				t.Log("CGO is not enabled so can't use -race")
   476  			} else if !strings.Contains(string(got), "WARNING: DATA RACE") {
   477  				t.Errorf("race not detected for test %d: err:%v", i, err)
   478  			}
   479  		})
   480  	}
   481  }