lab.nexedi.com/kirr/go123@v0.0.0-20240207185015-8299741fa871/xnet/net.go (about) 1 // Copyright (C) 2017-2020 Nexedi SA and Contributors. 2 // Kirill Smelkov <kirr@nexedi.com> 3 // 4 // This program is free software: you can Use, Study, Modify and Redistribute 5 // it under the terms of the GNU General Public License version 3, or (at your 6 // option) any later version, as published by the Free Software Foundation. 7 // 8 // You can also Link and Combine this program with other software covered by 9 // the terms of any of the Free Software licenses or any of the Open Source 10 // Initiative approved licenses and Convey the resulting work. Corresponding 11 // source of such a combination shall include the source code for all other 12 // software used. 13 // 14 // This program is distributed WITHOUT ANY WARRANTY; without even the implied 15 // warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 16 // 17 // See COPYING file for full licensing terms. 18 // See https://www.nexedi.com/licensing for rationale and options. 19 20 // Package xnet provides addons to std package net. 21 package xnet 22 23 import ( 24 "context" 25 "errors" 26 "fmt" 27 "net" 28 "os" 29 30 "crypto/tls" 31 32 "lab.nexedi.com/kirr/go123/xcontext" 33 "lab.nexedi.com/kirr/go123/xsync" 34 ) 35 36 // Networker is interface representing access-point to a streaming network. 37 type Networker interface { 38 // Network returns name of the network. 39 Network() string 40 41 // Name returns name of the access-point on the network. 42 // 43 // Example of name is local hostname if networker provides access to 44 // OS-level dial/listen. 45 Name() string 46 47 // Dial connects to addr on underlying network. 48 // 49 // See net.Dial for semantic details. 50 Dial(ctx context.Context, addr string) (net.Conn, error) 51 52 // Listen starts listening on local address laddr on underlying network access-point. 53 // 54 // See net.Listen for semantic details. 55 Listen(ctx context.Context, laddr string) (Listener, error) 56 57 // Close releases resources associated with the network access-point. 58 // 59 // In-progress and future network operations such as Dial and Listen, 60 // originated via this access-point, will return with an error. 61 Close() error 62 } 63 64 // Listener amends net.Listener for Accept to handle cancellation. 65 type Listener interface { 66 Accept(ctx context.Context) (net.Conn, error) 67 68 // same as in net.Listener 69 Close() error 70 Addr() net.Addr 71 } 72 73 74 var hostname string 75 func init() { 76 host, err := os.Hostname() 77 if err != nil { 78 panic(fmt.Errorf("cannot detect hostname: %s", err)) 79 } 80 hostname = host 81 } 82 83 var errNetClosed = errors.New("network access-point is closed") 84 85 86 // NetPlain creates Networker corresponding to regular network accessors from std package net. 87 // 88 // network is "tcp", "tcp4", "tcp6", "unix", etc... 89 func NetPlain(network string) Networker { 90 n := &netPlain{network: network, hostname: hostname} 91 n.ctx, n.cancel = context.WithCancel(context.Background()) 92 return n 93 } 94 95 type netPlain struct { 96 network, hostname string 97 98 // ctx.cancel is merged into context of network operations. 99 // ctx is cancelled on Close. 100 ctx context.Context 101 cancel func() 102 } 103 104 func (n *netPlain) Network() string { 105 return n.network 106 } 107 108 func (n *netPlain) Name() string { 109 return n.hostname 110 } 111 112 func (n *netPlain) Close() error { 113 n.cancel() 114 return nil 115 } 116 117 func (n *netPlain) Dial(ctx context.Context, addr string) (net.Conn, error) { 118 ctx, cancel := xcontext.Merge(ctx, n.ctx) 119 defer cancel() 120 121 dialErr := func(err error) error { 122 return &net.OpError{Op: "dial", Net: n.network, Addr: &strAddr{n.network, addr}, Err: err} 123 } 124 125 // don't try to call Dial if already closed / canceled 126 var conn net.Conn 127 err := ctx.Err() 128 if err == nil { 129 d := net.Dialer{} 130 conn, err = d.DialContext(ctx, n.network, addr) 131 } else { 132 err = dialErr(err) 133 } 134 135 if err != nil { 136 // convert n.ctx cancel -> "closed" error 137 if n.ctx.Err() != nil { 138 switch e := err.(type) { 139 case *net.OpError: 140 e.Err = errNetClosed 141 default: 142 // just in case 143 err = dialErr(errNetClosed) 144 } 145 } 146 } 147 return conn, err 148 } 149 150 func (n *netPlain) Listen(ctx context.Context, laddr string) (Listener, error) { 151 ctx, cancel := xcontext.Merge(ctx, n.ctx) 152 defer cancel() 153 154 listenErr := func(err error) error { 155 return &net.OpError{Op: "listen", Net: n.network, Addr: &strAddr{n.network, laddr}, Err: err} 156 } 157 158 // don't try to call Listen if already closed / canceled 159 var rawl net.Listener 160 err := ctx.Err() 161 if err == nil { 162 lc := net.ListenConfig{} 163 rawl, err = lc.Listen(ctx, n.network, laddr) 164 } else { 165 err = listenErr(err) 166 } 167 168 if err != nil { 169 // convert n.ctx cancel -> "closed" error 170 if n.ctx.Err() != nil { 171 switch e := err.(type) { 172 case *net.OpError: 173 e.Err = errNetClosed 174 default: 175 // just in case 176 err = listenErr(errNetClosed) 177 } 178 } 179 return nil, err 180 } 181 182 return WithCtxL(rawl), nil 183 } 184 185 // NetTLS wraps underlying networker with TLS layer according to config. 186 // 187 // The config must be valid: 188 // 189 // - for tls.Client -- for Dial to work, 190 // - for tls.Server -- for Listen to work. 191 func NetTLS(inner Networker, config *tls.Config) Networker { 192 return &netTLS{inner, config} 193 } 194 195 type netTLS struct { 196 inner Networker 197 config *tls.Config 198 } 199 200 func (n *netTLS) Network() string { 201 return n.inner.Network() + "+tls" 202 } 203 204 func (n *netTLS) Name() string { 205 return n.inner.Name() 206 } 207 208 func (n *netTLS) Close() error { 209 return n.inner.Close() 210 } 211 212 func (n *netTLS) Dial(ctx context.Context, addr string) (net.Conn, error) { 213 c, err := n.inner.Dial(ctx, addr) 214 if err != nil { 215 return nil, err 216 } 217 return tls.Client(c, n.config), nil 218 } 219 220 func (n *netTLS) Listen(ctx context.Context, laddr string) (Listener, error) { 221 l, err := n.inner.Listen(ctx, laddr) 222 if err != nil { 223 return nil, err 224 } 225 return &listenerTLS{l, n}, nil 226 } 227 228 // listenerTLS implements Listener for netTLS. 229 type listenerTLS struct { 230 innerl Listener 231 net *netTLS 232 } 233 234 func (l *listenerTLS) Close() error { 235 return l.innerl.Close() 236 } 237 238 func (l *listenerTLS) Addr() net.Addr { 239 return l.innerl.Addr() 240 } 241 242 func (l *listenerTLS) Accept(ctx context.Context) (net.Conn, error) { 243 conn, err := l.innerl.Accept(ctx) 244 if err != nil { 245 return nil, err 246 } 247 return tls.Server(conn, l.net.config), nil 248 } 249 250 251 // ---- misc ---- 252 253 // strAddr turns string into net.Addr. 254 type strAddr struct { 255 net string 256 addr string 257 } 258 func (a *strAddr) Network() string { return a.net } 259 func (a *strAddr) String() string { return a.addr } 260 261 262 // ---------------------------------------- 263 264 // BindCtx*(xnet.X, ctx) -> net.X 265 266 // BindCtxL binds Listener l and ctx into net.Listener which passes ctx to l on every Accept. 267 func BindCtxL(l Listener, ctx context.Context) net.Listener { 268 // NOTE even if l is listenerCtx we cannot return raw underlying listener 269 // because listenerCtx continues to call Accept in its serve goroutine. 270 // -> always wrap with bindCtx. 271 return &bindCtxL{l, ctx} 272 } 273 type bindCtxL struct {l Listener; ctx context.Context} 274 func (b *bindCtxL) Accept() (net.Conn, error) { return b.l.Accept(b.ctx) } 275 func (b *bindCtxL) Close() error { return b.l.Close() } 276 func (b *bindCtxL) Addr() net.Addr { return b.l.Addr() } 277 278 // WithCtx*(net.X) -> xnet.X that handles ctx. 279 280 // WithCtxL converts net.Listener l into Listener that accepts ctx in Accept. 281 // 282 // It returns original xnet object if l was created via BindCtx*. 283 func WithCtxL(l net.Listener) Listener { 284 // WithCtx(BindCtx(X)) = X 285 switch b := l.(type) { 286 case *bindCtxL: return b.l 287 } 288 289 return newListenerCtx(l) 290 } 291 292 293 // listenerCtx provides Listener given net.Listener. 294 type listenerCtx struct { 295 rawl net.Listener // underlying listener 296 serveWG *xsync.WorkGroup // Accept loop is run under serveWG 297 serveCancel func() // Close calls serveCancel to request Accept loop shutdown 298 acceptq chan accepted // Accept results go -> acceptq 299 } 300 301 // accepted represents Accept result. 302 type accepted struct { 303 conn net.Conn 304 err error 305 } 306 307 func newListenerCtx(rawl net.Listener) *listenerCtx { 308 l := &listenerCtx{rawl: rawl, acceptq: make(chan accepted)} 309 ctx, cancel := context.WithCancel(context.Background()) 310 l.serveWG = xsync.NewWorkGroup(ctx) 311 l.serveCancel = cancel 312 l.serveWG.Go(l.serve) 313 return l 314 } 315 316 func (l *listenerCtx) serve(ctx context.Context) error { 317 for { 318 // raw Accept. This should not stuck overliving ctx as Close closes rawl 319 conn, err := l.rawl.Accept() 320 321 // send result to Accept, but don't try to send if we are closed 322 ctxErr := ctx.Err() 323 if ctxErr == nil { 324 select { 325 case <-ctx.Done(): 326 // closed 327 ctxErr = ctx.Err() 328 329 case l.acceptq <- accepted{conn, err}: 330 // ok 331 } 332 } 333 // shutdown if we are closed 334 if ctxErr != nil { 335 if conn != nil { 336 conn.Close() // ignore err 337 } 338 return ctxErr 339 } 340 } 341 } 342 343 func (l *listenerCtx) Close() error { 344 l.serveCancel() 345 err := l.rawl.Close() 346 _ = l.serveWG.Wait() // ignore err - it is always "canceled" 347 return err 348 } 349 350 func (l *listenerCtx) Accept(ctx context.Context) (_ net.Conn, err error) { 351 err = ctx.Err() 352 353 // don't try to pull from acceptq if ctx is already canceled 354 if err == nil { 355 select { 356 case <-ctx.Done(): 357 err = ctx.Err() 358 359 case a := <-l.acceptq: 360 return a.conn, a.err 361 } 362 } 363 364 // here it is always due to ctx cancel 365 laddr := l.rawl.Addr() 366 return nil, &net.OpError{ 367 Op: "accept", 368 Net: laddr.Network(), 369 Source: nil, 370 Addr: laddr, 371 Err: err, 372 } 373 } 374 375 func (l *listenerCtx) Addr() net.Addr { 376 return l.rawl.Addr() 377 }