github.com/Cloud-Foundations/Dominator@v0.3.4/lib/net/bind.go (about)

     1  package net
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"net"
     8  	"sync"
     9  	"syscall"
    10  	"time"
    11  )
    12  
    13  const SO_REUSEPORT = 15
    14  
    15  var errorTimeout = errors.New("timeout")
    16  
    17  type connection struct {
    18  	fd            int
    19  	localAddress  *net.TCPAddr
    20  	remoteAddress *net.TCPAddr
    21  	lock          sync.Mutex
    22  	deadline      time.Time
    23  	readDeadline  time.Time
    24  	writeDeadline time.Time
    25  }
    26  
    27  func bindAndDial(network, localAddr, remoteAddr string, timeout time.Duration) (
    28  	net.Conn, error) {
    29  	if network != "tcp" && network != "tcp4" {
    30  		return net.DialTimeout(network, remoteAddr, timeout)
    31  	}
    32  	sockFd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, 0)
    33  	if err != nil {
    34  		return nil, fmt.Errorf("error creating socket: %s", err)
    35  	}
    36  	defer func() {
    37  		if sockFd >= 0 {
    38  			syscall.Close(sockFd)
    39  		}
    40  	}()
    41  	if err := setReuse(sockFd); err != nil {
    42  		return nil, err
    43  	}
    44  	if err := setReadTimeout(sockFd, timeout); err != nil {
    45  		return nil, err
    46  	}
    47  	if err := setWriteTimeout(sockFd, timeout); err != nil {
    48  		return nil, err
    49  	}
    50  	localTCPAddr, localSockAddr, err := resolveAddr(localAddr)
    51  	if err != nil {
    52  		return nil, err
    53  	}
    54  	if err := syscall.Bind(sockFd, localSockAddr); err != nil {
    55  		return nil, fmt.Errorf("error binding to: %s : %s", localAddr, err)
    56  	}
    57  	remTCPAddr, remSockAddr, err := resolveAddr(remoteAddr)
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  	if err := syscall.Connect(sockFd, remSockAddr); err != nil {
    62  		return nil, fmt.Errorf("error binding to: %s : %s", remoteAddr, err)
    63  	}
    64  	if err := setReadTimeout(sockFd, 0); err != nil {
    65  		return nil, err
    66  	}
    67  	if err := setWriteTimeout(sockFd, 0); err != nil {
    68  		return nil, err
    69  	}
    70  	conn := &connection{
    71  		fd:            sockFd,
    72  		localAddress:  localTCPAddr,
    73  		remoteAddress: remTCPAddr,
    74  	}
    75  	sockFd = -1 // Prevent Close on return.
    76  	return conn, nil
    77  }
    78  
    79  func listenWithReuse(network, address string) (net.Listener, error) {
    80  	listener, err := net.Listen(network, address)
    81  	if err != nil {
    82  		return nil, fmt.Errorf("error creating %s listener for %s : %s",
    83  			network, address, err.Error)
    84  	}
    85  	doClose := true
    86  	defer func() {
    87  		if doClose {
    88  			listener.Close()
    89  		}
    90  	}()
    91  	if tcpListener, ok := listener.(*net.TCPListener); ok {
    92  		rawConn, err := tcpListener.SyscallConn()
    93  		if err != nil {
    94  			return nil, err
    95  		}
    96  		e := rawConn.Control(func(fd uintptr) { err = setReuse(int(fd)) })
    97  		if e != nil {
    98  			return nil, e
    99  		}
   100  		if err != nil {
   101  			return nil, err
   102  		}
   103  	} else {
   104  		return nil, errors.New("not a TCPlistener")
   105  	}
   106  	doClose = false
   107  	return listener, nil
   108  }
   109  
   110  func read(fd int, b []byte) (int, error) {
   111  	if nRead, err := syscall.Read(fd, b); err != nil {
   112  		return 0, err
   113  	} else if nRead <= 0 {
   114  		return 0, io.EOF
   115  	} else {
   116  		return nRead, nil
   117  	}
   118  }
   119  
   120  func resolveAddr(address string) (*net.TCPAddr, *syscall.SockaddrInet4, error) {
   121  	tcpAddr, err := net.ResolveTCPAddr("tcp", address)
   122  	if err != nil {
   123  		return nil, nil, err
   124  	}
   125  	if len(tcpAddr.IP) < 1 {
   126  		return tcpAddr, &syscall.SockaddrInet4{Port: tcpAddr.Port}, nil
   127  	}
   128  	tcp4IP := tcpAddr.IP.To4()
   129  	if tcp4IP == nil {
   130  		return nil, nil, errors.New(address + " is not an IPv4 address")
   131  	}
   132  	var ip4 [4]byte
   133  	for index, b := range tcp4IP {
   134  		ip4[index] = b
   135  	}
   136  	return tcpAddr, &syscall.SockaddrInet4{Port: tcpAddr.Port, Addr: ip4}, nil
   137  }
   138  
   139  func setReadTimeout(fd int, timeout time.Duration) error {
   140  	timeval := syscall.NsecToTimeval(timeout.Nanoseconds())
   141  	return syscall.SetsockoptTimeval(fd, syscall.SOL_SOCKET,
   142  		syscall.SO_RCVTIMEO, &timeval)
   143  }
   144  
   145  func setReuse(fd int) error {
   146  	err := syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR,
   147  		1)
   148  	if err != nil {
   149  		return fmt.Errorf("error setting SO_REUSEADDR: %s", err)
   150  	}
   151  	err = syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, SO_REUSEPORT, 1)
   152  	if err != nil {
   153  		return fmt.Errorf("error setting SO_REUSEPORT: %s", err)
   154  	}
   155  	return nil
   156  }
   157  
   158  func setWriteTimeout(fd int, timeout time.Duration) error {
   159  	timeval := syscall.NsecToTimeval(timeout.Nanoseconds())
   160  	return syscall.SetsockoptTimeval(fd, syscall.SOL_SOCKET,
   161  		syscall.SO_SNDTIMEO, &timeval)
   162  }
   163  
   164  func write(fd int, b []byte) (int, error) {
   165  	if nWritten, err := syscall.Write(fd, b); err != nil {
   166  		return 0, err
   167  	} else if nWritten < len(b) {
   168  		return nWritten, io.EOF
   169  	} else {
   170  		return nWritten, nil
   171  	}
   172  }
   173  
   174  func (conn *connection) Close() error {
   175  	return syscall.Close(conn.fd)
   176  }
   177  
   178  func (conn *connection) getDeadline(write bool) time.Time {
   179  	conn.lock.Lock()
   180  	defer conn.lock.Unlock()
   181  	deadline := conn.readDeadline
   182  	if write {
   183  		deadline = conn.writeDeadline
   184  	}
   185  	if conn.deadline.IsZero() {
   186  		return deadline
   187  	} else if !deadline.IsZero() && deadline.Before(conn.deadline) {
   188  		return deadline
   189  	} else {
   190  		return conn.deadline
   191  	}
   192  }
   193  
   194  func (conn *connection) LocalAddr() net.Addr {
   195  	return conn.localAddress
   196  }
   197  
   198  func (conn *connection) Read(b []byte) (int, error) {
   199  	deadline := conn.getDeadline(false)
   200  	if deadline.IsZero() { // Fast check.
   201  		return read(conn.fd, b)
   202  	}
   203  	timeout := time.Until(deadline)
   204  	if timeout <= 0 {
   205  		return 0, errorTimeout
   206  	}
   207  	if err := setReadTimeout(conn.fd, timeout); err != nil {
   208  		return 0, err
   209  	}
   210  	nRead, err := read(conn.fd, b)
   211  	if err == syscall.EAGAIN {
   212  		err = errorTimeout
   213  	}
   214  	if e := setReadTimeout(conn.fd, 0); err == nil {
   215  		err = e
   216  	}
   217  	if err != nil {
   218  		return 0, err
   219  	}
   220  	return nRead, nil
   221  }
   222  
   223  func (conn *connection) RemoteAddr() net.Addr {
   224  	return conn.remoteAddress
   225  }
   226  
   227  func (conn *connection) SetDeadline(t time.Time) error {
   228  	conn.lock.Lock()
   229  	defer conn.lock.Unlock()
   230  	conn.deadline = t
   231  	return nil
   232  }
   233  
   234  func (conn *connection) SetKeepAlive(keepalive bool) error {
   235  	var ka int
   236  	if keepalive {
   237  		ka = 1
   238  	}
   239  	return syscall.SetsockoptInt(conn.fd, syscall.SOL_SOCKET,
   240  		syscall.SO_KEEPALIVE, ka)
   241  }
   242  
   243  func (conn *connection) SetReadDeadline(t time.Time) error {
   244  	conn.lock.Lock()
   245  	defer conn.lock.Unlock()
   246  	conn.readDeadline = t
   247  	return nil
   248  }
   249  
   250  func (conn *connection) SetWriteDeadline(t time.Time) error {
   251  	conn.lock.Lock()
   252  	defer conn.lock.Unlock()
   253  	conn.writeDeadline = t
   254  	return nil
   255  }
   256  
   257  func (conn *connection) Write(b []byte) (int, error) {
   258  	deadline := conn.getDeadline(true)
   259  	if deadline.IsZero() { // Fast check.
   260  		return write(conn.fd, b)
   261  	}
   262  	timeout := time.Until(deadline)
   263  	if timeout <= 0 {
   264  		return 0, errorTimeout
   265  	}
   266  	if err := setWriteTimeout(conn.fd, timeout); err != nil {
   267  		return 0, err
   268  	}
   269  	nWritten, err := write(conn.fd, b)
   270  	if err == syscall.EAGAIN {
   271  		err = errorTimeout
   272  	}
   273  	if e := setWriteTimeout(conn.fd, 0); err == nil {
   274  		err = e
   275  	}
   276  	if err != nil {
   277  		return 0, err
   278  	}
   279  	return nWritten, nil
   280  }