github.com/noisysockets/noisysockets@v0.21.2-0.20240515114641-7f467e651c90/source_sink_test.go (about)

     1  // SPDX-License-Identifier: MPL-2.0
     2  /*
     3   * Copyright (C) 2024 The Noisy Sockets Authors.
     4   *
     5   * This Source Code Form is subject to the terms of the Mozilla Public
     6   * License, v. 2.0. If a copy of the MPL was not distributed with this
     7   * file, You can obtain one at http://mozilla.org/MPL/2.0/.
     8   */
     9  
    10  package noisysockets
    11  
    12  import (
    13  	"net/netip"
    14  	"testing"
    15  
    16  	"github.com/neilotoole/slogt"
    17  	"github.com/noisysockets/netstack/pkg/tcpip"
    18  	"github.com/noisysockets/netstack/pkg/tcpip/header"
    19  	"github.com/noisysockets/noisysockets/types"
    20  	"github.com/stretchr/testify/require"
    21  )
    22  
    23  func TestValidateSourceAddress(t *testing.T) {
    24  	gwSK, err := types.NewPrivateKey()
    25  	require.NoError(t, err)
    26  
    27  	gwPK := gwSK.Public()
    28  
    29  	peer1SK, err := types.NewPrivateKey()
    30  	require.NoError(t, err)
    31  
    32  	peer1PK := peer1SK.Public()
    33  
    34  	peer2SK, err := types.NewPrivateKey()
    35  	require.NoError(t, err)
    36  
    37  	peer2PK := peer2SK.Public()
    38  
    39  	rt := newRoutingTable(slogt.New(t))
    40  
    41  	ipv4Net, err := netip.ParsePrefix("192.168.2.0/24")
    42  	require.NoError(t, err)
    43  
    44  	ipv6Net, err := netip.ParsePrefix("2001:db9::/64")
    45  	require.NoError(t, err)
    46  
    47  	p := newPeer(nil, "default-gateway", gwPK)
    48  	p.AddAddresses(netip.MustParseAddr("192.168.1.1"), netip.MustParseAddr("2001:db8::1"))
    49  	p.AddDestinationPrefixes(ipv4Net, ipv6Net)
    50  	require.NoError(t, rt.update(p))
    51  
    52  	p = newPeer(nil, "peer1", peer1PK)
    53  	p.AddAddresses(netip.MustParseAddr("192.168.1.2"), netip.MustParseAddr("2001:db8::2"))
    54  	require.NoError(t, rt.update(p))
    55  
    56  	p = newPeer(nil, "peer2", peer2PK)
    57  	p.AddAddresses(netip.MustParseAddr("192.168.1.3"), netip.MustParseAddr("2001:db8::3"))
    58  	require.NoError(t, rt.update(p))
    59  
    60  	ss := sourceSink{
    61  		rt: rt,
    62  	}
    63  
    64  	t.Run("Valid (IPv4)", func(t *testing.T) {
    65  		buf := make([]byte, header.IPv4MinimumSize)
    66  		header.IPv4(buf).Encode(&header.IPv4Fields{
    67  			TotalLength: header.IPv4MinimumSize,
    68  			SrcAddr:     tcpip.AddrFrom4Slice(netip.MustParseAddr("192.168.1.2").AsSlice()),
    69  		})
    70  
    71  		protocolNumber, err := ss.validateSourceAddress(buf, peer1PK)
    72  		require.NoError(t, err)
    73  
    74  		require.Equal(t, header.IPv4ProtocolNumber, protocolNumber)
    75  	})
    76  
    77  	t.Run("Impersonation (IPv4)", func(t *testing.T) {
    78  		buf := make([]byte, header.IPv4MinimumSize)
    79  		header.IPv4(buf).Encode(&header.IPv4Fields{
    80  			TotalLength: header.IPv4MinimumSize,
    81  			SrcAddr:     tcpip.AddrFrom4Slice(netip.MustParseAddr("192.168.1.2").AsSlice()),
    82  		})
    83  
    84  		_, err := ss.validateSourceAddress(buf, peer2PK)
    85  		require.Error(t, err)
    86  	})
    87  
    88  	t.Run("Unknown (IPv4)", func(t *testing.T) {
    89  		buf := make([]byte, header.IPv4MinimumSize)
    90  		header.IPv4(buf).Encode(&header.IPv4Fields{
    91  			TotalLength: header.IPv4MinimumSize,
    92  			SrcAddr:     tcpip.AddrFrom4Slice(netip.MustParseAddr("1.1.1.1").AsSlice()),
    93  		})
    94  
    95  		_, err := ss.validateSourceAddress(buf, peer1PK)
    96  		require.Error(t, err)
    97  	})
    98  
    99  	t.Run("Gateway (IPv4)", func(t *testing.T) {
   100  		buf := make([]byte, header.IPv4MinimumSize)
   101  		header.IPv4(buf).Encode(&header.IPv4Fields{
   102  			TotalLength: header.IPv4MinimumSize,
   103  			SrcAddr:     tcpip.AddrFrom4Slice(netip.MustParseAddr("192.168.2.2").AsSlice()),
   104  		})
   105  
   106  		protocolNumber, err := ss.validateSourceAddress(buf, gwPK)
   107  		require.NoError(t, err)
   108  
   109  		require.Equal(t, header.IPv4ProtocolNumber, protocolNumber)
   110  	})
   111  
   112  	t.Run("Gateway Invalid (IPv4)", func(t *testing.T) {
   113  		buf := make([]byte, header.IPv4MinimumSize)
   114  		header.IPv4(buf).Encode(&header.IPv4Fields{
   115  			TotalLength: header.IPv4MinimumSize,
   116  			SrcAddr:     tcpip.AddrFrom4Slice(netip.MustParseAddr("192.168.1.10").AsSlice()),
   117  		})
   118  
   119  		_, err := ss.validateSourceAddress(buf, gwPK)
   120  		require.Error(t, err)
   121  	})
   122  
   123  	t.Run("Gateway Impersonation (IPv4)", func(t *testing.T) {
   124  		buf := make([]byte, header.IPv4MinimumSize)
   125  		header.IPv4(buf).Encode(&header.IPv4Fields{
   126  			TotalLength: header.IPv4MinimumSize,
   127  			SrcAddr:     tcpip.AddrFrom4Slice(netip.MustParseAddr("192.168.1.2").AsSlice()),
   128  		})
   129  
   130  		_, err := ss.validateSourceAddress(buf, gwPK)
   131  		require.Error(t, err)
   132  	})
   133  
   134  	t.Run("Valid (IPv6)", func(t *testing.T) {
   135  		buf := make([]byte, header.IPv6MinimumSize)
   136  		header.IPv6(buf).Encode(&header.IPv6Fields{
   137  			SrcAddr: tcpip.AddrFrom16Slice(netip.MustParseAddr("2001:db8::2").AsSlice()),
   138  		})
   139  
   140  		protocolNumber, err := ss.validateSourceAddress(buf, peer1PK)
   141  		require.NoError(t, err)
   142  
   143  		require.Equal(t, header.IPv6ProtocolNumber, protocolNumber)
   144  	})
   145  
   146  	t.Run("Impersonation (IPv6)", func(t *testing.T) {
   147  		buf := make([]byte, header.IPv6MinimumSize)
   148  		header.IPv6(buf).Encode(&header.IPv6Fields{
   149  			SrcAddr: tcpip.AddrFrom16Slice(netip.MustParseAddr("2001:db8::2").AsSlice()),
   150  		})
   151  
   152  		_, err := ss.validateSourceAddress(buf, peer2PK)
   153  		require.Error(t, err)
   154  	})
   155  
   156  	t.Run("Unknown (IPv6)", func(t *testing.T) {
   157  		buf := make([]byte, header.IPv6MinimumSize)
   158  		header.IPv6(buf).Encode(&header.IPv6Fields{
   159  			SrcAddr: tcpip.AddrFrom16Slice(netip.MustParseAddr("2001:db8::dead:beef").AsSlice()),
   160  		})
   161  
   162  		_, err := ss.validateSourceAddress(buf, peer1PK)
   163  		require.Error(t, err)
   164  	})
   165  
   166  	t.Run("Gateway (IPv6)", func(t *testing.T) {
   167  		buf := make([]byte, header.IPv6MinimumSize)
   168  		header.IPv6(buf).Encode(&header.IPv6Fields{
   169  			SrcAddr: tcpip.AddrFrom16Slice(netip.MustParseAddr("2001:db9::2").AsSlice()),
   170  		})
   171  
   172  		protocolNumber, err := ss.validateSourceAddress(buf, gwPK)
   173  		require.NoError(t, err)
   174  
   175  		require.Equal(t, header.IPv6ProtocolNumber, protocolNumber)
   176  	})
   177  
   178  	t.Run("Gateway Invalid (IPv6)", func(t *testing.T) {
   179  		buf := make([]byte, header.IPv6MinimumSize)
   180  		header.IPv6(buf).Encode(&header.IPv6Fields{
   181  			SrcAddr: tcpip.AddrFrom16Slice(netip.MustParseAddr("2001:db8::10").AsSlice()),
   182  		})
   183  
   184  		_, err := ss.validateSourceAddress(buf, gwPK)
   185  		require.Error(t, err)
   186  	})
   187  
   188  	t.Run("Gateway Impersonation (IPv6)", func(t *testing.T) {
   189  		buf := make([]byte, header.IPv6MinimumSize)
   190  		header.IPv6(buf).Encode(&header.IPv6Fields{
   191  			SrcAddr: tcpip.AddrFrom16Slice(netip.MustParseAddr("2001:db8::2").AsSlice()),
   192  		})
   193  
   194  		_, err := ss.validateSourceAddress(buf, gwPK)
   195  		require.Error(t, err)
   196  	})
   197  }