github.com/pawelgaczynski/gain@v0.4.0-alpha.0.20230821120126-41f1e60a18da/pkg/socket/udp_socket.go (about) 1 // Copyright (c) 2023 Paweł Gaczyński 2 // Copyright (c) 2020 Andy Pan 3 // Copyright (c) 2017 Max Riveiro 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 17 package socket 18 19 import ( 20 "errors" 21 "fmt" 22 "net" 23 "os" 24 25 gainErrors "github.com/pawelgaczynski/gain/pkg/errors" 26 gainNet "github.com/pawelgaczynski/gain/pkg/net" 27 "golang.org/x/sys/unix" 28 ) 29 30 // GetUDPSockAddr the structured addresses based on the protocol and raw address. 31 // 32 //nolint:dupl // dupl marks this incorrectly as duplicate of GetTCPSockAddr 33 func GetUDPSockAddr(proto, addr string) (unix.Sockaddr, int, *net.UDPAddr, bool, error) { 34 var ( 35 sockAddr unix.Sockaddr 36 family int 37 udpAddr *net.UDPAddr 38 ipv6only bool 39 err error 40 udpVersion string 41 ) 42 43 udpAddr, err = net.ResolveUDPAddr(proto, addr) 44 if err != nil { 45 return sockAddr, family, udpAddr, ipv6only, fmt.Errorf("resolveUDPAddr error: %w", err) 46 } 47 48 udpVersion, err = determineUDPProto(proto, udpAddr) 49 if err != nil { 50 return sockAddr, family, udpAddr, ipv6only, err 51 } 52 53 switch udpVersion { 54 case gainNet.UDP4: 55 family = unix.AF_INET 56 sockAddr, err = ipToSockaddr(family, udpAddr.IP, udpAddr.Port, "") 57 58 case gainNet.UDP6: 59 ipv6only = true 60 61 fallthrough 62 63 case gainNet.UDP: 64 family = unix.AF_INET6 65 sockAddr, err = ipToSockaddr(family, udpAddr.IP, udpAddr.Port, udpAddr.Zone) 66 67 default: 68 err = gainErrors.ErrUnsupportedProtocol 69 } 70 71 return sockAddr, family, udpAddr, ipv6only, err 72 } 73 74 func determineUDPProto(proto string, addr *net.UDPAddr) (string, error) { 75 // If the protocol is set to "udp", we try to determine the actual protocol 76 // version from the size of the resolved IP address. Otherwise, we simple use 77 // the protocol given to us by the caller. 78 if addr.IP.To4() != nil { 79 return gainNet.UDP4, nil 80 } 81 82 if addr.IP.To16() != nil { 83 return gainNet.UDP6, nil 84 } 85 86 switch proto { 87 case gainNet.UDP, gainNet.UDP4, gainNet.UDP6: 88 return proto, nil 89 } 90 91 return "", gainErrors.ErrUnsupportedUDPProtocol 92 } 93 94 // udpSocket creates an endpoint for communication and returns a file descriptor that refers to that endpoint. 95 func udpSocket(proto, addr string, connect bool, sockOpts ...Option) (int, net.Addr, error) { 96 var ( 97 fd int 98 netAddr net.Addr 99 err error 100 family int 101 ipv6only bool 102 sockAddr unix.Sockaddr 103 ) 104 105 if sockAddr, family, netAddr, ipv6only, err = GetUDPSockAddr(proto, addr); err != nil { 106 return fd, netAddr, err 107 } 108 109 if fd, err = sysSocket(family, unix.SOCK_DGRAM, unix.IPPROTO_UDP); err != nil { 110 err = os.NewSyscallError("socket", err) 111 112 return fd, netAddr, err 113 } 114 115 defer func() { 116 // ignore EINPROGRESS for non-blocking socket connect, should be processed by caller 117 if err != nil { 118 var syscallErr *os.SyscallError 119 if errors.As(err, &syscallErr) && errors.Is(syscallErr.Err, unix.EINPROGRESS) { 120 return 121 } 122 _ = unix.Close(fd) 123 } 124 }() 125 126 if family == unix.AF_INET6 && ipv6only { 127 if err = SetIPv6Only(fd, 1); err != nil { 128 return fd, netAddr, err 129 } 130 } 131 132 // Allow broadcast. 133 if err = os.NewSyscallError("setsockopt", unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_BROADCAST, 1)); err != nil { 134 return fd, netAddr, err 135 } 136 137 for _, sockOpt := range sockOpts { 138 if err = sockOpt.SetSockOpt(fd, sockOpt.Opt); err != nil { 139 return fd, netAddr, err 140 } 141 } 142 143 if connect { 144 err = os.NewSyscallError("connect", unix.Connect(fd, sockAddr)) 145 } else { 146 err = os.NewSyscallError("bind", unix.Bind(fd, sockAddr)) 147 } 148 149 return fd, netAddr, err 150 }