github.com/metacubex/mihomo@v1.18.5/component/dialer/tfo.go (about)

     1  package dialer
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"net"
     7  	"time"
     8  
     9  	"github.com/metacubex/tfo-go"
    10  )
    11  
    12  var DisableTFO = false
    13  
    14  type tfoConn struct {
    15  	net.Conn
    16  	closed bool
    17  	dialed chan bool
    18  	cancel context.CancelFunc
    19  	ctx    context.Context
    20  	dialFn func(ctx context.Context, earlyData []byte) (net.Conn, error)
    21  }
    22  
    23  func (c *tfoConn) Dial(earlyData []byte) (err error) {
    24  	conn, err := c.dialFn(c.ctx, earlyData)
    25  	if err != nil {
    26  		return
    27  	}
    28  	c.Conn = conn
    29  	c.dialed <- true
    30  	return err
    31  }
    32  
    33  func (c *tfoConn) Read(b []byte) (n int, err error) {
    34  	if c.closed {
    35  		return 0, io.ErrClosedPipe
    36  	}
    37  	if c.Conn == nil {
    38  		select {
    39  		case <-c.ctx.Done():
    40  			return 0, io.ErrUnexpectedEOF
    41  		case <-c.dialed:
    42  		}
    43  	}
    44  	return c.Conn.Read(b)
    45  }
    46  
    47  func (c *tfoConn) Write(b []byte) (n int, err error) {
    48  	if c.closed {
    49  		return 0, io.ErrClosedPipe
    50  	}
    51  	if c.Conn == nil {
    52  		if err := c.Dial(b); err != nil {
    53  			return 0, err
    54  		}
    55  		return len(b), nil
    56  	}
    57  
    58  	return c.Conn.Write(b)
    59  }
    60  
    61  func (c *tfoConn) Close() error {
    62  	c.closed = true
    63  	c.cancel()
    64  	if c.Conn == nil {
    65  		return nil
    66  	}
    67  	return c.Conn.Close()
    68  }
    69  
    70  func (c *tfoConn) LocalAddr() net.Addr {
    71  	if c.Conn == nil {
    72  		return &net.TCPAddr{}
    73  	}
    74  	return c.Conn.LocalAddr()
    75  }
    76  
    77  func (c *tfoConn) RemoteAddr() net.Addr {
    78  	if c.Conn == nil {
    79  		return &net.TCPAddr{}
    80  	}
    81  	return c.Conn.RemoteAddr()
    82  }
    83  
    84  func (c *tfoConn) SetDeadline(t time.Time) error {
    85  	if err := c.SetReadDeadline(t); err != nil {
    86  		return err
    87  	}
    88  	return c.SetWriteDeadline(t)
    89  }
    90  
    91  func (c *tfoConn) SetReadDeadline(t time.Time) error {
    92  	if c.Conn == nil {
    93  		return nil
    94  	}
    95  	return c.Conn.SetReadDeadline(t)
    96  }
    97  
    98  func (c *tfoConn) SetWriteDeadline(t time.Time) error {
    99  	if c.Conn == nil {
   100  		return nil
   101  	}
   102  	return c.Conn.SetWriteDeadline(t)
   103  }
   104  
   105  func (c *tfoConn) Upstream() any {
   106  	if c.Conn == nil { // ensure return a nil interface not an interface with nil value
   107  		return nil
   108  	}
   109  	return c.Conn
   110  }
   111  
   112  func (c *tfoConn) NeedAdditionalReadDeadline() bool {
   113  	return c.Conn == nil
   114  }
   115  
   116  func (c *tfoConn) NeedHandshake() bool {
   117  	return c.Conn == nil
   118  }
   119  
   120  func (c *tfoConn) ReaderReplaceable() bool {
   121  	return c.Conn != nil
   122  }
   123  
   124  func (c *tfoConn) WriterReplaceable() bool {
   125  	return c.Conn != nil
   126  }
   127  
   128  func dialTFO(ctx context.Context, netDialer net.Dialer, network, address string) (net.Conn, error) {
   129  	ctx, cancel := context.WithTimeout(context.Background(), DefaultTCPTimeout)
   130  	dialer := tfo.Dialer{Dialer: netDialer, DisableTFO: false}
   131  	return &tfoConn{
   132  		dialed: make(chan bool, 1),
   133  		cancel: cancel,
   134  		ctx:    ctx,
   135  		dialFn: func(ctx context.Context, earlyData []byte) (net.Conn, error) {
   136  			return dialer.DialContext(ctx, network, address, earlyData)
   137  		},
   138  	}, nil
   139  }