github.com/mholt/caddy-l4@v0.0.0-20241104153248-ec8fae209322/layer4/listener.go (about)

     1  package layer4
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"net"
     8  	"runtime"
     9  	"sync"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"github.com/caddyserver/caddy/v2"
    14  	"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
    15  	"go.uber.org/zap"
    16  )
    17  
    18  func init() {
    19  	caddy.RegisterModule(&ListenerWrapper{})
    20  }
    21  
    22  // ListenerWrapper is a Caddy module that wraps App as a listener wrapper, it doesn't support udp.
    23  type ListenerWrapper struct {
    24  	// Routes express composable logic for handling byte streams.
    25  	Routes RouteList `json:"routes,omitempty"`
    26  
    27  	// Maximum time connections have to complete the matching phase (the first terminal handler is matched). Default: 3s.
    28  	MatchingTimeout caddy.Duration `json:"matching_timeout,omitempty"`
    29  
    30  	compiledRoute Handler
    31  
    32  	logger *zap.Logger
    33  	ctx    caddy.Context
    34  }
    35  
    36  // CaddyModule returns the Caddy module information.
    37  func (*ListenerWrapper) CaddyModule() caddy.ModuleInfo {
    38  	return caddy.ModuleInfo{
    39  		ID:  "caddy.listeners.layer4",
    40  		New: func() caddy.Module { return new(ListenerWrapper) },
    41  	}
    42  }
    43  
    44  // Provision sets up the ListenerWrapper.
    45  func (lw *ListenerWrapper) Provision(ctx caddy.Context) error {
    46  	lw.ctx = ctx
    47  	lw.logger = ctx.Logger()
    48  
    49  	if lw.MatchingTimeout <= 0 {
    50  		lw.MatchingTimeout = caddy.Duration(MatchingTimeoutDefault)
    51  	}
    52  
    53  	err := lw.Routes.Provision(ctx)
    54  	if err != nil {
    55  		return err
    56  	}
    57  	lw.compiledRoute = lw.Routes.Compile(lw.logger, time.Duration(lw.MatchingTimeout), listenerHandler{})
    58  
    59  	return nil
    60  }
    61  
    62  func (lw *ListenerWrapper) WrapListener(l net.Listener) net.Listener {
    63  	// TODO make channel capacity configurable
    64  	connChan := make(chan net.Conn, runtime.GOMAXPROCS(0))
    65  	li := &listener{
    66  		Listener:      l,
    67  		logger:        lw.logger,
    68  		compiledRoute: lw.compiledRoute,
    69  		done:          make(chan struct{}),
    70  		connChan:      connChan,
    71  		wg:            new(sync.WaitGroup),
    72  	}
    73  	go li.loop()
    74  	return li
    75  }
    76  
    77  // UnmarshalCaddyfile sets up the ListenerWrapper from Caddyfile tokens. Syntax:
    78  //
    79  //	layer4 {
    80  //		matching_timeout <duration>
    81  //		@a <matcher> [<matcher_args>]
    82  //		@b {
    83  //			<matcher> [<matcher_args>]
    84  //			<matcher> [<matcher_args>]
    85  //		}
    86  //		route @a @b {
    87  //			<handler> [<handler_args>]
    88  //		}
    89  //		@c <matcher> {
    90  //			<matcher_option> [<matcher_option_args>]
    91  //		}
    92  //		route @c {
    93  //			<handler> [<handler_args>]
    94  //			<handler> {
    95  //				<handler_option> [<handler_option_args>]
    96  //			}
    97  //		}
    98  //	}
    99  func (lw *ListenerWrapper) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
   100  	d.Next() // consume wrapper name
   101  
   102  	// No same-line options are supported
   103  	if d.CountRemainingArgs() > 0 {
   104  		return d.ArgErr()
   105  	}
   106  
   107  	if err := ParseCaddyfileNestedRoutes(d, &lw.Routes, &lw.MatchingTimeout); err != nil {
   108  		return err
   109  	}
   110  
   111  	return nil
   112  }
   113  
   114  type listener struct {
   115  	net.Listener
   116  	logger        *zap.Logger
   117  	compiledRoute Handler
   118  
   119  	closed atomic.Bool
   120  	done   chan struct{}
   121  	// closed when there is a non-recoverable error and all handle goroutines are done
   122  	connChan chan net.Conn
   123  
   124  	// count running handles
   125  	wg *sync.WaitGroup
   126  }
   127  
   128  func (l *listener) Close() error {
   129  	l.closed.Store(true)
   130  	return l.Listener.Close()
   131  }
   132  
   133  // loop accept connection from underlying listener and pipe the connection if there are any
   134  func (l *listener) loop() {
   135  	for {
   136  		conn, err := l.Listener.Accept()
   137  		var nerr net.Error
   138  		if errors.As(err, &nerr) && nerr.Temporary() && !l.closed.Load() {
   139  			l.logger.Error("temporary error accepting connection", zap.Error(err))
   140  			continue
   141  		}
   142  		if err != nil {
   143  			break
   144  		}
   145  
   146  		l.wg.Add(1)
   147  		go l.handle(conn)
   148  	}
   149  
   150  	// closing remaining conns in channel to release resources
   151  	go func() {
   152  		l.wg.Wait()
   153  		close(l.connChan)
   154  	}()
   155  	close(l.done)
   156  	for conn := range l.connChan {
   157  		_ = conn.Close()
   158  	}
   159  }
   160  
   161  // errHijacked is used when a handler takes over the connection, it's lifetime is not managed by handle
   162  var errHijacked = errors.New("hijacked connection")
   163  
   164  func (l *listener) handle(conn net.Conn) {
   165  	var err error
   166  	defer func() {
   167  		l.wg.Done()
   168  		if !errors.Is(err, errHijacked) {
   169  			_ = conn.Close()
   170  		}
   171  	}()
   172  
   173  	buf := bufPool.Get().([]byte)
   174  	buf = buf[:0]
   175  	defer bufPool.Put(buf)
   176  
   177  	cx := WrapConnection(conn, buf, l.logger)
   178  	cx.Context = context.WithValue(cx.Context, listenerCtxKey, l)
   179  
   180  	start := time.Now()
   181  	err = l.compiledRoute.Handle(cx)
   182  	duration := time.Since(start)
   183  	if err != nil && !errors.Is(err, errHijacked) {
   184  		l.logger.Error("handling connection", zap.Error(err))
   185  	}
   186  
   187  	l.logger.Debug("connection stats",
   188  		zap.String("remote", cx.RemoteAddr().String()),
   189  		zap.Uint64("read", cx.bytesRead),
   190  		zap.Uint64("written", cx.bytesWritten),
   191  		zap.Duration("duration", duration),
   192  	)
   193  }
   194  
   195  func (l *listener) Accept() (net.Conn, error) {
   196  	select {
   197  	case conn, ok := <-l.connChan:
   198  		if ok {
   199  			return conn, nil
   200  		}
   201  		return nil, net.ErrClosed
   202  	case <-l.done:
   203  		return nil, net.ErrClosed
   204  	}
   205  }
   206  
   207  func (l *listener) pipeConnection(conn *Connection) error {
   208  	// can't use l4tls.GetConnectionStates because of import cycle
   209  	// TODO export tls_connection_states as a special constant
   210  	var connectionStates []*tls.ConnectionState
   211  	if val := conn.GetVar("tls_connection_states"); val != nil {
   212  		connectionStates = val.([]*tls.ConnectionState)
   213  	}
   214  	if len(connectionStates) > 0 {
   215  		l.connChan <- &tlsConnection{
   216  			Conn:      conn,
   217  			connState: connectionStates[len(connectionStates)-1],
   218  		}
   219  	} else {
   220  		l.connChan <- conn
   221  	}
   222  	return errHijacked
   223  }
   224  
   225  // tlsConnection implements ConnectionState interface to use it with h2
   226  type tlsConnection struct {
   227  	net.Conn
   228  	connState *tls.ConnectionState
   229  }
   230  
   231  func (tc *tlsConnection) ConnectionState() tls.ConnectionState {
   232  	return *tc.connState
   233  }
   234  
   235  // Interface guards
   236  var (
   237  	_ caddy.Module          = (*ListenerWrapper)(nil)
   238  	_ caddy.ListenerWrapper = (*ListenerWrapper)(nil)
   239  	_ caddyfile.Unmarshaler = (*ListenerWrapper)(nil)
   240  )