istio.io/istio@v0.0.0-20240520182934-d79c90f27776/cni/pkg/nodeagent/server_test.go (about)

     1  // Copyright Istio Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package nodeagent
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"net/netip"
    21  	"sync/atomic"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/stretchr/testify/mock"
    26  	corev1 "k8s.io/api/core/v1"
    27  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    28  	"k8s.io/apimachinery/pkg/types"
    29  	"k8s.io/client-go/kubernetes/fake"
    30  
    31  	"istio.io/istio/pkg/config/constants"
    32  	"istio.io/istio/pkg/test/util/assert"
    33  )
    34  
    35  func TestMeshDataplaneAddsAnnotationOnAdd(t *testing.T) {
    36  	pod := &corev1.Pod{
    37  		ObjectMeta: metav1.ObjectMeta{
    38  			Name:      "test",
    39  			Namespace: "test",
    40  			UID:       types.UID("test"),
    41  		},
    42  	}
    43  
    44  	fakeCtx := context.Background()
    45  	fakeClientSet := fake.NewSimpleClientset(pod)
    46  
    47  	podIP := netip.MustParseAddr("99.9.9.1")
    48  	podIPs := []netip.Addr{podIP}
    49  
    50  	server := &fakeServer{}
    51  	server.On("AddPodToMesh",
    52  		fakeCtx,
    53  		pod,
    54  		podIPs,
    55  		"",
    56  	).Return(nil)
    57  
    58  	server.Start(fakeCtx)
    59  	m := meshDataplane{
    60  		kubeClient: fakeClientSet,
    61  		netServer:  server,
    62  	}
    63  
    64  	err := m.AddPodToMesh(fakeCtx, pod, podIPs, "")
    65  	assert.NoError(t, err)
    66  
    67  	pod, err = fakeClientSet.CoreV1().Pods("test").Get(fakeCtx, "test", metav1.GetOptions{})
    68  	assert.NoError(t, err)
    69  	assert.Equal(t, len(pod.Annotations), 1)
    70  	assert.Equal(t, pod.Annotations[constants.AmbientRedirection], constants.AmbientRedirectionEnabled)
    71  }
    72  
    73  func TestMeshDataplaneAddsAnnotationOnAddWithPartialError(t *testing.T) {
    74  	pod := &corev1.Pod{
    75  		ObjectMeta: metav1.ObjectMeta{
    76  			Name:      "test",
    77  			Namespace: "test",
    78  			UID:       types.UID("test"),
    79  		},
    80  	}
    81  	server := &fakeServer{}
    82  
    83  	podIP := netip.MustParseAddr("99.9.9.1")
    84  	podIPs := []netip.Addr{podIP}
    85  	fakeCtx := context.Background()
    86  
    87  	server.On("AddPodToMesh",
    88  		fakeCtx,
    89  		pod,
    90  		podIPs,
    91  		"",
    92  	).Return(ErrPartialAdd)
    93  
    94  	server.Start(fakeCtx)
    95  	fakeClientSet := fake.NewSimpleClientset(pod)
    96  	m := meshDataplane{
    97  		kubeClient: fakeClientSet,
    98  		netServer:  server,
    99  	}
   100  
   101  	err := m.AddPodToMesh(fakeCtx, pod, podIPs, "")
   102  	assert.Error(t, err)
   103  
   104  	pod, err = fakeClientSet.CoreV1().Pods("test").Get(fakeCtx, "test", metav1.GetOptions{})
   105  	assert.NoError(t, err)
   106  	assert.Equal(t, len(pod.Annotations), 1)
   107  	assert.Equal(t, pod.Annotations[constants.AmbientRedirection], constants.AmbientRedirectionEnabled)
   108  }
   109  
   110  func TestMeshDataplaneDoesntAnnotateOnAddWithRealError(t *testing.T) {
   111  	pod := &corev1.Pod{
   112  		ObjectMeta: metav1.ObjectMeta{
   113  			Name:      "test",
   114  			Namespace: "test",
   115  			UID:       types.UID("test"),
   116  		},
   117  	}
   118  	server := &fakeServer{}
   119  
   120  	podIP := netip.MustParseAddr("99.9.9.1")
   121  	podIPs := []netip.Addr{podIP}
   122  	fakeCtx := context.Background()
   123  
   124  	server.On("AddPodToMesh",
   125  		fakeCtx,
   126  		pod,
   127  		podIPs,
   128  		"",
   129  	).Return(errors.New("not partial error"))
   130  
   131  	server.Start(fakeCtx)
   132  	fakeClientSet := fake.NewSimpleClientset(pod)
   133  	m := meshDataplane{
   134  		kubeClient: fakeClientSet,
   135  		netServer:  server,
   136  	}
   137  
   138  	err := m.AddPodToMesh(fakeCtx, pod, podIPs, "")
   139  	assert.Error(t, err)
   140  
   141  	pod, err = fakeClientSet.CoreV1().Pods("test").Get(fakeCtx, "test", metav1.GetOptions{})
   142  	assert.NoError(t, err)
   143  	assert.Equal(t, len(pod.Annotations), 0)
   144  }
   145  
   146  func TestMeshDataplaneRemovePodRemovesAnnotation(t *testing.T) {
   147  	pod := podWithAnnotation()
   148  	fakeCtx := context.Background()
   149  
   150  	server := &fakeServer{}
   151  	server.Start(fakeCtx)
   152  
   153  	server.On("RemovePodFromMesh",
   154  		fakeCtx,
   155  		pod,
   156  	).Return(nil)
   157  
   158  	fakeClientSet := fake.NewSimpleClientset(pod)
   159  	m := meshDataplane{
   160  		kubeClient: fakeClientSet,
   161  		netServer:  server,
   162  	}
   163  
   164  	err := m.RemovePodFromMesh(fakeCtx, pod)
   165  	assert.NoError(t, err)
   166  
   167  	pod, err = fakeClientSet.CoreV1().Pods("test").Get(fakeCtx, "test", metav1.GetOptions{})
   168  	assert.NoError(t, err)
   169  	assert.Equal(t, len(pod.Annotations), 0)
   170  }
   171  
   172  func TestMeshDataplaneRemovePodErrorDoesntRemoveAnnotation(t *testing.T) {
   173  	pod := podWithAnnotation()
   174  	fakeCtx := context.Background()
   175  	server := &fakeServer{}
   176  	server.Start(fakeCtx)
   177  
   178  	server.On("RemovePodFromMesh",
   179  		fakeCtx,
   180  		pod,
   181  	).Return(errors.New("fake error"))
   182  
   183  	fakeClientSet := fake.NewSimpleClientset(pod)
   184  	m := meshDataplane{
   185  		kubeClient: fakeClientSet,
   186  		netServer:  server,
   187  	}
   188  
   189  	err := m.RemovePodFromMesh(fakeCtx, pod)
   190  	assert.Error(t, err)
   191  
   192  	pod, err = fakeClientSet.CoreV1().Pods("test").Get(fakeCtx, "test", metav1.GetOptions{})
   193  	assert.NoError(t, err)
   194  	assert.Equal(t, pod.Annotations[constants.AmbientRedirection], constants.AmbientRedirectionEnabled)
   195  }
   196  
   197  func TestMeshDataplaneDelPod(t *testing.T) {
   198  	pod := podWithAnnotation()
   199  
   200  	fakeCtx := context.Background()
   201  	server := &fakeServer{}
   202  	server.Start(fakeCtx)
   203  
   204  	server.On("DelPodFromMesh",
   205  		fakeCtx,
   206  		pod,
   207  	).Return(nil)
   208  
   209  	fakeClientSet := fake.NewSimpleClientset()
   210  	m := meshDataplane{
   211  		kubeClient: fakeClientSet,
   212  		netServer:  server,
   213  	}
   214  
   215  	// pod is not in fake client, so if this will try to remove annotation, it will fail.
   216  	err := m.DelPodFromMesh(fakeCtx, pod)
   217  	assert.NoError(t, err)
   218  }
   219  
   220  func TestMeshDataplaneDelPodErrorDoesntPatchPod(t *testing.T) {
   221  	pod := podWithAnnotation()
   222  
   223  	fakeCtx := context.Background()
   224  	server := &fakeServer{}
   225  	server.Start(fakeCtx)
   226  
   227  	server.On("DelPodFromMesh",
   228  		fakeCtx,
   229  		pod,
   230  	).Return(errors.New("fake error"))
   231  
   232  	fakeClientSet := fake.NewSimpleClientset()
   233  	m := meshDataplane{
   234  		kubeClient: fakeClientSet,
   235  		netServer:  server,
   236  	}
   237  
   238  	// pod is not in fake client, so if this will try to remove annotation, it will fail.
   239  	err := m.DelPodFromMesh(fakeCtx, pod)
   240  	assert.Error(t, err)
   241  }
   242  
   243  func podWithAnnotation() *corev1.Pod {
   244  	return &corev1.Pod{
   245  		ObjectMeta: metav1.ObjectMeta{
   246  			Name:      "test",
   247  			Namespace: "test",
   248  			UID:       types.UID("test"),
   249  			Annotations: map[string]string{
   250  				constants.AmbientRedirection: constants.AmbientRedirectionEnabled,
   251  			},
   252  		},
   253  	}
   254  }
   255  
   256  type fakeServer struct {
   257  	mock.Mock
   258  	testWG *WaitGroup // optional waitgroup, if code under test makes a number of async calls to fakeServer
   259  }
   260  
   261  func (f *fakeServer) AddPodToMesh(ctx context.Context, pod *corev1.Pod, podIPs []netip.Addr, netNs string) error {
   262  	if f.testWG != nil {
   263  		defer f.testWG.Done()
   264  	}
   265  	args := f.Called(ctx, pod, podIPs, netNs)
   266  	return args.Error(0)
   267  }
   268  
   269  func (f *fakeServer) RemovePodFromMesh(ctx context.Context, pod *corev1.Pod) error {
   270  	if f.testWG != nil {
   271  		defer f.testWG.Done()
   272  	}
   273  	args := f.Called(ctx, pod)
   274  	return args.Error(0)
   275  }
   276  
   277  func (f *fakeServer) DelPodFromMesh(ctx context.Context, pod *corev1.Pod) error {
   278  	if f.testWG != nil {
   279  		defer f.testWG.Done()
   280  	}
   281  	args := f.Called(ctx, pod)
   282  	return args.Error(0)
   283  }
   284  
   285  func (f *fakeServer) Start(ctx context.Context) {
   286  }
   287  
   288  func (f *fakeServer) Stop() {
   289  }
   290  
   291  func (f *fakeServer) ConstructInitialSnapshot(ambientPods []*corev1.Pod) error {
   292  	if f.testWG != nil {
   293  		defer f.testWG.Done()
   294  	}
   295  	args := f.Called(ambientPods)
   296  	return args.Error(0)
   297  }
   298  
   299  // Custom "wait group with timeout" for waiting for fakeServer calls in a goroutine to finish
   300  type WaitGroup struct {
   301  	count int32
   302  	done  chan struct{}
   303  }
   304  
   305  func NewWaitGroup() *WaitGroup {
   306  	return &WaitGroup{
   307  		done: make(chan struct{}),
   308  	}
   309  }
   310  
   311  func NewWaitForNCalls(t *testing.T, n int32) (*WaitGroup, func()) {
   312  	wg := &WaitGroup{
   313  		done: make(chan struct{}),
   314  	}
   315  
   316  	wg.Add(n)
   317  	return wg, func() {
   318  		select {
   319  		case <-wg.C():
   320  			return
   321  		case <-time.After(time.Second):
   322  			t.Fatal("Wait group timed out!\n")
   323  		}
   324  	}
   325  }
   326  
   327  func (wg *WaitGroup) Add(i int32) {
   328  	select {
   329  	case <-wg.done:
   330  		panic("use of an already closed WaitGroup")
   331  	default:
   332  	}
   333  	atomic.AddInt32(&wg.count, i)
   334  }
   335  
   336  func (wg *WaitGroup) Done() {
   337  	i := atomic.AddInt32(&wg.count, -1)
   338  	if i == 0 {
   339  		close(wg.done)
   340  	}
   341  }
   342  
   343  func (wg *WaitGroup) C() <-chan struct{} {
   344  	return wg.done
   345  }