github.com/metacubex/tfo-go@v0.0.0-20240228025757-be1269474a66/tfo.go (about) 1 // Package tfo provides TCP Fast Open support for the [net] dialer and listener. 2 // 3 // The dial functions have an additional buffer parameter, which specifies data in SYN. 4 // If the buffer is empty, TFO is not used. 5 // 6 // This package supports Linux, Windows, macOS, and FreeBSD. 7 // On unsupported platforms, [ErrPlatformUnsupported] is returned. 8 // 9 // FreeBSD code is completely untested. Use at your own risk. Feedback is welcome. 10 package tfo 11 12 import ( 13 "context" 14 "errors" 15 "net" 16 "os" 17 "sync/atomic" 18 "syscall" 19 "time" 20 ) 21 22 var ( 23 ErrPlatformUnsupported PlatformUnsupportedError 24 errMissingAddress = errors.New("missing address") 25 ) 26 27 // PlatformUnsupportedError is returned when tfo-go does not support TCP Fast Open on the current platform. 28 type PlatformUnsupportedError struct{} 29 30 func (PlatformUnsupportedError) Error() string { 31 return "tfo-go does not support TCP Fast Open on this platform" 32 } 33 34 func (PlatformUnsupportedError) Is(target error) bool { 35 return target == ErrUnsupported 36 } 37 38 var runtimeListenNoTFO atomic.Bool 39 40 // ListenConfig wraps [net.ListenConfig] with TFO-related options. 41 type ListenConfig struct { 42 net.ListenConfig 43 44 // Backlog specifies the maximum number of pending TFO connections on supported platforms. 45 // If the value is 0, Go std's listen(2) backlog (4096, as of the current version) is used. 46 // If the value is negative, TFO is disabled. 47 Backlog int 48 49 // DisableTFO controls whether TCP Fast Open is disabled when the Listen method is called. 50 // TFO is enabled by default, unless [ListenConfig.Backlog] is negative. 51 // Set to true to disable TFO and it will behave exactly the same as [net.ListenConfig]. 52 DisableTFO bool 53 54 // Fallback controls whether to proceed without TFO when TFO is enabled but not supported 55 // on the system. 56 Fallback bool 57 } 58 59 func (lc *ListenConfig) tfoDisabled() bool { 60 return lc.Backlog < 0 || lc.DisableTFO 61 } 62 63 func (lc *ListenConfig) tfoNeedsFallback() bool { 64 return lc.Fallback && (comptimeNoTFO || runtimeListenNoTFO.Load()) 65 } 66 67 // Listen is like [net.ListenConfig.Listen] but enables TFO whenever possible, 68 // unless [ListenConfig.Backlog] is negative or [ListenConfig.DisableTFO] is set to true. 69 func (lc *ListenConfig) Listen(ctx context.Context, network, address string) (net.Listener, error) { 70 if lc.tfoDisabled() || !networkIsTCP(network) || lc.tfoNeedsFallback() { 71 return lc.ListenConfig.Listen(ctx, network, address) 72 } 73 return lc.listenTFO(ctx, network, address) // tfo_darwin.go, tfo_listen_generic.go, tfo_unsupported.go 74 } 75 76 // ListenContext is like [net.ListenContext] but enables TFO whenever possible. 77 func ListenContext(ctx context.Context, network, address string) (net.Listener, error) { 78 var lc ListenConfig 79 return lc.Listen(ctx, network, address) 80 } 81 82 // Listen is like [net.Listen] but enables TFO whenever possible. 83 func Listen(network, address string) (net.Listener, error) { 84 return ListenContext(context.Background(), network, address) 85 } 86 87 // ListenTCP is like [net.ListenTCP] but enables TFO whenever possible. 88 func ListenTCP(network string, laddr *net.TCPAddr) (*net.TCPListener, error) { 89 if !networkIsTCP(network) { 90 return nil, &net.OpError{Op: "listen", Net: network, Source: nil, Addr: opAddr(laddr), Err: net.UnknownNetworkError(network)} 91 } 92 var address string 93 if laddr != nil { 94 address = laddr.String() 95 } 96 var lc ListenConfig 97 ln, err := lc.listenTFO(context.Background(), network, address) // tfo_darwin.go, tfo_listen_generic.go, tfo_unsupported.go 98 if err != nil { 99 return nil, err 100 } 101 return ln.(*net.TCPListener), err 102 } 103 104 type dialTFOSupport uint32 105 106 const ( 107 dialTFOSupportDefault dialTFOSupport = iota 108 dialTFOSupportNone 109 dialTFOSupportLinuxSendto 110 ) 111 112 type atomicDialTFOSupport struct { 113 v atomic.Uint32 114 } 115 116 func (a *atomicDialTFOSupport) load() dialTFOSupport { 117 return dialTFOSupport(a.v.Load()) 118 } 119 120 func (a *atomicDialTFOSupport) storeNone() { 121 a.v.Store(uint32(dialTFOSupportNone)) 122 } 123 124 var runtimeDialTFOSupport atomicDialTFOSupport 125 126 // Dialer wraps [net.Dialer] with an additional option that allows you to disable TFO. 127 type Dialer struct { 128 net.Dialer 129 130 // DisableTFO controls whether TCP Fast Open is disabled when the dial methods are called. 131 // TFO is enabled by default. 132 // Set to true to disable TFO and it will behave exactly the same as [net.Dialer]. 133 DisableTFO bool 134 135 // Fallback controls whether to proceed without TFO when TFO is enabled but not supported 136 // on the system. 137 // On Linux this also controls whether the sendto(MSG_FASTOPEN) fallback path is tried 138 // before giving up on TFO. 139 Fallback bool 140 } 141 142 func (d *Dialer) dialAndWrite(ctx context.Context, network, address string, b []byte) (net.Conn, error) { 143 c, err := d.Dialer.DialContext(ctx, network, address) 144 if err != nil { 145 return nil, err 146 } 147 if err = netConnWriteBytes(ctx, c, b); err != nil { 148 c.Close() 149 return nil, err 150 } 151 return c, nil 152 } 153 154 func (d *Dialer) dialAndWriteTCPConn(ctx context.Context, network, address string, b []byte) (*net.TCPConn, error) { 155 c, err := d.Dialer.DialContext(ctx, network, address) 156 if err != nil { 157 return nil, err 158 } 159 if err = netConnWriteBytes(ctx, c, b); err != nil { 160 c.Close() 161 return nil, err 162 } 163 return c.(*net.TCPConn), nil 164 } 165 166 // DialContext is like [net.Dialer.DialContext] but enables TFO whenever possible, 167 // unless [Dialer.DisableTFO] is set to true. 168 func (d *Dialer) DialContext(ctx context.Context, network, address string, b []byte) (net.Conn, error) { 169 if len(b) == 0 { 170 return d.Dialer.DialContext(ctx, network, address) 171 } 172 if d.DisableTFO || !networkIsTCP(network) { 173 return d.dialAndWrite(ctx, network, address, b) 174 } 175 return d.dialTFO(ctx, network, address, b) // tfo_bsd+windows.go, tfo_linux.go, tfo_unsupported.go 176 } 177 178 // Dial is like [net.Dialer.Dial] but enables TFO whenever possible, 179 // unless [Dialer.DisableTFO] is set to true. 180 func (d *Dialer) Dial(network, address string, b []byte) (net.Conn, error) { 181 return d.DialContext(context.Background(), network, address, b) 182 } 183 184 // Dial is like [net.Dial] but enables TFO whenever possible. 185 func Dial(network, address string, b []byte) (net.Conn, error) { 186 var d Dialer 187 return d.DialContext(context.Background(), network, address, b) 188 } 189 190 // DialTimeout is like [net.DialTimeout] but enables TFO whenever possible. 191 func DialTimeout(network, address string, timeout time.Duration, b []byte) (net.Conn, error) { 192 var d Dialer 193 d.Timeout = timeout 194 return d.DialContext(context.Background(), network, address, b) 195 } 196 197 // DialTCP is like [net.DialTCP] but enables TFO whenever possible. 198 func DialTCP(network string, laddr, raddr *net.TCPAddr, b []byte) (*net.TCPConn, error) { 199 if len(b) == 0 { 200 return net.DialTCP(network, laddr, raddr) 201 } 202 if !networkIsTCP(network) { 203 return nil, &net.OpError{Op: "dial", Net: network, Source: opAddr(laddr), Addr: opAddr(raddr), Err: net.UnknownNetworkError(network)} 204 } 205 if raddr == nil { 206 return nil, &net.OpError{Op: "dial", Net: network, Source: opAddr(laddr), Addr: nil, Err: errMissingAddress} 207 } 208 return dialTCPAddr(network, laddr, raddr, b) // tfo_bsd+windows.go, tfo_linux.go, tfo_unsupported.go 209 } 210 211 func networkIsTCP(network string) bool { 212 switch network { 213 case "tcp", "tcp4", "tcp6": 214 return true 215 default: 216 return false 217 } 218 } 219 220 func opAddr(a *net.TCPAddr) net.Addr { 221 if a == nil { 222 return nil 223 } 224 return a 225 } 226 227 // wrapSyscallError takes an error and a syscall name. If the error is 228 // a syscall.Errno, it wraps it in a os.SyscallError using the syscall name. 229 func wrapSyscallError(name string, err error) error { 230 if _, ok := err.(syscall.Errno); ok { 231 err = os.NewSyscallError(name, err) 232 } 233 return err 234 } 235 236 // aLongTimeAgo is a non-zero time, far in the past, used for immediate deadlines. 237 var aLongTimeAgo = time.Unix(0, 0) 238 239 // writeDeadliner allows cancellation of ongoing write operations. 240 type writeDeadliner interface { 241 SetWriteDeadline(t time.Time) error 242 } 243 244 // connWriteFunc invokes the given function on a [writeDeadliner] to execute any arbitrary write operation. 245 // If the given context can be canceled, it will spin up an interruptor goroutine to cancel the write operation 246 // when the context is canceled. 247 func connWriteFunc[C writeDeadliner](ctx context.Context, c C, fn func(C) error) (err error) { 248 if ctxDone := ctx.Done(); ctxDone != nil { 249 done := make(chan struct{}) 250 interruptRes := make(chan error) 251 252 defer func() { 253 close(done) 254 if ctxErr := <-interruptRes; ctxErr != nil && err == nil { 255 err = ctxErr 256 } 257 }() 258 259 go func() { 260 select { 261 case <-ctxDone: 262 c.SetWriteDeadline(aLongTimeAgo) 263 interruptRes <- ctx.Err() 264 case <-done: 265 interruptRes <- nil 266 } 267 }() 268 } 269 270 return fn(c) 271 } 272 273 // netConnWriteBytes is a convenience wrapper around [connWriteFunc] for writing bytes to a [net.Conn]. 274 func netConnWriteBytes(ctx context.Context, c net.Conn, b []byte) error { 275 return connWriteFunc(ctx, c, func(c net.Conn) error { 276 _, err := c.Write(b) 277 return err 278 }) 279 }