github.com/sagernet/sing-box@v1.9.0-rc.20/common/dialer/tfo.go (about) 1 //go:build go1.20 2 3 package dialer 4 5 import ( 6 "context" 7 "io" 8 "net" 9 "os" 10 "sync" 11 "time" 12 13 "github.com/sagernet/sing/common" 14 "github.com/sagernet/sing/common/bufio" 15 E "github.com/sagernet/sing/common/exceptions" 16 M "github.com/sagernet/sing/common/metadata" 17 N "github.com/sagernet/sing/common/network" 18 "github.com/sagernet/tfo-go" 19 ) 20 21 type slowOpenConn struct { 22 dialer *tfo.Dialer 23 ctx context.Context 24 network string 25 destination M.Socksaddr 26 conn net.Conn 27 create chan struct{} 28 access sync.Mutex 29 err error 30 } 31 32 func DialSlowContext(dialer *tcpDialer, ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { 33 if dialer.DisableTFO || N.NetworkName(network) != N.NetworkTCP { 34 switch N.NetworkName(network) { 35 case N.NetworkTCP, N.NetworkUDP: 36 return dialer.Dialer.DialContext(ctx, network, destination.String()) 37 default: 38 return dialer.Dialer.DialContext(ctx, network, destination.AddrString()) 39 } 40 } 41 return &slowOpenConn{ 42 dialer: dialer, 43 ctx: ctx, 44 network: network, 45 destination: destination, 46 create: make(chan struct{}), 47 }, nil 48 } 49 50 func (c *slowOpenConn) Read(b []byte) (n int, err error) { 51 if c.conn == nil { 52 select { 53 case <-c.create: 54 if c.err != nil { 55 return 0, c.err 56 } 57 case <-c.ctx.Done(): 58 return 0, c.ctx.Err() 59 } 60 } 61 return c.conn.Read(b) 62 } 63 64 func (c *slowOpenConn) Write(b []byte) (n int, err error) { 65 if c.conn != nil { 66 return c.conn.Write(b) 67 } 68 c.access.Lock() 69 defer c.access.Unlock() 70 select { 71 case <-c.create: 72 if c.err != nil { 73 return 0, c.err 74 } 75 return c.conn.Write(b) 76 default: 77 } 78 c.conn, err = c.dialer.DialContext(c.ctx, c.network, c.destination.String(), b) 79 if err != nil { 80 c.conn = nil 81 c.err = E.Cause(err, "dial tcp fast open") 82 } 83 n = len(b) 84 close(c.create) 85 return 86 } 87 88 func (c *slowOpenConn) Close() error { 89 return common.Close(c.conn) 90 } 91 92 func (c *slowOpenConn) LocalAddr() net.Addr { 93 if c.conn == nil { 94 return M.Socksaddr{} 95 } 96 return c.conn.LocalAddr() 97 } 98 99 func (c *slowOpenConn) RemoteAddr() net.Addr { 100 if c.conn == nil { 101 return M.Socksaddr{} 102 } 103 return c.conn.RemoteAddr() 104 } 105 106 func (c *slowOpenConn) SetDeadline(t time.Time) error { 107 if c.conn == nil { 108 return os.ErrInvalid 109 } 110 return c.conn.SetDeadline(t) 111 } 112 113 func (c *slowOpenConn) SetReadDeadline(t time.Time) error { 114 if c.conn == nil { 115 return os.ErrInvalid 116 } 117 return c.conn.SetReadDeadline(t) 118 } 119 120 func (c *slowOpenConn) SetWriteDeadline(t time.Time) error { 121 if c.conn == nil { 122 return os.ErrInvalid 123 } 124 return c.conn.SetWriteDeadline(t) 125 } 126 127 func (c *slowOpenConn) Upstream() any { 128 return c.conn 129 } 130 131 func (c *slowOpenConn) ReaderReplaceable() bool { 132 return c.conn != nil 133 } 134 135 func (c *slowOpenConn) WriterReplaceable() bool { 136 return c.conn != nil 137 } 138 139 func (c *slowOpenConn) LazyHeadroom() bool { 140 return c.conn == nil 141 } 142 143 func (c *slowOpenConn) NeedHandshake() bool { 144 return c.conn == nil 145 } 146 147 func (c *slowOpenConn) WriteTo(w io.Writer) (n int64, err error) { 148 if c.conn == nil { 149 select { 150 case <-c.create: 151 if c.err != nil { 152 return 0, c.err 153 } 154 case <-c.ctx.Done(): 155 return 0, c.ctx.Err() 156 } 157 } 158 return bufio.Copy(w, c.conn) 159 }