github.com/mholt/caddy-l4@v0.0.0-20241104153248-ec8fae209322/layer4/server.go (about)

     1  // Copyright 2020 Matthew Holt
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package layer4
    16  
    17  import (
    18  	"bytes"
    19  	"errors"
    20  	"fmt"
    21  	"io"
    22  	"net"
    23  	"os"
    24  	"sync"
    25  	"sync/atomic"
    26  	"time"
    27  
    28  	"github.com/caddyserver/caddy/v2"
    29  	"github.com/caddyserver/caddy/v2/caddyconfig/caddyfile"
    30  	"go.uber.org/zap"
    31  )
    32  
    33  const MatchingTimeoutDefault = 3 * time.Second
    34  
    35  // Server represents a Caddy layer4 server.
    36  type Server struct {
    37  	// The network address to bind to. Any Caddy network address
    38  	// is an acceptable value:
    39  	// https://caddyserver.com/docs/conventions#network-addresses
    40  	Listen []string `json:"listen,omitempty"`
    41  
    42  	// Routes express composable logic for handling byte streams.
    43  	Routes RouteList `json:"routes,omitempty"`
    44  
    45  	// Maximum time connections have to complete the matching phase (the first terminal handler is matched). Default: 3s.
    46  	MatchingTimeout caddy.Duration `json:"matching_timeout,omitempty"`
    47  
    48  	logger        *zap.Logger
    49  	listenAddrs   []caddy.NetworkAddress
    50  	compiledRoute Handler
    51  }
    52  
    53  // Provision sets up the server.
    54  func (s *Server) Provision(ctx caddy.Context, logger *zap.Logger) error {
    55  	s.logger = logger
    56  
    57  	if s.MatchingTimeout <= 0 {
    58  		s.MatchingTimeout = caddy.Duration(MatchingTimeoutDefault)
    59  	}
    60  
    61  	repl := caddy.NewReplacer()
    62  	for i, address := range s.Listen {
    63  		address = repl.ReplaceAll(address, "")
    64  		addr, err := caddy.ParseNetworkAddress(address)
    65  		if err != nil {
    66  			return fmt.Errorf("parsing listener address '%s' in position %d: %v", address, i, err)
    67  		}
    68  		s.listenAddrs = append(s.listenAddrs, addr)
    69  	}
    70  
    71  	err := s.Routes.Provision(ctx)
    72  	if err != nil {
    73  		return err
    74  	}
    75  	s.compiledRoute = s.Routes.Compile(s.logger, time.Duration(s.MatchingTimeout), nopHandler{})
    76  
    77  	return nil
    78  }
    79  
    80  func (s *Server) serve(ln net.Listener) error {
    81  	for {
    82  		conn, err := ln.Accept()
    83  		var nerr net.Error
    84  		if errors.As(err, &nerr) && nerr.Timeout() {
    85  			s.logger.Error("timeout accepting connection", zap.Error(err))
    86  			continue
    87  		}
    88  		if err != nil {
    89  			return err
    90  		}
    91  		go s.handle(conn)
    92  	}
    93  }
    94  
    95  func (s *Server) servePacket(pc net.PacketConn) error {
    96  	// Spawn a goroutine whose only job is to consume packets from the socket
    97  	// and send to the packets channel.
    98  	packets := make(chan packet, 10)
    99  	go func(packets chan packet) {
   100  		for {
   101  			buf := udpBufPool.Get().([]byte)
   102  			n, addr, err := pc.ReadFrom(buf)
   103  			if err != nil {
   104  				var netErr net.Error
   105  				if errors.As(err, &netErr) && netErr.Timeout() {
   106  					continue
   107  				}
   108  				packets <- packet{err: err}
   109  				return
   110  			}
   111  			packets <- packet{
   112  				pooledBuf: buf,
   113  				n:         n,
   114  				addr:      addr,
   115  			}
   116  		}
   117  	}(packets)
   118  
   119  	// udpConns tracks active packetConns by downstream address:port. They will
   120  	// be removed from this map after being closed.
   121  	udpConns := make(map[string]*packetConn)
   122  	// closeCh is used to receive notifications of socket closures from
   123  	// packetConn, which allows us to remove stale connections (whose
   124  	// proxy handlers have completed) from the udpConns map.
   125  	closeCh := make(chan string, 10)
   126  	for {
   127  		select {
   128  		case addr := <-closeCh:
   129  			// UDP connection is closed (either implicitly through timeout or by
   130  			// explicit call to Close()).
   131  			delete(udpConns, addr)
   132  
   133  		case pkt := <-packets:
   134  			if pkt.err != nil {
   135  				return pkt.err
   136  			}
   137  			conn, ok := udpConns[pkt.addr.String()]
   138  			if !ok {
   139  				// No existing proxy handler is running for this downstream.
   140  				// Create one now.
   141  				conn = &packetConn{
   142  					PacketConn: pc,
   143  					readCh:     make(chan *packet, 5),
   144  					addr:       pkt.addr,
   145  					closeCh:    closeCh,
   146  				}
   147  				udpConns[pkt.addr.String()] = conn
   148  				go func(conn *packetConn) {
   149  					s.handle(conn)
   150  					// It might seem cleaner to send to closeCh here rather than
   151  					// in packetConn, but doing it earlier in packetConn closes
   152  					// the gap between the proxy handler shutting down and new
   153  					// packets coming in from the same downstream.  Should that
   154  					// happen, we'll just spin up a new handler concurrent to
   155  					// the old one shutting down.
   156  				}(conn)
   157  			}
   158  			conn.readCh <- &pkt
   159  		}
   160  	}
   161  }
   162  
   163  func (s *Server) handle(conn net.Conn) {
   164  	defer func() { _ = conn.Close() }()
   165  
   166  	buf := bufPool.Get().([]byte)
   167  	buf = buf[:0]
   168  	defer bufPool.Put(buf)
   169  
   170  	cx := WrapConnection(conn, buf, s.logger)
   171  
   172  	start := time.Now()
   173  	err := s.compiledRoute.Handle(cx)
   174  	duration := time.Since(start)
   175  	if err != nil {
   176  		s.logger.Error("handling connection", zap.String("remote", cx.RemoteAddr().String()), zap.Error(err))
   177  	}
   178  
   179  	s.logger.Debug("connection stats",
   180  		zap.String("remote", cx.RemoteAddr().String()),
   181  		zap.Uint64("read", cx.bytesRead),
   182  		zap.Uint64("written", cx.bytesWritten),
   183  		zap.Duration("duration", duration),
   184  	)
   185  }
   186  
   187  // UnmarshalCaddyfile sets up the Server from Caddyfile tokens. Syntax:
   188  //
   189  //	<address:port> [<address:port>] {
   190  //		matching_timeout <duration>
   191  //		@a <matcher> [<matcher_args>]
   192  //		@b {
   193  //			<matcher> [<matcher_args>]
   194  //			<matcher> [<matcher_args>]
   195  //		}
   196  //		route @a @b {
   197  //			<handler> [<handler_args>]
   198  //		}
   199  //		@c <matcher> {
   200  //			<matcher_option> [<matcher_option_args>]
   201  //		}
   202  //		route @c {
   203  //			<handler> [<handler_args>]
   204  //			<handler> {
   205  //				<handler_option> [<handler_option_args>]
   206  //			}
   207  //		}
   208  //	}
   209  func (s *Server) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
   210  	// Wrapper name and all same-line options are treated as network addresses
   211  	for ok := true; ok; ok = d.NextArg() {
   212  		s.Listen = append(s.Listen, d.Val())
   213  	}
   214  
   215  	if err := ParseCaddyfileNestedRoutes(d, &s.Routes, &s.MatchingTimeout); err != nil {
   216  		return err
   217  	}
   218  
   219  	return nil
   220  }
   221  
   222  type packet struct {
   223  	// The underlying bytes slice that was gotten from udpBufPool.  It's up to
   224  	// packetConn to return it to udpBufPool once it's consumed.
   225  	pooledBuf []byte
   226  	// Number of bytes read from socket
   227  	n int
   228  	// Error that occurred while reading from socket
   229  	err error
   230  	// Address of downstream
   231  	addr net.Addr
   232  }
   233  
   234  type packetConn struct {
   235  	net.PacketConn
   236  	addr    net.Addr
   237  	readCh  chan *packet
   238  	closeCh chan string
   239  	// If not nil, then the previous Read() call didn't consume all the data
   240  	// from the buffer, and this packet will be reused in the next Read()
   241  	// without waiting for readCh.
   242  	lastPacket *packet
   243  	lastBuf    *bytes.Reader
   244  
   245  	// stores time.Time as Unix as Read maybe called concurrently with SetReadDeadline
   246  	deadline      atomic.Int64
   247  	deadlineTimer *time.Timer
   248  	idleTimer     *time.Timer
   249  }
   250  
   251  // SetReadDeadline sets the deadline to wait for data from the underlying net.PacketConn.
   252  func (pc *packetConn) SetReadDeadline(t time.Time) error {
   253  	pc.deadline.Store(t.Unix())
   254  	if pc.deadlineTimer != nil {
   255  		pc.deadlineTimer.Reset(time.Until(t))
   256  	} else {
   257  		pc.deadlineTimer = time.NewTimer(time.Until(t))
   258  	}
   259  	return nil
   260  }
   261  
   262  // TODO: idle timeout should be configurable per server
   263  const udpAssociationIdleTimeout = 30 * time.Second
   264  
   265  func isDeadlineExceeded(t time.Time) bool {
   266  	return !t.IsZero() && t.Before(time.Now())
   267  }
   268  
   269  func (pc *packetConn) Read(b []byte) (n int, err error) {
   270  	if pc.lastPacket != nil {
   271  		// There is a partial buffer to continue reading from the previous
   272  		// packet.
   273  		n, err = pc.lastBuf.Read(b)
   274  		if pc.lastBuf.Len() == 0 {
   275  			udpBufPool.Put(pc.lastPacket.pooledBuf)
   276  			pc.lastPacket = nil
   277  			pc.lastBuf = nil
   278  		}
   279  		return
   280  	}
   281  	// check deadline
   282  	if isDeadlineExceeded(time.Unix(pc.deadline.Load(), 0)) {
   283  		return 0, os.ErrDeadlineExceeded
   284  	}
   285  	// set or refresh idle timeout
   286  	if pc.idleTimer == nil {
   287  		pc.idleTimer = time.NewTimer(udpAssociationIdleTimeout)
   288  	} else {
   289  		pc.idleTimer.Reset(udpAssociationIdleTimeout)
   290  	}
   291  	var done bool
   292  	for !done {
   293  		select {
   294  		case pkt := <-pc.readCh:
   295  			if pkt == nil {
   296  				// Channel is closed. Return EOF below.
   297  				done = true
   298  				break
   299  			}
   300  			buf := bytes.NewReader(pkt.pooledBuf[:pkt.n])
   301  			n, err = buf.Read(b)
   302  			if buf.Len() == 0 {
   303  				// Buffer fully consumed, release it.
   304  				udpBufPool.Put(pkt.pooledBuf)
   305  			} else {
   306  				// Buffer only partially consumed. Keep track of it for
   307  				// next Read() call.
   308  				pc.lastPacket = pkt
   309  				pc.lastBuf = buf
   310  			}
   311  			return
   312  		case <-pc.deadlineTimer.C:
   313  			// deadline may change during the wait, recheck
   314  			if isDeadlineExceeded(time.Unix(pc.deadline.Load(), 0)) {
   315  				return 0, os.ErrDeadlineExceeded
   316  			}
   317  			// next loop will run. Don't call Read as that will reset the idle timer.
   318  		case <-pc.idleTimer.C:
   319  			done = true
   320  			break
   321  		}
   322  	}
   323  	// Idle timeout simulates socket closure.
   324  	//
   325  	// Although Close() also does this, we inform the server loop early about
   326  	// the closure to ensure that if any new packets are received from this
   327  	// connection in the meantime, a new handler will be started.
   328  	pc.closeCh <- pc.addr.String()
   329  	// Returning EOF here ensures that io.Copy() waiting on the downstream for
   330  	// reads will terminate.
   331  	return 0, io.EOF
   332  }
   333  
   334  func (pc *packetConn) Write(b []byte) (n int, err error) {
   335  	return pc.PacketConn.WriteTo(b, pc.addr)
   336  }
   337  
   338  func (pc *packetConn) Close() error {
   339  	if pc.lastPacket != nil {
   340  		udpBufPool.Put(pc.lastPacket.pooledBuf)
   341  		pc.lastPacket = nil
   342  	}
   343  	// This will abort any active Read() from another goroutine and return EOF
   344  	close(pc.readCh)
   345  	// Drain pending packets to ensure we release buffers back to the pool
   346  	for pkt := range pc.readCh {
   347  		udpBufPool.Put(pkt.pooledBuf)
   348  	}
   349  	// We may have already done this earlier in Read(), but just in case
   350  	// Read() wasn't being called, (re-)notify server loop we're closed.
   351  	pc.closeCh <- pc.addr.String()
   352  	// We don't call net.PacketConn.Close() here as we would stop the UDP
   353  	// server.
   354  	return nil
   355  }
   356  
   357  func (pc *packetConn) RemoteAddr() net.Addr { return pc.addr }
   358  
   359  var udpBufPool = sync.Pool{
   360  	New: func() interface{} {
   361  		// Buffers need to be as large as the largest datagram we'll consume, because
   362  		// ReadFrom() can't resume partial reads.  (This is standard for UDP
   363  		// sockets on *nix.)  So our buffer sizes are 9000 bytes to accommodate
   364  		// networks with jumbo frames.  See also https://github.com/golang/go/issues/18056
   365  		return make([]byte, 9000)
   366  	},
   367  }
   368  
   369  // Interface guard
   370  var _ caddyfile.Unmarshaler = (*Server)(nil)