github.com/kubeflow/training-operator@v1.7.0/pkg/controller.v1/common/pod_test.go (about) 1 // Copyright 2018 The Kubeflow Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package common 16 17 import ( 18 "testing" 19 20 apiv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" 21 "github.com/kubeflow/training-operator/pkg/core" 22 testjobv1 "github.com/kubeflow/training-operator/test_job/apis/test_job/v1" 23 v12 "github.com/kubeflow/training-operator/test_job/test_util/v1" 24 25 "github.com/stretchr/testify/assert" 26 v1 "k8s.io/api/core/v1" 27 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" 28 ) 29 30 func TestSetRestartPolicy(t *testing.T) { 31 type tc struct { 32 testJob *testjobv1.TestJob 33 expectedRestartPolicy v1.RestartPolicy 34 expectedType testjobv1.TestReplicaType 35 } 36 testCase := []tc{ 37 func() tc { 38 tj := v12.NewTestJob(2) 39 tj.Spec.TestReplicaSpecs[testjobv1.TestReplicaTypeWorker].RestartPolicy = apiv1.RestartPolicyExitCode 40 return tc{ 41 testJob: tj, 42 expectedRestartPolicy: v1.RestartPolicyNever, 43 expectedType: testjobv1.TestReplicaTypeWorker, 44 } 45 }(), 46 func() tc { 47 tj := v12.NewTestJob(2) 48 tj.Spec.TestReplicaSpecs[testjobv1.TestReplicaTypeWorker].RestartPolicy = apiv1.RestartPolicyNever 49 return tc{ 50 testJob: tj, 51 expectedRestartPolicy: v1.RestartPolicyNever, 52 expectedType: testjobv1.TestReplicaTypeWorker, 53 } 54 }(), 55 func() tc { 56 tj := v12.NewTestJob(2) 57 tj.Spec.TestReplicaSpecs[testjobv1.TestReplicaTypeWorker].RestartPolicy = apiv1.RestartPolicyAlways 58 return tc{ 59 testJob: tj, 60 expectedRestartPolicy: v1.RestartPolicyAlways, 61 expectedType: testjobv1.TestReplicaTypeWorker, 62 } 63 }(), 64 func() tc { 65 tj := v12.NewTestJob(2) 66 tj.Spec.TestReplicaSpecs[testjobv1.TestReplicaTypeWorker].RestartPolicy = apiv1.RestartPolicyOnFailure 67 return tc{ 68 testJob: tj, 69 expectedRestartPolicy: v1.RestartPolicyOnFailure, 70 expectedType: testjobv1.TestReplicaTypeWorker, 71 } 72 }(), 73 } 74 for _, c := range testCase { 75 spec := c.testJob.Spec.TestReplicaSpecs[c.expectedType] 76 podTemplate := spec.Template 77 core.SetRestartPolicy(&podTemplate, spec) 78 if podTemplate.Spec.RestartPolicy != c.expectedRestartPolicy { 79 t.Errorf("Expected %s, got %s", c.expectedRestartPolicy, podTemplate.Spec.RestartPolicy) 80 } 81 } 82 } 83 84 func TestIsCustomSchedulerSet(t *testing.T) { 85 gangSchedulerName := "test-gang-scheduler" 86 replicaSpecs := map[apiv1.ReplicaType]*apiv1.ReplicaSpec{} 87 assert.False(t, isCustomSchedulerSet(replicaSpecs, gangSchedulerName)) 88 89 replicaSpecs[apiv1.ReplicaType(testjobv1.TestReplicaTypeWorker)] = &apiv1.ReplicaSpec{ 90 Template: v1.PodTemplateSpec{ 91 Spec: v1.PodSpec{ 92 SchedulerName: gangSchedulerName, 93 }, 94 }, 95 } 96 assert.False(t, isCustomSchedulerSet(replicaSpecs, gangSchedulerName)) 97 98 replicaSpecs[apiv1.ReplicaType(testjobv1.TestReplicaTypeWorker)] = &apiv1.ReplicaSpec{ 99 Template: v1.PodTemplateSpec{ 100 Spec: v1.PodSpec{ 101 SchedulerName: "other-scheduler", 102 }, 103 }, 104 } 105 assert.True(t, isCustomSchedulerSet(replicaSpecs, gangSchedulerName)) 106 } 107 108 func TestCalculatePodSliceSize(t *testing.T) { 109 type testCase struct { 110 pods []*v1.Pod 111 replicas int 112 expectedSize int 113 } 114 115 pods := []*v1.Pod{ 116 { 117 ObjectMeta: metav1.ObjectMeta{ 118 Labels: map[string]string{apiv1.ReplicaIndexLabel: "0"}, 119 }, 120 }, 121 { 122 ObjectMeta: metav1.ObjectMeta{ 123 Labels: map[string]string{apiv1.ReplicaIndexLabel: "1"}, 124 }, 125 }, 126 { 127 ObjectMeta: metav1.ObjectMeta{ 128 Labels: map[string]string{apiv1.ReplicaIndexLabel: "2"}, 129 }, 130 }, 131 } 132 133 var testCases = []testCase{ 134 { 135 pods: pods, 136 replicas: 3, 137 expectedSize: 3, 138 }, 139 { 140 pods: pods, 141 replicas: 4, 142 expectedSize: 4, 143 }, 144 { 145 pods: pods, 146 replicas: 2, 147 expectedSize: 3, 148 }, 149 { 150 pods: append(pods, &v1.Pod{ 151 ObjectMeta: metav1.ObjectMeta{ 152 Labels: map[string]string{apiv1.ReplicaIndexLabel: "4"}, 153 }, 154 }), 155 replicas: 3, 156 expectedSize: 5, 157 }, 158 } 159 160 for _, tc := range testCases { 161 result := core.CalculatePodSliceSize(tc.pods, tc.replicas) 162 assert.Equal(t, tc.expectedSize, result) 163 } 164 } 165 166 func TestFilterPodsForReplicaType(t *testing.T) { 167 pods := []*v1.Pod{ 168 { 169 ObjectMeta: metav1.ObjectMeta{ 170 Name: "a", 171 Labels: map[string]string{apiv1.ReplicaTypeLabel: "foo"}, 172 }, 173 }, 174 { 175 ObjectMeta: metav1.ObjectMeta{ 176 Name: "b", 177 Labels: map[string]string{apiv1.ReplicaTypeLabel: "bar"}, 178 }, 179 }, 180 { 181 ObjectMeta: metav1.ObjectMeta{ 182 Name: "c", 183 Labels: map[string]string{apiv1.ReplicaTypeLabel: "foo"}, 184 }, 185 }, 186 { 187 ObjectMeta: metav1.ObjectMeta{ 188 Name: "d", 189 Labels: map[string]string{apiv1.ReplicaTypeLabel: "bar"}, 190 }, 191 }, 192 { 193 ObjectMeta: metav1.ObjectMeta{ 194 Name: "e", 195 Labels: map[string]string{ 196 apiv1.ReplicaTypeLabel: "foo", 197 }, 198 }, 199 }, 200 } 201 c := &JobController{} 202 got, err := c.FilterPodsForReplicaType(pods, "foo") 203 if err != nil { 204 t.Fatalf("FilterPodsForReplicaType returned error: %v", err) 205 } 206 want := []*v1.Pod{pods[0], pods[2], pods[4]} 207 assert.Equal(t, want, got) 208 }