github.com/kubeflow/training-operator@v1.7.0/pkg/controller.v1/pytorch/pytorchjob_controller_test.go (about)

     1  // Copyright 2021 The Kubeflow Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package pytorch
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  
    21  	. "github.com/onsi/ginkgo/v2"
    22  	. "github.com/onsi/gomega"
    23  	autoscalingv2 "k8s.io/api/autoscaling/v2"
    24  	corev1 "k8s.io/api/core/v1"
    25  	"k8s.io/apimachinery/pkg/api/errors"
    26  	"k8s.io/apimachinery/pkg/api/resource"
    27  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    28  	"k8s.io/apimachinery/pkg/types"
    29  	"k8s.io/utils/pointer"
    30  	"sigs.k8s.io/controller-runtime/pkg/client"
    31  
    32  	kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
    33  	commonutil "github.com/kubeflow/training-operator/pkg/util"
    34  	"github.com/kubeflow/training-operator/pkg/util/testutil"
    35  )
    36  
    37  var _ = Describe("PyTorchJob controller", func() {
    38  	// Define utility constants for object names.
    39  	const (
    40  		expectedPort = int32(8080)
    41  	)
    42  
    43  	Context("When creating the PyTorchJob", func() {
    44  		const name = "test-job"
    45  		var (
    46  			ns         *corev1.Namespace
    47  			job        *kubeflowv1.PyTorchJob
    48  			jobKey     types.NamespacedName
    49  			masterKey  types.NamespacedName
    50  			worker0Key types.NamespacedName
    51  			ctx        = context.Background()
    52  		)
    53  		BeforeEach(func() {
    54  			ns = &corev1.Namespace{
    55  				ObjectMeta: metav1.ObjectMeta{
    56  					GenerateName: "pytorch-test-",
    57  				},
    58  			}
    59  			Expect(testK8sClient.Create(ctx, ns)).Should(Succeed())
    60  
    61  			job = newPyTorchJobForTest(name, ns.Name)
    62  			jobKey = client.ObjectKeyFromObject(job)
    63  			masterKey = types.NamespacedName{
    64  				Name:      fmt.Sprintf("%s-master-0", name),
    65  				Namespace: ns.Name,
    66  			}
    67  			worker0Key = types.NamespacedName{
    68  				Name:      fmt.Sprintf("%s-worker-0", name),
    69  				Namespace: ns.Name,
    70  			}
    71  			job.Spec.NprocPerNode = nil
    72  			job.Spec.PyTorchReplicaSpecs = map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{
    73  				kubeflowv1.PyTorchJobReplicaTypeMaster: {
    74  					Replicas: pointer.Int32(1),
    75  					Template: corev1.PodTemplateSpec{
    76  						Spec: corev1.PodSpec{
    77  							Containers: []corev1.Container{
    78  								{
    79  									Image: "test-image",
    80  									Name:  kubeflowv1.PyTorchJobDefaultContainerName,
    81  									Ports: []corev1.ContainerPort{
    82  										{
    83  											Name:          kubeflowv1.PyTorchJobDefaultPortName,
    84  											ContainerPort: expectedPort,
    85  											Protocol:      corev1.ProtocolTCP,
    86  										},
    87  									},
    88  								},
    89  							},
    90  						},
    91  					},
    92  				},
    93  				kubeflowv1.PyTorchJobReplicaTypeWorker: {
    94  					Replicas: pointer.Int32(2),
    95  					Template: corev1.PodTemplateSpec{
    96  						Spec: corev1.PodSpec{
    97  							Containers: []corev1.Container{
    98  								{
    99  									Image: "test-image",
   100  									Name:  kubeflowv1.PyTorchJobDefaultContainerName,
   101  									Ports: []corev1.ContainerPort{
   102  										{
   103  											Name:          kubeflowv1.PyTorchJobDefaultPortName,
   104  											ContainerPort: expectedPort,
   105  											Protocol:      corev1.ProtocolTCP,
   106  										},
   107  									},
   108  								},
   109  							},
   110  						},
   111  					},
   112  				},
   113  			}
   114  		})
   115  		AfterEach(func() {
   116  			Expect(testK8sClient.Delete(ctx, job)).Should(Succeed())
   117  			Expect(testK8sClient.Delete(ctx, ns)).Should(Succeed())
   118  		})
   119  		It("Should get the corresponding resources successfully", func() {
   120  			By("By creating a new PyTorchJob")
   121  			Expect(testK8sClient.Create(ctx, job)).Should(Succeed())
   122  
   123  			created := &kubeflowv1.PyTorchJob{}
   124  
   125  			// We'll need to retry getting this newly created PyTorchJob, given that creation may not immediately happen.
   126  			Eventually(func() bool {
   127  				err := testK8sClient.Get(ctx, jobKey, created)
   128  				return err == nil
   129  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   130  
   131  			masterPod := &corev1.Pod{}
   132  			Eventually(func() bool {
   133  				err := testK8sClient.Get(ctx, masterKey, masterPod)
   134  				return err == nil
   135  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   136  
   137  			masterSvc := &corev1.Service{}
   138  			Eventually(func() bool {
   139  				err := testK8sClient.Get(ctx, masterKey, masterSvc)
   140  				return err == nil
   141  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   142  
   143  			// Check the pod port.
   144  			Expect(masterPod.Spec.Containers[0].Ports).To(ContainElement(corev1.ContainerPort{
   145  				Name:          kubeflowv1.PyTorchJobDefaultPortName,
   146  				ContainerPort: expectedPort,
   147  				Protocol:      corev1.ProtocolTCP}))
   148  			// Check env variable
   149  			Expect(masterPod.Spec.Containers[0].Env).To(ContainElements(corev1.EnvVar{
   150  				Name:  EnvMasterPort,
   151  				Value: fmt.Sprintf("%d", masterSvc.Spec.Ports[0].Port),
   152  			}, corev1.EnvVar{
   153  				Name:  EnvMasterAddr,
   154  				Value: masterSvc.Name,
   155  			}, corev1.EnvVar{
   156  				Name:  EnvNprocPerNode,
   157  				Value: kubeflowv1.DefaultNprocPerNode,
   158  			}))
   159  			// Check service port.
   160  			Expect(masterSvc.Spec.Ports[0].Port).To(Equal(expectedPort))
   161  			// Check owner reference.
   162  			trueVal := true
   163  			Expect(masterPod.OwnerReferences).To(ContainElement(metav1.OwnerReference{
   164  				APIVersion:         kubeflowv1.SchemeGroupVersion.String(),
   165  				Kind:               kubeflowv1.PyTorchJobKind,
   166  				Name:               name,
   167  				UID:                created.UID,
   168  				Controller:         &trueVal,
   169  				BlockOwnerDeletion: &trueVal,
   170  			}))
   171  			Expect(masterSvc.OwnerReferences).To(ContainElement(metav1.OwnerReference{
   172  				APIVersion:         kubeflowv1.SchemeGroupVersion.String(),
   173  				Kind:               kubeflowv1.PyTorchJobKind,
   174  				Name:               name,
   175  				UID:                created.UID,
   176  				Controller:         &trueVal,
   177  				BlockOwnerDeletion: &trueVal,
   178  			}))
   179  
   180  			// Test job status.
   181  			Eventually(func() error {
   182  				Expect(testK8sClient.Get(ctx, masterKey, masterPod)).Should(Succeed())
   183  				masterPod.Status.Phase = corev1.PodSucceeded
   184  				return testK8sClient.Status().Update(ctx, masterPod)
   185  			}, testutil.Timeout, testutil.Interval).Should(Succeed())
   186  			Eventually(func() bool {
   187  				err := testK8sClient.Get(ctx, jobKey, created)
   188  				if err != nil {
   189  					return false
   190  				}
   191  				return created.Status.ReplicaStatuses != nil && created.Status.
   192  					ReplicaStatuses[kubeflowv1.PyTorchJobReplicaTypeMaster].Succeeded == 1
   193  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   194  			// Check if the job is succeeded.
   195  			cond := getCondition(created.Status, kubeflowv1.JobSucceeded)
   196  			Expect(cond.Status).To(Equal(corev1.ConditionTrue))
   197  		})
   198  
   199  		It("Shouldn't create resources if PyTorchJob is suspended", func() {
   200  			By("By creating a new PyTorchJob with suspend=true")
   201  			job.Spec.RunPolicy.Suspend = pointer.Bool(true)
   202  			job.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas = pointer.Int32(1)
   203  			Expect(testK8sClient.Create(ctx, job)).Should(Succeed())
   204  
   205  			created := &kubeflowv1.PyTorchJob{}
   206  			masterPod := &corev1.Pod{}
   207  			workerPod := &corev1.Pod{}
   208  			masterSvc := &corev1.Service{}
   209  			workerSvc := &corev1.Service{}
   210  
   211  			By("Checking created PyTorchJob")
   212  			Eventually(func() bool {
   213  				err := testK8sClient.Get(ctx, jobKey, created)
   214  				return err == nil
   215  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   216  			By("Checking created PyTorchJob has a nil startTime")
   217  			Consistently(func() *metav1.Time {
   218  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   219  				return created.Status.StartTime
   220  			}, testutil.ConsistentDuration, testutil.Interval).Should(BeNil())
   221  
   222  			By("Checking if the pods and services aren't created")
   223  			Consistently(func() bool {
   224  				errMasterPod := testK8sClient.Get(ctx, masterKey, masterPod)
   225  				errWorkerPod := testK8sClient.Get(ctx, worker0Key, workerPod)
   226  				errMasterSvc := testK8sClient.Get(ctx, masterKey, masterSvc)
   227  				errWorkerSvc := testK8sClient.Get(ctx, worker0Key, workerSvc)
   228  				return errors.IsNotFound(errMasterPod) && errors.IsNotFound(errWorkerPod) &&
   229  					errors.IsNotFound(errMasterSvc) && errors.IsNotFound(errWorkerSvc)
   230  			}, testutil.ConsistentDuration, testutil.Interval).Should(BeTrue())
   231  
   232  			By("Checking if the PyTorchJob has suspended condition")
   233  			Eventually(func() []kubeflowv1.JobCondition {
   234  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   235  				return created.Status.Conditions
   236  			}, testutil.ConsistentDuration, testutil.Interval).Should(BeComparableTo([]kubeflowv1.JobCondition{
   237  				{
   238  					Type:    kubeflowv1.JobCreated,
   239  					Status:  corev1.ConditionTrue,
   240  					Reason:  commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobCreatedReason),
   241  					Message: fmt.Sprintf("PyTorchJob %s is created.", name),
   242  				},
   243  				{
   244  					Type:    kubeflowv1.JobSuspended,
   245  					Status:  corev1.ConditionTrue,
   246  					Reason:  commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobSuspendedReason),
   247  					Message: fmt.Sprintf("PyTorchJob %s is suspended.", name),
   248  				},
   249  			}, testutil.IgnoreJobConditionsTimes))
   250  		})
   251  
   252  		It("Should delete resources after PyTorchJob is suspended; Should resume PyTorchJob after PyTorchJob is unsuspended", func() {
   253  			By("By creating a new PyTorchJob")
   254  			job.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas = pointer.Int32(1)
   255  			Expect(testK8sClient.Create(ctx, job)).Should(Succeed())
   256  
   257  			created := &kubeflowv1.PyTorchJob{}
   258  			masterPod := &corev1.Pod{}
   259  			workerPod := &corev1.Pod{}
   260  			masterSvc := &corev1.Service{}
   261  			workerSvc := &corev1.Service{}
   262  
   263  			// We'll need to retry getting this newly created PyTorchJob, given that creation may not immediately happen.
   264  			By("Checking created PyTorchJob")
   265  			Eventually(func() bool {
   266  				err := testK8sClient.Get(ctx, jobKey, created)
   267  				return err == nil
   268  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   269  
   270  			var startTimeBeforeSuspended *metav1.Time
   271  			Eventually(func() *metav1.Time {
   272  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   273  				startTimeBeforeSuspended = created.Status.StartTime
   274  				return startTimeBeforeSuspended
   275  			}, testutil.Timeout, testutil.Interval).ShouldNot(BeNil())
   276  
   277  			By("Checking the created pods and services")
   278  			Eventually(func() bool {
   279  				errMaster := testK8sClient.Get(ctx, masterKey, masterPod)
   280  				errWorker := testK8sClient.Get(ctx, worker0Key, workerPod)
   281  				return errMaster == nil && errWorker == nil
   282  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   283  			Eventually(func() bool {
   284  				errMaster := testK8sClient.Get(ctx, masterKey, masterSvc)
   285  				errWorker := testK8sClient.Get(ctx, worker0Key, workerSvc)
   286  				return errMaster == nil && errWorker == nil
   287  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   288  
   289  			By("Updating the pod's phase with Running")
   290  			Eventually(func() error {
   291  				Expect(testK8sClient.Get(ctx, masterKey, masterPod)).Should(Succeed())
   292  				masterPod.Status.Phase = corev1.PodRunning
   293  				return testK8sClient.Status().Update(ctx, masterPod)
   294  			}, testutil.Timeout, testutil.Interval).Should(Succeed())
   295  			Eventually(func() error {
   296  				Expect(testK8sClient.Get(ctx, worker0Key, workerPod)).Should(Succeed())
   297  				workerPod.Status.Phase = corev1.PodRunning
   298  				return testK8sClient.Status().Update(ctx, workerPod)
   299  			}, testutil.Timeout, testutil.Interval).Should(Succeed())
   300  
   301  			By("Checking the PyTorchJob's condition")
   302  			Eventually(func() []kubeflowv1.JobCondition {
   303  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   304  				return created.Status.Conditions
   305  			}, testutil.Timeout, testutil.Interval).Should(BeComparableTo([]kubeflowv1.JobCondition{
   306  				{
   307  					Type:    kubeflowv1.JobCreated,
   308  					Status:  corev1.ConditionTrue,
   309  					Reason:  commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobCreatedReason),
   310  					Message: fmt.Sprintf("PyTorchJob %s is created.", name),
   311  				},
   312  				{
   313  					Type:    kubeflowv1.JobRunning,
   314  					Status:  corev1.ConditionTrue,
   315  					Reason:  commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobRunningReason),
   316  					Message: fmt.Sprintf("PyTorchJob %s is running.", name),
   317  				},
   318  			}, testutil.IgnoreJobConditionsTimes))
   319  
   320  			By("Updating the PyTorchJob with suspend=true")
   321  			Eventually(func() error {
   322  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   323  				created.Spec.RunPolicy.Suspend = pointer.Bool(true)
   324  				return testK8sClient.Update(ctx, created)
   325  			}, testutil.Timeout, testutil.Interval).Should(Succeed())
   326  
   327  			By("Checking if the pods and services are removed")
   328  			Eventually(func() bool {
   329  				errMaster := testK8sClient.Get(ctx, masterKey, masterPod)
   330  				errWorker := testK8sClient.Get(ctx, worker0Key, workerPod)
   331  				return errors.IsNotFound(errMaster) && errors.IsNotFound(errWorker)
   332  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   333  			Eventually(func() bool {
   334  				errMaster := testK8sClient.Get(ctx, masterKey, masterSvc)
   335  				errWorker := testK8sClient.Get(ctx, worker0Key, workerSvc)
   336  				return errors.IsNotFound(errMaster) && errors.IsNotFound(errWorker)
   337  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   338  			Consistently(func() bool {
   339  				errMasterPod := testK8sClient.Get(ctx, masterKey, masterPod)
   340  				errWorkerPod := testK8sClient.Get(ctx, worker0Key, workerPod)
   341  				errMasterSvc := testK8sClient.Get(ctx, masterKey, masterSvc)
   342  				errWorkerSvc := testK8sClient.Get(ctx, worker0Key, workerSvc)
   343  				return errors.IsNotFound(errMasterPod) && errors.IsNotFound(errWorkerPod) &&
   344  					errors.IsNotFound(errMasterSvc) && errors.IsNotFound(errWorkerSvc)
   345  			}, testutil.ConsistentDuration, testutil.Interval).Should(BeTrue())
   346  
   347  			By("Checking if the PyTorchJob has a suspended condition")
   348  			Eventually(func() bool {
   349  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   350  				return created.Status.ReplicaStatuses[kubeflowv1.PyTorchJobReplicaTypeMaster].Active == 0 &&
   351  					created.Status.ReplicaStatuses[kubeflowv1.PyTorchJobReplicaTypeWorker].Active == 0 &&
   352  					created.Status.StartTime.Equal(startTimeBeforeSuspended)
   353  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   354  			Consistently(func() bool {
   355  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   356  				return created.Status.ReplicaStatuses[kubeflowv1.PyTorchJobReplicaTypeMaster].Active == 0 &&
   357  					created.Status.ReplicaStatuses[kubeflowv1.PyTorchJobReplicaTypeWorker].Active == 0 &&
   358  					created.Status.StartTime.Equal(startTimeBeforeSuspended)
   359  			}, testutil.ConsistentDuration, testutil.Interval).Should(BeTrue())
   360  			Expect(created.Status.Conditions).Should(BeComparableTo([]kubeflowv1.JobCondition{
   361  				{
   362  					Type:    kubeflowv1.JobCreated,
   363  					Status:  corev1.ConditionTrue,
   364  					Reason:  commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobCreatedReason),
   365  					Message: fmt.Sprintf("PyTorchJob %s is created.", name),
   366  				},
   367  				{
   368  					Type:    kubeflowv1.JobRunning,
   369  					Status:  corev1.ConditionFalse,
   370  					Reason:  commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobSuspendedReason),
   371  					Message: fmt.Sprintf("PyTorchJob %s is suspended.", name),
   372  				},
   373  				{
   374  					Type:    kubeflowv1.JobSuspended,
   375  					Reason:  commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobSuspendedReason),
   376  					Message: fmt.Sprintf("PyTorchJob %s is suspended.", name),
   377  					Status:  corev1.ConditionTrue,
   378  				},
   379  			}, testutil.IgnoreJobConditionsTimes))
   380  
   381  			By("Unsuspending the PyTorchJob")
   382  			Eventually(func() error {
   383  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   384  				created.Spec.RunPolicy.Suspend = pointer.Bool(false)
   385  				return testK8sClient.Update(ctx, created)
   386  			}, testutil.Timeout, testutil.Interval).Should(Succeed())
   387  			Eventually(func() *metav1.Time {
   388  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   389  				return created.Status.StartTime
   390  			}, testutil.Timeout, testutil.Interval).ShouldNot(BeNil())
   391  
   392  			By("Check if the pods and services are created")
   393  			Eventually(func() error {
   394  				return testK8sClient.Get(ctx, masterKey, masterPod)
   395  			}, testutil.Timeout, testutil.Interval).Should(BeNil())
   396  			Eventually(func() error {
   397  				return testK8sClient.Get(ctx, worker0Key, workerPod)
   398  			}, testutil.Timeout, testutil.Interval).Should(BeNil())
   399  			Eventually(func() error {
   400  				return testK8sClient.Get(ctx, masterKey, masterSvc)
   401  			}, testutil.Timeout, testutil.Interval).Should(BeNil())
   402  			Eventually(func() error {
   403  				return testK8sClient.Get(ctx, worker0Key, workerSvc)
   404  			}, testutil.Timeout, testutil.Interval).Should(BeNil())
   405  
   406  			By("Updating Pod's condition with running")
   407  			Eventually(func() error {
   408  				Expect(testK8sClient.Get(ctx, masterKey, masterPod)).Should(Succeed())
   409  				masterPod.Status.Phase = corev1.PodRunning
   410  				return testK8sClient.Status().Update(ctx, masterPod)
   411  			}, testutil.Timeout, testutil.Interval).Should(Succeed())
   412  			Eventually(func() error {
   413  				Expect(testK8sClient.Get(ctx, worker0Key, workerPod)).Should(Succeed())
   414  				workerPod.Status.Phase = corev1.PodRunning
   415  				return testK8sClient.Status().Update(ctx, workerPod)
   416  			}, testutil.Timeout, testutil.Interval).Should(Succeed())
   417  
   418  			By("Checking if the PyTorchJob has resumed conditions")
   419  			Eventually(func() []kubeflowv1.JobCondition {
   420  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   421  				return created.Status.Conditions
   422  			}, testutil.Timeout, testutil.Interval).Should(BeComparableTo([]kubeflowv1.JobCondition{
   423  				{
   424  					Type:    kubeflowv1.JobCreated,
   425  					Status:  corev1.ConditionTrue,
   426  					Reason:  commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobCreatedReason),
   427  					Message: fmt.Sprintf("PyTorchJob %s is created.", name),
   428  				},
   429  				{
   430  					Type:    kubeflowv1.JobSuspended,
   431  					Reason:  commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobResumedReason),
   432  					Message: fmt.Sprintf("PyTorchJob %s is resumed.", name),
   433  					Status:  corev1.ConditionFalse,
   434  				},
   435  				{
   436  					Type:    kubeflowv1.JobRunning,
   437  					Status:  corev1.ConditionTrue,
   438  					Reason:  commonutil.NewReason(kubeflowv1.PyTorchJobKind, commonutil.JobRunningReason),
   439  					Message: fmt.Sprintf("PyTorchJob %s is running.", name),
   440  				},
   441  			}, testutil.IgnoreJobConditionsTimes))
   442  
   443  			By("Checking if the startTime is updated")
   444  			Expect(created.Status.StartTime).ShouldNot(Equal(startTimeBeforeSuspended))
   445  		})
   446  	})
   447  
   448  	Context("When creating the elastic PyTorchJob", func() {
   449  		const name = "elastic-job"
   450  		var (
   451  			ctx         = context.Background()
   452  			ns          *corev1.Namespace
   453  			job         *kubeflowv1.PyTorchJob
   454  			jobKey      types.NamespacedName
   455  			workerKey   types.NamespacedName
   456  			backendC10D = kubeflowv1.BackendC10D
   457  			minReplicas = int32(1)
   458  			maxReplicas = int32(3)
   459  			maxRestarts = int32(3)
   460  		)
   461  		BeforeEach(func() {
   462  			ns = &corev1.Namespace{
   463  				ObjectMeta: metav1.ObjectMeta{
   464  					GenerateName: "elastic-pytorch-test-",
   465  				},
   466  			}
   467  			Expect(testK8sClient.Create(ctx, ns))
   468  
   469  			job = newPyTorchJobForTest(name, ns.Name)
   470  			jobKey = client.ObjectKeyFromObject(job)
   471  			workerKey = types.NamespacedName{
   472  				Name:      fmt.Sprintf("%s-worker-0", name),
   473  				Namespace: ns.Name,
   474  			}
   475  			// Define the expected elastic policy.
   476  			job.Spec.ElasticPolicy = &kubeflowv1.ElasticPolicy{
   477  				RDZVBackend: &backendC10D,
   478  				MinReplicas: &minReplicas,
   479  				MaxReplicas: &maxReplicas,
   480  				MaxRestarts: &maxRestarts,
   481  				Metrics: []autoscalingv2.MetricSpec{
   482  					{
   483  						Type: autoscalingv2.ResourceMetricSourceType,
   484  						Resource: &autoscalingv2.ResourceMetricSource{
   485  							Name: corev1.ResourceCPU,
   486  							Target: autoscalingv2.MetricTarget{
   487  								Type:         autoscalingv2.UtilizationMetricType,
   488  								AverageValue: resource.NewQuantity(80, resource.DecimalSI),
   489  							},
   490  						},
   491  					},
   492  				},
   493  			}
   494  			job.Spec.PyTorchReplicaSpecs = map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{
   495  				kubeflowv1.PyTorchJobReplicaTypeWorker: {
   496  					Replicas: pointer.Int32(1),
   497  					Template: corev1.PodTemplateSpec{
   498  						Spec: corev1.PodSpec{
   499  							Containers: []corev1.Container{
   500  								{
   501  									Image: "test-image",
   502  									Name:  kubeflowv1.PyTorchJobDefaultContainerName,
   503  									Ports: []corev1.ContainerPort{
   504  										{
   505  											Name:          kubeflowv1.PyTorchJobDefaultPortName,
   506  											ContainerPort: expectedPort,
   507  											Protocol:      corev1.ProtocolTCP,
   508  										},
   509  									},
   510  								},
   511  							},
   512  						},
   513  					},
   514  				},
   515  			}
   516  		})
   517  		AfterEach(func() {
   518  			Expect(testK8sClient.Delete(ctx, job)).Should(Succeed())
   519  			Expect(testK8sClient.Delete(ctx, ns)).Should(Succeed())
   520  		})
   521  		// TODO(gaocegege): Test with more than 1 worker.
   522  		It("Should get the corresponding resources successfully", func() {
   523  			By("By creating a new PyTorchJob")
   524  			Expect(testK8sClient.Create(ctx, job)).Should(Succeed())
   525  
   526  			created := &kubeflowv1.PyTorchJob{}
   527  
   528  			// We'll need to retry getting this newly created PyTorchJob, given that creation may not immediately happen.
   529  			Eventually(func() bool {
   530  				err := testK8sClient.Get(ctx, jobKey, created)
   531  				return err == nil
   532  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   533  
   534  			pod := &corev1.Pod{}
   535  			Eventually(func() bool {
   536  				err := testK8sClient.Get(ctx, workerKey, pod)
   537  				return err == nil
   538  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   539  
   540  			svc := &corev1.Service{}
   541  			Eventually(func() bool {
   542  				err := testK8sClient.Get(ctx, workerKey, svc)
   543  				return err == nil
   544  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   545  
   546  			hpa := &autoscalingv2.HorizontalPodAutoscaler{}
   547  			Eventually(func() error {
   548  				return testK8sClient.Get(ctx, jobKey, hpa)
   549  			}, testutil.Timeout, testutil.Interval).Should(BeNil())
   550  
   551  			// Check pod port.
   552  			Expect(pod.Spec.Containers[0].Ports).To(ContainElement(corev1.ContainerPort{
   553  				Name:          kubeflowv1.PyTorchJobDefaultPortName,
   554  				ContainerPort: expectedPort,
   555  				Protocol:      corev1.ProtocolTCP}))
   556  			// Check environment variables.
   557  			Expect(pod.Spec.Containers[0].Env).To(ContainElements(corev1.EnvVar{
   558  				Name:  EnvRDZVBackend,
   559  				Value: string(backendC10D),
   560  			}, corev1.EnvVar{
   561  				Name:  EnvNnodes,
   562  				Value: fmt.Sprintf("%d:%d", minReplicas, maxReplicas),
   563  			}, corev1.EnvVar{
   564  				Name:  EnvRDZVEndpoint,
   565  				Value: fmt.Sprintf("%s:%d", svc.Name, expectedPort),
   566  			}, corev1.EnvVar{
   567  				Name:  EnvMaxRestarts,
   568  				Value: fmt.Sprintf("%d", maxRestarts),
   569  			}))
   570  			Expect(svc.Spec.Ports[0].Port).To(Equal(expectedPort))
   571  			// Check owner references.
   572  			trueVal := true
   573  			Expect(pod.OwnerReferences).To(ContainElement(metav1.OwnerReference{
   574  				APIVersion:         kubeflowv1.SchemeGroupVersion.String(),
   575  				Kind:               kubeflowv1.PyTorchJobKind,
   576  				Name:               name,
   577  				UID:                created.UID,
   578  				Controller:         &trueVal,
   579  				BlockOwnerDeletion: &trueVal,
   580  			}))
   581  			Expect(svc.OwnerReferences).To(ContainElement(metav1.OwnerReference{
   582  				APIVersion:         kubeflowv1.SchemeGroupVersion.String(),
   583  				Kind:               kubeflowv1.PyTorchJobKind,
   584  				Name:               name,
   585  				UID:                created.UID,
   586  				Controller:         &trueVal,
   587  				BlockOwnerDeletion: &trueVal,
   588  			}))
   589  
   590  			// Test job status.
   591  			Eventually(func() error {
   592  				Expect(testK8sClient.Get(ctx, workerKey, pod)).Should(Succeed())
   593  				pod.Status.Phase = corev1.PodSucceeded
   594  				return testK8sClient.Status().Update(ctx, pod)
   595  			}, testutil.Timeout, testutil.Interval).Should(Succeed())
   596  			Eventually(func() bool {
   597  				err := testK8sClient.Get(ctx, jobKey, created)
   598  				if err != nil {
   599  					return false
   600  				}
   601  				return created.Status.ReplicaStatuses != nil && created.Status.
   602  					ReplicaStatuses[kubeflowv1.PyTorchJobReplicaTypeWorker].Succeeded == 1
   603  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   604  			// Check if the job is succeeded.
   605  			cond := getCondition(created.Status, kubeflowv1.JobSucceeded)
   606  			Expect(cond.Status).To(Equal(corev1.ConditionTrue))
   607  		})
   608  		It("Should delete HPA once the PyTorchJob is suspended", func() {
   609  			By("By creating a new PyTorchJob")
   610  			Expect(testK8sClient.Create(ctx, job)).Should(Succeed())
   611  
   612  			created := &kubeflowv1.PyTorchJob{}
   613  			hpa := &autoscalingv2.HorizontalPodAutoscaler{}
   614  
   615  			By("Checking if the PyTorchJob and HPA are created")
   616  			Eventually(func() error {
   617  				return testK8sClient.Get(ctx, jobKey, created)
   618  			}, testutil.Timeout, testutil.Interval).Should(BeNil())
   619  			Eventually(func() error {
   620  				return testK8sClient.Get(ctx, jobKey, hpa)
   621  			}, testutil.Timeout, testutil.Interval).Should(BeNil())
   622  
   623  			By("Suspending PyTorchJob")
   624  			Eventually(func() error {
   625  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   626  				created.Spec.RunPolicy.Suspend = pointer.Bool(true)
   627  				return testK8sClient.Update(ctx, created)
   628  			}, testutil.Timeout, testutil.Interval).Should(Succeed())
   629  
   630  			By("Checking if the HPA is deleted")
   631  			Eventually(func() bool {
   632  				return errors.IsNotFound(testK8sClient.Get(ctx, jobKey, hpa))
   633  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   634  		})
   635  	})
   636  })
   637  
   638  func newPyTorchJobForTest(name, namespace string) *kubeflowv1.PyTorchJob {
   639  	return &kubeflowv1.PyTorchJob{
   640  		ObjectMeta: metav1.ObjectMeta{
   641  			Name:      name,
   642  			Namespace: namespace,
   643  		},
   644  	}
   645  }
   646  
   647  // getCondition returns the condition with the provided type.
   648  func getCondition(status kubeflowv1.JobStatus, condType kubeflowv1.JobConditionType) *kubeflowv1.JobCondition {
   649  	for _, condition := range status.Conditions {
   650  		if condition.Type == condType {
   651  			return &condition
   652  		}
   653  	}
   654  	return nil
   655  }