github.com/mholt/caddy-l4@v0.0.0-20241104153248-ec8fae209322/layer4/routes_test.go (about)

     1  package layer4
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"github.com/caddyserver/caddy/v2/modules/caddyhttp"
     8  	"io"
     9  	"net"
    10  	"os"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/caddyserver/caddy/v2"
    15  	"go.uber.org/zap"
    16  	"go.uber.org/zap/zapcore"
    17  	"go.uber.org/zap/zaptest/observer"
    18  )
    19  
    20  type testIoMatcher struct {
    21  }
    22  
    23  func (testIoMatcher) CaddyModule() caddy.ModuleInfo {
    24  	return caddy.ModuleInfo{
    25  		ID:  "layer4.matchers.testIoMatcher",
    26  		New: func() caddy.Module { return new(testIoMatcher) },
    27  	}
    28  }
    29  
    30  func (m *testIoMatcher) Match(cx *Connection) (bool, error) {
    31  	buf := make([]byte, 1)
    32  	n, err := io.ReadFull(cx, buf)
    33  	return n > 0, err
    34  }
    35  
    36  func TestMatchingTimeoutWorks(t *testing.T) {
    37  	ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()})
    38  	defer cancel()
    39  
    40  	caddy.RegisterModule(testIoMatcher{})
    41  
    42  	routes := RouteList{&Route{
    43  		MatcherSetsRaw: caddyhttp.RawMatcherSets{
    44  			caddy.ModuleMap{"testIoMatcher": json.RawMessage("{}")}, // any io using matcher
    45  		},
    46  	}}
    47  
    48  	err := routes.Provision(ctx)
    49  	if err != nil {
    50  		t.Fatalf("provision failed | %s", err)
    51  	}
    52  
    53  	matched := false
    54  	loggerCore, logs := observer.New(zapcore.WarnLevel)
    55  	compiledRoutes := routes.Compile(zap.New(loggerCore), 5*time.Millisecond,
    56  		HandlerFunc(func(con *Connection) error {
    57  			matched = true
    58  			return nil
    59  		}))
    60  
    61  	in, out := net.Pipe()
    62  	defer func() { _ = in.Close() }()
    63  	defer func() { _ = out.Close() }()
    64  
    65  	cx := WrapConnection(out, []byte{}, zap.NewNop())
    66  	defer func() { _ = cx.Close() }()
    67  
    68  	err = compiledRoutes.Handle(cx)
    69  	if err != nil {
    70  		t.Fatalf("handle failed | %s", err)
    71  	}
    72  
    73  	// verify the matching aborted error was logged
    74  	if logs.Len() != 1 {
    75  		t.Fatalf("logs should contain 1 entry but has %d", logs.Len())
    76  	}
    77  	logEntry := logs.All()[0]
    78  	if logEntry.Level != zapcore.WarnLevel {
    79  		t.Fatalf("wrong log level | %s", logEntry.Level)
    80  	}
    81  	if logEntry.Message != "matching connection" {
    82  		t.Fatalf("wrong log message | %s", logEntry.Message)
    83  	}
    84  	if !(logEntry.Context[1].Key == "error" && errors.Is(logEntry.Context[1].Interface.(error), ErrMatchingTimeout)) {
    85  		t.Fatalf("wrong error | %v", logEntry.Context[1].Interface)
    86  	}
    87  
    88  	// since matching failed no handler should be called
    89  	if matched {
    90  		t.Fatal("handler was called but should not")
    91  	}
    92  }
    93  
    94  // used to test the timeout of udp associations
    95  type testIoUdpMatcher struct {
    96  }
    97  
    98  func (testIoUdpMatcher) CaddyModule() caddy.ModuleInfo {
    99  	return caddy.ModuleInfo{
   100  		ID:  "layer4.matchers.testIoUdpMatcher",
   101  		New: func() caddy.Module { return new(testIoUdpMatcher) },
   102  	}
   103  }
   104  
   105  var (
   106  	testConnection *Connection
   107  	handlingDone   chan struct{}
   108  )
   109  
   110  func (m *testIoUdpMatcher) Match(cx *Connection) (bool, error) {
   111  	// normally deadline exceeded error is handled during prefetch, and custom matcher can't
   112  	// read more than what's prefetched, but it's a test.
   113  	cx.matching = false
   114  	buf := make([]byte, 10)
   115  	n, err := io.ReadFull(cx, buf)
   116  	if err != nil {
   117  		cx.SetVar("time", time.Now())
   118  		cx.SetVar("err", err)
   119  		testConnection = cx
   120  		close(handlingDone)
   121  	}
   122  	return n > 0, err
   123  }
   124  
   125  func TestMatchingTimeoutWorksUDP(t *testing.T) {
   126  	ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()})
   127  	defer cancel()
   128  
   129  	caddy.RegisterModule(testIoUdpMatcher{})
   130  
   131  	routes := RouteList{&Route{
   132  		MatcherSetsRaw: caddyhttp.RawMatcherSets{
   133  			caddy.ModuleMap{"testIoUdpMatcher": json.RawMessage("{}")}, // any io using matcher
   134  		},
   135  	}}
   136  
   137  	err := routes.Provision(ctx)
   138  	if err != nil {
   139  		t.Fatalf("provision failed | %s", err)
   140  	}
   141  
   142  	matchingTimeout := time.Second
   143  
   144  	compiledRoutes := routes.Compile(zap.NewNop(), matchingTimeout,
   145  		HandlerFunc(func(con *Connection) error {
   146  			return nil
   147  		}))
   148  
   149  	handlingDone = make(chan struct{})
   150  
   151  	// Because udp is connectionless and every read can be from different addresses. A mapping between
   152  	// addresses and data read is created. A virtual connection can only read data from a certain address.
   153  	// Using real udp sockets and server to test timeout.
   154  	// We can't wait for the handler to finish this way, but that is tested above.
   155  	pc, err := net.ListenPacket("udp", "127.0.0.1:0")
   156  	if err != nil {
   157  		t.Fatalf("failed to listen | %s", err)
   158  	}
   159  	defer func() { _ = pc.Close() }()
   160  
   161  	server := new(Server)
   162  	server.compiledRoute = compiledRoutes
   163  	server.logger = zap.NewNop()
   164  	go server.servePacket(pc)
   165  
   166  	now := time.Now()
   167  
   168  	client, err := net.Dial("udp", pc.LocalAddr().String())
   169  	if err != nil {
   170  		t.Fatalf("failed to dial | %s", err)
   171  	}
   172  	defer func() { _ = client.Close() }()
   173  
   174  	_, err = client.Write([]byte("hello"))
   175  	if err != nil {
   176  		t.Fatalf("failed to write | %s", err)
   177  	}
   178  
   179  	// only wait for the matcher to return
   180  	<-handlingDone
   181  	if !errors.Is(testConnection.GetVar("err").(error), os.ErrDeadlineExceeded) {
   182  		t.Fatalf("expected deadline exceeded error but got %s", testConnection.GetVar("err"))
   183  	}
   184  
   185  	elasped := testConnection.GetVar("time").(time.Time).Sub(now)
   186  	if !(matchingTimeout <= elasped && elasped <= 2*matchingTimeout) {
   187  		t.Fatalf("timeout takes too long %s", elasped)
   188  	}
   189  }