github.com/sagernet/sing-box@v1.2.7/outbound/default.go (about)

     1  package outbound
     2  
     3  import (
     4  	"context"
     5  	"net"
     6  	"net/netip"
     7  	"os"
     8  	"runtime"
     9  	"time"
    10  
    11  	"github.com/sagernet/sing-box/adapter"
    12  	C "github.com/sagernet/sing-box/constant"
    13  	"github.com/sagernet/sing-box/log"
    14  	"github.com/sagernet/sing/common"
    15  	"github.com/sagernet/sing/common/buf"
    16  	"github.com/sagernet/sing/common/bufio"
    17  	"github.com/sagernet/sing/common/canceler"
    18  	E "github.com/sagernet/sing/common/exceptions"
    19  	N "github.com/sagernet/sing/common/network"
    20  )
    21  
    22  type myOutboundAdapter struct {
    23  	protocol string
    24  	network  []string
    25  	router   adapter.Router
    26  	logger   log.ContextLogger
    27  	tag      string
    28  }
    29  
    30  func (a *myOutboundAdapter) Type() string {
    31  	return a.protocol
    32  }
    33  
    34  func (a *myOutboundAdapter) Tag() string {
    35  	return a.tag
    36  }
    37  
    38  func (a *myOutboundAdapter) Network() []string {
    39  	return a.network
    40  }
    41  
    42  func NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext) error {
    43  	ctx = adapter.WithContext(ctx, &metadata)
    44  	var outConn net.Conn
    45  	var err error
    46  	if len(metadata.DestinationAddresses) > 0 {
    47  		outConn, err = N.DialSerial(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses)
    48  	} else {
    49  		outConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
    50  	}
    51  	if err != nil {
    52  		return N.HandshakeFailure(conn, err)
    53  	}
    54  	return CopyEarlyConn(ctx, conn, outConn)
    55  }
    56  
    57  func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext) error {
    58  	ctx = adapter.WithContext(ctx, &metadata)
    59  	var outConn net.PacketConn
    60  	var destinationAddress netip.Addr
    61  	var err error
    62  	if len(metadata.DestinationAddresses) > 0 {
    63  		outConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses)
    64  	} else {
    65  		outConn, err = this.ListenPacket(ctx, metadata.Destination)
    66  	}
    67  	if err != nil {
    68  		return N.HandshakeFailure(conn, err)
    69  	}
    70  	if destinationAddress.IsValid() {
    71  		if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded {
    72  			natConn.UpdateDestination(destinationAddress)
    73  		}
    74  	}
    75  	switch metadata.Protocol {
    76  	case C.ProtocolSTUN:
    77  		ctx, conn = canceler.NewPacketConn(ctx, conn, C.STUNTimeout)
    78  	case C.ProtocolQUIC:
    79  		ctx, conn = canceler.NewPacketConn(ctx, conn, C.QUICTimeout)
    80  	case C.ProtocolDNS:
    81  		ctx, conn = canceler.NewPacketConn(ctx, conn, C.DNSTimeout)
    82  	}
    83  	return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outConn))
    84  }
    85  
    86  func CopyEarlyConn(ctx context.Context, conn net.Conn, serverConn net.Conn) error {
    87  	if cachedReader, isCached := conn.(N.CachedReader); isCached {
    88  		payload := cachedReader.ReadCached()
    89  		if payload != nil && !payload.IsEmpty() {
    90  			_, err := serverConn.Write(payload.Bytes())
    91  			if err != nil {
    92  				return err
    93  			}
    94  			return bufio.CopyConn(ctx, conn, serverConn)
    95  		}
    96  	}
    97  	if earlyConn, isEarlyConn := common.Cast[N.EarlyConn](serverConn); isEarlyConn && earlyConn.NeedHandshake() {
    98  		_payload := buf.StackNew()
    99  		payload := common.Dup(_payload)
   100  		err := conn.SetReadDeadline(time.Now().Add(C.ReadPayloadTimeout))
   101  		if err != os.ErrInvalid {
   102  			if err != nil {
   103  				return err
   104  			}
   105  			_, err = payload.ReadOnceFrom(conn)
   106  			if err != nil && !E.IsTimeout(err) {
   107  				return E.Cause(err, "read payload")
   108  			}
   109  			err = conn.SetReadDeadline(time.Time{})
   110  			if err != nil {
   111  				payload.Release()
   112  				return err
   113  			}
   114  		}
   115  		_, err = serverConn.Write(payload.Bytes())
   116  		if err != nil {
   117  			return N.HandshakeFailure(conn, err)
   118  		}
   119  		runtime.KeepAlive(_payload)
   120  		payload.Release()
   121  	}
   122  	return bufio.CopyConn(ctx, conn, serverConn)
   123  }