trpc.group/trpc-go/trpc-go@v1.0.3/internal/reuseport/tcp_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  //go:build linux || darwin || dragonfly || freebsd || netbsd || openbsd
    15  // +build linux darwin dragonfly freebsd netbsd openbsd
    16  
    17  package reuseport
    18  
    19  import (
    20  	"fmt"
    21  	"html"
    22  	"io/ioutil"
    23  	"net"
    24  	"net/http"
    25  	"net/http/httptest"
    26  	"os"
    27  	"syscall"
    28  	"testing"
    29  
    30  	"github.com/stretchr/testify/assert"
    31  )
    32  
    33  const (
    34  	httpServerOneResponse = "1"
    35  	httpServerTwoResponse = "2"
    36  )
    37  
    38  var (
    39  	httpServerOne = NewHTTPServer(httpServerOneResponse)
    40  	httpServerTwo = NewHTTPServer(httpServerTwoResponse)
    41  )
    42  
    43  func NewHTTPServer(resp string) *httptest.Server {
    44  	return httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    45  		fmt.Fprint(w, resp)
    46  	}))
    47  }
    48  
    49  func TestNewReusablePortListener(t *testing.T) {
    50  	listenerOne, err := NewReusablePortListener("tcp4", "localhost:10081")
    51  	assert.Nil(t, err)
    52  	defer listenerOne.Close()
    53  
    54  	listenerTwo, err := NewReusablePortListener("tcp", "127.0.0.1:10081")
    55  	assert.Nil(t, err)
    56  	defer listenerTwo.Close()
    57  
    58  	// devcloud ipv6地址无效
    59  	_, err = NewReusablePortListener("tcp6", "[::x]:10081")
    60  	if err == nil {
    61  		t.Errorf("expect err, err[%v]", err)
    62  	}
    63  
    64  	listenerFour, err := NewReusablePortListener("tcp6", ":10081")
    65  	assert.Nil(t, err)
    66  	defer listenerFour.Close()
    67  
    68  	listenerFive, err := NewReusablePortListener("tcp4", ":10081")
    69  	assert.Nil(t, err)
    70  	defer listenerFive.Close()
    71  
    72  	listenerSix, err := NewReusablePortListener("tcp", ":10081")
    73  	assert.Nil(t, err)
    74  	defer listenerSix.Close()
    75  
    76  	// proto invalid 非法协议
    77  	_, err = NewReusablePortListener("xxx", "")
    78  	if err == nil {
    79  		t.Errorf("expect err")
    80  	}
    81  }
    82  
    83  func TestListen(t *testing.T) {
    84  	listenerOne, err := Listen("tcp4", "localhost:10081")
    85  	assert.Nil(t, err)
    86  	defer listenerOne.Close()
    87  
    88  	listenerTwo, err := Listen("tcp", "127.0.0.1:10081")
    89  	assert.Nil(t, err)
    90  	defer listenerTwo.Close()
    91  
    92  	listenerThree, err := Listen("tcp6", ":10081")
    93  	assert.Nil(t, err)
    94  	defer listenerThree.Close()
    95  
    96  	listenerFour, err := Listen("tcp6", ":10081")
    97  	assert.Nil(t, err)
    98  	defer listenerFour.Close()
    99  
   100  	listenerFive, err := Listen("tcp4", ":10081")
   101  	assert.Nil(t, err)
   102  	defer listenerFive.Close()
   103  
   104  	listenerSix, err := Listen("tcp", ":10081")
   105  	assert.Nil(t, err)
   106  	defer listenerSix.Close()
   107  }
   108  
   109  func TestNewReusablePortServers(t *testing.T) {
   110  	listenerOne, err := NewReusablePortListener("tcp4", "localhost:10081")
   111  	assert.Nil(t, err)
   112  	defer listenerOne.Close()
   113  
   114  	// listenerTwo, err := NewReusablePortListener("tcp6", ":10081")
   115  	listenerTwo, err := NewReusablePortListener("tcp", "localhost:10081")
   116  	assert.Nil(t, err)
   117  	defer listenerTwo.Close()
   118  
   119  	httpServerOne.Listener = listenerOne
   120  	httpServerTwo.Listener = listenerTwo
   121  
   122  	httpServerOne.Start()
   123  	httpServerTwo.Start()
   124  
   125  	// Server One — First Response
   126  	httpGet(httpServerOne.URL, httpServerOneResponse, httpServerTwoResponse, t)
   127  
   128  	// Server Two — First Response
   129  	httpGet(httpServerTwo.URL, httpServerOneResponse, httpServerTwoResponse, t)
   130  	httpServerTwo.Close()
   131  
   132  	// Server One — Second Response
   133  	httpGet(httpServerOne.URL, httpServerOneResponse, "", t)
   134  
   135  	// Server One — Third Response
   136  	httpGet(httpServerOne.URL, httpServerOneResponse, "", t)
   137  	httpServerOne.Close()
   138  }
   139  
   140  func httpGet(url string, expected1 string, expected2 string, t *testing.T) {
   141  	resp, err := http.Get(url)
   142  	assert.Nil(t, err)
   143  	body, err := ioutil.ReadAll(resp.Body)
   144  	resp.Body.Close()
   145  	assert.Nil(t, err)
   146  	if string(body) != expected1 && string(body) != expected2 {
   147  		t.Errorf("Expected %#v or %#v, got %#v.", expected1, expected2, string(body))
   148  	}
   149  }
   150  
   151  func BenchmarkNewReusablePortListener(b *testing.B) {
   152  	for i := 0; i < b.N; i++ {
   153  		listener, err := NewReusablePortListener("tcp", ":10081")
   154  
   155  		if err != nil {
   156  			b.Error(err)
   157  		} else {
   158  			listener.Close()
   159  		}
   160  	}
   161  }
   162  
   163  func ExampleNewReusablePortListener() {
   164  	listener, err := NewReusablePortListener("tcp", ":8881")
   165  	if err != nil {
   166  		panic(err)
   167  	}
   168  	defer listener.Close()
   169  
   170  	server := &http.Server{}
   171  	http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
   172  		fmt.Println(os.Getgid())
   173  		fmt.Fprintf(w, "Hello, %q\n", html.EscapeString(r.URL.Path))
   174  	})
   175  
   176  	panic(server.Serve(listener))
   177  }
   178  
   179  // TestBoundaryCase 一些边界条件覆盖
   180  func TestBoundaryCase(t *testing.T) {
   181  	proto, err := determineTCPProto("tcp", &net.TCPAddr{})
   182  	if proto != "tcp" {
   183  		t.Errorf("proto not tcp")
   184  	}
   185  	assert.Nil(t, err)
   186  	_, err = determineTCPProto("udp", &net.TCPAddr{})
   187  	if err == nil {
   188  		t.Errorf("expect error")
   189  	}
   190  
   191  	// getTCPAddr 边界
   192  	if _, _, err := getTCPAddr("udp", "localhost:8001"); err == nil {
   193  		t.Error("expect error")
   194  	}
   195  
   196  	// ipv6 zone id,不存在的网卡
   197  	addr := &net.TCPAddr{
   198  		IP:   net.IPv4(127, 0, 0, 1),
   199  		Zone: "ethx",
   200  	}
   201  	_, _, err = getTCP6Sockaddr(addr)
   202  	assert.NotNil(t, err)
   203  
   204  	// udp ipv6
   205  	udpAddr := &net.UDPAddr{
   206  		IP:   net.IPv4(127, 0, 0, 1),
   207  		Zone: "ethx",
   208  	}
   209  	_, _, err = getUDP6Sockaddr(udpAddr)
   210  	assert.NotNil(t, err)
   211  
   212  	// ResolveUDPAddr failed
   213  	_, _, err = getUDPSockaddr("xxx", ":10086")
   214  	assert.NotNil(t, err)
   215  }
   216  
   217  func TestCreateReusableFd(t *testing.T) {
   218  	fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, syscall.IPPROTO_TCP)
   219  	assert.Nil(t, err)
   220  	assert.NotZero(t, fd)
   221  
   222  	// set opt failed, bad fd: -1
   223  	sa := &syscall.SockaddrInet4{}
   224  	err = createReusableFd(-1, sa)
   225  	assert.NotNil(t, err)
   226  
   227  	// set opt failed
   228  	oldReusePort := reusePort
   229  	defer func() {
   230  		reusePort = oldReusePort
   231  	}()
   232  	reusePort = 0
   233  	err = createReusableFd(fd, sa)
   234  	assert.NotNil(t, err)
   235  
   236  	// file descriptor invalid
   237  	_, err = createReusableListener(10081, "tcp", "localhost:8001")
   238  	assert.NotNil(t, err)
   239  }