gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/stack/iptables_test.go (about)

     1  // Copyright 2021 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package stack
    16  
    17  import (
    18  	"math/rand"
    19  	"testing"
    20  
    21  	"gvisor.dev/gvisor/pkg/tcpip"
    22  	"gvisor.dev/gvisor/pkg/tcpip/faketime"
    23  	"gvisor.dev/gvisor/pkg/tcpip/header"
    24  	"gvisor.dev/gvisor/pkg/tcpip/testutil"
    25  )
    26  
    27  const (
    28  	nattedPort = 1
    29  	srcPort    = 2
    30  	dstPort    = 3
    31  
    32  	// The network protocol used for these tests doesn't matter as the tests are
    33  	// not targeting anything protocol specific.
    34  	ipv6     = true
    35  	netProto = header.IPv6ProtocolNumber
    36  )
    37  
    38  var (
    39  	nattedAddr = testutil.MustParse6("a::1")
    40  	srcAddr    = testutil.MustParse6("b::2")
    41  	dstAddr    = testutil.MustParse6("c::3")
    42  )
    43  
    44  func v6PacketBufferWithSrcAddr(srcAddr tcpip.Address) *PacketBuffer {
    45  	pkt := NewPacketBuffer(PacketBufferOptions{
    46  		ReserveHeaderBytes: header.IPv6MinimumSize + header.UDPMinimumSize,
    47  	})
    48  	udp := header.UDP(pkt.TransportHeader().Push(header.UDPMinimumSize))
    49  	udp.SetSourcePort(srcPort)
    50  	udp.SetDestinationPort(dstPort)
    51  	udp.SetLength(uint16(len(udp)))
    52  	udp.SetChecksum(0)
    53  	udp.SetChecksum(^udp.CalculateChecksum(header.PseudoHeaderChecksum(
    54  		header.UDPProtocolNumber,
    55  		srcAddr,
    56  		dstAddr,
    57  		uint16(len(udp)),
    58  	)))
    59  	pkt.TransportProtocolNumber = header.UDPProtocolNumber
    60  	ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize))
    61  	ip.Encode(&header.IPv6Fields{
    62  		PayloadLength:     uint16(len(udp)),
    63  		TransportProtocol: header.UDPProtocolNumber,
    64  		HopLimit:          64,
    65  		SrcAddr:           srcAddr,
    66  		DstAddr:           dstAddr,
    67  	})
    68  	pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber
    69  	return pkt
    70  }
    71  
    72  func v6PacketBuffer() *PacketBuffer {
    73  	return v6PacketBufferWithSrcAddr(srcAddr)
    74  }
    75  
    76  // TestNATedConnectionReap tests that NATed connections are properly reaped.
    77  func TestNATedConnectionReap(t *testing.T) {
    78  	clock := faketime.NewManualClock()
    79  	iptables := DefaultTables(clock, rand.New(rand.NewSource(0 /* seed */)))
    80  
    81  	table := Table{
    82  		Rules: []Rule{
    83  			// Prerouting
    84  			{
    85  				Target: &DNATTarget{NetworkProtocol: netProto, Addr: nattedAddr, Port: nattedPort, ChangePort: true, ChangeAddress: true},
    86  			},
    87  			{
    88  				Target: &AcceptTarget{},
    89  			},
    90  
    91  			// Input
    92  			{
    93  				Target: &AcceptTarget{},
    94  			},
    95  
    96  			// Forward
    97  			{
    98  				Target: &AcceptTarget{},
    99  			},
   100  
   101  			// Output
   102  			{
   103  				Target: &AcceptTarget{},
   104  			},
   105  
   106  			// Postrouting
   107  			{
   108  				Target: &AcceptTarget{},
   109  			},
   110  		},
   111  		BuiltinChains: [NumHooks]int{
   112  			Prerouting:  0,
   113  			Input:       2,
   114  			Forward:     3,
   115  			Output:      4,
   116  			Postrouting: 5,
   117  		},
   118  	}
   119  	iptables.ReplaceTable(NATID, table, ipv6)
   120  
   121  	// Stop the reaper if it is running so we can reap manually as it is started
   122  	// on the first change to IPTables.
   123  	if !iptables.reaper.Stop() {
   124  		t.Fatal("failed to stop reaper")
   125  	}
   126  
   127  	pkt := v6PacketBuffer()
   128  
   129  	originalTID, res := getTupleID(pkt)
   130  	if res != getTupleIDOKAndAllowNewConn {
   131  		t.Fatalf("got getTupleID(...) = (%#v, %d), want = (_, %d)", originalTID, res, getTupleIDOKAndAllowNewConn)
   132  	}
   133  
   134  	if !iptables.CheckPrerouting(pkt, nil /* addressEP */, "" /* inNicName */) {
   135  		t.Fatal("got ipt.CheckPrerouting(...) = false, want = true")
   136  	}
   137  	if !iptables.CheckInput(pkt, "" /* inNicName */) {
   138  		t.Fatal("got ipt.CheckInput(...) = false, want = true")
   139  	}
   140  
   141  	invertedReplyTID, res := getTupleID(pkt)
   142  	if res != getTupleIDOKAndAllowNewConn {
   143  		t.Fatalf("got getTupleID(...) = (%#v, %d), want = (_, %d)", invertedReplyTID, res, getTupleIDOKAndAllowNewConn)
   144  	}
   145  	if invertedReplyTID == originalTID {
   146  		t.Fatalf("NAT not performed; got invertedReplyTID = %#v", invertedReplyTID)
   147  	}
   148  	replyTID := invertedReplyTID.reply()
   149  
   150  	iptables.connections.mu.RLock()
   151  	originalBktID := iptables.connections.bucket(originalTID)
   152  	replyBktID := iptables.connections.bucket(replyTID)
   153  	iptables.connections.mu.RUnlock()
   154  
   155  	// This test depends on the original and reply tuples mapping to different
   156  	// buckets.
   157  	if originalBktID == replyBktID {
   158  		t.Fatalf("expected bucket IDs to be different; got = %d", originalBktID)
   159  	}
   160  
   161  	lowerBktID := originalBktID
   162  	if lowerBktID > replyBktID {
   163  		lowerBktID = replyBktID
   164  	}
   165  
   166  	runReaper := func() {
   167  		// Reaping the bucket with the lower ID should reap both tuples of the
   168  		// connection if it has timed out.
   169  		//
   170  		// We will manually pick the next start bucket ID and don't use the
   171  		// interval so we ignore the return values.
   172  		_, _ = iptables.connections.reapUnused(lowerBktID, 0 /* prevInterval */)
   173  	}
   174  
   175  	iptables.connections.mu.RLock()
   176  	buckets := iptables.connections.buckets
   177  	iptables.connections.mu.RUnlock()
   178  
   179  	originalBkt := &buckets[originalBktID]
   180  	replyBkt := &buckets[replyBktID]
   181  
   182  	// Run the reaper and make sure the tuples were not reaped.
   183  	reapAndCheckForConnections := func() {
   184  		t.Helper()
   185  
   186  		runReaper()
   187  
   188  		now := clock.NowMonotonic()
   189  		if originalTuple := originalBkt.connForTID(originalTID, now); originalTuple == nil {
   190  			t.Error("expected to get original tuple")
   191  		}
   192  
   193  		if replyTuple := replyBkt.connForTID(replyTID, now); replyTuple == nil {
   194  			t.Error("expected to get reply tuple")
   195  		}
   196  
   197  		if t.Failed() {
   198  			t.FailNow()
   199  		}
   200  	}
   201  
   202  	// Connection was just added and no time has passed - it should not be reaped.
   203  	reapAndCheckForConnections()
   204  
   205  	// Time must advance past the unestablished timeout for a connection to be
   206  	// reaped.
   207  	clock.Advance(unestablishedTimeout)
   208  	reapAndCheckForConnections()
   209  
   210  	// Connection should now be reaped.
   211  	clock.Advance(1)
   212  	runReaper()
   213  	now := clock.NowMonotonic()
   214  	if originalTuple := originalBkt.connForTID(originalTID, now); originalTuple != nil {
   215  		t.Errorf("got originalBkt.connForTID(%#v, %#v) = %#v, want = nil", originalTID, now, originalTuple)
   216  	}
   217  	if replyTuple := replyBkt.connForTID(replyTID, now); replyTuple != nil {
   218  		t.Errorf("got replyBkt.connForTID(%#v, %#v) = %#v, want = nil", replyTID, now, replyTuple)
   219  	}
   220  	// Make sure we don't have stale tuples just lying around.
   221  	//
   222  	// We manually check the buckets as connForTID will skip over tuples that
   223  	// have timed out.
   224  	checkNoTupleInBucket := func(bkt *bucket, tid tupleID, reply bool) {
   225  		t.Helper()
   226  
   227  		bkt.mu.RLock()
   228  		defer bkt.mu.RUnlock()
   229  		for tuple := bkt.tuples.Front(); tuple != nil; tuple = tuple.Next() {
   230  			if tuple.tupleID == tid {
   231  				t.Errorf("unexpectedly found tuple with ID = %#v; reply = %t", tid, reply)
   232  			}
   233  		}
   234  	}
   235  	checkNoTupleInBucket(originalBkt, originalTID, false /* reply */)
   236  	checkNoTupleInBucket(replyBkt, replyTID, true /* reply */)
   237  }
   238  
   239  // TestNATAlwaysPerformed tests that a connection will have a noop-NAT
   240  // performed on it when no rule matches its associated packet. (Note that SNAT
   241  // is performed on all connections to ensure that ports used by locally
   242  // generated traffic do not clash with ports used by forwarded traffic.
   243  func TestNATAlwaysPerformed(t *testing.T) {
   244  	tests := []struct {
   245  		name     string
   246  		dnatHook func(*testing.T, *IPTables, *PacketBuffer)
   247  		snatHook func(*testing.T, *IPTables, *PacketBuffer)
   248  	}{
   249  		{
   250  			name: "Prerouting and Input",
   251  			dnatHook: func(t *testing.T, iptables *IPTables, pkt *PacketBuffer) {
   252  				t.Helper()
   253  
   254  				if !iptables.CheckPrerouting(pkt, nil /* addressEP */, "" /* inNicName */) {
   255  					t.Fatal("got iptables.CheckPrerouting(...) = false, want = true")
   256  				}
   257  			},
   258  			snatHook: func(t *testing.T, iptables *IPTables, pkt *PacketBuffer) {
   259  				t.Helper()
   260  
   261  				if !iptables.CheckInput(pkt, "" /* inNicName */) {
   262  					t.Fatal("got iptables.CheckInput(...) = false, want = true")
   263  				}
   264  			},
   265  		},
   266  		{
   267  			name: "Output and Postrouting",
   268  			dnatHook: func(t *testing.T, iptables *IPTables, pkt *PacketBuffer) {
   269  				t.Helper()
   270  
   271  				// Output hook depends on a route but if the route is local, we don't
   272  				// need anything else from it.
   273  				r := Route{
   274  					routeInfo: routeInfo{
   275  						Loop: PacketLoop,
   276  					},
   277  				}
   278  				if !iptables.CheckOutput(pkt, &r, "" /* outNicName */) {
   279  					t.Fatal("got iptables.CheckOutput(...) = false, want = true")
   280  				}
   281  			},
   282  			snatHook: func(t *testing.T, iptables *IPTables, pkt *PacketBuffer) {
   283  				t.Helper()
   284  
   285  				// Postrouting hook depends on a route but if the route is local, we
   286  				// don't need anything else from it.
   287  				r := Route{
   288  					routeInfo: routeInfo{
   289  						Loop: PacketLoop,
   290  					},
   291  				}
   292  				if !iptables.CheckPostrouting(pkt, &r, nil /* addressEP */, "" /* outNicName */) {
   293  					t.Fatal("got iptables.CheckPostrouting(...) = false, want = true")
   294  				}
   295  			},
   296  		},
   297  	}
   298  
   299  	for _, test := range tests {
   300  		t.Run(test.name, func(t *testing.T) {
   301  			clock := faketime.NewManualClock()
   302  			iptables := DefaultTables(clock, rand.New(rand.NewSource(0 /* seed */)))
   303  
   304  			// Just to make sure the iptables is not short circuited.
   305  			iptables.ForceReplaceTable(NATID, iptables.GetTable(NATID, ipv6), ipv6)
   306  
   307  			pkt := v6PacketBuffer()
   308  
   309  			test.dnatHook(t, iptables, pkt)
   310  			conn := pkt.tuple.conn
   311  			conn.mu.RLock()
   312  			destManip := conn.destinationManip
   313  			conn.mu.RUnlock()
   314  			if destManip != manipPerformedNoop {
   315  				t.Errorf("got destManip = %d, want = %d", destManip, manipPerformedNoop)
   316  			}
   317  
   318  			test.snatHook(t, iptables, pkt)
   319  			conn.mu.RLock()
   320  			srcManip := conn.sourceManip
   321  			conn.mu.RUnlock()
   322  			if srcManip != manipPerformed {
   323  				t.Errorf("got srcManip = %d, want = %d", srcManip, manipPerformed)
   324  			}
   325  		})
   326  	}
   327  }
   328  
   329  func TestNATConflict(t *testing.T) {
   330  	otherSrcAddr := testutil.MustParse6("d::4")
   331  
   332  	tests := []struct {
   333  		name          string
   334  		checkIPTables func(*testing.T, *IPTables, *PacketBuffer, bool)
   335  	}{
   336  		{
   337  			name: "Prerouting and Input",
   338  			checkIPTables: func(t *testing.T, iptables *IPTables, pkt *PacketBuffer, lastHookOK bool) {
   339  				t.Helper()
   340  
   341  				if !iptables.CheckPrerouting(pkt, nil /* addressEP */, "" /* inNicName */) {
   342  					t.Fatal("got ipt.CheckPrerouting(...) = false, want = true")
   343  				}
   344  				if got := iptables.CheckInput(pkt, "" /* inNicName */); got != lastHookOK {
   345  					t.Fatalf("got ipt.CheckInput(...) = %t, want = %t", got, lastHookOK)
   346  				}
   347  			},
   348  		},
   349  		{
   350  			name: "Output and Postrouting",
   351  			checkIPTables: func(t *testing.T, iptables *IPTables, pkt *PacketBuffer, lastHookOK bool) {
   352  				t.Helper()
   353  
   354  				// Output and Postrouting hooks depends on a route but if the route is
   355  				// local, we don't need anything else from it.
   356  				r := Route{
   357  					routeInfo: routeInfo{
   358  						Loop: PacketLoop,
   359  					},
   360  				}
   361  				if !iptables.CheckOutput(pkt, &r, "" /* outNicName */) {
   362  					t.Fatal("got iptables.CheckOutput(...) = false, want = true")
   363  				}
   364  				if got := iptables.CheckPostrouting(pkt, &r, nil /* addressEP */, "" /* outNicName */); got != lastHookOK {
   365  					t.Fatalf("got iptables.CheckPostrouting(...) = %t, want = %t", got, lastHookOK)
   366  				}
   367  			},
   368  		},
   369  	}
   370  
   371  	for _, test := range tests {
   372  		t.Run(test.name, func(t *testing.T) {
   373  
   374  			clock := faketime.NewManualClock()
   375  			iptables := DefaultTables(clock, rand.New(rand.NewSource(0 /* seed */)))
   376  
   377  			table := Table{
   378  				Rules: []Rule{
   379  					// Prerouting
   380  					{
   381  						Target: &AcceptTarget{},
   382  					},
   383  
   384  					// Input
   385  					{
   386  						Target: &SNATTarget{NetworkProtocol: header.IPv6ProtocolNumber, Addr: nattedAddr, Port: nattedPort, ChangeAddress: true, ChangePort: true},
   387  					},
   388  					{
   389  						Target: &AcceptTarget{},
   390  					},
   391  
   392  					// Forward
   393  					{
   394  						Target: &AcceptTarget{},
   395  					},
   396  
   397  					// Output
   398  					{
   399  						Target: &AcceptTarget{},
   400  					},
   401  
   402  					// Postrouting
   403  					{
   404  						Target: &SNATTarget{NetworkProtocol: header.IPv6ProtocolNumber, Addr: nattedAddr, Port: nattedPort, ChangeAddress: true, ChangePort: true},
   405  					},
   406  					{
   407  						Target: &AcceptTarget{},
   408  					},
   409  				},
   410  				BuiltinChains: [NumHooks]int{
   411  					Prerouting:  0,
   412  					Input:       1,
   413  					Forward:     3,
   414  					Output:      4,
   415  					Postrouting: 5,
   416  				},
   417  			}
   418  			iptables.ReplaceTable(NATID, table, ipv6)
   419  
   420  			// Create and finalize the connection.
   421  			test.checkIPTables(t, iptables, v6PacketBufferWithSrcAddr(srcAddr), true /* lastHookOK */)
   422  
   423  			// A packet from a different source that get NATed to the same tuple as
   424  			// the connection created above should be dropped when finalizing.
   425  			test.checkIPTables(t, iptables, v6PacketBufferWithSrcAddr(otherSrcAddr), false /* lastHookOK */)
   426  
   427  			// A packet from the original source should be NATed as normal.
   428  			test.checkIPTables(t, iptables, v6PacketBufferWithSrcAddr(srcAddr), true /* lastHookOK */)
   429  		})
   430  	}
   431  }