github.com/pawelgaczynski/gain@v0.4.0-alpha.0.20230821120126-41f1e60a18da/pkg/socket/udp_socket.go (about)

     1  // Copyright (c) 2023 Paweł Gaczyński
     2  // Copyright (c) 2020 Andy Pan
     3  // Copyright (c) 2017 Max Riveiro
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //     http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  
    17  package socket
    18  
    19  import (
    20  	"errors"
    21  	"fmt"
    22  	"net"
    23  	"os"
    24  
    25  	gainErrors "github.com/pawelgaczynski/gain/pkg/errors"
    26  	gainNet "github.com/pawelgaczynski/gain/pkg/net"
    27  	"golang.org/x/sys/unix"
    28  )
    29  
    30  // GetUDPSockAddr the structured addresses based on the protocol and raw address.
    31  //
    32  //nolint:dupl // dupl marks this incorrectly as duplicate of GetTCPSockAddr
    33  func GetUDPSockAddr(proto, addr string) (unix.Sockaddr, int, *net.UDPAddr, bool, error) {
    34  	var (
    35  		sockAddr   unix.Sockaddr
    36  		family     int
    37  		udpAddr    *net.UDPAddr
    38  		ipv6only   bool
    39  		err        error
    40  		udpVersion string
    41  	)
    42  
    43  	udpAddr, err = net.ResolveUDPAddr(proto, addr)
    44  	if err != nil {
    45  		return sockAddr, family, udpAddr, ipv6only, fmt.Errorf("resolveUDPAddr error: %w", err)
    46  	}
    47  
    48  	udpVersion, err = determineUDPProto(proto, udpAddr)
    49  	if err != nil {
    50  		return sockAddr, family, udpAddr, ipv6only, err
    51  	}
    52  
    53  	switch udpVersion {
    54  	case gainNet.UDP4:
    55  		family = unix.AF_INET
    56  		sockAddr, err = ipToSockaddr(family, udpAddr.IP, udpAddr.Port, "")
    57  
    58  	case gainNet.UDP6:
    59  		ipv6only = true
    60  
    61  		fallthrough
    62  
    63  	case gainNet.UDP:
    64  		family = unix.AF_INET6
    65  		sockAddr, err = ipToSockaddr(family, udpAddr.IP, udpAddr.Port, udpAddr.Zone)
    66  
    67  	default:
    68  		err = gainErrors.ErrUnsupportedProtocol
    69  	}
    70  
    71  	return sockAddr, family, udpAddr, ipv6only, err
    72  }
    73  
    74  func determineUDPProto(proto string, addr *net.UDPAddr) (string, error) {
    75  	// If the protocol is set to "udp", we try to determine the actual protocol
    76  	// version from the size of the resolved IP address. Otherwise, we simple use
    77  	// the protocol given to us by the caller.
    78  	if addr.IP.To4() != nil {
    79  		return gainNet.UDP4, nil
    80  	}
    81  
    82  	if addr.IP.To16() != nil {
    83  		return gainNet.UDP6, nil
    84  	}
    85  
    86  	switch proto {
    87  	case gainNet.UDP, gainNet.UDP4, gainNet.UDP6:
    88  		return proto, nil
    89  	}
    90  
    91  	return "", gainErrors.ErrUnsupportedUDPProtocol
    92  }
    93  
    94  // udpSocket creates an endpoint for communication and returns a file descriptor that refers to that endpoint.
    95  func udpSocket(proto, addr string, connect bool, sockOpts ...Option) (int, net.Addr, error) {
    96  	var (
    97  		fd       int
    98  		netAddr  net.Addr
    99  		err      error
   100  		family   int
   101  		ipv6only bool
   102  		sockAddr unix.Sockaddr
   103  	)
   104  
   105  	if sockAddr, family, netAddr, ipv6only, err = GetUDPSockAddr(proto, addr); err != nil {
   106  		return fd, netAddr, err
   107  	}
   108  
   109  	if fd, err = sysSocket(family, unix.SOCK_DGRAM, unix.IPPROTO_UDP); err != nil {
   110  		err = os.NewSyscallError("socket", err)
   111  
   112  		return fd, netAddr, err
   113  	}
   114  
   115  	defer func() {
   116  		// ignore EINPROGRESS for non-blocking socket connect, should be processed by caller
   117  		if err != nil {
   118  			var syscallErr *os.SyscallError
   119  			if errors.As(err, &syscallErr) && errors.Is(syscallErr.Err, unix.EINPROGRESS) {
   120  				return
   121  			}
   122  			_ = unix.Close(fd)
   123  		}
   124  	}()
   125  
   126  	if family == unix.AF_INET6 && ipv6only {
   127  		if err = SetIPv6Only(fd, 1); err != nil {
   128  			return fd, netAddr, err
   129  		}
   130  	}
   131  
   132  	// Allow broadcast.
   133  	if err = os.NewSyscallError("setsockopt", unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_BROADCAST, 1)); err != nil {
   134  		return fd, netAddr, err
   135  	}
   136  
   137  	for _, sockOpt := range sockOpts {
   138  		if err = sockOpt.SetSockOpt(fd, sockOpt.Opt); err != nil {
   139  			return fd, netAddr, err
   140  		}
   141  	}
   142  
   143  	if connect {
   144  		err = os.NewSyscallError("connect", unix.Connect(fd, sockAddr))
   145  	} else {
   146  		err = os.NewSyscallError("bind", unix.Bind(fd, sockAddr))
   147  	}
   148  
   149  	return fd, netAddr, err
   150  }