lab.nexedi.com/kirr/go123@v0.0.0-20240207185015-8299741fa871/xnet/net.go (about)

     1  // Copyright (C) 2017-2020  Nexedi SA and Contributors.
     2  //                          Kirill Smelkov <kirr@nexedi.com>
     3  //
     4  // This program is free software: you can Use, Study, Modify and Redistribute
     5  // it under the terms of the GNU General Public License version 3, or (at your
     6  // option) any later version, as published by the Free Software Foundation.
     7  //
     8  // You can also Link and Combine this program with other software covered by
     9  // the terms of any of the Free Software licenses or any of the Open Source
    10  // Initiative approved licenses and Convey the resulting work. Corresponding
    11  // source of such a combination shall include the source code for all other
    12  // software used.
    13  //
    14  // This program is distributed WITHOUT ANY WARRANTY; without even the implied
    15  // warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
    16  //
    17  // See COPYING file for full licensing terms.
    18  // See https://www.nexedi.com/licensing for rationale and options.
    19  
    20  // Package xnet provides addons to std package net.
    21  package xnet
    22  
    23  import (
    24  	"context"
    25  	"errors"
    26  	"fmt"
    27  	"net"
    28  	"os"
    29  
    30  	"crypto/tls"
    31  
    32  	"lab.nexedi.com/kirr/go123/xcontext"
    33  	"lab.nexedi.com/kirr/go123/xsync"
    34  )
    35  
    36  // Networker is interface representing access-point to a streaming network.
    37  type Networker interface {
    38  	// Network returns name of the network.
    39  	Network() string
    40  
    41  	// Name returns name of the access-point on the network.
    42  	//
    43  	// Example of name is local hostname if networker provides access to
    44  	// OS-level dial/listen.
    45  	Name() string
    46  
    47  	// Dial connects to addr on underlying network.
    48  	//
    49  	// See net.Dial for semantic details.
    50  	Dial(ctx context.Context, addr string) (net.Conn, error)
    51  
    52  	// Listen starts listening on local address laddr on underlying network access-point.
    53  	//
    54  	// See net.Listen for semantic details.
    55  	Listen(ctx context.Context, laddr string) (Listener, error)
    56  
    57  	// Close releases resources associated with the network access-point.
    58  	//
    59  	// In-progress and future network operations such as Dial and Listen,
    60  	// originated via this access-point, will return with an error.
    61  	Close() error
    62  }
    63  
    64  // Listener amends net.Listener for Accept to handle cancellation.
    65  type Listener interface {
    66  	Accept(ctx context.Context) (net.Conn, error)
    67  
    68  	// same as in net.Listener
    69  	Close() error
    70  	Addr() net.Addr
    71  }
    72  
    73  
    74  var hostname string
    75  func init() {
    76  	host, err := os.Hostname()
    77  	if err != nil {
    78  		panic(fmt.Errorf("cannot detect hostname: %s", err))
    79  	}
    80  	hostname = host
    81  }
    82  
    83  var errNetClosed = errors.New("network access-point is closed")
    84  
    85  
    86  // NetPlain creates Networker corresponding to regular network accessors from std package net.
    87  //
    88  // network is "tcp", "tcp4", "tcp6", "unix", etc...
    89  func NetPlain(network string) Networker {
    90  	n := &netPlain{network: network, hostname: hostname}
    91  	n.ctx, n.cancel = context.WithCancel(context.Background())
    92  	return n
    93  }
    94  
    95  type netPlain struct {
    96  	network, hostname string
    97  
    98  	// ctx.cancel is merged into context of network operations.
    99  	// ctx is cancelled on Close.
   100  	ctx    context.Context
   101  	cancel func()
   102  }
   103  
   104  func (n *netPlain) Network() string {
   105  	return n.network
   106  }
   107  
   108  func (n *netPlain) Name() string {
   109  	return n.hostname
   110  }
   111  
   112  func (n *netPlain) Close() error {
   113  	n.cancel()
   114  	return nil
   115  }
   116  
   117  func (n *netPlain) Dial(ctx context.Context, addr string) (net.Conn, error) {
   118  	ctx, cancel := xcontext.Merge(ctx, n.ctx)
   119  	defer cancel()
   120  
   121  	dialErr := func(err error) error {
   122  		return &net.OpError{Op: "dial", Net: n.network, Addr: &strAddr{n.network, addr}, Err: err}
   123  	}
   124  
   125  	// don't try to call Dial if already closed / canceled
   126  	var conn net.Conn
   127  	err := ctx.Err()
   128  	if err == nil {
   129  		d := net.Dialer{}
   130  		conn, err = d.DialContext(ctx, n.network, addr)
   131  	} else {
   132  		err = dialErr(err)
   133  	}
   134  
   135  	if err != nil {
   136  		// convert n.ctx cancel -> "closed" error
   137  		if n.ctx.Err() != nil {
   138  			switch e := err.(type) {
   139  			case *net.OpError:
   140  				e.Err = errNetClosed
   141  			default:
   142  				// just in case
   143  				err = dialErr(errNetClosed)
   144  			}
   145  		}
   146  	}
   147  	return conn, err
   148  }
   149  
   150  func (n *netPlain) Listen(ctx context.Context, laddr string) (Listener, error) {
   151  	ctx, cancel := xcontext.Merge(ctx, n.ctx)
   152  	defer cancel()
   153  
   154  	listenErr := func(err error) error {
   155  		return &net.OpError{Op: "listen", Net: n.network, Addr: &strAddr{n.network, laddr}, Err: err}
   156  	}
   157  
   158  	// don't try to call Listen if already closed / canceled
   159  	var rawl net.Listener
   160  	err := ctx.Err()
   161  	if err == nil {
   162  		lc := net.ListenConfig{}
   163  		rawl, err = lc.Listen(ctx, n.network, laddr)
   164  	} else {
   165  		err = listenErr(err)
   166  	}
   167  
   168  	if err != nil {
   169  		// convert n.ctx cancel -> "closed" error
   170  		if n.ctx.Err() != nil {
   171  			switch e := err.(type) {
   172  			case *net.OpError:
   173  				e.Err = errNetClosed
   174  			default:
   175  				// just in case
   176  				err = listenErr(errNetClosed)
   177  			}
   178  		}
   179  		return nil, err
   180  	}
   181  
   182  	return WithCtxL(rawl), nil
   183  }
   184  
   185  // NetTLS wraps underlying networker with TLS layer according to config.
   186  //
   187  // The config must be valid:
   188  //
   189  //	- for tls.Client -- for Dial to work,
   190  //	- for tls.Server -- for Listen to work.
   191  func NetTLS(inner Networker, config *tls.Config) Networker {
   192  	return &netTLS{inner, config}
   193  }
   194  
   195  type netTLS struct {
   196  	inner  Networker
   197  	config *tls.Config
   198  }
   199  
   200  func (n *netTLS) Network() string {
   201  	return n.inner.Network() + "+tls"
   202  }
   203  
   204  func (n *netTLS) Name() string {
   205  	return n.inner.Name()
   206  }
   207  
   208  func (n *netTLS) Close() error {
   209  	return n.inner.Close()
   210  }
   211  
   212  func (n *netTLS) Dial(ctx context.Context, addr string) (net.Conn, error) {
   213  	c, err := n.inner.Dial(ctx, addr)
   214  	if err != nil {
   215  		return nil, err
   216  	}
   217  	return tls.Client(c, n.config), nil
   218  }
   219  
   220  func (n *netTLS) Listen(ctx context.Context, laddr string) (Listener, error) {
   221  	l, err := n.inner.Listen(ctx, laddr)
   222  	if err != nil {
   223  		return nil, err
   224  	}
   225  	return &listenerTLS{l, n}, nil
   226  }
   227  
   228  // listenerTLS implements Listener for netTLS.
   229  type listenerTLS struct {
   230  	innerl Listener
   231  	net    *netTLS
   232  }
   233  
   234  func (l *listenerTLS) Close() error {
   235  	return l.innerl.Close()
   236  }
   237  
   238  func (l *listenerTLS) Addr() net.Addr {
   239  	return l.innerl.Addr()
   240  }
   241  
   242  func (l *listenerTLS) Accept(ctx context.Context) (net.Conn, error) {
   243  	conn, err := l.innerl.Accept(ctx)
   244  	if err != nil {
   245  		return nil, err
   246  	}
   247  	return tls.Server(conn, l.net.config), nil
   248  }
   249  
   250  
   251  // ---- misc ----
   252  
   253  // strAddr turns string into net.Addr.
   254  type strAddr struct {
   255  	net  string
   256  	addr string
   257  }
   258  func (a *strAddr) Network() string { return a.net  }
   259  func (a *strAddr) String()  string { return a.addr }
   260  
   261  
   262  // ----------------------------------------
   263  
   264  // BindCtx*(xnet.X, ctx) -> net.X
   265  
   266  // BindCtxL binds Listener l and ctx into net.Listener which passes ctx to l on every Accept.
   267  func BindCtxL(l Listener, ctx context.Context) net.Listener {
   268  	// NOTE even if l is listenerCtx we cannot return raw underlying listener
   269  	// because listenerCtx continues to call Accept in its serve goroutine.
   270  	// -> always wrap with bindCtx.
   271  	return &bindCtxL{l, ctx}
   272  }
   273  type bindCtxL struct {l Listener; ctx context.Context}
   274  func (b *bindCtxL) Accept() (net.Conn, error)  { return b.l.Accept(b.ctx) }
   275  func (b *bindCtxL) Close() error               { return b.l.Close() }
   276  func (b *bindCtxL) Addr() net.Addr             { return b.l.Addr() }
   277  
   278  // WithCtx*(net.X) -> xnet.X that handles ctx.
   279  
   280  // WithCtxL converts net.Listener l into Listener that accepts ctx in Accept.
   281  //
   282  // It returns original xnet object if l was created via BindCtx*.
   283  func WithCtxL(l net.Listener) Listener {
   284  	// WithCtx(BindCtx(X)) = X
   285  	switch b := l.(type) {
   286  	case *bindCtxL: return b.l
   287  	}
   288  
   289  	return newListenerCtx(l)
   290  }
   291  
   292  
   293  // listenerCtx provides Listener given net.Listener.
   294  type listenerCtx struct {
   295  	rawl        net.Listener     // underlying listener
   296  	serveWG     *xsync.WorkGroup // Accept loop is run under serveWG
   297  	serveCancel func()           // Close calls serveCancel to request Accept loop shutdown
   298  	acceptq     chan accepted    // Accept results go -> acceptq
   299  }
   300  
   301  // accepted represents Accept result.
   302  type accepted struct {
   303  	conn net.Conn
   304  	err  error
   305  }
   306  
   307  func newListenerCtx(rawl net.Listener) *listenerCtx {
   308  	l := &listenerCtx{rawl: rawl, acceptq: make(chan accepted)}
   309  	ctx, cancel := context.WithCancel(context.Background())
   310  	l.serveWG = xsync.NewWorkGroup(ctx)
   311  	l.serveCancel = cancel
   312  	l.serveWG.Go(l.serve)
   313  	return l
   314  }
   315  
   316  func (l *listenerCtx) serve(ctx context.Context) error {
   317  	for {
   318  		// raw Accept. This should not stuck overliving ctx as Close closes rawl
   319  		conn, err := l.rawl.Accept()
   320  
   321  		// send result to Accept, but don't try to send if we are closed
   322  		ctxErr := ctx.Err()
   323  		if ctxErr == nil {
   324  			select {
   325  			case <-ctx.Done():
   326  				// closed
   327  				ctxErr = ctx.Err()
   328  
   329  			case l.acceptq <- accepted{conn, err}:
   330  				// ok
   331  			}
   332  		}
   333  		// shutdown if we are closed
   334  		if ctxErr != nil {
   335  			if conn != nil {
   336  				conn.Close() // ignore err
   337  			}
   338  			return ctxErr
   339  		}
   340  	}
   341  }
   342  
   343  func (l *listenerCtx) Close() error {
   344  	l.serveCancel()
   345  	err := l.rawl.Close()
   346  	_ = l.serveWG.Wait() // ignore err - it is always "canceled"
   347  	return err
   348  }
   349  
   350  func (l *listenerCtx) Accept(ctx context.Context) (_ net.Conn, err error) {
   351  	err = ctx.Err()
   352  
   353  	// don't try to pull from acceptq if ctx is already canceled
   354  	if err == nil {
   355  		select {
   356  		case <-ctx.Done():
   357  			err = ctx.Err()
   358  
   359  		case a := <-l.acceptq:
   360  			return a.conn, a.err
   361  		}
   362  	}
   363  
   364  	// here it is always due to ctx cancel
   365  	laddr := l.rawl.Addr()
   366  	return nil, &net.OpError{
   367  		Op:     "accept",
   368  		Net:    laddr.Network(),
   369  		Source: nil,
   370  		Addr:   laddr,
   371  		Err:    err,
   372  	}
   373  }
   374  
   375  func (l *listenerCtx) Addr() net.Addr {
   376  	return l.rawl.Addr()
   377  }