gitee.com/zhaochuninhefei/gmgo@v0.0.31-0.20240209061119-069254a02979/net/internal/socket/socket_test.go (about)

     1  // Copyright 2017 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || windows || zos
     6  // +build aix darwin dragonfly freebsd linux netbsd openbsd solaris windows zos
     7  
     8  package socket_test
     9  
    10  import (
    11  	"bytes"
    12  	"fmt"
    13  	"io/ioutil"
    14  	"net"
    15  	"os"
    16  	"os/exec"
    17  	"path/filepath"
    18  	"runtime"
    19  	"strings"
    20  	"syscall"
    21  	"testing"
    22  
    23  	"gitee.com/zhaochuninhefei/gmgo/net/internal/socket"
    24  	"gitee.com/zhaochuninhefei/gmgo/net/nettest"
    25  )
    26  
    27  func TestSocket(t *testing.T) {
    28  	t.Run("Option", func(t *testing.T) {
    29  		testSocketOption(t, &socket.Option{Level: syscall.SOL_SOCKET, Name: syscall.SO_RCVBUF, Len: 4})
    30  	})
    31  }
    32  
    33  func testSocketOption(t *testing.T, so *socket.Option) {
    34  	c, err := nettest.NewLocalPacketListener("udp")
    35  	if err != nil {
    36  		t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
    37  	}
    38  	defer c.Close()
    39  	cc, err := socket.NewConn(c.(net.Conn))
    40  	if err != nil {
    41  		t.Fatal(err)
    42  	}
    43  	const N = 2048
    44  	if err := so.SetInt(cc, N); err != nil {
    45  		t.Fatal(err)
    46  	}
    47  	n, err := so.GetInt(cc)
    48  	if err != nil {
    49  		t.Fatal(err)
    50  	}
    51  	if n < N {
    52  		t.Fatalf("got %d; want greater than or equal to %d", n, N)
    53  	}
    54  }
    55  
    56  type mockControl struct {
    57  	Level int
    58  	Type  int
    59  	Data  []byte
    60  }
    61  
    62  func TestControlMessage(t *testing.T) {
    63  	switch runtime.GOOS {
    64  	case "windows":
    65  		t.Skipf("not supported on %s", runtime.GOOS)
    66  	}
    67  
    68  	for _, tt := range []struct {
    69  		cs []mockControl
    70  	}{
    71  		{
    72  			[]mockControl{
    73  				{Level: 1, Type: 1},
    74  			},
    75  		},
    76  		{
    77  			[]mockControl{
    78  				{Level: 2, Type: 2, Data: []byte{0xfe}},
    79  			},
    80  		},
    81  		{
    82  			[]mockControl{
    83  				{Level: 3, Type: 3, Data: []byte{0xfe, 0xff, 0xff, 0xfe}},
    84  			},
    85  		},
    86  		{
    87  			[]mockControl{
    88  				{Level: 4, Type: 4, Data: []byte{0xfe, 0xff, 0xff, 0xfe, 0xfe, 0xff, 0xff, 0xfe}},
    89  			},
    90  		},
    91  		{
    92  			[]mockControl{
    93  				{Level: 4, Type: 4, Data: []byte{0xfe, 0xff, 0xff, 0xfe, 0xfe, 0xff, 0xff, 0xfe}},
    94  				{Level: 2, Type: 2, Data: []byte{0xfe}},
    95  			},
    96  		},
    97  	} {
    98  		var w []byte
    99  		var tailPadLen int
   100  		mm := socket.NewControlMessage([]int{0})
   101  		for i, c := range tt.cs {
   102  			m := socket.NewControlMessage([]int{len(c.Data)})
   103  			l := len(m) - len(mm)
   104  			if i == len(tt.cs)-1 && l > len(c.Data) {
   105  				tailPadLen = l - len(c.Data)
   106  			}
   107  			w = append(w, m...)
   108  		}
   109  
   110  		var err error
   111  		ww := make([]byte, len(w))
   112  		copy(ww, w)
   113  		m := socket.ControlMessage(ww)
   114  		for _, c := range tt.cs {
   115  			if err = m.MarshalHeader(c.Level, c.Type, len(c.Data)); err != nil {
   116  				t.Fatalf("(%v).MarshalHeader() = %v", tt.cs, err)
   117  			}
   118  			copy(m.Data(len(c.Data)), c.Data)
   119  			m = m.Next(len(c.Data))
   120  		}
   121  		m = socket.ControlMessage(w)
   122  		for _, c := range tt.cs {
   123  			m, err = m.Marshal(c.Level, c.Type, c.Data)
   124  			if err != nil {
   125  				t.Fatalf("(%v).Marshal() = %v", tt.cs, err)
   126  			}
   127  		}
   128  		if !bytes.Equal(ww, w) {
   129  			t.Fatalf("got %#v; want %#v", ww, w)
   130  		}
   131  
   132  		ws := [][]byte{w}
   133  		if tailPadLen > 0 {
   134  			// Test a message with no tail padding.
   135  			nopad := w[:len(w)-tailPadLen]
   136  			ws = append(ws, [][]byte{nopad}...)
   137  		}
   138  		for _, w := range ws {
   139  			ms, err := socket.ControlMessage(w).Parse()
   140  			if err != nil {
   141  				t.Fatalf("(%v).Parse() = %v", tt.cs, err)
   142  			}
   143  			for i, m := range ms {
   144  				lvl, typ, dataLen, err := m.ParseHeader()
   145  				if err != nil {
   146  					t.Fatalf("(%v).ParseHeader() = %v", tt.cs, err)
   147  				}
   148  				if lvl != tt.cs[i].Level || typ != tt.cs[i].Type || dataLen != len(tt.cs[i].Data) {
   149  					t.Fatalf("%v: got %d, %d, %d; want %d, %d, %d", tt.cs[i], lvl, typ, dataLen, tt.cs[i].Level, tt.cs[i].Type, len(tt.cs[i].Data))
   150  				}
   151  			}
   152  		}
   153  	}
   154  }
   155  
   156  func TestUDP(t *testing.T) {
   157  	switch runtime.GOOS {
   158  	case "windows":
   159  		t.Skipf("not supported on %s", runtime.GOOS)
   160  	}
   161  
   162  	c, err := nettest.NewLocalPacketListener("udp")
   163  	if err != nil {
   164  		t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
   165  	}
   166  	defer c.Close()
   167  	// test that wrapped connections work with NewConn too
   168  	type wrappedConn struct{ *net.UDPConn }
   169  	cc, err := socket.NewConn(&wrappedConn{c.(*net.UDPConn)})
   170  	if err != nil {
   171  		t.Fatal(err)
   172  	}
   173  
   174  	// create a dialed connection talking (only) to c/cc
   175  	cDialed, err := net.Dial("udp", c.LocalAddr().String())
   176  	if err != nil {
   177  		t.Fatal(err)
   178  	}
   179  	ccDialed, err := socket.NewConn(cDialed)
   180  	if err != nil {
   181  		t.Fatal(err)
   182  	}
   183  
   184  	const data = "HELLO-R-U-THERE"
   185  	messageTests := []struct {
   186  		name string
   187  		conn *socket.Conn
   188  		dest net.Addr
   189  	}{
   190  		{
   191  			name: "Message",
   192  			conn: cc,
   193  			dest: c.LocalAddr(),
   194  		},
   195  		{
   196  			name: "Message-dialed",
   197  			conn: ccDialed,
   198  			dest: nil,
   199  		},
   200  	}
   201  	for _, tt := range messageTests {
   202  		t.Run(tt.name, func(t *testing.T) {
   203  			wm := socket.Message{
   204  				Buffers: bytes.SplitAfter([]byte(data), []byte("-")),
   205  				Addr:    tt.dest,
   206  			}
   207  			if err := tt.conn.SendMsg(&wm, 0); err != nil {
   208  				t.Fatal(err)
   209  			}
   210  			b := make([]byte, 32)
   211  			rm := socket.Message{
   212  				Buffers: [][]byte{b[:1], b[1:3], b[3:7], b[7:11], b[11:]},
   213  			}
   214  			if err := cc.RecvMsg(&rm, 0); err != nil {
   215  				t.Fatal(err)
   216  			}
   217  			received := string(b[:rm.N])
   218  			if received != data {
   219  				t.Fatalf("Roundtrip SendMsg/RecvMsg got %q; want %q", received, data)
   220  			}
   221  		})
   222  	}
   223  
   224  	switch runtime.GOOS {
   225  	case "android", "linux":
   226  		messagesTests := []struct {
   227  			name string
   228  			conn *socket.Conn
   229  			dest net.Addr
   230  		}{
   231  			{
   232  				name: "Messages",
   233  				conn: cc,
   234  				dest: c.LocalAddr(),
   235  			},
   236  			{
   237  				name: "Messages-dialed",
   238  				conn: ccDialed,
   239  				dest: nil,
   240  			},
   241  		}
   242  		for _, tt := range messagesTests {
   243  			t.Run(tt.name, func(t *testing.T) {
   244  				wmbs := bytes.SplitAfter([]byte(data), []byte("-"))
   245  				wms := []socket.Message{
   246  					{Buffers: wmbs[:1], Addr: tt.dest},
   247  					{Buffers: wmbs[1:], Addr: tt.dest},
   248  				}
   249  				n, err := tt.conn.SendMsgs(wms, 0)
   250  				if err != nil {
   251  					t.Fatal(err)
   252  				}
   253  				if n != len(wms) {
   254  					t.Fatalf("SendMsgs(%#v) != %d; want %d", wms, n, len(wms))
   255  				}
   256  				rmbs := [][]byte{make([]byte, 32), make([]byte, 32)}
   257  				rms := []socket.Message{
   258  					{Buffers: [][]byte{rmbs[0]}},
   259  					{Buffers: [][]byte{rmbs[1][:1], rmbs[1][1:3], rmbs[1][3:7], rmbs[1][7:11], rmbs[1][11:]}},
   260  				}
   261  				nrecv := 0
   262  				for nrecv < len(rms) {
   263  					n, err := cc.RecvMsgs(rms[nrecv:], 0)
   264  					if err != nil {
   265  						t.Fatal(err)
   266  					}
   267  					nrecv += n
   268  				}
   269  				received0, received1 := string(rmbs[0][:rms[0].N]), string(rmbs[1][:rms[1].N])
   270  				assembled := received0 + received1
   271  				assembledReordered := received1 + received0
   272  				if assembled != data && assembledReordered != data {
   273  					t.Fatalf("Roundtrip SendMsgs/RecvMsgs got %q / %q; want %q", assembled, assembledReordered, data)
   274  				}
   275  			})
   276  		}
   277  		t.Run("Messages-undialed-no-dst", func(t *testing.T) {
   278  			// sending without destination address should fail.
   279  			// This checks that the internally recycled buffers are reset correctly.
   280  			data := []byte("HELLO-R-U-THERE")
   281  			wmbs := bytes.SplitAfter(data, []byte("-"))
   282  			wms := []socket.Message{
   283  				{Buffers: wmbs[:1], Addr: nil},
   284  				{Buffers: wmbs[1:], Addr: nil},
   285  			}
   286  			n, err := cc.SendMsgs(wms, 0)
   287  			if n != 0 && err == nil {
   288  				t.Fatal("expected error, destination address required")
   289  			}
   290  		})
   291  	}
   292  
   293  	// The behavior of transmission for zero byte paylaod depends
   294  	// on each platform implementation. Some may transmit only
   295  	// protocol header and options, other may transmit nothing.
   296  	// We test only that SendMsg and SendMsgs will not crash with
   297  	// empty buffers.
   298  	wm := socket.Message{
   299  		Buffers: [][]byte{{}},
   300  		Addr:    c.LocalAddr(),
   301  	}
   302  	cc.SendMsg(&wm, 0)
   303  	wms := []socket.Message{
   304  		{Buffers: [][]byte{{}}, Addr: c.LocalAddr()},
   305  	}
   306  	cc.SendMsgs(wms, 0)
   307  }
   308  
   309  func BenchmarkUDP(b *testing.B) {
   310  	c, err := nettest.NewLocalPacketListener("udp")
   311  	if err != nil {
   312  		b.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
   313  	}
   314  	defer c.Close()
   315  	cc, err := socket.NewConn(c.(net.Conn))
   316  	if err != nil {
   317  		b.Fatal(err)
   318  	}
   319  	data := []byte("HELLO-R-U-THERE")
   320  	wm := socket.Message{
   321  		Buffers: [][]byte{data},
   322  		Addr:    c.LocalAddr(),
   323  	}
   324  	rm := socket.Message{
   325  		Buffers: [][]byte{make([]byte, 128)},
   326  		OOB:     make([]byte, 128),
   327  	}
   328  
   329  	for M := 1; M <= 1<<9; M = M << 1 {
   330  		b.Run(fmt.Sprintf("Iter-%d", M), func(b *testing.B) {
   331  			for i := 0; i < b.N; i++ {
   332  				for j := 0; j < M; j++ {
   333  					if err := cc.SendMsg(&wm, 0); err != nil {
   334  						b.Fatal(err)
   335  					}
   336  					if err := cc.RecvMsg(&rm, 0); err != nil {
   337  						b.Fatal(err)
   338  					}
   339  				}
   340  			}
   341  		})
   342  		switch runtime.GOOS {
   343  		case "android", "linux":
   344  			wms := make([]socket.Message, M)
   345  			for i := range wms {
   346  				wms[i].Buffers = [][]byte{data}
   347  				wms[i].Addr = c.LocalAddr()
   348  			}
   349  			rms := make([]socket.Message, M)
   350  			for i := range rms {
   351  				rms[i].Buffers = [][]byte{make([]byte, 128)}
   352  				rms[i].OOB = make([]byte, 128)
   353  			}
   354  			b.Run(fmt.Sprintf("Batch-%d", M), func(b *testing.B) {
   355  				for i := 0; i < b.N; i++ {
   356  					if _, err := cc.SendMsgs(wms, 0); err != nil {
   357  						b.Fatal(err)
   358  					}
   359  					if _, err := cc.RecvMsgs(rms, 0); err != nil {
   360  						b.Fatal(err)
   361  					}
   362  				}
   363  			})
   364  		}
   365  	}
   366  }
   367  
   368  func TestRace(t *testing.T) {
   369  	tests := []string{
   370  		`
   371  package main
   372  import (
   373  	"log"
   374  	"net"
   375  
   376  	"gitee.com/zhaochuninhefei/gmgo/net/ipv4"
   377  )
   378  
   379  var g byte
   380  
   381  func main() {
   382  	c, err := net.ListenPacket("udp", "127.0.0.1:0")
   383  	if err != nil {
   384  		log.Fatalf("ListenPacket: %v", err)
   385  	}
   386  	cc := ipv4.NewPacketConn(c)
   387  	sync := make(chan bool)
   388  	src := make([]byte, 100)
   389  	dst := make([]byte, 100)
   390  	go func() {
   391  		if _, err := cc.WriteTo(src, nil, c.LocalAddr()); err != nil {
   392  			log.Fatalf("WriteTo: %v", err)
   393  		}
   394  	}()
   395  	go func() {
   396  		if _, _, _, err := cc.ReadFrom(dst); err != nil {
   397  			log.Fatalf("ReadFrom: %v", err)
   398  		}
   399  		sync <- true
   400  	}()
   401  	g = dst[0]
   402  	<-sync
   403  }
   404  `,
   405  		`
   406  package main
   407  import (
   408  	"log"
   409  	"net"
   410  
   411  	"gitee.com/zhaochuninhefei/gmgo/net/ipv4"
   412  )
   413  
   414  func main() {
   415  	c, err := net.ListenPacket("udp", "127.0.0.1:0")
   416  	if err != nil {
   417  		log.Fatalf("ListenPacket: %v", err)
   418  	}
   419  	cc := ipv4.NewPacketConn(c)
   420  	sync := make(chan bool)
   421  	src := make([]byte, 100)
   422  	dst := make([]byte, 100)
   423  	go func() {
   424  		if _, err := cc.WriteTo(src, nil, c.LocalAddr()); err != nil {
   425  			log.Fatalf("WriteTo: %v", err)
   426  		}
   427  		sync <- true
   428  	}()
   429  	src[0] = 0
   430  	go func() {
   431  		if _, _, _, err := cc.ReadFrom(dst); err != nil {
   432  			log.Fatalf("ReadFrom: %v", err)
   433  		}
   434  	}()
   435  	<-sync
   436  }
   437  `,
   438  	}
   439  	platforms := map[string]bool{
   440  		"linux/amd64":   true,
   441  		"linux/ppc64le": true,
   442  		"linux/arm64":   true,
   443  	}
   444  	if !platforms[runtime.GOOS+"/"+runtime.GOARCH] {
   445  		t.Skip("skipping test on non-race-enabled host.")
   446  	}
   447  	dir, err := ioutil.TempDir("", "testrace")
   448  	if err != nil {
   449  		t.Fatalf("failed to create temp directory: %v", err)
   450  	}
   451  	defer os.RemoveAll(dir)
   452  	goBinary := filepath.Join(runtime.GOROOT(), "bin", "go")
   453  	t.Logf("%s version", goBinary)
   454  	got, err := exec.Command(goBinary, "version").CombinedOutput()
   455  	if len(got) > 0 {
   456  		t.Logf("%s", got)
   457  	}
   458  	if err != nil {
   459  		t.Fatalf("go version failed: %v", err)
   460  	}
   461  	for i, test := range tests {
   462  		t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) {
   463  			src := filepath.Join(dir, fmt.Sprintf("test%d.go", i))
   464  			if err := ioutil.WriteFile(src, []byte(test), 0644); err != nil {
   465  				t.Fatalf("failed to write file: %v", err)
   466  			}
   467  			t.Logf("%s run -race %s", goBinary, src)
   468  			got, err := exec.Command(goBinary, "run", "-race", src).CombinedOutput()
   469  			if len(got) > 0 {
   470  				t.Logf("%s", got)
   471  			}
   472  			if strings.Contains(string(got), "-race requires cgo") {
   473  				t.Log("CGO is not enabled so can't use -race")
   474  			} else if !strings.Contains(string(got), "WARNING: DATA RACE") {
   475  				t.Errorf("race not detected for test %d: err:%v", i, err)
   476  			}
   477  		})
   478  	}
   479  }