github.com/pawelgaczynski/gain@v0.4.0-alpha.0.20230821120126-41f1e60a18da/pkg/socket/tcp_socket.go (about) 1 // Copyright (c) 2023 Paweł Gaczyński 2 // Copyright (c) 2020 Andy Pan 3 // Copyright (c) 2017 Max Riveiro 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 17 package socket 18 19 import ( 20 "errors" 21 "fmt" 22 "net" 23 "os" 24 25 gainErrors "github.com/pawelgaczynski/gain/pkg/errors" 26 gainNet "github.com/pawelgaczynski/gain/pkg/net" 27 "golang.org/x/sys/unix" 28 ) 29 30 var listenerBacklogMaxSize = maxListenerBacklog() 31 32 // GetTCPSockAddr the structured addresses based on the protocol and raw address. 33 // 34 //nolint:dupl // dupl marks this incorrectly as duplicate of GetUDPSockAddr 35 func GetTCPSockAddr(proto, addr string) (unix.Sockaddr, int, *net.TCPAddr, bool, error) { 36 var ( 37 sockAddr unix.Sockaddr 38 family int 39 tcpAddr *net.TCPAddr 40 ipv6only bool 41 err error 42 tcpVersion string 43 ) 44 45 tcpAddr, err = net.ResolveTCPAddr(proto, addr) 46 if err != nil { 47 return sockAddr, family, tcpAddr, ipv6only, fmt.Errorf("resolveTCPAddr error: %w", err) 48 } 49 50 tcpVersion, err = determineTCPProto(proto, tcpAddr) 51 if err != nil { 52 return sockAddr, family, tcpAddr, ipv6only, err 53 } 54 55 switch tcpVersion { 56 case gainNet.TCP4: 57 family = unix.AF_INET 58 sockAddr, err = ipToSockaddr(family, tcpAddr.IP, tcpAddr.Port, "") 59 60 case gainNet.TCP6: 61 ipv6only = true 62 63 fallthrough 64 65 case gainNet.TCP: 66 family = unix.AF_INET6 67 sockAddr, err = ipToSockaddr(family, tcpAddr.IP, tcpAddr.Port, tcpAddr.Zone) 68 69 default: 70 err = gainErrors.ErrUnsupportedProtocol 71 } 72 73 return sockAddr, family, tcpAddr, ipv6only, err 74 } 75 76 func determineTCPProto(proto string, addr *net.TCPAddr) (string, error) { 77 // If the protocol is set to "tcp", we try to determine the actual protocol 78 // version from the size of the resolved IP address. Otherwise, we simple use 79 // the protocol given to us by the caller. 80 if addr.IP.To4() != nil { 81 return gainNet.TCP4, nil 82 } 83 84 if addr.IP.To16() != nil { 85 return gainNet.TCP6, nil 86 } 87 88 switch proto { 89 case gainNet.TCP, gainNet.TCP4, gainNet.TCP6: 90 return proto, nil 91 } 92 93 return "", gainErrors.ErrUnsupportedTCPProtocol 94 } 95 96 // tcpSocket creates an endpoint for communication and returns a file descriptor that refers to that endpoint. 97 func tcpSocket(proto, addr string, passive bool, sockOpts ...Option) (int, net.Addr, error) { 98 var ( 99 fd int 100 netAddr net.Addr 101 err error 102 family int 103 ipv6only bool 104 sockAddr unix.Sockaddr 105 ) 106 107 if sockAddr, family, netAddr, ipv6only, err = GetTCPSockAddr(proto, addr); err != nil { 108 return fd, netAddr, err 109 } 110 111 if fd, err = sysSocket(family, unix.SOCK_STREAM, unix.IPPROTO_TCP); err != nil { 112 err = os.NewSyscallError("socket", err) 113 114 return fd, netAddr, err 115 } 116 117 defer func() { 118 // ignore EINPROGRESS for non-blocking socket connect, should be processed by caller 119 if err != nil { 120 var syscallErr *os.SyscallError 121 if errors.As(err, &syscallErr) && errors.Is(syscallErr.Err, unix.EINPROGRESS) { 122 return 123 } 124 _ = unix.Close(fd) 125 } 126 }() 127 128 if family == unix.AF_INET6 && ipv6only { 129 if err = SetIPv6Only(fd, 1); err != nil { 130 return fd, netAddr, err 131 } 132 } 133 134 for _, sockOpt := range sockOpts { 135 if err = sockOpt.SetSockOpt(fd, sockOpt.Opt); err != nil { 136 return fd, netAddr, err 137 } 138 } 139 140 if passive { 141 if err = os.NewSyscallError("bind", unix.Bind(fd, sockAddr)); err != nil { 142 return fd, netAddr, err 143 } 144 // Set backlog size to the maximum. 145 err = os.NewSyscallError("listen", unix.Listen(fd, listenerBacklogMaxSize)) 146 } else { 147 err = os.NewSyscallError("connect", unix.Connect(fd, sockAddr)) 148 } 149 150 return fd, netAddr, err 151 }