github.com/sagernet/sing-box@v1.9.0-rc.20/transport/wireguard/gonet.go (about)

     1  //go:build with_gvisor
     2  
     3  package wireguard
     4  
     5  import (
     6  	"context"
     7  	"errors"
     8  	"fmt"
     9  	"net"
    10  	"net/netip"
    11  	"time"
    12  
    13  	"github.com/sagernet/gvisor/pkg/tcpip"
    14  	"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
    15  	"github.com/sagernet/gvisor/pkg/tcpip/stack"
    16  	"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
    17  	"github.com/sagernet/gvisor/pkg/waiter"
    18  	"github.com/sagernet/sing-tun"
    19  	M "github.com/sagernet/sing/common/metadata"
    20  )
    21  
    22  func DialTCPWithBind(ctx context.Context, s *stack.Stack, localAddr, remoteAddr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*gonet.TCPConn, error) {
    23  	// Create TCP endpoint, then connect.
    24  	var wq waiter.Queue
    25  	ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
    26  	if err != nil {
    27  		return nil, errors.New(err.String())
    28  	}
    29  
    30  	// Create wait queue entry that notifies a channel.
    31  	//
    32  	// We do this unconditionally as Connect will always return an error.
    33  	waitEntry, notifyCh := waiter.NewChannelEntry(waiter.WritableEvents)
    34  	wq.EventRegister(&waitEntry)
    35  	defer wq.EventUnregister(&waitEntry)
    36  
    37  	select {
    38  	case <-ctx.Done():
    39  		return nil, ctx.Err()
    40  	default:
    41  	}
    42  
    43  	// Bind before connect if requested.
    44  	if localAddr != (tcpip.FullAddress{}) {
    45  		if err = ep.Bind(localAddr); err != nil {
    46  			return nil, fmt.Errorf("ep.Bind(%+v) = %s", localAddr, err)
    47  		}
    48  	}
    49  
    50  	err = ep.Connect(remoteAddr)
    51  	if _, ok := err.(*tcpip.ErrConnectStarted); ok {
    52  		select {
    53  		case <-ctx.Done():
    54  			ep.Close()
    55  			return nil, ctx.Err()
    56  		case <-notifyCh:
    57  		}
    58  
    59  		err = ep.LastError()
    60  	}
    61  	if err != nil {
    62  		ep.Close()
    63  		return nil, &net.OpError{
    64  			Op:   "connect",
    65  			Net:  "tcp",
    66  			Addr: M.SocksaddrFromNetIP(netip.AddrPortFrom(tun.AddrFromAddress(remoteAddr.Addr), remoteAddr.Port)).TCPAddr(),
    67  			Err:  errors.New(err.String()),
    68  		}
    69  	}
    70  
    71  	// sing-box added: set keepalive
    72  	ep.SocketOptions().SetKeepAlive(true)
    73  	keepAliveIdle := tcpip.KeepaliveIdleOption(15 * time.Second)
    74  	ep.SetSockOpt(&keepAliveIdle)
    75  	keepAliveInterval := tcpip.KeepaliveIntervalOption(15 * time.Second)
    76  	ep.SetSockOpt(&keepAliveInterval)
    77  
    78  	return gonet.NewTCPConn(&wq, ep), nil
    79  }