github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/common/dialer/tfo.go (about)

     1  //go:build go1.20
     2  
     3  package dialer
     4  
     5  import (
     6  	"context"
     7  	"io"
     8  	"net"
     9  	"os"
    10  	"time"
    11  
    12  	"github.com/sagernet/sing/common"
    13  	"github.com/sagernet/sing/common/bufio"
    14  	E "github.com/sagernet/sing/common/exceptions"
    15  	M "github.com/sagernet/sing/common/metadata"
    16  	N "github.com/sagernet/sing/common/network"
    17  	"github.com/sagernet/tfo-go"
    18  )
    19  
    20  type slowOpenConn struct {
    21  	dialer      *tfo.Dialer
    22  	ctx         context.Context
    23  	network     string
    24  	destination M.Socksaddr
    25  	conn        net.Conn
    26  	create      chan struct{}
    27  	err         error
    28  }
    29  
    30  func DialSlowContext(dialer *tcpDialer, ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
    31  	if dialer.DisableTFO || N.NetworkName(network) != N.NetworkTCP {
    32  		switch N.NetworkName(network) {
    33  		case N.NetworkTCP, N.NetworkUDP:
    34  			return dialer.Dialer.DialContext(ctx, network, destination.String())
    35  		default:
    36  			return dialer.Dialer.DialContext(ctx, network, destination.AddrString())
    37  		}
    38  	}
    39  	return &slowOpenConn{
    40  		dialer:      dialer,
    41  		ctx:         ctx,
    42  		network:     network,
    43  		destination: destination,
    44  		create:      make(chan struct{}),
    45  	}, nil
    46  }
    47  
    48  func (c *slowOpenConn) Read(b []byte) (n int, err error) {
    49  	if c.conn == nil {
    50  		select {
    51  		case <-c.create:
    52  			if c.err != nil {
    53  				return 0, c.err
    54  			}
    55  		case <-c.ctx.Done():
    56  			return 0, c.ctx.Err()
    57  		}
    58  	}
    59  	return c.conn.Read(b)
    60  }
    61  
    62  func (c *slowOpenConn) Write(b []byte) (n int, err error) {
    63  	if c.conn == nil {
    64  		c.conn, err = c.dialer.DialContext(c.ctx, c.network, c.destination.String(), b)
    65  		if err != nil {
    66  			c.conn = nil
    67  			c.err = E.Cause(err, "dial tcp fast open")
    68  		}
    69  		close(c.create)
    70  		return
    71  	}
    72  	return c.conn.Write(b)
    73  }
    74  
    75  func (c *slowOpenConn) Close() error {
    76  	return common.Close(c.conn)
    77  }
    78  
    79  func (c *slowOpenConn) LocalAddr() net.Addr {
    80  	if c.conn == nil {
    81  		return M.Socksaddr{}
    82  	}
    83  	return c.conn.LocalAddr()
    84  }
    85  
    86  func (c *slowOpenConn) RemoteAddr() net.Addr {
    87  	if c.conn == nil {
    88  		return M.Socksaddr{}
    89  	}
    90  	return c.conn.RemoteAddr()
    91  }
    92  
    93  func (c *slowOpenConn) SetDeadline(t time.Time) error {
    94  	if c.conn == nil {
    95  		return os.ErrInvalid
    96  	}
    97  	return c.conn.SetDeadline(t)
    98  }
    99  
   100  func (c *slowOpenConn) SetReadDeadline(t time.Time) error {
   101  	if c.conn == nil {
   102  		return os.ErrInvalid
   103  	}
   104  	return c.conn.SetReadDeadline(t)
   105  }
   106  
   107  func (c *slowOpenConn) SetWriteDeadline(t time.Time) error {
   108  	if c.conn == nil {
   109  		return os.ErrInvalid
   110  	}
   111  	return c.conn.SetWriteDeadline(t)
   112  }
   113  
   114  func (c *slowOpenConn) Upstream() any {
   115  	return c.conn
   116  }
   117  
   118  func (c *slowOpenConn) ReaderReplaceable() bool {
   119  	return c.conn != nil
   120  }
   121  
   122  func (c *slowOpenConn) WriterReplaceable() bool {
   123  	return c.conn != nil
   124  }
   125  
   126  func (c *slowOpenConn) LazyHeadroom() bool {
   127  	return c.conn == nil
   128  }
   129  
   130  func (c *slowOpenConn) NeedHandshake() bool {
   131  	return c.conn == nil
   132  }
   133  
   134  func (c *slowOpenConn) WriteTo(w io.Writer) (n int64, err error) {
   135  	if c.conn == nil {
   136  		select {
   137  		case <-c.create:
   138  			if c.err != nil {
   139  				return 0, c.err
   140  			}
   141  		case <-c.ctx.Done():
   142  			return 0, c.ctx.Err()
   143  		}
   144  	}
   145  	return bufio.Copy(w, c.conn)
   146  }