github.com/kubeflow/training-operator@v1.7.0/pkg/controller.v1/common/pod_test.go (about)

     1  // Copyright 2018 The Kubeflow Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package common
    16  
    17  import (
    18  	"testing"
    19  
    20  	apiv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
    21  	"github.com/kubeflow/training-operator/pkg/core"
    22  	testjobv1 "github.com/kubeflow/training-operator/test_job/apis/test_job/v1"
    23  	v12 "github.com/kubeflow/training-operator/test_job/test_util/v1"
    24  
    25  	"github.com/stretchr/testify/assert"
    26  	v1 "k8s.io/api/core/v1"
    27  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    28  )
    29  
    30  func TestSetRestartPolicy(t *testing.T) {
    31  	type tc struct {
    32  		testJob               *testjobv1.TestJob
    33  		expectedRestartPolicy v1.RestartPolicy
    34  		expectedType          testjobv1.TestReplicaType
    35  	}
    36  	testCase := []tc{
    37  		func() tc {
    38  			tj := v12.NewTestJob(2)
    39  			tj.Spec.TestReplicaSpecs[testjobv1.TestReplicaTypeWorker].RestartPolicy = apiv1.RestartPolicyExitCode
    40  			return tc{
    41  				testJob:               tj,
    42  				expectedRestartPolicy: v1.RestartPolicyNever,
    43  				expectedType:          testjobv1.TestReplicaTypeWorker,
    44  			}
    45  		}(),
    46  		func() tc {
    47  			tj := v12.NewTestJob(2)
    48  			tj.Spec.TestReplicaSpecs[testjobv1.TestReplicaTypeWorker].RestartPolicy = apiv1.RestartPolicyNever
    49  			return tc{
    50  				testJob:               tj,
    51  				expectedRestartPolicy: v1.RestartPolicyNever,
    52  				expectedType:          testjobv1.TestReplicaTypeWorker,
    53  			}
    54  		}(),
    55  		func() tc {
    56  			tj := v12.NewTestJob(2)
    57  			tj.Spec.TestReplicaSpecs[testjobv1.TestReplicaTypeWorker].RestartPolicy = apiv1.RestartPolicyAlways
    58  			return tc{
    59  				testJob:               tj,
    60  				expectedRestartPolicy: v1.RestartPolicyAlways,
    61  				expectedType:          testjobv1.TestReplicaTypeWorker,
    62  			}
    63  		}(),
    64  		func() tc {
    65  			tj := v12.NewTestJob(2)
    66  			tj.Spec.TestReplicaSpecs[testjobv1.TestReplicaTypeWorker].RestartPolicy = apiv1.RestartPolicyOnFailure
    67  			return tc{
    68  				testJob:               tj,
    69  				expectedRestartPolicy: v1.RestartPolicyOnFailure,
    70  				expectedType:          testjobv1.TestReplicaTypeWorker,
    71  			}
    72  		}(),
    73  	}
    74  	for _, c := range testCase {
    75  		spec := c.testJob.Spec.TestReplicaSpecs[c.expectedType]
    76  		podTemplate := spec.Template
    77  		core.SetRestartPolicy(&podTemplate, spec)
    78  		if podTemplate.Spec.RestartPolicy != c.expectedRestartPolicy {
    79  			t.Errorf("Expected %s, got %s", c.expectedRestartPolicy, podTemplate.Spec.RestartPolicy)
    80  		}
    81  	}
    82  }
    83  
    84  func TestIsCustomSchedulerSet(t *testing.T) {
    85  	gangSchedulerName := "test-gang-scheduler"
    86  	replicaSpecs := map[apiv1.ReplicaType]*apiv1.ReplicaSpec{}
    87  	assert.False(t, isCustomSchedulerSet(replicaSpecs, gangSchedulerName))
    88  
    89  	replicaSpecs[apiv1.ReplicaType(testjobv1.TestReplicaTypeWorker)] = &apiv1.ReplicaSpec{
    90  		Template: v1.PodTemplateSpec{
    91  			Spec: v1.PodSpec{
    92  				SchedulerName: gangSchedulerName,
    93  			},
    94  		},
    95  	}
    96  	assert.False(t, isCustomSchedulerSet(replicaSpecs, gangSchedulerName))
    97  
    98  	replicaSpecs[apiv1.ReplicaType(testjobv1.TestReplicaTypeWorker)] = &apiv1.ReplicaSpec{
    99  		Template: v1.PodTemplateSpec{
   100  			Spec: v1.PodSpec{
   101  				SchedulerName: "other-scheduler",
   102  			},
   103  		},
   104  	}
   105  	assert.True(t, isCustomSchedulerSet(replicaSpecs, gangSchedulerName))
   106  }
   107  
   108  func TestCalculatePodSliceSize(t *testing.T) {
   109  	type testCase struct {
   110  		pods         []*v1.Pod
   111  		replicas     int
   112  		expectedSize int
   113  	}
   114  
   115  	pods := []*v1.Pod{
   116  		{
   117  			ObjectMeta: metav1.ObjectMeta{
   118  				Labels: map[string]string{apiv1.ReplicaIndexLabel: "0"},
   119  			},
   120  		},
   121  		{
   122  			ObjectMeta: metav1.ObjectMeta{
   123  				Labels: map[string]string{apiv1.ReplicaIndexLabel: "1"},
   124  			},
   125  		},
   126  		{
   127  			ObjectMeta: metav1.ObjectMeta{
   128  				Labels: map[string]string{apiv1.ReplicaIndexLabel: "2"},
   129  			},
   130  		},
   131  	}
   132  
   133  	var testCases = []testCase{
   134  		{
   135  			pods:         pods,
   136  			replicas:     3,
   137  			expectedSize: 3,
   138  		},
   139  		{
   140  			pods:         pods,
   141  			replicas:     4,
   142  			expectedSize: 4,
   143  		},
   144  		{
   145  			pods:         pods,
   146  			replicas:     2,
   147  			expectedSize: 3,
   148  		},
   149  		{
   150  			pods: append(pods, &v1.Pod{
   151  				ObjectMeta: metav1.ObjectMeta{
   152  					Labels: map[string]string{apiv1.ReplicaIndexLabel: "4"},
   153  				},
   154  			}),
   155  			replicas:     3,
   156  			expectedSize: 5,
   157  		},
   158  	}
   159  
   160  	for _, tc := range testCases {
   161  		result := core.CalculatePodSliceSize(tc.pods, tc.replicas)
   162  		assert.Equal(t, tc.expectedSize, result)
   163  	}
   164  }
   165  
   166  func TestFilterPodsForReplicaType(t *testing.T) {
   167  	pods := []*v1.Pod{
   168  		{
   169  			ObjectMeta: metav1.ObjectMeta{
   170  				Name:   "a",
   171  				Labels: map[string]string{apiv1.ReplicaTypeLabel: "foo"},
   172  			},
   173  		},
   174  		{
   175  			ObjectMeta: metav1.ObjectMeta{
   176  				Name:   "b",
   177  				Labels: map[string]string{apiv1.ReplicaTypeLabel: "bar"},
   178  			},
   179  		},
   180  		{
   181  			ObjectMeta: metav1.ObjectMeta{
   182  				Name:   "c",
   183  				Labels: map[string]string{apiv1.ReplicaTypeLabel: "foo"},
   184  			},
   185  		},
   186  		{
   187  			ObjectMeta: metav1.ObjectMeta{
   188  				Name:   "d",
   189  				Labels: map[string]string{apiv1.ReplicaTypeLabel: "bar"},
   190  			},
   191  		},
   192  		{
   193  			ObjectMeta: metav1.ObjectMeta{
   194  				Name: "e",
   195  				Labels: map[string]string{
   196  					apiv1.ReplicaTypeLabel: "foo",
   197  				},
   198  			},
   199  		},
   200  	}
   201  	c := &JobController{}
   202  	got, err := c.FilterPodsForReplicaType(pods, "foo")
   203  	if err != nil {
   204  		t.Fatalf("FilterPodsForReplicaType returned error: %v", err)
   205  	}
   206  	want := []*v1.Pod{pods[0], pods[2], pods[4]}
   207  	assert.Equal(t, want, got)
   208  }