github.com/cilium/cilium@v1.16.2/pkg/datapath/iptables/reconciler_test.go (about)

     1  // SPDX-License-Identifier: Apache-2.0
     2  // Copyright Authors of Cilium
     3  
     4  package iptables
     5  
     6  import (
     7  	"context"
     8  	"fmt"
     9  	"maps"
    10  	"net"
    11  	"net/netip"
    12  	"reflect"
    13  	"testing"
    14  
    15  	"github.com/cilium/hive/cell"
    16  	"github.com/cilium/hive/hivetest"
    17  	"github.com/cilium/statedb"
    18  	"github.com/stretchr/testify/assert"
    19  	"go.uber.org/goleak"
    20  	"k8s.io/apimachinery/pkg/util/sets"
    21  
    22  	"github.com/cilium/cilium/pkg/cidr"
    23  	"github.com/cilium/cilium/pkg/datapath/tables"
    24  	"github.com/cilium/cilium/pkg/hive"
    25  	"github.com/cilium/cilium/pkg/lock"
    26  	"github.com/cilium/cilium/pkg/node"
    27  	"github.com/cilium/cilium/pkg/node/addressing"
    28  	"github.com/cilium/cilium/pkg/node/types"
    29  	"github.com/cilium/cilium/pkg/time"
    30  )
    31  
    32  func TestReconciliationLoop(t *testing.T) {
    33  	defer goleak.VerifyNone(t)
    34  
    35  	var (
    36  		db      *statedb.DB
    37  		devices statedb.RWTable[*tables.Device]
    38  		store   *node.LocalNodeStore
    39  		health  cell.Health
    40  		params  *reconcilerParams
    41  	)
    42  	h := hive.New(
    43  		cell.Module(
    44  			"iptables-reconciler-test",
    45  			"iptables-reconciler-test",
    46  
    47  			cell.Provide(
    48  				tables.NewDeviceTable,
    49  				statedb.RWTable[*tables.Device].ToTable,
    50  				func() *node.LocalNodeStore { return node.NewTestLocalNodeStore(node.LocalNode{}) },
    51  			),
    52  			cell.Invoke(func(
    53  				db_ *statedb.DB,
    54  				devices_ statedb.RWTable[*tables.Device],
    55  				store_ *node.LocalNodeStore,
    56  				health_ cell.Health,
    57  			) {
    58  				db = db_
    59  				devices = devices_
    60  				store = store_
    61  				db.RegisterTable(devices_)
    62  				health = health_.NewScope("iptables-reconciler-test")
    63  				params = &reconcilerParams{
    64  					localNodeStore: store_,
    65  					db:             db_,
    66  					devices:        devices_,
    67  					proxies:        make(chan reconciliationRequest[proxyInfo]),
    68  					addNoTrackPod:  make(chan reconciliationRequest[noTrackPodInfo]),
    69  					delNoTrackPod:  make(chan reconciliationRequest[noTrackPodInfo]),
    70  				}
    71  			}),
    72  		),
    73  	)
    74  
    75  	var (
    76  		state desiredState
    77  		mu    lock.Mutex
    78  	)
    79  
    80  	updateFunc := func(newState desiredState, firstInit bool) error {
    81  		mu.Lock()
    82  		defer mu.Unlock()
    83  
    84  		// copy newState to avoid a race with the reconciler mutating it
    85  		// and the test asserting the expected values with Eventually
    86  		state = newState.deepCopy()
    87  
    88  		return nil
    89  	}
    90  	updateProxyFunc := func(proxyPort uint16, name string) error {
    91  		mu.Lock()
    92  		defer mu.Unlock()
    93  		state.proxies[name] = proxyInfo{
    94  			name: name,
    95  			port: proxyPort,
    96  		}
    97  		return nil
    98  	}
    99  	installNoTrackFunc := func(addr netip.Addr, port uint16) error {
   100  		mu.Lock()
   101  		defer mu.Unlock()
   102  		state.noTrackPods.Insert(noTrackPodInfo{
   103  			ip:   addr,
   104  			port: port,
   105  		})
   106  		return nil
   107  	}
   108  	removeNoTrackFunc := func(addr netip.Addr, port uint16) error {
   109  		mu.Lock()
   110  		defer mu.Unlock()
   111  		state.noTrackPods.Delete(noTrackPodInfo{
   112  			ip:   addr,
   113  			port: port,
   114  		})
   115  		return nil
   116  	}
   117  
   118  	testCases := []struct {
   119  		name     string
   120  		action   func()
   121  		expected desiredState
   122  	}{
   123  		{
   124  			name: "initial state",
   125  			action: func() {
   126  				store.Update(func(n *node.LocalNode) {
   127  					n.IPAddresses = []types.Address{
   128  						{
   129  							IP:   netip.MustParseAddr("1.1.1.1").AsSlice(),
   130  							Type: addressing.NodeCiliumInternalIP,
   131  						},
   132  					}
   133  					n.IPv4AllocCIDR = cidr.MustParseCIDR("5.5.5.0/24")
   134  					n.IPv6AllocCIDR = cidr.MustParseCIDR("2001:aaaa::/96")
   135  				})
   136  
   137  				txn := db.WriteTxn(devices)
   138  				if _, _, err := devices.Insert(txn, &tables.Device{
   139  					Index:    1,
   140  					Name:     "test-1",
   141  					Selected: true,
   142  				}); err != nil {
   143  					t.Fatal(err)
   144  				}
   145  				txn.Commit()
   146  			},
   147  			expected: desiredState{
   148  				installRules: true,
   149  				devices:      sets.New("test-1"),
   150  				localNodeInfo: localNodeInfo{
   151  					internalIPv4:  net.ParseIP("1.1.1.1"),
   152  					ipv4AllocCIDR: cidr.MustParseCIDR("5.5.5.0/24").String(),
   153  					ipv6AllocCIDR: cidr.MustParseCIDR("2001:aaaa::/96").String(),
   154  				},
   155  			},
   156  		},
   157  		{
   158  			name: "devices update",
   159  			action: func() {
   160  				txn := db.WriteTxn(devices)
   161  				devices.Insert(txn, &tables.Device{
   162  					Index:    2,
   163  					Name:     "test-2",
   164  					Selected: true,
   165  				})
   166  				txn.Commit()
   167  			},
   168  			expected: desiredState{
   169  				installRules: true,
   170  				devices:      sets.New("test-1", "test-2"),
   171  				localNodeInfo: localNodeInfo{
   172  					internalIPv4:  net.ParseIP("1.1.1.1"),
   173  					ipv4AllocCIDR: cidr.MustParseCIDR("5.5.5.0/24").String(),
   174  					ipv6AllocCIDR: cidr.MustParseCIDR("2001:aaaa::/96").String(),
   175  				},
   176  			},
   177  		},
   178  		{
   179  			name: "local node update",
   180  			action: func() {
   181  				store.Update(func(n *node.LocalNode) {
   182  					n.IPAddresses = []types.Address{
   183  						{
   184  							IP:   netip.MustParseAddr("2.2.2.2").AsSlice(),
   185  							Type: addressing.NodeCiliumInternalIP,
   186  						},
   187  					}
   188  					n.IPv4AllocCIDR = cidr.MustParseCIDR("6.6.6.0/24")
   189  					n.IPv6AllocCIDR = cidr.MustParseCIDR("3002:bbbb::/96")
   190  				})
   191  			},
   192  			expected: desiredState{
   193  				installRules: true,
   194  				devices:      sets.New("test-1", "test-2"),
   195  				localNodeInfo: localNodeInfo{
   196  					internalIPv4:  net.ParseIP("2.2.2.2"),
   197  					ipv4AllocCIDR: cidr.MustParseCIDR("6.6.6.0/24").String(),
   198  					ipv6AllocCIDR: cidr.MustParseCIDR("3002:bbbb::/96").String(),
   199  				},
   200  			},
   201  		},
   202  		{
   203  			name: "add first proxy",
   204  			action: func() {
   205  				params.proxies <- reconciliationRequest[proxyInfo]{
   206  					info: proxyInfo{
   207  						name: "proxy-test-1",
   208  						port: 9090,
   209  					},
   210  					updated: make(chan struct{}),
   211  				}
   212  			},
   213  			expected: desiredState{
   214  				installRules: true,
   215  				devices:      sets.New("test-1", "test-2"),
   216  				localNodeInfo: localNodeInfo{
   217  					internalIPv4:  net.ParseIP("2.2.2.2"),
   218  					ipv4AllocCIDR: cidr.MustParseCIDR("6.6.6.0/24").String(),
   219  					ipv6AllocCIDR: cidr.MustParseCIDR("3002:bbbb::/96").String(),
   220  				},
   221  				proxies: map[string]proxyInfo{
   222  					"proxy-test-1": {
   223  						name: "proxy-test-1",
   224  						port: 9090,
   225  					},
   226  				},
   227  			},
   228  		},
   229  		{
   230  			name: "add second proxy",
   231  			action: func() {
   232  				params.proxies <- reconciliationRequest[proxyInfo]{
   233  					info: proxyInfo{
   234  						name: "proxy-test-2",
   235  						port: 9091,
   236  					},
   237  					updated: make(chan struct{}),
   238  				}
   239  			},
   240  			expected: desiredState{
   241  				installRules: true,
   242  				devices:      sets.New("test-1", "test-2"),
   243  				localNodeInfo: localNodeInfo{
   244  					internalIPv4:  net.ParseIP("2.2.2.2"),
   245  					ipv4AllocCIDR: cidr.MustParseCIDR("6.6.6.0/24").String(),
   246  					ipv6AllocCIDR: cidr.MustParseCIDR("3002:bbbb::/96").String(),
   247  				},
   248  				proxies: map[string]proxyInfo{
   249  					"proxy-test-1": {
   250  						name: "proxy-test-1",
   251  						port: 9090,
   252  					},
   253  					"proxy-test-2": {
   254  						name: "proxy-test-2",
   255  						port: 9091,
   256  					},
   257  				},
   258  			},
   259  		},
   260  		{
   261  			name: "add no track pods",
   262  			action: func() {
   263  				params.addNoTrackPod <- reconciliationRequest[noTrackPodInfo]{
   264  					info: noTrackPodInfo{
   265  						ip:   netip.MustParseAddr("1.2.3.4"),
   266  						port: 10001,
   267  					},
   268  					updated: make(chan struct{}),
   269  				}
   270  				params.addNoTrackPod <- reconciliationRequest[noTrackPodInfo]{
   271  					info: noTrackPodInfo{
   272  						ip:   netip.MustParseAddr("11.22.33.44"),
   273  						port: 10002,
   274  					},
   275  					updated: make(chan struct{}),
   276  				}
   277  			},
   278  			expected: desiredState{
   279  				installRules: true,
   280  				devices:      sets.New("test-1", "test-2"),
   281  				localNodeInfo: localNodeInfo{
   282  					internalIPv4:  net.ParseIP("2.2.2.2"),
   283  					ipv4AllocCIDR: cidr.MustParseCIDR("6.6.6.0/24").String(),
   284  					ipv6AllocCIDR: cidr.MustParseCIDR("3002:bbbb::/96").String(),
   285  				},
   286  				proxies: map[string]proxyInfo{
   287  					"proxy-test-1": {
   288  						name: "proxy-test-1",
   289  						port: 9090,
   290  					},
   291  					"proxy-test-2": {
   292  						name: "proxy-test-2",
   293  						port: 9091,
   294  					},
   295  				},
   296  				noTrackPods: sets.New(
   297  					noTrackPodInfo{netip.MustParseAddr("1.2.3.4"), 10001},
   298  					noTrackPodInfo{netip.MustParseAddr("11.22.33.44"), 10002},
   299  				),
   300  			},
   301  		},
   302  		{
   303  			name: "remove no track pod",
   304  			action: func() {
   305  				params.delNoTrackPod <- reconciliationRequest[noTrackPodInfo]{
   306  					info: noTrackPodInfo{
   307  						ip:   netip.MustParseAddr("1.2.3.4"),
   308  						port: 10001,
   309  					},
   310  					updated: make(chan struct{}),
   311  				}
   312  			},
   313  			expected: desiredState{
   314  				installRules: true,
   315  				devices:      sets.New("test-1", "test-2"),
   316  				localNodeInfo: localNodeInfo{
   317  					internalIPv4:  net.ParseIP("2.2.2.2"),
   318  					ipv4AllocCIDR: cidr.MustParseCIDR("6.6.6.0/24").String(),
   319  					ipv6AllocCIDR: cidr.MustParseCIDR("3002:bbbb::/96").String(),
   320  				},
   321  				proxies: map[string]proxyInfo{
   322  					"proxy-test-1": {
   323  						name: "proxy-test-1",
   324  						port: 9090,
   325  					},
   326  					"proxy-test-2": {
   327  						name: "proxy-test-2",
   328  						port: 9091,
   329  					},
   330  				},
   331  				noTrackPods: sets.New(
   332  					noTrackPodInfo{netip.MustParseAddr("11.22.33.44"), 10002},
   333  				),
   334  			},
   335  		},
   336  	}
   337  
   338  	ctx, cancel := context.WithCancel(context.Background())
   339  	defer cancel()
   340  
   341  	tlog := hivetest.Logger(t)
   342  	assert.NoError(t, h.Start(tlog, ctx))
   343  
   344  	// apply initial state
   345  	testCases[0].action()
   346  
   347  	// start the reconciliation loop
   348  	errs := make(chan error)
   349  	go func() {
   350  		defer close(errs)
   351  		errs <- reconciliationLoop(ctx, log, health, true, params, updateFunc, updateProxyFunc, installNoTrackFunc, removeNoTrackFunc)
   352  	}()
   353  
   354  	// wait for reconciler to react to the initial state
   355  	assert.Eventually(t, func() bool {
   356  		mu.Lock()
   357  		defer mu.Unlock()
   358  		if err := assertIptablesState(state, testCases[0].expected); err != nil {
   359  			t.Logf("assertIptablesState: %s", err)
   360  			return false
   361  		}
   362  		return true
   363  	}, 10*time.Second, 10*time.Millisecond, "initial state not reconciled. %v", testCases[0].expected)
   364  
   365  	// test all the remaining steps
   366  	for _, tc := range testCases[1:] {
   367  		t.Run(tc.name, func(t *testing.T) {
   368  			// apply the action to update the state
   369  			tc.action()
   370  
   371  			// wait for reconciler to react to the update
   372  			assert.Eventuallyf(t, func() bool {
   373  				mu.Lock()
   374  				defer mu.Unlock()
   375  				if err := assertIptablesState(state, tc.expected); err != nil {
   376  					t.Logf("assertIptablesState: %s", err)
   377  					return false
   378  				}
   379  				return true
   380  			}, 10*time.Second, 10*time.Millisecond, "expected state not reached. %v", tc.expected)
   381  		})
   382  	}
   383  
   384  	assert.NoError(t, h.Stop(tlog, ctx))
   385  
   386  	close(params.proxies)
   387  	close(params.addNoTrackPod)
   388  	close(params.delNoTrackPod)
   389  	cancel()
   390  	assert.NoError(t, <-errs)
   391  }
   392  
   393  func assertIptablesState(current, expected desiredState) error {
   394  	if current.installRules != expected.installRules {
   395  		return fmt.Errorf("expected installRules to be %t, found %t",
   396  			expected.installRules, current.installRules)
   397  	}
   398  	if !current.devices.Equal(expected.devices) {
   399  		return fmt.Errorf("expected devices names to be %v, found %v",
   400  			expected.devices.UnsortedList(), current.devices.UnsortedList())
   401  	}
   402  	if !current.localNodeInfo.equal(expected.localNodeInfo) {
   403  		return fmt.Errorf("expected local node info to be %v, found %v",
   404  			expected.localNodeInfo, current.localNodeInfo)
   405  	}
   406  	if len(current.proxies) != 0 && len(expected.proxies) != 0 &&
   407  		!reflect.DeepEqual(current.proxies, expected.proxies) {
   408  		return fmt.Errorf("expected proxies info to be %v, found %v",
   409  			expected.proxies, current.proxies)
   410  	}
   411  	if !current.noTrackPods.Equal(expected.noTrackPods) {
   412  		return fmt.Errorf("expected no tracking pods info to be %v, found %v",
   413  			expected.noTrackPods.UnsortedList(), current.noTrackPods.UnsortedList())
   414  	}
   415  	return nil
   416  }
   417  
   418  func (s desiredState) deepCopy() desiredState {
   419  	ipv4 := make(net.IP, len(s.localNodeInfo.internalIPv4))
   420  	copy(ipv4, s.localNodeInfo.internalIPv4)
   421  	ipv6 := make(net.IP, len(s.localNodeInfo.internalIPv6))
   422  	copy(ipv6, s.localNodeInfo.internalIPv6)
   423  	return desiredState{
   424  		installRules: s.installRules,
   425  		devices:      s.devices.Clone(),
   426  		localNodeInfo: localNodeInfo{
   427  			internalIPv4:          ipv4,
   428  			internalIPv6:          ipv6,
   429  			ipv4AllocCIDR:         s.localNodeInfo.ipv4AllocCIDR,
   430  			ipv6AllocCIDR:         s.localNodeInfo.ipv6AllocCIDR,
   431  			ipv4NativeRoutingCIDR: s.localNodeInfo.ipv4NativeRoutingCIDR,
   432  			ipv6NativeRoutingCIDR: s.localNodeInfo.ipv6NativeRoutingCIDR,
   433  		},
   434  		proxies:     maps.Clone(s.proxies),
   435  		noTrackPods: s.noTrackPods.Clone(),
   436  	}
   437  }