trpc.group/trpc-go/trpc-go@v1.0.3/internal/reuseport/tcp.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  //go:build linux || darwin || dragonfly || freebsd || netbsd || openbsd
    15  // +build linux darwin dragonfly freebsd netbsd openbsd
    16  
    17  package reuseport
    18  
    19  import (
    20  	"errors"
    21  	"net"
    22  	"os"
    23  	"syscall"
    24  )
    25  
    26  var (
    27  	// ListenerBacklogMaxSize setting backlog size
    28  	ListenerBacklogMaxSize    = maxListenerBacklog()
    29  	errUnsupportedTCPProtocol = errors.New("only tcp, tcp4, tcp6 are supported")
    30  )
    31  
    32  func getTCP4Sockaddr(tcp *net.TCPAddr) (syscall.Sockaddr, int, error) {
    33  	sa := &syscall.SockaddrInet4{Port: tcp.Port}
    34  
    35  	if tcp.IP != nil {
    36  		if len(tcp.IP) == 16 {
    37  			copy(sa.Addr[:], tcp.IP[12:16]) // copy last 4 bytes of slice to array
    38  		} else {
    39  			copy(sa.Addr[:], tcp.IP) // copy all bytes of slice to array
    40  		}
    41  	}
    42  
    43  	return sa, syscall.AF_INET, nil
    44  }
    45  
    46  func getTCP6Sockaddr(tcp *net.TCPAddr) (syscall.Sockaddr, int, error) {
    47  	sa := &syscall.SockaddrInet6{Port: tcp.Port}
    48  
    49  	if tcp.IP != nil {
    50  		copy(sa.Addr[:], tcp.IP) // copy all bytes of slice to array
    51  	}
    52  
    53  	if tcp.Zone != "" {
    54  		iface, err := net.InterfaceByName(tcp.Zone)
    55  		if err != nil {
    56  			return nil, -1, err
    57  		}
    58  
    59  		sa.ZoneId = uint32(iface.Index)
    60  	}
    61  
    62  	return sa, syscall.AF_INET6, nil
    63  }
    64  
    65  func getTCPAddr(proto, addr string) (*net.TCPAddr, string, error) {
    66  	var tcp *net.TCPAddr
    67  
    68  	// fix bugs https://github.com/kavu/go_reuseport/pull/33
    69  	tcp, err := net.ResolveTCPAddr(proto, addr)
    70  	if err != nil {
    71  		return nil, "", err
    72  	}
    73  
    74  	tcpVersion, err := determineTCPProto(proto, tcp)
    75  	if err != nil {
    76  		return nil, "", err
    77  	}
    78  	return tcp, tcpVersion, nil
    79  }
    80  
    81  func getTCPSockaddr(proto, addr string) (sa syscall.Sockaddr, soType int, err error) {
    82  	tcp, tcpVersion, err := getTCPAddr(proto, addr)
    83  	if err != nil {
    84  		return nil, -1, err
    85  	}
    86  	switch tcpVersion {
    87  	case "tcp":
    88  		return &syscall.SockaddrInet4{Port: tcp.Port}, syscall.AF_INET, nil
    89  	case "tcp4":
    90  		return getTCP4Sockaddr(tcp)
    91  	default:
    92  		// must be "tcp6"
    93  		return getTCP6Sockaddr(tcp)
    94  	}
    95  }
    96  
    97  func determineTCPProto(proto string, ip *net.TCPAddr) (string, error) {
    98  	// If the protocol is set to "tcp", we try to determine the actual protocol
    99  	// version from the size of the resolved IP address. Otherwise, we simple use
   100  	// the protocol given to us by the caller.
   101  
   102  	if ip.IP.To4() != nil {
   103  		return "tcp4", nil
   104  	}
   105  
   106  	if ip.IP.To16() != nil {
   107  		return "tcp6", nil
   108  	}
   109  
   110  	switch proto {
   111  	case "tcp", "tcp4", "tcp6":
   112  		return proto, nil
   113  	default:
   114  		return "", errUnsupportedTCPProtocol
   115  	}
   116  }
   117  
   118  // NewReusablePortListener returns net.FileListener that created from
   119  // a file descriptor for a socket with SO_REUSEPORT option.
   120  func NewReusablePortListener(proto, addr string) (l net.Listener, err error) {
   121  	var (
   122  		soType, fd int
   123  		sockaddr   syscall.Sockaddr
   124  	)
   125  	if sockaddr, soType, err = getSockaddr(proto, addr); err != nil {
   126  		return nil, err
   127  	}
   128  
   129  	syscall.ForkLock.RLock()
   130  	if fd, err = syscall.Socket(soType, syscall.SOCK_STREAM, syscall.IPPROTO_TCP); err != nil {
   131  		syscall.ForkLock.RUnlock()
   132  		return nil, err
   133  	}
   134  	syscall.ForkLock.RUnlock()
   135  
   136  	if err = createReusableFd(fd, sockaddr); err != nil {
   137  		return nil, err
   138  	}
   139  	return createReusableListener(fd, proto, addr)
   140  }
   141  
   142  func createReusableListener(fd int, proto, addr string) (l net.Listener, err error) {
   143  	file := os.NewFile(uintptr(fd), getSocketFileName(proto, addr))
   144  	if l, err = net.FileListener(file); err != nil {
   145  		file.Close()
   146  		return nil, err
   147  	}
   148  
   149  	if err = file.Close(); err != nil {
   150  		return nil, err
   151  	}
   152  	return l, err
   153  }
   154  
   155  func createReusableFd(fd int, sockaddr syscall.Sockaddr) (err error) {
   156  	defer func() {
   157  		if err != nil {
   158  			syscall.Close(fd)
   159  		}
   160  	}()
   161  
   162  	if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1); err != nil {
   163  		return err
   164  	}
   165  
   166  	if err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, reusePort, 1); err != nil {
   167  		return err
   168  	}
   169  
   170  	if err = syscall.Bind(fd, sockaddr); err != nil {
   171  		return err
   172  	}
   173  
   174  	// Set backlog size to the maximum
   175  	if err = syscall.Listen(fd, ListenerBacklogMaxSize); err != nil {
   176  		return err
   177  	}
   178  
   179  	return nil
   180  }