github.com/mholt/caddy-l4@v0.0.0-20241104153248-ec8fae209322/modules/l4wireguard/matcher_test.go (about)

     1  // Copyright 2024 VNXME
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package l4wireguard
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"errors"
    21  	"io"
    22  	"net"
    23  	"testing"
    24  
    25  	"github.com/caddyserver/caddy/v2"
    26  	"go.uber.org/zap"
    27  
    28  	"github.com/mholt/caddy-l4/layer4"
    29  )
    30  
    31  func assertNoError(t *testing.T, err error) {
    32  	t.Helper()
    33  	if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
    34  		t.Fatalf("Unexpected error: %s\n", err)
    35  	}
    36  }
    37  
    38  func Test_MatchWireGuard_ProcessMessageInitiation(t *testing.T) {
    39  	p := [][]byte{
    40  		append(packet00000001, make([]byte, MessageInitiationBytesTotal-len(packet00000001))...),
    41  		append(packet010077FF, make([]byte, MessageInitiationBytesTotal-len(packet010077FF))...),
    42  	}
    43  	for _, b := range p {
    44  		func() {
    45  			s := &MessageInitiation{}
    46  			errFrom := s.FromBytes(b)
    47  			assertNoError(t, errFrom)
    48  			sb, errTo := s.ToBytes()
    49  			assertNoError(t, errTo)
    50  			if !bytes.Equal(b, sb) {
    51  				t.Fatalf("test %T bytes processing: resulting bytes [% x] don't match original bytes [% x]", *s, b, sb)
    52  			}
    53  		}()
    54  	}
    55  }
    56  
    57  func Test_MatchWireGuard_ProcessMessageData(t *testing.T) {
    58  	p := [][]byte{
    59  		append(packet00000004, make([]byte, MessageTransportBytesMin-len(packet00000001))...),
    60  		append(packet00000004, make([]byte, MessageTransportBytesMin-len(packet00000001)+160)...),
    61  	}
    62  	for _, b := range p {
    63  		func() {
    64  			s := &MessageTransport{}
    65  			errFrom := s.FromBytes(b)
    66  			assertNoError(t, errFrom)
    67  			sb, errTo := s.ToBytes()
    68  			assertNoError(t, errTo)
    69  			if !bytes.Equal(b, sb) {
    70  				t.Fatalf("test %T bytes processing: resulting bytes [% x] don't match original bytes [% x]", *s, b, sb)
    71  			}
    72  		}()
    73  	}
    74  }
    75  
    76  func Test_MatchWireGuard_Match(t *testing.T) {
    77  	type test struct {
    78  		matcher     *MatchWireGuard
    79  		data        []byte
    80  		shouldMatch bool
    81  	}
    82  
    83  	tests := []test{
    84  		{matcher: &MatchWireGuard{}, data: packet00000001, shouldMatch: false},
    85  		{matcher: &MatchWireGuard{}, data: append(packet00000001, make([]byte, MessageInitiationBytesTotal-len(packet00000001))...), shouldMatch: true},
    86  		{matcher: &MatchWireGuard{}, data: append(packet00000001, make([]byte, MessageInitiationBytesTotal-len(packet00000001)+1)...), shouldMatch: false},
    87  
    88  		{matcher: &MatchWireGuard{}, data: packet00000002, shouldMatch: false},
    89  		{matcher: &MatchWireGuard{}, data: append(packet00000002, make([]byte, MessageInitiationBytesTotal-len(packet00000002))...), shouldMatch: false},
    90  		{matcher: &MatchWireGuard{}, data: append(packet00000002, make([]byte, MessageResponseBytesTotal-len(packet00000002))...), shouldMatch: false},
    91  
    92  		{matcher: &MatchWireGuard{}, data: packet00000003, shouldMatch: false},
    93  		{matcher: &MatchWireGuard{}, data: append(packet00000003, make([]byte, MessageInitiationBytesTotal-len(packet00000003))...), shouldMatch: false},
    94  		{matcher: &MatchWireGuard{}, data: append(packet00000003, make([]byte, MessageCookieReplyBytesTotal-len(packet00000003))...), shouldMatch: false},
    95  
    96  		{matcher: &MatchWireGuard{}, data: packet00000004, shouldMatch: false},
    97  		{matcher: &MatchWireGuard{}, data: append(packet00000004, make([]byte, MessageInitiationBytesTotal-len(packet00000004))...), shouldMatch: false},
    98  		{matcher: &MatchWireGuard{}, data: append(packet00000004, make([]byte, MessageTransportBytesMin-len(packet00000004))...), shouldMatch: true},
    99  
   100  		{matcher: &MatchWireGuard{}, data: packet010077FF, shouldMatch: false},
   101  		{matcher: &MatchWireGuard{}, data: append(packet010077FF, make([]byte, MessageInitiationBytesTotal-len(packet010077FF))...), shouldMatch: false},
   102  		{matcher: &MatchWireGuard{Zero: 4285988864}, data: append(packet010077FF, make([]byte, MessageInitiationBytesTotal-len(packet010077FF))...), shouldMatch: true},
   103  	}
   104  
   105  	ctx, cancel := caddy.NewContext(caddy.Context{Context: context.Background()})
   106  	defer cancel()
   107  
   108  	for i, tc := range tests {
   109  		func() {
   110  			err := tc.matcher.Provision(ctx)
   111  			assertNoError(t, err)
   112  
   113  			in, out := net.Pipe()
   114  			defer func() {
   115  				_, _ = io.Copy(io.Discard, out)
   116  				_ = out.Close()
   117  			}()
   118  
   119  			cx := layer4.WrapConnection(out, []byte{}, zap.NewNop())
   120  			go func() {
   121  				_, err := in.Write(tc.data)
   122  				assertNoError(t, err)
   123  				_ = in.Close()
   124  			}()
   125  
   126  			matched, err := tc.matcher.Match(cx)
   127  			assertNoError(t, err)
   128  
   129  			if matched != tc.shouldMatch {
   130  				if tc.shouldMatch {
   131  					t.Fatalf("test %d: matcher did not match | %+v\n", i, tc.matcher)
   132  				} else {
   133  					t.Fatalf("test %d: matcher should not match | %+v\n", i, tc.matcher)
   134  				}
   135  			}
   136  		}()
   137  	}
   138  }
   139  
   140  var packet00000001 = []byte{uint8(MessageTypeInitiation), 0x00, 0x00, 0x00}
   141  var packet00000002 = []byte{uint8(MessageTypeResponse), 0x00, 0x00, 0x00}
   142  var packet00000003 = []byte{uint8(MessageTypeCookieReply), 0x00, 0x00, 0x00}
   143  var packet00000004 = []byte{uint8(MessageTypeTransport), 0x00, 0x00, 0x00}
   144  var packet010077FF = []byte{uint8(MessageTypeInitiation), 0x00, 0x77, 0xFF}