github.com/kubeflow/training-operator@v1.7.0/pkg/controller.v1/tensorflow/pod_test.go (about)

     1  // Copyright 2021 The Kubeflow Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package tensorflow
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"os"
    21  
    22  	. "github.com/onsi/ginkgo/v2"
    23  	. "github.com/onsi/gomega"
    24  	corev1 "k8s.io/api/core/v1"
    25  	"k8s.io/apimachinery/pkg/api/errors"
    26  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    27  	"k8s.io/apimachinery/pkg/types"
    28  	"k8s.io/apimachinery/pkg/util/uuid"
    29  	"sigs.k8s.io/controller-runtime/pkg/client"
    30  
    31  	kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
    32  	tftestutil "github.com/kubeflow/training-operator/pkg/controller.v1/tensorflow/testutil"
    33  	"github.com/kubeflow/training-operator/pkg/core"
    34  	"github.com/kubeflow/training-operator/pkg/util/testutil"
    35  )
    36  
    37  var _ = Describe("TFJob controller", func() {
    38  	Context("Test ClusterSpec", func() {
    39  		It("should generate desired cluster spec", func() {
    40  			type tc struct {
    41  				tfJob               *kubeflowv1.TFJob
    42  				rt                  string
    43  				index               string
    44  				customClusterDomain string
    45  				expectedClusterSpec string
    46  			}
    47  			testCase := []tc{
    48  				{
    49  					tfJob:               tftestutil.NewTFJobWithNamespace(1, 0, "ns0"),
    50  					rt:                  "worker",
    51  					index:               "0",
    52  					customClusterDomain: "",
    53  					expectedClusterSpec: "",
    54  				},
    55  				{
    56  					tfJob:               tftestutil.NewTFJobWithNamespace(1, 0, "ns1"),
    57  					rt:                  "worker",
    58  					index:               "0",
    59  					customClusterDomain: "tf.training.com",
    60  					expectedClusterSpec: "",
    61  				},
    62  				{
    63  					tfJob:               tftestutil.NewTFJobWithNamespace(1, 1, "ns2"),
    64  					rt:                  "worker",
    65  					index:               "0",
    66  					customClusterDomain: "tf.training.org",
    67  					expectedClusterSpec: `{"cluster":{"ps":["` + tftestutil.TestTFJobName +
    68  						`-ps-0.ns2.svc.tf.training.org:2222"],"worker":["` + tftestutil.TestTFJobName +
    69  						`-worker-0.ns2.svc.tf.training.org:2222"]},"task":{"type":"worker","index":0},"environment":"cloud"}`,
    70  				},
    71  				{
    72  					tfJob:               tftestutil.NewTFJobWithEvaluatorAndNamespace(1, 1, 1, "ns3"),
    73  					rt:                  "worker",
    74  					index:               "0",
    75  					customClusterDomain: "tf.training.io",
    76  					expectedClusterSpec: `{"cluster":{"evaluator":["` + tftestutil.TestTFJobName +
    77  						`-evaluator-0.ns3.svc.tf.training.io:2222"],"ps":["` + tftestutil.TestTFJobName +
    78  						`-ps-0.ns3.svc.tf.training.io:2222"],"worker":["` + tftestutil.TestTFJobName +
    79  						`-worker-0.ns3.svc.tf.training.io:2222"]},"task":{"type":"worker","index":0},"environment":"cloud"}`,
    80  				},
    81  				{
    82  					tfJob:               tftestutil.NewTFJobWithEvaluatorAndNamespace(1, 1, 1, "ns3"),
    83  					rt:                  "worker",
    84  					index:               "0",
    85  					customClusterDomain: "",
    86  					expectedClusterSpec: `{"cluster":{"evaluator":["` + tftestutil.TestTFJobName +
    87  						`-evaluator-0.ns3.svc:2222"],"ps":["` + tftestutil.TestTFJobName +
    88  						`-ps-0.ns3.svc:2222"],"worker":["` + tftestutil.TestTFJobName +
    89  						`-worker-0.ns3.svc:2222"]},"task":{"type":"worker","index":0},"environment":"cloud"}`,
    90  				},
    91  			}
    92  
    93  			for _, c := range testCase {
    94  				c.tfJob.SetName(tftestutil.TestTFJobName)
    95  				c.tfJob.SetUID(uuid.NewUUID())
    96  				_ = os.Setenv(EnvCustomClusterDomain, c.customClusterDomain)
    97  
    98  				podTemplate := c.tfJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeWorker].Template.DeepCopy()
    99  
   100  				podTemplate.Name = core.GenGeneralName(c.tfJob.GetName(), c.rt, c.index)
   101  
   102  				if podTemplate.Labels == nil {
   103  					podTemplate.Labels = map[string]string{}
   104  				}
   105  
   106  				jobName := c.tfJob.GetName()
   107  				labels := reconciler.GenLabels(jobName)
   108  				labels[kubeflowv1.ReplicaTypeLabel] = c.rt
   109  				labels[kubeflowv1.ReplicaIndexLabel] = c.index
   110  
   111  				Expect(reconciler.SetClusterSpec(c.tfJob, podTemplate, c.rt, c.index)).Should(Succeed())
   112  
   113  				if c.expectedClusterSpec == "" {
   114  					Expect(len(podTemplate.Spec.Containers[0].Env)).Should(Equal(0))
   115  				} else {
   116  					actual := podTemplate.Spec.Containers[0].Env[0].Value
   117  					reconciler.Log.Info("printing cluster spec", "expected", c.expectedClusterSpec, "actual pod", podTemplate)
   118  					Expect(actual).Should(Equal(c.expectedClusterSpec))
   119  				}
   120  			}
   121  		})
   122  	})
   123  
   124  	Context("Test IsDistributed", func() {
   125  		It("should returns correctly", func() {
   126  			type tc struct {
   127  				tfJob    *kubeflowv1.TFJob
   128  				expected bool
   129  			}
   130  			testCase := []tc{
   131  				{
   132  					tfJob:    tftestutil.NewTFJob(1, 0),
   133  					expected: false,
   134  				},
   135  				{
   136  					tfJob:    tftestutil.NewTFJob(1, 1),
   137  					expected: true,
   138  				},
   139  				{
   140  					tfJob:    tftestutil.NewTFJob(0, 1),
   141  					expected: false,
   142  				},
   143  				{
   144  					tfJob:    tftestutil.NewTFJobWithChief(1, 0),
   145  					expected: true,
   146  				},
   147  			}
   148  			for _, c := range testCase {
   149  				Expect(isDistributed(c.tfJob)).To(Equal(c.expected))
   150  			}
   151  		})
   152  	})
   153  
   154  	Context("Test Restart Policy", func() {
   155  		It("should assign proper restart policy to pod", func() {
   156  			type tc struct {
   157  				tfJob                 *kubeflowv1.TFJob
   158  				expectedRestartPolicy corev1.RestartPolicy
   159  				expectedType          kubeflowv1.ReplicaType
   160  			}
   161  			testCase := []tc{
   162  				func() tc {
   163  					tfJob := tftestutil.NewTFJob(1, 0)
   164  					specRestartPolicy := kubeflowv1.RestartPolicyExitCode
   165  					tfJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeWorker].RestartPolicy = specRestartPolicy
   166  					return tc{
   167  						tfJob:                 tfJob,
   168  						expectedRestartPolicy: corev1.RestartPolicyNever,
   169  						expectedType:          kubeflowv1.TFJobReplicaTypeWorker,
   170  					}
   171  				}(),
   172  				func() tc {
   173  					tfJob := tftestutil.NewTFJob(1, 0)
   174  					specRestartPolicy := kubeflowv1.RestartPolicyNever
   175  					tfJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeWorker].RestartPolicy = specRestartPolicy
   176  					return tc{
   177  						tfJob:                 tfJob,
   178  						expectedRestartPolicy: corev1.RestartPolicyNever,
   179  						expectedType:          kubeflowv1.TFJobReplicaTypeWorker,
   180  					}
   181  				}(),
   182  				func() tc {
   183  					tfJob := tftestutil.NewTFJob(1, 0)
   184  					specRestartPolicy := kubeflowv1.RestartPolicyAlways
   185  					tfJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeWorker].RestartPolicy = specRestartPolicy
   186  					return tc{
   187  						tfJob:                 tfJob,
   188  						expectedRestartPolicy: corev1.RestartPolicyAlways,
   189  						expectedType:          kubeflowv1.TFJobReplicaTypeWorker,
   190  					}
   191  				}(),
   192  				func() tc {
   193  					tfJob := tftestutil.NewTFJob(1, 0)
   194  					specRestartPolicy := kubeflowv1.RestartPolicyOnFailure
   195  					tfJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeWorker].RestartPolicy = specRestartPolicy
   196  					return tc{
   197  						tfJob:                 tfJob,
   198  						expectedRestartPolicy: corev1.RestartPolicyOnFailure,
   199  						expectedType:          kubeflowv1.TFJobReplicaTypeWorker,
   200  					}
   201  				}(),
   202  			}
   203  			for _, c := range testCase {
   204  				spec := c.tfJob.Spec.TFReplicaSpecs[c.expectedType]
   205  				podTemplate := spec.Template
   206  				setRestartPolicy(&podTemplate, spec)
   207  				Expect(podTemplate.Spec.RestartPolicy).To(Equal(c.expectedRestartPolicy))
   208  			}
   209  		})
   210  	})
   211  
   212  	Context("Test Exit Code", func() {
   213  		It("should delete designated Pod", func() {
   214  			By("Creating TFJob \"test-exit-code\" with 1 worker only")
   215  			ctx := context.Background()
   216  
   217  			tfJob := tftestutil.NewTFJob(1, 0)
   218  			tfJob.SetName("test-exit-code")
   219  			tfJob.SetUID(uuid.NewUUID())
   220  			tfJob.Spec.TFReplicaSpecs[kubeflowv1.TFJobReplicaTypeWorker].RestartPolicy = kubeflowv1.RestartPolicyExitCode
   221  
   222  			refs := []metav1.OwnerReference{
   223  				*reconciler.GenOwnerReference(tfJob),
   224  			}
   225  			By("creating worker Pod")
   226  			pod := tftestutil.NewPod(tfJob, kubeflowv1.TFJobReplicaTypeWorker, 0, refs)
   227  			basicLabels := reconciler.GenLabels(tfJob.GetName())
   228  			for k, v := range basicLabels {
   229  				pod.Labels[k] = v
   230  			}
   231  			pod.Spec.Containers = append(pod.Spec.Containers, corev1.Container{
   232  				Name:  kubeflowv1.TFJobDefaultContainerName,
   233  				Image: tftestutil.DummyContainerImage,
   234  			})
   235  			Expect(testK8sClient.Create(ctx, pod)).Should(Succeed())
   236  
   237  			created := &corev1.Pod{}
   238  			key := types.NamespacedName{Namespace: metav1.NamespaceDefault, Name: pod.GetName()}
   239  			Expect(testK8sClient.Get(ctx, key, created)).Should(Succeed())
   240  			created.Status.Phase = corev1.PodFailed
   241  			created.Status.ContainerStatuses = append(created.Status.ContainerStatuses, corev1.ContainerStatus{
   242  				Name: kubeflowv1.TFJobDefaultContainerName,
   243  				State: corev1.ContainerState{
   244  					Terminated: &corev1.ContainerStateTerminated{
   245  						ExitCode: 130,
   246  					},
   247  				},
   248  			})
   249  			Expect(testK8sClient.Status().Update(ctx, created))
   250  
   251  			// Make sure the version of pod created is updated with desired status
   252  			Eventually(func() error {
   253  				updated := &corev1.Pod{}
   254  				if err := testK8sClient.Get(ctx, key, updated); err != nil {
   255  					return err
   256  				}
   257  				if updated.Status.Phase != corev1.PodFailed {
   258  					return fmt.Errorf("pod status is not Failed")
   259  				}
   260  				return nil
   261  			}, testutil.Timeout, testutil.Interval).Should(BeNil())
   262  
   263  			_ = reconciler.ReconcileJobs(tfJob, tfJob.Spec.TFReplicaSpecs, tfJob.Status, &tfJob.Spec.RunPolicy)
   264  
   265  			Eventually(func() bool {
   266  				noPod := &corev1.Pod{}
   267  				err := testK8sClient.Get(ctx, key, noPod)
   268  				if err == nil {
   269  					reconciler.Log.Info("still got pod", "jobName", tfJob.GetName(), "pod", noPod)
   270  					return noPod.GetDeletionTimestamp() != nil
   271  				}
   272  				return errors.IsNotFound(err)
   273  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   274  		})
   275  	})
   276  
   277  	Describe("Test Scale Down", func() {
   278  		It("should delete redundant Pods", func() {
   279  			ctx := context.Background()
   280  
   281  			tfJob := tftestutil.NewTFJob(2, 0)
   282  			//tfJob.SelfLink = "/api/v1/namespaces/default/tfjob/test-tfjob"
   283  			tfJob.SetName("test-scale-down")
   284  			tfJob.SetUID(uuid.NewUUID())
   285  			tfJob.Spec.EnableDynamicWorker = true
   286  
   287  			refs := []metav1.OwnerReference{*reconciler.GenOwnerReference(tfJob)}
   288  
   289  			pods := []*corev1.Pod{
   290  				tftestutil.NewPod(tfJob, kubeflowv1.TFJobReplicaTypeWorker, 0, refs),
   291  				tftestutil.NewPod(tfJob, kubeflowv1.TFJobReplicaTypeWorker, 1, refs),
   292  				tftestutil.NewPod(tfJob, kubeflowv1.TFJobReplicaTypeWorker, 2, refs),
   293  			}
   294  
   295  			for i := range pods {
   296  				pod := pods[i]
   297  				for k, v := range reconciler.GenLabels(tfJob.GetName()) {
   298  					pod.Labels[k] = v
   299  				}
   300  				Expect(testK8sClient.Create(ctx, pod)).Should(Succeed())
   301  			}
   302  
   303  			// Ensure the created Pods are all in cache
   304  			Eventually(func() error {
   305  				podList := &corev1.PodList{}
   306  				selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{
   307  					MatchLabels: reconciler.GenLabels(tfJob.GetName()),
   308  				})
   309  				if err != nil {
   310  					return err
   311  				}
   312  				listOpt := client.MatchingLabelsSelector{
   313  					Selector: selector,
   314  				}
   315  				err = testK8sClient.List(ctx, podList, listOpt)
   316  				if err != nil {
   317  					return err
   318  				}
   319  				if len(podList.Items) != 3 {
   320  					return fmt.Errorf("expecting %d Pods while got %d", 3, len(podList.Items))
   321  				}
   322  				return nil
   323  			}, testutil.Timeout, testutil.Interval).Should(BeNil())
   324  
   325  			_ = reconciler.ReconcileJobs(tfJob, tfJob.Spec.TFReplicaSpecs, tfJob.Status, &tfJob.Spec.RunPolicy)
   326  
   327  			noKey := types.NamespacedName{
   328  				Namespace: metav1.NamespaceDefault,
   329  				Name:      pods[2].GetName(),
   330  			}
   331  			Eventually(func() bool {
   332  				noPod := &corev1.Pod{}
   333  				err := testK8sClient.Get(ctx, noKey, noPod)
   334  				if err == nil {
   335  					return false
   336  				}
   337  				return errors.IsNotFound(err)
   338  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   339  		})
   340  	})
   341  
   342  	Describe("Test Scale Up", func() {
   343  		It("should create missing Pods", func() {
   344  			ctx := context.Background()
   345  
   346  			tfJob := tftestutil.NewTFJob(3, 0)
   347  			tfJob.SetName("test-scale-up")
   348  			tfJob.SetUID(uuid.NewUUID())
   349  			tfJob.Spec.EnableDynamicWorker = true
   350  
   351  			refs := []metav1.OwnerReference{*reconciler.GenOwnerReference(tfJob)}
   352  
   353  			pods := []*corev1.Pod{
   354  				tftestutil.NewPod(tfJob, kubeflowv1.TFJobReplicaTypeWorker, 0, refs),
   355  			}
   356  
   357  			for i := range pods {
   358  				pod := pods[i]
   359  				for k, v := range reconciler.GenLabels(tfJob.GetName()) {
   360  					pod.Labels[k] = v
   361  				}
   362  				Expect(testK8sClient.Create(ctx, pod)).Should(Succeed())
   363  			}
   364  
   365  			// Ensure the created Pods are all in cache
   366  			Eventually(func() error {
   367  				podList := &corev1.PodList{}
   368  				selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{
   369  					MatchLabels: reconciler.GenLabels(tfJob.GetName()),
   370  				})
   371  				if err != nil {
   372  					return err
   373  				}
   374  				listOpt := client.MatchingLabelsSelector{
   375  					Selector: selector,
   376  				}
   377  				err = testK8sClient.List(ctx, podList, listOpt)
   378  				if err != nil {
   379  					return err
   380  				}
   381  				if len(podList.Items) != 1 {
   382  					return fmt.Errorf("before reconciling, expecting %d Pods while got %d", 1, len(podList.Items))
   383  				}
   384  				return nil
   385  			}, testutil.Timeout, testutil.Interval).Should(BeNil())
   386  
   387  			_ = reconciler.ReconcileJobs(tfJob, tfJob.Spec.TFReplicaSpecs, tfJob.Status, &tfJob.Spec.RunPolicy)
   388  
   389  			// Check if there are two more Pods created
   390  			Eventually(func() error {
   391  				podList := &corev1.PodList{}
   392  				selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{
   393  					MatchLabels: reconciler.GenLabels(tfJob.GetName()),
   394  				})
   395  				if err != nil {
   396  					return err
   397  				}
   398  				listOpt := client.MatchingLabelsSelector{
   399  					Selector: selector,
   400  				}
   401  				err = testK8sClient.List(ctx, podList, listOpt)
   402  				if err != nil {
   403  					return err
   404  				}
   405  				if len(podList.Items) != 3 {
   406  					return fmt.Errorf("after reconciling, expecting %d Pods while got %d", 3, len(podList.Items))
   407  				}
   408  				return nil
   409  			}, testutil.Timeout, testutil.Interval).Should(BeNil())
   410  		})
   411  	})
   412  
   413  	Describe("TestIsWorker0Completed", func() {
   414  		It("should match expected result", func() {
   415  			newInt32 := func(in int32) *int32 {
   416  				return &in
   417  			}
   418  			tests := []struct {
   419  				// worker failed, succeeded, running num
   420  				workers     [3]int32
   421  				tfJob       *kubeflowv1.TFJob
   422  				replicas    map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec
   423  				expected    bool
   424  				expectedErr bool
   425  			}{
   426  				{
   427  					workers:     [3]int32{0, 0, 1},
   428  					tfJob:       tftestutil.NewTFJobV2(1, 1, 0, 0, 0),
   429  					expected:    false,
   430  					expectedErr: false,
   431  					replicas: map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{
   432  						kubeflowv1.TFJobReplicaTypeWorker: {
   433  							Replicas: newInt32(1),
   434  							Template: tftestutil.NewTFReplicaSpecTemplate(),
   435  						},
   436  						kubeflowv1.TFJobReplicaTypePS: {
   437  							Replicas: newInt32(1),
   438  							Template: tftestutil.NewTFReplicaSpecTemplate(),
   439  						},
   440  					},
   441  				},
   442  				{
   443  					workers:     [3]int32{0, 1, 0},
   444  					tfJob:       tftestutil.NewTFJobV2(1, 0, 0, 0, 0),
   445  					expected:    true,
   446  					expectedErr: false,
   447  					replicas: map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{
   448  						kubeflowv1.TFJobReplicaTypeWorker: {
   449  							Replicas: newInt32(1),
   450  							Template: tftestutil.NewTFReplicaSpecTemplate(),
   451  						},
   452  					},
   453  				},
   454  				{
   455  					workers:     [3]int32{0, 0, 0},
   456  					tfJob:       tftestutil.NewTFJobV2(0, 0, 1, 0, 0),
   457  					expected:    true,
   458  					expectedErr: false,
   459  					replicas: map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{
   460  						kubeflowv1.TFJobReplicaTypeMaster: {
   461  							Replicas: newInt32(1),
   462  							Template: tftestutil.NewTFReplicaSpecTemplate(),
   463  						},
   464  					},
   465  				},
   466  				{
   467  					workers:     [3]int32{0, 0, 0},
   468  					tfJob:       tftestutil.NewTFJobV2(0, 0, 0, 1, 0),
   469  					expected:    true,
   470  					expectedErr: false,
   471  					replicas: map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{
   472  						kubeflowv1.TFJobReplicaTypeChief: {
   473  							Replicas: newInt32(1),
   474  							Template: tftestutil.NewTFReplicaSpecTemplate(),
   475  						},
   476  					},
   477  				},
   478  				{
   479  					workers:     [3]int32{1, 1, 0},
   480  					tfJob:       tftestutil.NewTFJobV2(2, 0, 0, 0, 0),
   481  					expected:    true,
   482  					expectedErr: false,
   483  					replicas: map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{
   484  						kubeflowv1.TFJobReplicaTypeWorker: {
   485  							Replicas: newInt32(2),
   486  							Template: tftestutil.NewTFReplicaSpecTemplate(),
   487  						},
   488  					},
   489  				},
   490  				{
   491  					workers:     [3]int32{1, 0, 1},
   492  					tfJob:       tftestutil.NewTFJobV2(2, 0, 0, 0, 0),
   493  					expected:    false,
   494  					expectedErr: false,
   495  					replicas: map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{
   496  						kubeflowv1.TFJobReplicaTypeWorker: {
   497  							Replicas: newInt32(2),
   498  							Template: tftestutil.NewTFReplicaSpecTemplate(),
   499  						},
   500  					},
   501  				},
   502  			}
   503  
   504  			jobNameTemplate := "test-worker0-complete-%d"
   505  			for i, tt := range tests {
   506  				tt.tfJob.SetName(fmt.Sprintf(jobNameTemplate, i))
   507  				tt.tfJob.SetUID(uuid.NewUUID())
   508  				// only related to worker status
   509  				initializeReplicaStatuses(&tt.tfJob.Status, kubeflowv1.TFJobReplicaTypeWorker)
   510  				// set status and add pod to indexer
   511  				setStatusForTest(tt.tfJob, kubeflowv1.TFJobReplicaTypeWorker, tt.workers[0], tt.workers[1], tt.workers[2], false, true, testK8sClient)
   512  
   513  				// Adding this section to make sure all pods are created and cached
   514  				Eventually(func() error {
   515  					podList := &corev1.PodList{}
   516  					selector, err := metav1.LabelSelectorAsSelector(&metav1.LabelSelector{
   517  						MatchLabels: reconciler.GenLabels(tt.tfJob.GetName()),
   518  					})
   519  					if err != nil {
   520  						return err
   521  					}
   522  					listOpt := client.MatchingLabelsSelector{
   523  						Selector: selector,
   524  					}
   525  					err = testK8sClient.List(context.Background(), podList, listOpt)
   526  					if err != nil {
   527  						return nil
   528  					}
   529  					totalExpectedPodCount := tt.workers[0] + tt.workers[1] + tt.workers[2]
   530  					if len(podList.Items) != int(totalExpectedPodCount) {
   531  						return fmt.Errorf("pod number (%d) for %s not match for expected pod number %d",
   532  							len(podList.Items), tt.tfJob.GetName(), totalExpectedPodCount)
   533  					}
   534  					return nil
   535  				}, testutil.Timeout, testutil.Interval).Should(BeNil())
   536  
   537  				got, err := reconciler.IsWorker0Completed(tt.tfJob, tt.replicas)
   538  
   539  				if err != nil {
   540  					Expect(err).To(Equal(tt.expectedErr))
   541  				} else {
   542  					Expect(got).To(Equal(tt.expected))
   543  				}
   544  			}
   545  		})
   546  	})
   547  })