github.com/sagernet/sing-box@v1.9.0-rc.20/common/dialer/tfo.go (about)

     1  //go:build go1.20
     2  
     3  package dialer
     4  
     5  import (
     6  	"context"
     7  	"io"
     8  	"net"
     9  	"os"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/sagernet/sing/common"
    14  	"github.com/sagernet/sing/common/bufio"
    15  	E "github.com/sagernet/sing/common/exceptions"
    16  	M "github.com/sagernet/sing/common/metadata"
    17  	N "github.com/sagernet/sing/common/network"
    18  	"github.com/sagernet/tfo-go"
    19  )
    20  
    21  type slowOpenConn struct {
    22  	dialer      *tfo.Dialer
    23  	ctx         context.Context
    24  	network     string
    25  	destination M.Socksaddr
    26  	conn        net.Conn
    27  	create      chan struct{}
    28  	access      sync.Mutex
    29  	err         error
    30  }
    31  
    32  func DialSlowContext(dialer *tcpDialer, ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
    33  	if dialer.DisableTFO || N.NetworkName(network) != N.NetworkTCP {
    34  		switch N.NetworkName(network) {
    35  		case N.NetworkTCP, N.NetworkUDP:
    36  			return dialer.Dialer.DialContext(ctx, network, destination.String())
    37  		default:
    38  			return dialer.Dialer.DialContext(ctx, network, destination.AddrString())
    39  		}
    40  	}
    41  	return &slowOpenConn{
    42  		dialer:      dialer,
    43  		ctx:         ctx,
    44  		network:     network,
    45  		destination: destination,
    46  		create:      make(chan struct{}),
    47  	}, nil
    48  }
    49  
    50  func (c *slowOpenConn) Read(b []byte) (n int, err error) {
    51  	if c.conn == nil {
    52  		select {
    53  		case <-c.create:
    54  			if c.err != nil {
    55  				return 0, c.err
    56  			}
    57  		case <-c.ctx.Done():
    58  			return 0, c.ctx.Err()
    59  		}
    60  	}
    61  	return c.conn.Read(b)
    62  }
    63  
    64  func (c *slowOpenConn) Write(b []byte) (n int, err error) {
    65  	if c.conn != nil {
    66  		return c.conn.Write(b)
    67  	}
    68  	c.access.Lock()
    69  	defer c.access.Unlock()
    70  	select {
    71  	case <-c.create:
    72  		if c.err != nil {
    73  			return 0, c.err
    74  		}
    75  		return c.conn.Write(b)
    76  	default:
    77  	}
    78  	c.conn, err = c.dialer.DialContext(c.ctx, c.network, c.destination.String(), b)
    79  	if err != nil {
    80  		c.conn = nil
    81  		c.err = E.Cause(err, "dial tcp fast open")
    82  	}
    83  	n = len(b)
    84  	close(c.create)
    85  	return
    86  }
    87  
    88  func (c *slowOpenConn) Close() error {
    89  	return common.Close(c.conn)
    90  }
    91  
    92  func (c *slowOpenConn) LocalAddr() net.Addr {
    93  	if c.conn == nil {
    94  		return M.Socksaddr{}
    95  	}
    96  	return c.conn.LocalAddr()
    97  }
    98  
    99  func (c *slowOpenConn) RemoteAddr() net.Addr {
   100  	if c.conn == nil {
   101  		return M.Socksaddr{}
   102  	}
   103  	return c.conn.RemoteAddr()
   104  }
   105  
   106  func (c *slowOpenConn) SetDeadline(t time.Time) error {
   107  	if c.conn == nil {
   108  		return os.ErrInvalid
   109  	}
   110  	return c.conn.SetDeadline(t)
   111  }
   112  
   113  func (c *slowOpenConn) SetReadDeadline(t time.Time) error {
   114  	if c.conn == nil {
   115  		return os.ErrInvalid
   116  	}
   117  	return c.conn.SetReadDeadline(t)
   118  }
   119  
   120  func (c *slowOpenConn) SetWriteDeadline(t time.Time) error {
   121  	if c.conn == nil {
   122  		return os.ErrInvalid
   123  	}
   124  	return c.conn.SetWriteDeadline(t)
   125  }
   126  
   127  func (c *slowOpenConn) Upstream() any {
   128  	return c.conn
   129  }
   130  
   131  func (c *slowOpenConn) ReaderReplaceable() bool {
   132  	return c.conn != nil
   133  }
   134  
   135  func (c *slowOpenConn) WriterReplaceable() bool {
   136  	return c.conn != nil
   137  }
   138  
   139  func (c *slowOpenConn) LazyHeadroom() bool {
   140  	return c.conn == nil
   141  }
   142  
   143  func (c *slowOpenConn) NeedHandshake() bool {
   144  	return c.conn == nil
   145  }
   146  
   147  func (c *slowOpenConn) WriteTo(w io.Writer) (n int64, err error) {
   148  	if c.conn == nil {
   149  		select {
   150  		case <-c.create:
   151  			if c.err != nil {
   152  				return 0, c.err
   153  			}
   154  		case <-c.ctx.Done():
   155  			return 0, c.ctx.Err()
   156  		}
   157  	}
   158  	return bufio.Copy(w, c.conn)
   159  }