github.com/pawelgaczynski/gain@v0.4.0-alpha.0.20230821120126-41f1e60a18da/pkg/socket/tcp_socket.go (about)

     1  // Copyright (c) 2023 Paweł Gaczyński
     2  // Copyright (c) 2020 Andy Pan
     3  // Copyright (c) 2017 Max Riveiro
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //	http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  
    17  package socket
    18  
    19  import (
    20  	"errors"
    21  	"fmt"
    22  	"net"
    23  	"os"
    24  
    25  	gainErrors "github.com/pawelgaczynski/gain/pkg/errors"
    26  	gainNet "github.com/pawelgaczynski/gain/pkg/net"
    27  	"golang.org/x/sys/unix"
    28  )
    29  
    30  var listenerBacklogMaxSize = maxListenerBacklog()
    31  
    32  // GetTCPSockAddr the structured addresses based on the protocol and raw address.
    33  //
    34  //nolint:dupl // dupl marks this incorrectly as duplicate of GetUDPSockAddr
    35  func GetTCPSockAddr(proto, addr string) (unix.Sockaddr, int, *net.TCPAddr, bool, error) {
    36  	var (
    37  		sockAddr   unix.Sockaddr
    38  		family     int
    39  		tcpAddr    *net.TCPAddr
    40  		ipv6only   bool
    41  		err        error
    42  		tcpVersion string
    43  	)
    44  
    45  	tcpAddr, err = net.ResolveTCPAddr(proto, addr)
    46  	if err != nil {
    47  		return sockAddr, family, tcpAddr, ipv6only, fmt.Errorf("resolveTCPAddr error: %w", err)
    48  	}
    49  
    50  	tcpVersion, err = determineTCPProto(proto, tcpAddr)
    51  	if err != nil {
    52  		return sockAddr, family, tcpAddr, ipv6only, err
    53  	}
    54  
    55  	switch tcpVersion {
    56  	case gainNet.TCP4:
    57  		family = unix.AF_INET
    58  		sockAddr, err = ipToSockaddr(family, tcpAddr.IP, tcpAddr.Port, "")
    59  
    60  	case gainNet.TCP6:
    61  		ipv6only = true
    62  
    63  		fallthrough
    64  
    65  	case gainNet.TCP:
    66  		family = unix.AF_INET6
    67  		sockAddr, err = ipToSockaddr(family, tcpAddr.IP, tcpAddr.Port, tcpAddr.Zone)
    68  
    69  	default:
    70  		err = gainErrors.ErrUnsupportedProtocol
    71  	}
    72  
    73  	return sockAddr, family, tcpAddr, ipv6only, err
    74  }
    75  
    76  func determineTCPProto(proto string, addr *net.TCPAddr) (string, error) {
    77  	// If the protocol is set to "tcp", we try to determine the actual protocol
    78  	// version from the size of the resolved IP address. Otherwise, we simple use
    79  	// the protocol given to us by the caller.
    80  	if addr.IP.To4() != nil {
    81  		return gainNet.TCP4, nil
    82  	}
    83  
    84  	if addr.IP.To16() != nil {
    85  		return gainNet.TCP6, nil
    86  	}
    87  
    88  	switch proto {
    89  	case gainNet.TCP, gainNet.TCP4, gainNet.TCP6:
    90  		return proto, nil
    91  	}
    92  
    93  	return "", gainErrors.ErrUnsupportedTCPProtocol
    94  }
    95  
    96  // tcpSocket creates an endpoint for communication and returns a file descriptor that refers to that endpoint.
    97  func tcpSocket(proto, addr string, passive bool, sockOpts ...Option) (int, net.Addr, error) {
    98  	var (
    99  		fd       int
   100  		netAddr  net.Addr
   101  		err      error
   102  		family   int
   103  		ipv6only bool
   104  		sockAddr unix.Sockaddr
   105  	)
   106  
   107  	if sockAddr, family, netAddr, ipv6only, err = GetTCPSockAddr(proto, addr); err != nil {
   108  		return fd, netAddr, err
   109  	}
   110  
   111  	if fd, err = sysSocket(family, unix.SOCK_STREAM, unix.IPPROTO_TCP); err != nil {
   112  		err = os.NewSyscallError("socket", err)
   113  
   114  		return fd, netAddr, err
   115  	}
   116  
   117  	defer func() {
   118  		// ignore EINPROGRESS for non-blocking socket connect, should be processed by caller
   119  		if err != nil {
   120  			var syscallErr *os.SyscallError
   121  			if errors.As(err, &syscallErr) && errors.Is(syscallErr.Err, unix.EINPROGRESS) {
   122  				return
   123  			}
   124  			_ = unix.Close(fd)
   125  		}
   126  	}()
   127  
   128  	if family == unix.AF_INET6 && ipv6only {
   129  		if err = SetIPv6Only(fd, 1); err != nil {
   130  			return fd, netAddr, err
   131  		}
   132  	}
   133  
   134  	for _, sockOpt := range sockOpts {
   135  		if err = sockOpt.SetSockOpt(fd, sockOpt.Opt); err != nil {
   136  			return fd, netAddr, err
   137  		}
   138  	}
   139  
   140  	if passive {
   141  		if err = os.NewSyscallError("bind", unix.Bind(fd, sockAddr)); err != nil {
   142  			return fd, netAddr, err
   143  		}
   144  		// Set backlog size to the maximum.
   145  		err = os.NewSyscallError("listen", unix.Listen(fd, listenerBacklogMaxSize))
   146  	} else {
   147  		err = os.NewSyscallError("connect", unix.Connect(fd, sockAddr))
   148  	}
   149  
   150  	return fd, netAddr, err
   151  }