github.com/kelleygo/clashcore@v1.0.2/component/dialer/tfo.go (about)

     1  package dialer
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"net"
     7  	"time"
     8  
     9  	"github.com/metacubex/tfo-go"
    10  )
    11  
    12  type tfoConn struct {
    13  	net.Conn
    14  	closed bool
    15  	dialed chan bool
    16  	cancel context.CancelFunc
    17  	ctx    context.Context
    18  	dialFn func(ctx context.Context, earlyData []byte) (net.Conn, error)
    19  }
    20  
    21  func (c *tfoConn) Dial(earlyData []byte) (err error) {
    22  	conn, err := c.dialFn(c.ctx, earlyData)
    23  	if err != nil {
    24  		return
    25  	}
    26  	c.Conn = conn
    27  	c.dialed <- true
    28  	return err
    29  }
    30  
    31  func (c *tfoConn) Read(b []byte) (n int, err error) {
    32  	if c.closed {
    33  		return 0, io.ErrClosedPipe
    34  	}
    35  	if c.Conn == nil {
    36  		select {
    37  		case <-c.ctx.Done():
    38  			return 0, io.ErrUnexpectedEOF
    39  		case <-c.dialed:
    40  		}
    41  	}
    42  	return c.Conn.Read(b)
    43  }
    44  
    45  func (c *tfoConn) Write(b []byte) (n int, err error) {
    46  	if c.closed {
    47  		return 0, io.ErrClosedPipe
    48  	}
    49  	if c.Conn == nil {
    50  		if err := c.Dial(b); err != nil {
    51  			return 0, err
    52  		}
    53  		return len(b), nil
    54  	}
    55  
    56  	return c.Conn.Write(b)
    57  }
    58  
    59  func (c *tfoConn) Close() error {
    60  	c.closed = true
    61  	c.cancel()
    62  	if c.Conn == nil {
    63  		return nil
    64  	}
    65  	return c.Conn.Close()
    66  }
    67  
    68  func (c *tfoConn) LocalAddr() net.Addr {
    69  	if c.Conn == nil {
    70  		return &net.TCPAddr{}
    71  	}
    72  	return c.Conn.LocalAddr()
    73  }
    74  
    75  func (c *tfoConn) RemoteAddr() net.Addr {
    76  	if c.Conn == nil {
    77  		return &net.TCPAddr{}
    78  	}
    79  	return c.Conn.RemoteAddr()
    80  }
    81  
    82  func (c *tfoConn) SetDeadline(t time.Time) error {
    83  	if err := c.SetReadDeadline(t); err != nil {
    84  		return err
    85  	}
    86  	return c.SetWriteDeadline(t)
    87  }
    88  
    89  func (c *tfoConn) SetReadDeadline(t time.Time) error {
    90  	if c.Conn == nil {
    91  		return nil
    92  	}
    93  	return c.Conn.SetReadDeadline(t)
    94  }
    95  
    96  func (c *tfoConn) SetWriteDeadline(t time.Time) error {
    97  	if c.Conn == nil {
    98  		return nil
    99  	}
   100  	return c.Conn.SetWriteDeadline(t)
   101  }
   102  
   103  func (c *tfoConn) Upstream() any {
   104  	if c.Conn == nil { // ensure return a nil interface not an interface with nil value
   105  		return nil
   106  	}
   107  	return c.Conn
   108  }
   109  
   110  func (c *tfoConn) NeedAdditionalReadDeadline() bool {
   111  	return c.Conn == nil
   112  }
   113  
   114  func (c *tfoConn) NeedHandshake() bool {
   115  	return c.Conn == nil
   116  }
   117  
   118  func (c *tfoConn) ReaderReplaceable() bool {
   119  	return c.Conn != nil
   120  }
   121  
   122  func (c *tfoConn) WriterReplaceable() bool {
   123  	return c.Conn != nil
   124  }
   125  
   126  func dialTFO(ctx context.Context, netDialer net.Dialer, network, address string) (net.Conn, error) {
   127  	ctx, cancel := context.WithTimeout(context.Background(), DefaultTCPTimeout)
   128  	dialer := tfo.Dialer{Dialer: netDialer, DisableTFO: false}
   129  	return &tfoConn{
   130  		dialed: make(chan bool, 1),
   131  		cancel: cancel,
   132  		ctx:    ctx,
   133  		dialFn: func(ctx context.Context, earlyData []byte) (net.Conn, error) {
   134  			return dialer.DialContext(ctx, network, address, earlyData)
   135  		},
   136  	}, nil
   137  }