github.com/kubeflow/training-operator@v1.7.0/pkg/controller.v1/pytorch/pytorchjob_controller_test.go (about) 1 // Copyright 2021 The Kubeflow Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package pytorch 16 17 import ( 18 "context" 19 "fmt" 20 21 . "github.com/onsi/ginkgo/v2" 22 . "github.com/onsi/gomega" 23 autoscalingv2 "k8s.io/api/autoscaling/v2" 24 corev1 "k8s.io/api/core/v1" 25 "k8s.io/apimachinery/pkg/api/errors" 26 "k8s.io/apimachinery/pkg/api/resource" 27 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" 28 "k8s.io/apimachinery/pkg/types" 29 "k8s.io/utils/pointer" 30 "sigs.k8s.io/controller-runtime/pkg/client" 31 32 kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" 33 commonutil "github.com/kubeflow/training-operator/pkg/util" 34 "github.com/kubeflow/training-operator/pkg/util/testutil" 35 ) 36 37 var _ = Describe("PyTorchJob controller", func() { 38 // Define utility constants for object names. 39 const ( 40 expectedPort = int32(8080) 41 ) 42 43 Context("When creating the PyTorchJob", func() { 44 const name = "test-job" 45 var ( 46 ns *corev1.Namespace 47 job *kubeflowv1.PyTorchJob 48 jobKey types.NamespacedName 49 masterKey types.NamespacedName 50 worker0Key types.NamespacedName 51 ctx = context.Background() 52 ) 53 BeforeEach(func() { 54 ns = &corev1.Namespace{ 55 ObjectMeta: metav1.ObjectMeta{ 56 GenerateName: "pytorch-test-", 57 }, 58 } 59 Expect(testK8sClient.Create(ctx, ns)).Should(Succeed()) 60 61 job = newPyTorchJobForTest(name, ns.Name) 62 jobKey = client.ObjectKeyFromObject(job) 63 masterKey = types.NamespacedName{ 64 Name: fmt.Sprintf("%s-master-0", name), 65 Namespace: ns.Name, 66 } 67 worker0Key = types.NamespacedName{ 68 Name: fmt.Sprintf("%s-worker-0", name), 69 Namespace: ns.Name, 70 } 71 job.Spec.NprocPerNode = nil 72 job.Spec.PyTorchReplicaSpecs = map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{ 73 kubeflowv1.PyTorchJobReplicaTypeMaster: { 74 Replicas: pointer.Int32(1), 75 Template: corev1.PodTemplateSpec{ 76 Spec: corev1.PodSpec{ 77 Containers: []corev1.Container{ 78 { 79 Image: "test-image", 80 Name: kubeflowv1.PyTorchJobDefaultContainerName, 81 Ports: []corev1.ContainerPort{ 82 { 83 Name: kubeflowv1.PyTorchJobDefaultPortName, 84 ContainerPort: expectedPort, 85 Protocol: corev1.ProtocolTCP, 86 }, 87 }, 88 }, 89 }, 90 }, 91 }, 92 }, 93 kubeflowv1.PyTorchJobReplicaTypeWorker: { 94 Replicas: pointer.Int32(2), 95 Template: corev1.PodTemplateSpec{ 96 Spec: corev1.PodSpec{ 97 Containers: []corev1.Container{ 98 { 99 Image: "test-image", 100 Name: kubeflowv1.PyTorchJobDefaultContainerName, 101 Ports: []corev1.ContainerPort{ 102 { 103 Name: kubeflowv1.PyTorchJobDefaultPortName, 104 ContainerPort: expectedPort, 105 Protocol: corev1.ProtocolTCP, 106 }, 107 }, 108 }, 109 }, 110 }, 111 }, 112 }, 113 } 114 }) 115 AfterEach(func() { 116 Expect(testK8sClient.Delete(ctx, job)).Should(Succeed()) 117 Expect(testK8sClient.Delete(ctx, ns)).Should(Succeed()) 118 }) 119 It("Should get the corresponding resources successfully", func() { 120 By("By creating a new PyTorchJob") 121 Expect(testK8sClient.Create(ctx, job)).Should(Succeed()) 122 123 created := &kubeflowv1.PyTorchJob{} 124 125 // We'll need to retry getting this newly created PyTorchJob, given that creation may not immediately happen. 126 Eventually(func() bool { 127 err := testK8sClient.Get(ctx, jobKey, created) 128 return err == nil 129 }, testutil.Timeout, testutil.Interval).Should(BeTrue()) 130 131 masterPod := &corev1.Pod{} 132 Eventually(func() bool { 133 err := testK8sClient.Get(ctx, masterKey, masterPod) 134 return err == nil 135 }, testutil.Timeout, testutil.Interval).Should(BeTrue()) 136 137 masterSvc := &corev1.Service{} 138 Eventually(func() bool { 139 err := testK8sClient.Get(ctx, masterKey, masterSvc) 140 return err == nil 141 }, testutil.Timeout, testutil.Interval).Should(BeTrue()) 142 143 // Check the pod port. 144 Expect(masterPod.Spec.Containers[0].Ports).To(ContainElement(corev1.ContainerPort{ 145 Name: kubeflowv1.PyTorchJobDefaultPortName, 146 ContainerPort: expectedPort, 147 Protocol: corev1.ProtocolTCP})) 148 // Check env variable 149 Expect(masterPod.Spec.Containers[0].Env).To(ContainElements(corev1.EnvVar{ 150 Name: EnvMasterPort, 151 Value: fmt.Sprintf("%d", masterSvc.Spec.Ports[0].Port), 152 }, corev1.EnvVar{ 153 Name: EnvMasterAddr, 154 Value: masterSvc.Name, 155 }, corev1.EnvVar{ 156 Name: EnvNprocPerNode, 157 Value: kubeflowv1.DefaultNprocPerNode, 158 })) 159 // Check service port. 160 Expect(masterSvc.Spec.Ports[0].Port).To(Equal(expectedPort)) 161 // Check owner reference. 162 trueVal := true 163 Expect(masterPod.OwnerReferences).To(ContainElement(metav1.OwnerReference{ 164 APIVersion: kubeflowv1.SchemeGroupVersion.String(), 165 Kind: kubeflowv1.PyTorchJobKind, 166 Name: name, 167 UID: created.UID, 168 Controller: &trueVal, 169 BlockOwnerDeletion: &trueVal, 170 })) 171 Expect(masterSvc.OwnerReferences).To(ContainElement(metav1.OwnerReference{ 172 APIVersion: kubeflowv1.SchemeGroupVersion.String(), 173 Kind: kubeflowv1.PyTorchJobKind, 174 Name: name, 175 UID: created.UID, 176 Controller: &trueVal, 177 BlockOwnerDeletion: &trueVal, 178 })) 179 180 // Test job status. 181 Eventually(func() error { 182 Expect(testK8sClient.Get(ctx, masterKey, masterPod)).Should(Succeed()) 183 masterPod.Status.Phase = corev1.PodSucceeded 184 return testK8sClient.Status().Update(ctx, masterPod) 185 }, testutil.Timeout, testutil.Interval).Should(Succeed()) 186 Eventually(func() bool { 187 err := testK8sClient.Get(ctx, jobKey, created) 188 if err != nil { 189 return false 190 } 191 return created.Status.ReplicaStatuses != nil && created.Status. 192 ReplicaStatuses[kubeflowv1.PyTorchJobReplicaTypeMaster].Succeeded == 1 193 }, testutil.Timeout, testutil.Interval).Should(BeTrue()) 194 // Check if the job is succeeded. 195 cond := getCondition(created.Status, kubeflowv1.JobSucceeded) 196 Expect(cond.Status).To(Equal(corev1.ConditionTrue)) 197 }) 198 199 It("Shouldn't create resources if PyTorchJob is suspended", func() { 200 By("By creating a new PyTorchJob with suspend=true") 201 job.Spec.RunPolicy.Suspend = pointer.Bool(true) 202 job.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas = pointer.Int32(1) 203 Expect(testK8sClient.Create(ctx, job)).Should(Succeed()) 204 205 created := &kubeflowv1.PyTorchJob{} 206 masterPod := &corev1.Pod{} 207 workerPod := &corev1.Pod{} 208 masterSvc := &corev1.Service{} 209 workerSvc := &corev1.Service{} 210 211 By("Checking created PyTorchJob") 212 Eventually(func() bool { 213 err := testK8sClient.Get(ctx, jobKey, created) 214 return err == nil 215 }, testutil.Timeout, testutil.Interval).Should(BeTrue()) 216 By("Checking created PyTorchJob has a nil startTime") 217 Consistently(func() *metav1.Time { 218 Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) 219 return created.Status.StartTime 220 }, testutil.ConsistentDuration, testutil.Interval).Should(BeNil()) 221 222 By("Checking if the pods and services aren't created") 223 Consistently(func() bool { 224 errMasterPod := testK8sClient.Get(ctx, masterKey, masterPod) 225 errWorkerPod := testK8sClient.Get(ctx, worker0Key, workerPod) 226 errMasterSvc := testK8sClient.Get(ctx, masterKey, masterSvc) 227 errWorkerSvc := testK8sClient.Get(ctx, worker0Key, workerSvc) 228 return errors.IsNotFound(errMasterPod) && errors.IsNotFound(errWorkerPod) && 229 errors.IsNotFound(errMasterSvc) && errors.IsNotFound(errWorkerSvc) 230 }, testutil.ConsistentDuration, testutil.Interval).Should(BeTrue()) 231 232 By("Checking if the PyTorchJob has suspended condition") 233 Eventually(func() []kubeflowv1.JobCondition { 234 Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) 235 return created.Status.Conditions 236 }, testutil.ConsistentDuration, testutil.Interval).Should(BeComparableTo([]kubeflowv1.JobCondition{ 237 { 238 Type: kubeflowv1.JobCreated, 239 Status: corev1.ConditionTrue, 240 Reason: commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobCreatedReason), 241 Message: fmt.Sprintf("PyTorchJob %s is created.", name), 242 }, 243 { 244 Type: kubeflowv1.JobSuspended, 245 Status: corev1.ConditionTrue, 246 Reason: commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobSuspendedReason), 247 Message: fmt.Sprintf("PyTorchJob %s is suspended.", name), 248 }, 249 }, testutil.IgnoreJobConditionsTimes)) 250 }) 251 252 It("Should delete resources after PyTorchJob is suspended; Should resume PyTorchJob after PyTorchJob is unsuspended", func() { 253 By("By creating a new PyTorchJob") 254 job.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas = pointer.Int32(1) 255 Expect(testK8sClient.Create(ctx, job)).Should(Succeed()) 256 257 created := &kubeflowv1.PyTorchJob{} 258 masterPod := &corev1.Pod{} 259 workerPod := &corev1.Pod{} 260 masterSvc := &corev1.Service{} 261 workerSvc := &corev1.Service{} 262 263 // We'll need to retry getting this newly created PyTorchJob, given that creation may not immediately happen. 264 By("Checking created PyTorchJob") 265 Eventually(func() bool { 266 err := testK8sClient.Get(ctx, jobKey, created) 267 return err == nil 268 }, testutil.Timeout, testutil.Interval).Should(BeTrue()) 269 270 var startTimeBeforeSuspended *metav1.Time 271 Eventually(func() *metav1.Time { 272 Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) 273 startTimeBeforeSuspended = created.Status.StartTime 274 return startTimeBeforeSuspended 275 }, testutil.Timeout, testutil.Interval).ShouldNot(BeNil()) 276 277 By("Checking the created pods and services") 278 Eventually(func() bool { 279 errMaster := testK8sClient.Get(ctx, masterKey, masterPod) 280 errWorker := testK8sClient.Get(ctx, worker0Key, workerPod) 281 return errMaster == nil && errWorker == nil 282 }, testutil.Timeout, testutil.Interval).Should(BeTrue()) 283 Eventually(func() bool { 284 errMaster := testK8sClient.Get(ctx, masterKey, masterSvc) 285 errWorker := testK8sClient.Get(ctx, worker0Key, workerSvc) 286 return errMaster == nil && errWorker == nil 287 }, testutil.Timeout, testutil.Interval).Should(BeTrue()) 288 289 By("Updating the pod's phase with Running") 290 Eventually(func() error { 291 Expect(testK8sClient.Get(ctx, masterKey, masterPod)).Should(Succeed()) 292 masterPod.Status.Phase = corev1.PodRunning 293 return testK8sClient.Status().Update(ctx, masterPod) 294 }, testutil.Timeout, testutil.Interval).Should(Succeed()) 295 Eventually(func() error { 296 Expect(testK8sClient.Get(ctx, worker0Key, workerPod)).Should(Succeed()) 297 workerPod.Status.Phase = corev1.PodRunning 298 return testK8sClient.Status().Update(ctx, workerPod) 299 }, testutil.Timeout, testutil.Interval).Should(Succeed()) 300 301 By("Checking the PyTorchJob's condition") 302 Eventually(func() []kubeflowv1.JobCondition { 303 Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) 304 return created.Status.Conditions 305 }, testutil.Timeout, testutil.Interval).Should(BeComparableTo([]kubeflowv1.JobCondition{ 306 { 307 Type: kubeflowv1.JobCreated, 308 Status: corev1.ConditionTrue, 309 Reason: commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobCreatedReason), 310 Message: fmt.Sprintf("PyTorchJob %s is created.", name), 311 }, 312 { 313 Type: kubeflowv1.JobRunning, 314 Status: corev1.ConditionTrue, 315 Reason: commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobRunningReason), 316 Message: fmt.Sprintf("PyTorchJob %s is running.", name), 317 }, 318 }, testutil.IgnoreJobConditionsTimes)) 319 320 By("Updating the PyTorchJob with suspend=true") 321 Eventually(func() error { 322 Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) 323 created.Spec.RunPolicy.Suspend = pointer.Bool(true) 324 return testK8sClient.Update(ctx, created) 325 }, testutil.Timeout, testutil.Interval).Should(Succeed()) 326 327 By("Checking if the pods and services are removed") 328 Eventually(func() bool { 329 errMaster := testK8sClient.Get(ctx, masterKey, masterPod) 330 errWorker := testK8sClient.Get(ctx, worker0Key, workerPod) 331 return errors.IsNotFound(errMaster) && errors.IsNotFound(errWorker) 332 }, testutil.Timeout, testutil.Interval).Should(BeTrue()) 333 Eventually(func() bool { 334 errMaster := testK8sClient.Get(ctx, masterKey, masterSvc) 335 errWorker := testK8sClient.Get(ctx, worker0Key, workerSvc) 336 return errors.IsNotFound(errMaster) && errors.IsNotFound(errWorker) 337 }, testutil.Timeout, testutil.Interval).Should(BeTrue()) 338 Consistently(func() bool { 339 errMasterPod := testK8sClient.Get(ctx, masterKey, masterPod) 340 errWorkerPod := testK8sClient.Get(ctx, worker0Key, workerPod) 341 errMasterSvc := testK8sClient.Get(ctx, masterKey, masterSvc) 342 errWorkerSvc := testK8sClient.Get(ctx, worker0Key, workerSvc) 343 return errors.IsNotFound(errMasterPod) && errors.IsNotFound(errWorkerPod) && 344 errors.IsNotFound(errMasterSvc) && errors.IsNotFound(errWorkerSvc) 345 }, testutil.ConsistentDuration, testutil.Interval).Should(BeTrue()) 346 347 By("Checking if the PyTorchJob has a suspended condition") 348 Eventually(func() bool { 349 Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) 350 return created.Status.ReplicaStatuses[kubeflowv1.PyTorchJobReplicaTypeMaster].Active == 0 && 351 created.Status.ReplicaStatuses[kubeflowv1.PyTorchJobReplicaTypeWorker].Active == 0 && 352 created.Status.StartTime.Equal(startTimeBeforeSuspended) 353 }, testutil.Timeout, testutil.Interval).Should(BeTrue()) 354 Consistently(func() bool { 355 Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) 356 return created.Status.ReplicaStatuses[kubeflowv1.PyTorchJobReplicaTypeMaster].Active == 0 && 357 created.Status.ReplicaStatuses[kubeflowv1.PyTorchJobReplicaTypeWorker].Active == 0 && 358 created.Status.StartTime.Equal(startTimeBeforeSuspended) 359 }, testutil.ConsistentDuration, testutil.Interval).Should(BeTrue()) 360 Expect(created.Status.Conditions).Should(BeComparableTo([]kubeflowv1.JobCondition{ 361 { 362 Type: kubeflowv1.JobCreated, 363 Status: corev1.ConditionTrue, 364 Reason: commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobCreatedReason), 365 Message: fmt.Sprintf("PyTorchJob %s is created.", name), 366 }, 367 { 368 Type: kubeflowv1.JobRunning, 369 Status: corev1.ConditionFalse, 370 Reason: commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobSuspendedReason), 371 Message: fmt.Sprintf("PyTorchJob %s is suspended.", name), 372 }, 373 { 374 Type: kubeflowv1.JobSuspended, 375 Reason: commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobSuspendedReason), 376 Message: fmt.Sprintf("PyTorchJob %s is suspended.", name), 377 Status: corev1.ConditionTrue, 378 }, 379 }, testutil.IgnoreJobConditionsTimes)) 380 381 By("Unsuspending the PyTorchJob") 382 Eventually(func() error { 383 Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) 384 created.Spec.RunPolicy.Suspend = pointer.Bool(false) 385 return testK8sClient.Update(ctx, created) 386 }, testutil.Timeout, testutil.Interval).Should(Succeed()) 387 Eventually(func() *metav1.Time { 388 Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) 389 return created.Status.StartTime 390 }, testutil.Timeout, testutil.Interval).ShouldNot(BeNil()) 391 392 By("Check if the pods and services are created") 393 Eventually(func() error { 394 return testK8sClient.Get(ctx, masterKey, masterPod) 395 }, testutil.Timeout, testutil.Interval).Should(BeNil()) 396 Eventually(func() error { 397 return testK8sClient.Get(ctx, worker0Key, workerPod) 398 }, testutil.Timeout, testutil.Interval).Should(BeNil()) 399 Eventually(func() error { 400 return testK8sClient.Get(ctx, masterKey, masterSvc) 401 }, testutil.Timeout, testutil.Interval).Should(BeNil()) 402 Eventually(func() error { 403 return testK8sClient.Get(ctx, worker0Key, workerSvc) 404 }, testutil.Timeout, testutil.Interval).Should(BeNil()) 405 406 By("Updating Pod's condition with running") 407 Eventually(func() error { 408 Expect(testK8sClient.Get(ctx, masterKey, masterPod)).Should(Succeed()) 409 masterPod.Status.Phase = corev1.PodRunning 410 return testK8sClient.Status().Update(ctx, masterPod) 411 }, testutil.Timeout, testutil.Interval).Should(Succeed()) 412 Eventually(func() error { 413 Expect(testK8sClient.Get(ctx, worker0Key, workerPod)).Should(Succeed()) 414 workerPod.Status.Phase = corev1.PodRunning 415 return testK8sClient.Status().Update(ctx, workerPod) 416 }, testutil.Timeout, testutil.Interval).Should(Succeed()) 417 418 By("Checking if the PyTorchJob has resumed conditions") 419 Eventually(func() []kubeflowv1.JobCondition { 420 Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) 421 return created.Status.Conditions 422 }, testutil.Timeout, testutil.Interval).Should(BeComparableTo([]kubeflowv1.JobCondition{ 423 { 424 Type: kubeflowv1.JobCreated, 425 Status: corev1.ConditionTrue, 426 Reason: commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobCreatedReason), 427 Message: fmt.Sprintf("PyTorchJob %s is created.", name), 428 }, 429 { 430 Type: kubeflowv1.JobSuspended, 431 Reason: commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobResumedReason), 432 Message: fmt.Sprintf("PyTorchJob %s is resumed.", name), 433 Status: corev1.ConditionFalse, 434 }, 435 { 436 Type: kubeflowv1.JobRunning, 437 Status: corev1.ConditionTrue, 438 Reason: commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobRunningReason), 439 Message: fmt.Sprintf("PyTorchJob %s is running.", name), 440 }, 441 }, testutil.IgnoreJobConditionsTimes)) 442 443 By("Checking if the startTime is updated") 444 Expect(created.Status.StartTime).ShouldNot(Equal(startTimeBeforeSuspended)) 445 }) 446 }) 447 448 Context("When creating the elastic PyTorchJob", func() { 449 const name = "elastic-job" 450 var ( 451 ctx = context.Background() 452 ns *corev1.Namespace 453 job *kubeflowv1.PyTorchJob 454 jobKey types.NamespacedName 455 workerKey types.NamespacedName 456 backendC10D = kubeflowv1.BackendC10D 457 minReplicas = int32(1) 458 maxReplicas = int32(3) 459 maxRestarts = int32(3) 460 ) 461 BeforeEach(func() { 462 ns = &corev1.Namespace{ 463 ObjectMeta: metav1.ObjectMeta{ 464 GenerateName: "elastic-pytorch-test-", 465 }, 466 } 467 Expect(testK8sClient.Create(ctx, ns)) 468 469 job = newPyTorchJobForTest(name, ns.Name) 470 jobKey = client.ObjectKeyFromObject(job) 471 workerKey = types.NamespacedName{ 472 Name: fmt.Sprintf("%s-worker-0", name), 473 Namespace: ns.Name, 474 } 475 // Define the expected elastic policy. 476 job.Spec.ElasticPolicy = &kubeflowv1.ElasticPolicy{ 477 RDZVBackend: &backendC10D, 478 MinReplicas: &minReplicas, 479 MaxReplicas: &maxReplicas, 480 MaxRestarts: &maxRestarts, 481 Metrics: []autoscalingv2.MetricSpec{ 482 { 483 Type: autoscalingv2.ResourceMetricSourceType, 484 Resource: &autoscalingv2.ResourceMetricSource{ 485 Name: corev1.ResourceCPU, 486 Target: autoscalingv2.MetricTarget{ 487 Type: autoscalingv2.UtilizationMetricType, 488 AverageValue: resource.NewQuantity(80, resource.DecimalSI), 489 }, 490 }, 491 }, 492 }, 493 } 494 job.Spec.PyTorchReplicaSpecs = map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{ 495 kubeflowv1.PyTorchJobReplicaTypeWorker: { 496 Replicas: pointer.Int32(1), 497 Template: corev1.PodTemplateSpec{ 498 Spec: corev1.PodSpec{ 499 Containers: []corev1.Container{ 500 { 501 Image: "test-image", 502 Name: kubeflowv1.PyTorchJobDefaultContainerName, 503 Ports: []corev1.ContainerPort{ 504 { 505 Name: kubeflowv1.PyTorchJobDefaultPortName, 506 ContainerPort: expectedPort, 507 Protocol: corev1.ProtocolTCP, 508 }, 509 }, 510 }, 511 }, 512 }, 513 }, 514 }, 515 } 516 }) 517 AfterEach(func() { 518 Expect(testK8sClient.Delete(ctx, job)).Should(Succeed()) 519 Expect(testK8sClient.Delete(ctx, ns)).Should(Succeed()) 520 }) 521 // TODO(gaocegege): Test with more than 1 worker. 522 It("Should get the corresponding resources successfully", func() { 523 By("By creating a new PyTorchJob") 524 Expect(testK8sClient.Create(ctx, job)).Should(Succeed()) 525 526 created := &kubeflowv1.PyTorchJob{} 527 528 // We'll need to retry getting this newly created PyTorchJob, given that creation may not immediately happen. 529 Eventually(func() bool { 530 err := testK8sClient.Get(ctx, jobKey, created) 531 return err == nil 532 }, testutil.Timeout, testutil.Interval).Should(BeTrue()) 533 534 pod := &corev1.Pod{} 535 Eventually(func() bool { 536 err := testK8sClient.Get(ctx, workerKey, pod) 537 return err == nil 538 }, testutil.Timeout, testutil.Interval).Should(BeTrue()) 539 540 svc := &corev1.Service{} 541 Eventually(func() bool { 542 err := testK8sClient.Get(ctx, workerKey, svc) 543 return err == nil 544 }, testutil.Timeout, testutil.Interval).Should(BeTrue()) 545 546 hpa := &autoscalingv2.HorizontalPodAutoscaler{} 547 Eventually(func() error { 548 return testK8sClient.Get(ctx, jobKey, hpa) 549 }, testutil.Timeout, testutil.Interval).Should(BeNil()) 550 551 // Check pod port. 552 Expect(pod.Spec.Containers[0].Ports).To(ContainElement(corev1.ContainerPort{ 553 Name: kubeflowv1.PyTorchJobDefaultPortName, 554 ContainerPort: expectedPort, 555 Protocol: corev1.ProtocolTCP})) 556 // Check environment variables. 557 Expect(pod.Spec.Containers[0].Env).To(ContainElements(corev1.EnvVar{ 558 Name: EnvRDZVBackend, 559 Value: string(backendC10D), 560 }, corev1.EnvVar{ 561 Name: EnvNnodes, 562 Value: fmt.Sprintf("%d:%d", minReplicas, maxReplicas), 563 }, corev1.EnvVar{ 564 Name: EnvRDZVEndpoint, 565 Value: fmt.Sprintf("%s:%d", svc.Name, expectedPort), 566 }, corev1.EnvVar{ 567 Name: EnvMaxRestarts, 568 Value: fmt.Sprintf("%d", maxRestarts), 569 })) 570 Expect(svc.Spec.Ports[0].Port).To(Equal(expectedPort)) 571 // Check owner references. 572 trueVal := true 573 Expect(pod.OwnerReferences).To(ContainElement(metav1.OwnerReference{ 574 APIVersion: kubeflowv1.SchemeGroupVersion.String(), 575 Kind: kubeflowv1.PyTorchJobKind, 576 Name: name, 577 UID: created.UID, 578 Controller: &trueVal, 579 BlockOwnerDeletion: &trueVal, 580 })) 581 Expect(svc.OwnerReferences).To(ContainElement(metav1.OwnerReference{ 582 APIVersion: kubeflowv1.SchemeGroupVersion.String(), 583 Kind: kubeflowv1.PyTorchJobKind, 584 Name: name, 585 UID: created.UID, 586 Controller: &trueVal, 587 BlockOwnerDeletion: &trueVal, 588 })) 589 590 // Test job status. 591 Eventually(func() error { 592 Expect(testK8sClient.Get(ctx, workerKey, pod)).Should(Succeed()) 593 pod.Status.Phase = corev1.PodSucceeded 594 return testK8sClient.Status().Update(ctx, pod) 595 }, testutil.Timeout, testutil.Interval).Should(Succeed()) 596 Eventually(func() bool { 597 err := testK8sClient.Get(ctx, jobKey, created) 598 if err != nil { 599 return false 600 } 601 return created.Status.ReplicaStatuses != nil && created.Status. 602 ReplicaStatuses[kubeflowv1.PyTorchJobReplicaTypeWorker].Succeeded == 1 603 }, testutil.Timeout, testutil.Interval).Should(BeTrue()) 604 // Check if the job is succeeded. 605 cond := getCondition(created.Status, kubeflowv1.JobSucceeded) 606 Expect(cond.Status).To(Equal(corev1.ConditionTrue)) 607 }) 608 It("Should delete HPA once the PyTorchJob is suspended", func() { 609 By("By creating a new PyTorchJob") 610 Expect(testK8sClient.Create(ctx, job)).Should(Succeed()) 611 612 created := &kubeflowv1.PyTorchJob{} 613 hpa := &autoscalingv2.HorizontalPodAutoscaler{} 614 615 By("Checking if the PyTorchJob and HPA are created") 616 Eventually(func() error { 617 return testK8sClient.Get(ctx, jobKey, created) 618 }, testutil.Timeout, testutil.Interval).Should(BeNil()) 619 Eventually(func() error { 620 return testK8sClient.Get(ctx, jobKey, hpa) 621 }, testutil.Timeout, testutil.Interval).Should(BeNil()) 622 623 By("Suspending PyTorchJob") 624 Eventually(func() error { 625 Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed()) 626 created.Spec.RunPolicy.Suspend = pointer.Bool(true) 627 return testK8sClient.Update(ctx, created) 628 }, testutil.Timeout, testutil.Interval).Should(Succeed()) 629 630 By("Checking if the HPA is deleted") 631 Eventually(func() bool { 632 return errors.IsNotFound(testK8sClient.Get(ctx, jobKey, hpa)) 633 }, testutil.Timeout, testutil.Interval).Should(BeTrue()) 634 }) 635 }) 636 }) 637 638 func newPyTorchJobForTest(name, namespace string) *kubeflowv1.PyTorchJob { 639 return &kubeflowv1.PyTorchJob{ 640 ObjectMeta: metav1.ObjectMeta{ 641 Name: name, 642 Namespace: namespace, 643 }, 644 } 645 } 646 647 // getCondition returns the condition with the provided type. 648 func getCondition(status kubeflowv1.JobStatus, condType kubeflowv1.JobConditionType) *kubeflowv1.JobCondition { 649 for _, condition := range status.Conditions { 650 if condition.Type == condType { 651 return &condition 652 } 653 } 654 return nil 655 }