github.com/rafaeltorres324/go/src@v0.0.0-20210519164414-9fdf653a9838/net/splice_test.go (about)

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // +build linux
     6  
     7  package net
     8  
     9  import (
    10  	"io"
    11  	"log"
    12  	"os"
    13  	"os/exec"
    14  	"strconv"
    15  	"sync"
    16  	"testing"
    17  	"time"
    18  )
    19  
    20  func TestSplice(t *testing.T) {
    21  	t.Run("tcp-to-tcp", func(t *testing.T) { testSplice(t, "tcp", "tcp") })
    22  	if !testableNetwork("unixgram") {
    23  		t.Skip("skipping unix-to-tcp tests")
    24  	}
    25  	t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") })
    26  	t.Run("no-unixpacket", testSpliceNoUnixpacket)
    27  	t.Run("no-unixgram", testSpliceNoUnixgram)
    28  }
    29  
    30  func testSplice(t *testing.T, upNet, downNet string) {
    31  	t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.test)
    32  	t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.test)
    33  	t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.test)
    34  	t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.test)
    35  	t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.test)
    36  	t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.test)
    37  	t.Run("readerAtEOF", func(t *testing.T) { testSpliceReaderAtEOF(t, upNet, downNet) })
    38  	t.Run("issue25985", func(t *testing.T) { testSpliceIssue25985(t, upNet, downNet) })
    39  }
    40  
    41  type spliceTestCase struct {
    42  	upNet, downNet string
    43  
    44  	chunkSize, totalSize int
    45  	limitReadSize        int
    46  }
    47  
    48  func (tc spliceTestCase) test(t *testing.T) {
    49  	clientUp, serverUp, err := spliceTestSocketPair(tc.upNet)
    50  	if err != nil {
    51  		t.Fatal(err)
    52  	}
    53  	defer serverUp.Close()
    54  	cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.totalSize)
    55  	if err != nil {
    56  		t.Fatal(err)
    57  	}
    58  	defer cleanup()
    59  	clientDown, serverDown, err := spliceTestSocketPair(tc.downNet)
    60  	if err != nil {
    61  		t.Fatal(err)
    62  	}
    63  	defer serverDown.Close()
    64  	cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.totalSize)
    65  	if err != nil {
    66  		t.Fatal(err)
    67  	}
    68  	defer cleanup()
    69  	var (
    70  		r    io.Reader = serverUp
    71  		size           = tc.totalSize
    72  	)
    73  	if tc.limitReadSize > 0 {
    74  		if tc.limitReadSize < size {
    75  			size = tc.limitReadSize
    76  		}
    77  
    78  		r = &io.LimitedReader{
    79  			N: int64(tc.limitReadSize),
    80  			R: serverUp,
    81  		}
    82  		defer serverUp.Close()
    83  	}
    84  	n, err := io.Copy(serverDown, r)
    85  	serverDown.Close()
    86  	if err != nil {
    87  		t.Fatal(err)
    88  	}
    89  	if want := int64(size); want != n {
    90  		t.Errorf("want %d bytes spliced, got %d", want, n)
    91  	}
    92  
    93  	if tc.limitReadSize > 0 {
    94  		wantN := 0
    95  		if tc.limitReadSize > size {
    96  			wantN = tc.limitReadSize - size
    97  		}
    98  
    99  		if n := r.(*io.LimitedReader).N; n != int64(wantN) {
   100  			t.Errorf("r.N = %d, want %d", n, wantN)
   101  		}
   102  	}
   103  }
   104  
   105  func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
   106  	clientUp, serverUp, err := spliceTestSocketPair(upNet)
   107  	if err != nil {
   108  		t.Fatal(err)
   109  	}
   110  	defer clientUp.Close()
   111  	clientDown, serverDown, err := spliceTestSocketPair(downNet)
   112  	if err != nil {
   113  		t.Fatal(err)
   114  	}
   115  	defer clientDown.Close()
   116  
   117  	serverUp.Close()
   118  
   119  	// We'd like to call net.splice here and check the handled return
   120  	// value, but we disable splice on old Linux kernels.
   121  	//
   122  	// In that case, poll.Splice and net.splice return a non-nil error
   123  	// and handled == false. We'd ideally like to see handled == true
   124  	// because the source reader is at EOF, but if we're running on an old
   125  	// kernel, and splice is disabled, we won't see EOF from net.splice,
   126  	// because we won't touch the reader at all.
   127  	//
   128  	// Trying to untangle the errors from net.splice and match them
   129  	// against the errors created by the poll package would be brittle,
   130  	// so this is a higher level test.
   131  	//
   132  	// The following ReadFrom should return immediately, regardless of
   133  	// whether splice is disabled or not. The other side should then
   134  	// get a goodbye signal. Test for the goodbye signal.
   135  	msg := "bye"
   136  	go func() {
   137  		serverDown.(io.ReaderFrom).ReadFrom(serverUp)
   138  		io.WriteString(serverDown, msg)
   139  		serverDown.Close()
   140  	}()
   141  
   142  	buf := make([]byte, 3)
   143  	_, err = io.ReadFull(clientDown, buf)
   144  	if err != nil {
   145  		t.Errorf("clientDown: %v", err)
   146  	}
   147  	if string(buf) != msg {
   148  		t.Errorf("clientDown got %q, want %q", buf, msg)
   149  	}
   150  }
   151  
   152  func testSpliceIssue25985(t *testing.T, upNet, downNet string) {
   153  	front, err := newLocalListener(upNet)
   154  	if err != nil {
   155  		t.Fatal(err)
   156  	}
   157  	defer front.Close()
   158  	back, err := newLocalListener(downNet)
   159  	if err != nil {
   160  		t.Fatal(err)
   161  	}
   162  	defer back.Close()
   163  
   164  	var wg sync.WaitGroup
   165  	wg.Add(2)
   166  
   167  	proxy := func() {
   168  		src, err := front.Accept()
   169  		if err != nil {
   170  			return
   171  		}
   172  		dst, err := Dial(downNet, back.Addr().String())
   173  		if err != nil {
   174  			return
   175  		}
   176  		defer dst.Close()
   177  		defer src.Close()
   178  		go func() {
   179  			io.Copy(src, dst)
   180  			wg.Done()
   181  		}()
   182  		go func() {
   183  			io.Copy(dst, src)
   184  			wg.Done()
   185  		}()
   186  	}
   187  
   188  	go proxy()
   189  
   190  	toFront, err := Dial(upNet, front.Addr().String())
   191  	if err != nil {
   192  		t.Fatal(err)
   193  	}
   194  
   195  	io.WriteString(toFront, "foo")
   196  	toFront.Close()
   197  
   198  	fromProxy, err := back.Accept()
   199  	if err != nil {
   200  		t.Fatal(err)
   201  	}
   202  	defer fromProxy.Close()
   203  
   204  	_, err = io.ReadAll(fromProxy)
   205  	if err != nil {
   206  		t.Fatal(err)
   207  	}
   208  
   209  	wg.Wait()
   210  }
   211  
   212  func testSpliceNoUnixpacket(t *testing.T) {
   213  	clientUp, serverUp, err := spliceTestSocketPair("unixpacket")
   214  	if err != nil {
   215  		t.Fatal(err)
   216  	}
   217  	defer clientUp.Close()
   218  	defer serverUp.Close()
   219  	clientDown, serverDown, err := spliceTestSocketPair("tcp")
   220  	if err != nil {
   221  		t.Fatal(err)
   222  	}
   223  	defer clientDown.Close()
   224  	defer serverDown.Close()
   225  	// If splice called poll.Splice here, we'd get err == syscall.EINVAL
   226  	// and handled == false.  If poll.Splice gets an EINVAL on the first
   227  	// try, it assumes the kernel it's running on doesn't support splice
   228  	// for unix sockets and returns handled == false. This works for our
   229  	// purposes by somewhat of an accident, but is not entirely correct.
   230  	//
   231  	// What we want is err == nil and handled == false, i.e. we never
   232  	// called poll.Splice, because we know the unix socket's network.
   233  	_, err, handled := splice(serverDown.(*TCPConn).fd, serverUp)
   234  	if err != nil || handled != false {
   235  		t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
   236  	}
   237  }
   238  
   239  func testSpliceNoUnixgram(t *testing.T) {
   240  	addr, err := ResolveUnixAddr("unixgram", testUnixAddr())
   241  	if err != nil {
   242  		t.Fatal(err)
   243  	}
   244  	defer os.Remove(addr.Name)
   245  	up, err := ListenUnixgram("unixgram", addr)
   246  	if err != nil {
   247  		t.Fatal(err)
   248  	}
   249  	defer up.Close()
   250  	clientDown, serverDown, err := spliceTestSocketPair("tcp")
   251  	if err != nil {
   252  		t.Fatal(err)
   253  	}
   254  	defer clientDown.Close()
   255  	defer serverDown.Close()
   256  	// Analogous to testSpliceNoUnixpacket.
   257  	_, err, handled := splice(serverDown.(*TCPConn).fd, up)
   258  	if err != nil || handled != false {
   259  		t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
   260  	}
   261  }
   262  
   263  func BenchmarkSplice(b *testing.B) {
   264  	testHookUninstaller.Do(uninstallTestHooks)
   265  
   266  	b.Run("tcp-to-tcp", func(b *testing.B) { benchSplice(b, "tcp", "tcp") })
   267  	b.Run("unix-to-tcp", func(b *testing.B) { benchSplice(b, "unix", "tcp") })
   268  }
   269  
   270  func benchSplice(b *testing.B, upNet, downNet string) {
   271  	for i := 0; i <= 10; i++ {
   272  		chunkSize := 1 << uint(i+10)
   273  		tc := spliceTestCase{
   274  			upNet:     upNet,
   275  			downNet:   downNet,
   276  			chunkSize: chunkSize,
   277  		}
   278  
   279  		b.Run(strconv.Itoa(chunkSize), tc.bench)
   280  	}
   281  }
   282  
   283  func (tc spliceTestCase) bench(b *testing.B) {
   284  	// To benchmark the genericReadFrom code path, set this to false.
   285  	useSplice := true
   286  
   287  	clientUp, serverUp, err := spliceTestSocketPair(tc.upNet)
   288  	if err != nil {
   289  		b.Fatal(err)
   290  	}
   291  	defer serverUp.Close()
   292  
   293  	cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.chunkSize*b.N)
   294  	if err != nil {
   295  		b.Fatal(err)
   296  	}
   297  	defer cleanup()
   298  
   299  	clientDown, serverDown, err := spliceTestSocketPair(tc.downNet)
   300  	if err != nil {
   301  		b.Fatal(err)
   302  	}
   303  	defer serverDown.Close()
   304  
   305  	cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.chunkSize*b.N)
   306  	if err != nil {
   307  		b.Fatal(err)
   308  	}
   309  	defer cleanup()
   310  
   311  	b.SetBytes(int64(tc.chunkSize))
   312  	b.ResetTimer()
   313  
   314  	if useSplice {
   315  		_, err := io.Copy(serverDown, serverUp)
   316  		if err != nil {
   317  			b.Fatal(err)
   318  		}
   319  	} else {
   320  		type onlyReader struct {
   321  			io.Reader
   322  		}
   323  		_, err := io.Copy(serverDown, onlyReader{serverUp})
   324  		if err != nil {
   325  			b.Fatal(err)
   326  		}
   327  	}
   328  }
   329  
   330  func spliceTestSocketPair(net string) (client, server Conn, err error) {
   331  	ln, err := newLocalListener(net)
   332  	if err != nil {
   333  		return nil, nil, err
   334  	}
   335  	defer ln.Close()
   336  	var cerr, serr error
   337  	acceptDone := make(chan struct{})
   338  	go func() {
   339  		server, serr = ln.Accept()
   340  		acceptDone <- struct{}{}
   341  	}()
   342  	client, cerr = Dial(ln.Addr().Network(), ln.Addr().String())
   343  	<-acceptDone
   344  	if cerr != nil {
   345  		if server != nil {
   346  			server.Close()
   347  		}
   348  		return nil, nil, cerr
   349  	}
   350  	if serr != nil {
   351  		if client != nil {
   352  			client.Close()
   353  		}
   354  		return nil, nil, serr
   355  	}
   356  	return client, server, nil
   357  }
   358  
   359  func startSpliceClient(conn Conn, op string, chunkSize, totalSize int) (func(), error) {
   360  	f, err := conn.(interface{ File() (*os.File, error) }).File()
   361  	if err != nil {
   362  		return nil, err
   363  	}
   364  
   365  	cmd := exec.Command(os.Args[0], os.Args[1:]...)
   366  	cmd.Env = []string{
   367  		"GO_NET_TEST_SPLICE=1",
   368  		"GO_NET_TEST_SPLICE_OP=" + op,
   369  		"GO_NET_TEST_SPLICE_CHUNK_SIZE=" + strconv.Itoa(chunkSize),
   370  		"GO_NET_TEST_SPLICE_TOTAL_SIZE=" + strconv.Itoa(totalSize),
   371  		"TMPDIR=" + os.Getenv("TMPDIR"),
   372  	}
   373  	cmd.ExtraFiles = append(cmd.ExtraFiles, f)
   374  	cmd.Stdout = os.Stdout
   375  	cmd.Stderr = os.Stderr
   376  
   377  	if err := cmd.Start(); err != nil {
   378  		return nil, err
   379  	}
   380  
   381  	donec := make(chan struct{})
   382  	go func() {
   383  		cmd.Wait()
   384  		conn.Close()
   385  		f.Close()
   386  		close(donec)
   387  	}()
   388  
   389  	return func() {
   390  		select {
   391  		case <-donec:
   392  		case <-time.After(5 * time.Second):
   393  			log.Printf("killing splice client after 5 second shutdown timeout")
   394  			cmd.Process.Kill()
   395  			select {
   396  			case <-donec:
   397  			case <-time.After(5 * time.Second):
   398  				log.Printf("splice client didn't die after 10 seconds")
   399  			}
   400  		}
   401  	}, nil
   402  }
   403  
   404  func init() {
   405  	if os.Getenv("GO_NET_TEST_SPLICE") == "" {
   406  		return
   407  	}
   408  	defer os.Exit(0)
   409  
   410  	f := os.NewFile(uintptr(3), "splice-test-conn")
   411  	defer f.Close()
   412  
   413  	conn, err := FileConn(f)
   414  	if err != nil {
   415  		log.Fatal(err)
   416  	}
   417  
   418  	var chunkSize int
   419  	if chunkSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_CHUNK_SIZE")); err != nil {
   420  		log.Fatal(err)
   421  	}
   422  	buf := make([]byte, chunkSize)
   423  
   424  	var totalSize int
   425  	if totalSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_TOTAL_SIZE")); err != nil {
   426  		log.Fatal(err)
   427  	}
   428  
   429  	var fn func([]byte) (int, error)
   430  	switch op := os.Getenv("GO_NET_TEST_SPLICE_OP"); op {
   431  	case "r":
   432  		fn = conn.Read
   433  	case "w":
   434  		defer conn.Close()
   435  
   436  		fn = conn.Write
   437  	default:
   438  		log.Fatalf("unknown op %q", op)
   439  	}
   440  
   441  	var n int
   442  	for count := 0; count < totalSize; count += n {
   443  		if count+chunkSize > totalSize {
   444  			buf = buf[:totalSize-count]
   445  		}
   446  
   447  		var err error
   448  		if n, err = fn(buf); err != nil {
   449  			return
   450  		}
   451  	}
   452  }