trpc.group/trpc-go/trpc-go@v1.0.3/internal/reuseport/udp.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  //go:build linux || darwin || dragonfly || freebsd || netbsd || openbsd
    15  // +build linux darwin dragonfly freebsd netbsd openbsd
    16  
    17  package reuseport
    18  
    19  import (
    20  	"errors"
    21  	"net"
    22  	"os"
    23  	"syscall"
    24  )
    25  
    26  var errUnsupportedUDPProtocol = errors.New("only udp, udp4, udp6 are supported")
    27  
    28  func getUDP4Sockaddr(udp *net.UDPAddr) (syscall.Sockaddr, int, error) {
    29  	sa := &syscall.SockaddrInet4{Port: udp.Port}
    30  
    31  	if udp.IP != nil {
    32  		if len(udp.IP) == 16 {
    33  			copy(sa.Addr[:], udp.IP[12:16]) // copy last 4 bytes of slice to array
    34  		} else {
    35  			copy(sa.Addr[:], udp.IP) // copy all bytes of slice to array
    36  		}
    37  	}
    38  
    39  	return sa, syscall.AF_INET, nil
    40  }
    41  
    42  func getUDP6Sockaddr(udp *net.UDPAddr) (syscall.Sockaddr, int, error) {
    43  	sa := &syscall.SockaddrInet6{Port: udp.Port}
    44  
    45  	if udp.IP != nil {
    46  		copy(sa.Addr[:], udp.IP) // copy all bytes of slice to array
    47  	}
    48  
    49  	if udp.Zone != "" {
    50  		iface, err := net.InterfaceByName(udp.Zone)
    51  		if err != nil {
    52  			return nil, -1, err
    53  		}
    54  
    55  		sa.ZoneId = uint32(iface.Index)
    56  	}
    57  
    58  	return sa, syscall.AF_INET6, nil
    59  }
    60  
    61  func getUDPAddr(proto, addr string) (*net.UDPAddr, string, error) {
    62  
    63  	var udp *net.UDPAddr
    64  
    65  	udp, err := net.ResolveUDPAddr(proto, addr)
    66  	if err != nil {
    67  		return nil, "", err
    68  	}
    69  
    70  	udpVersion, err := determineUDPProto(proto, udp)
    71  	if err != nil {
    72  		return nil, "", err
    73  	}
    74  
    75  	return udp, udpVersion, nil
    76  }
    77  
    78  func getUDPSockaddr(proto, addr string) (sa syscall.Sockaddr, soType int, err error) {
    79  	udp, udpVersion, err := getUDPAddr(proto, addr)
    80  	if err != nil {
    81  		return nil, -1, err
    82  	}
    83  
    84  	switch udpVersion {
    85  	case "udp":
    86  		return &syscall.SockaddrInet4{Port: udp.Port}, syscall.AF_INET, nil
    87  	case "udp4":
    88  		return getUDP4Sockaddr(udp)
    89  	default:
    90  		// must be "udp6"
    91  		return getUDP6Sockaddr(udp)
    92  	}
    93  }
    94  
    95  func determineUDPProto(proto string, ip *net.UDPAddr) (string, error) {
    96  	// If the protocol is set to "udp", we try to determine the actual protocol
    97  	// version from the size of the resolved IP address. Otherwise, we simple use
    98  	// the protocol given to us by the caller.
    99  
   100  	if ip.IP.To4() != nil {
   101  		return "udp4", nil
   102  	}
   103  
   104  	if ip.IP.To16() != nil {
   105  		return "udp6", nil
   106  	}
   107  
   108  	switch proto {
   109  	case "udp", "udp4", "udp6":
   110  		return proto, nil
   111  	default:
   112  		return "", errUnsupportedUDPProtocol
   113  	}
   114  }
   115  
   116  // NewReusablePortPacketConn returns net.FilePacketConn that created from
   117  // a file descriptor for a socket with SO_REUSEPORT option.
   118  func NewReusablePortPacketConn(proto, addr string) (net.PacketConn, error) {
   119  	sockaddr, soType, err := getSockaddr(proto, addr)
   120  	if err != nil {
   121  		return nil, err
   122  	}
   123  
   124  	syscall.ForkLock.RLock()
   125  	fd, err := syscall.Socket(soType, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP)
   126  	if err == nil {
   127  		syscall.CloseOnExec(fd)
   128  	}
   129  	syscall.ForkLock.RUnlock()
   130  	if err != nil {
   131  		syscall.Close(fd)
   132  		return nil, err
   133  	}
   134  	return createPacketConn(fd, sockaddr, getSocketFileName(proto, addr))
   135  }
   136  
   137  func createPacketConn(fd int, sockaddr syscall.Sockaddr, fdName string) (net.PacketConn, error) {
   138  	if err := setPacketConnSockOpt(fd, sockaddr); err != nil {
   139  		syscall.Close(fd)
   140  		return nil, err
   141  	}
   142  
   143  	file := os.NewFile(uintptr(fd), fdName)
   144  	l, err := net.FilePacketConn(file)
   145  	if err != nil {
   146  		syscall.Close(fd)
   147  		return nil, err
   148  	}
   149  
   150  	if err = file.Close(); err != nil {
   151  		syscall.Close(fd)
   152  		return nil, err
   153  	}
   154  	return l, err
   155  }
   156  
   157  func setPacketConnSockOpt(fd int, sockaddr syscall.Sockaddr) error {
   158  	if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1); err != nil {
   159  		return err
   160  	}
   161  
   162  	if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, reusePort, 1); err != nil {
   163  		return err
   164  	}
   165  
   166  	if err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_BROADCAST, 1); err != nil {
   167  		return err
   168  	}
   169  
   170  	return syscall.Bind(fd, sockaddr)
   171  }