github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/common/dialer/tfo.go (about) 1 //go:build go1.20 2 3 package dialer 4 5 import ( 6 "context" 7 "io" 8 "net" 9 "os" 10 "time" 11 12 "github.com/sagernet/sing/common" 13 "github.com/sagernet/sing/common/bufio" 14 E "github.com/sagernet/sing/common/exceptions" 15 M "github.com/sagernet/sing/common/metadata" 16 N "github.com/sagernet/sing/common/network" 17 "github.com/sagernet/tfo-go" 18 ) 19 20 type slowOpenConn struct { 21 dialer *tfo.Dialer 22 ctx context.Context 23 network string 24 destination M.Socksaddr 25 conn net.Conn 26 create chan struct{} 27 err error 28 } 29 30 func DialSlowContext(dialer *tcpDialer, ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { 31 if dialer.DisableTFO || N.NetworkName(network) != N.NetworkTCP { 32 switch N.NetworkName(network) { 33 case N.NetworkTCP, N.NetworkUDP: 34 return dialer.Dialer.DialContext(ctx, network, destination.String()) 35 default: 36 return dialer.Dialer.DialContext(ctx, network, destination.AddrString()) 37 } 38 } 39 return &slowOpenConn{ 40 dialer: dialer, 41 ctx: ctx, 42 network: network, 43 destination: destination, 44 create: make(chan struct{}), 45 }, nil 46 } 47 48 func (c *slowOpenConn) Read(b []byte) (n int, err error) { 49 if c.conn == nil { 50 select { 51 case <-c.create: 52 if c.err != nil { 53 return 0, c.err 54 } 55 case <-c.ctx.Done(): 56 return 0, c.ctx.Err() 57 } 58 } 59 return c.conn.Read(b) 60 } 61 62 func (c *slowOpenConn) Write(b []byte) (n int, err error) { 63 if c.conn == nil { 64 c.conn, err = c.dialer.DialContext(c.ctx, c.network, c.destination.String(), b) 65 if err != nil { 66 c.conn = nil 67 c.err = E.Cause(err, "dial tcp fast open") 68 } 69 close(c.create) 70 return 71 } 72 return c.conn.Write(b) 73 } 74 75 func (c *slowOpenConn) Close() error { 76 return common.Close(c.conn) 77 } 78 79 func (c *slowOpenConn) LocalAddr() net.Addr { 80 if c.conn == nil { 81 return M.Socksaddr{} 82 } 83 return c.conn.LocalAddr() 84 } 85 86 func (c *slowOpenConn) RemoteAddr() net.Addr { 87 if c.conn == nil { 88 return M.Socksaddr{} 89 } 90 return c.conn.RemoteAddr() 91 } 92 93 func (c *slowOpenConn) SetDeadline(t time.Time) error { 94 if c.conn == nil { 95 return os.ErrInvalid 96 } 97 return c.conn.SetDeadline(t) 98 } 99 100 func (c *slowOpenConn) SetReadDeadline(t time.Time) error { 101 if c.conn == nil { 102 return os.ErrInvalid 103 } 104 return c.conn.SetReadDeadline(t) 105 } 106 107 func (c *slowOpenConn) SetWriteDeadline(t time.Time) error { 108 if c.conn == nil { 109 return os.ErrInvalid 110 } 111 return c.conn.SetWriteDeadline(t) 112 } 113 114 func (c *slowOpenConn) Upstream() any { 115 return c.conn 116 } 117 118 func (c *slowOpenConn) ReaderReplaceable() bool { 119 return c.conn != nil 120 } 121 122 func (c *slowOpenConn) WriterReplaceable() bool { 123 return c.conn != nil 124 } 125 126 func (c *slowOpenConn) LazyHeadroom() bool { 127 return c.conn == nil 128 } 129 130 func (c *slowOpenConn) NeedHandshake() bool { 131 return c.conn == nil 132 } 133 134 func (c *slowOpenConn) WriteTo(w io.Writer) (n int64, err error) { 135 if c.conn == nil { 136 select { 137 case <-c.create: 138 if c.err != nil { 139 return 0, c.err 140 } 141 case <-c.ctx.Done(): 142 return 0, c.ctx.Err() 143 } 144 } 145 return bufio.Copy(w, c.conn) 146 }