github.com/kubeflow/training-operator@v1.7.0/pkg/controller.v1/mxnet/mxjob_controller_test.go (about)

     1  // Copyright 2023 The Kubeflow Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mxnet
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  
    21  	. "github.com/onsi/ginkgo/v2"
    22  	. "github.com/onsi/gomega"
    23  	corev1 "k8s.io/api/core/v1"
    24  	"k8s.io/apimachinery/pkg/api/errors"
    25  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    26  	"k8s.io/apimachinery/pkg/types"
    27  	"k8s.io/utils/pointer"
    28  	"sigs.k8s.io/controller-runtime/pkg/client"
    29  
    30  	kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
    31  	commonutil "github.com/kubeflow/training-operator/pkg/util"
    32  	"github.com/kubeflow/training-operator/pkg/util/testutil"
    33  )
    34  
    35  // TODO: we should implement more tests.
    36  var _ = Describe("MXJob controller", func() {
    37  	const (
    38  		expectedPort = int32(9091)
    39  	)
    40  	Context("When creating the MXJob", func() {
    41  		const name = "test-job"
    42  		var (
    43  			ns           *corev1.Namespace
    44  			job          *kubeflowv1.MXJob
    45  			jobKey       types.NamespacedName
    46  			serverKey    types.NamespacedName
    47  			worker0Key   types.NamespacedName
    48  			schedulerKey types.NamespacedName
    49  			ctx          = context.Background()
    50  		)
    51  		BeforeEach(func() {
    52  			ns = &corev1.Namespace{
    53  				ObjectMeta: metav1.ObjectMeta{
    54  					GenerateName: "mxjob-test-",
    55  				},
    56  			}
    57  			Expect(testK8sClient.Create(ctx, ns)).Should(Succeed())
    58  
    59  			job = newMXJobForTest(name, ns.Name)
    60  			jobKey = client.ObjectKeyFromObject(job)
    61  			serverKey = types.NamespacedName{
    62  				Name:      fmt.Sprintf("%s-server-0", name),
    63  				Namespace: ns.Name,
    64  			}
    65  			worker0Key = types.NamespacedName{
    66  				Name:      fmt.Sprintf("%s-worker-0", name),
    67  				Namespace: ns.Name,
    68  			}
    69  			schedulerKey = types.NamespacedName{
    70  				Name:      fmt.Sprintf("%s-scheduler-0", name),
    71  				Namespace: ns.Name,
    72  			}
    73  			job.Spec.MXReplicaSpecs = map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{
    74  				kubeflowv1.MXJobReplicaTypeServer: {
    75  					Replicas: pointer.Int32(1),
    76  					Template: corev1.PodTemplateSpec{
    77  						Spec: corev1.PodSpec{
    78  							Containers: []corev1.Container{
    79  								{
    80  									Image: "test-image",
    81  									Name:  kubeflowv1.MXJobDefaultContainerName,
    82  									Ports: []corev1.ContainerPort{
    83  										{
    84  											Name:          kubeflowv1.MXJobDefaultPortName,
    85  											ContainerPort: expectedPort,
    86  											Protocol:      corev1.ProtocolTCP,
    87  										},
    88  									},
    89  								},
    90  							},
    91  						},
    92  					},
    93  				},
    94  				kubeflowv1.MXJobReplicaTypeScheduler: {
    95  					Replicas: pointer.Int32(1),
    96  					Template: corev1.PodTemplateSpec{
    97  						Spec: corev1.PodSpec{
    98  							Containers: []corev1.Container{
    99  								{
   100  									Image: "test-image",
   101  									Name:  kubeflowv1.MXJobDefaultContainerName,
   102  									Ports: []corev1.ContainerPort{
   103  										{
   104  											Name:          kubeflowv1.MXJobDefaultPortName,
   105  											ContainerPort: expectedPort,
   106  											Protocol:      corev1.ProtocolTCP,
   107  										},
   108  									},
   109  								},
   110  							},
   111  						},
   112  					},
   113  				},
   114  				kubeflowv1.MXJobReplicaTypeWorker: {
   115  					Replicas: pointer.Int32(2),
   116  					Template: corev1.PodTemplateSpec{
   117  						Spec: corev1.PodSpec{
   118  							Containers: []corev1.Container{
   119  								{
   120  									Image: "test-image",
   121  									Name:  kubeflowv1.MXJobDefaultContainerName,
   122  									Ports: []corev1.ContainerPort{
   123  										{
   124  											Name:          kubeflowv1.MXJobDefaultPortName,
   125  											ContainerPort: expectedPort,
   126  											Protocol:      corev1.ProtocolTCP,
   127  										},
   128  									},
   129  								},
   130  							},
   131  						},
   132  					},
   133  				},
   134  			}
   135  		})
   136  		AfterEach(func() {
   137  			Expect(testK8sClient.Delete(ctx, job)).Should(Succeed())
   138  			Expect(testK8sClient.Delete(ctx, ns)).Should(Succeed())
   139  		})
   140  		It("Shouldn't create resources when MXJob is suspended; Should create resources once MXJob is unsuspended", func() {
   141  			By("By creating a new MXJob with suspend=true")
   142  			job.Spec.RunPolicy.Suspend = pointer.Bool(true)
   143  			job.Spec.MXReplicaSpecs[kubeflowv1.MXJobReplicaTypeWorker].Replicas = pointer.Int32(1)
   144  			Expect(testK8sClient.Create(ctx, job)).Should(Succeed())
   145  
   146  			created := &kubeflowv1.MXJob{}
   147  			serverPod := &corev1.Pod{}
   148  			workerPod := &corev1.Pod{}
   149  			schedulerPod := &corev1.Pod{}
   150  			serverSvc := &corev1.Service{}
   151  			workerSvc := &corev1.Service{}
   152  			schedulerSvc := &corev1.Service{}
   153  
   154  			By("Checking created MXJob")
   155  			Eventually(func() bool {
   156  				err := testK8sClient.Get(ctx, jobKey, created)
   157  				return err == nil
   158  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   159  			By("Checking created MXJob has a nil startTime")
   160  			Consistently(func() *metav1.Time {
   161  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   162  				return created.Status.StartTime
   163  			}, testutil.ConsistentDuration, testutil.Interval).Should(BeNil())
   164  
   165  			By("Checking if the pods and services aren't created")
   166  			Consistently(func() bool {
   167  				errServerPod := testK8sClient.Get(ctx, serverKey, serverPod)
   168  				errWorkerPod := testK8sClient.Get(ctx, worker0Key, workerPod)
   169  				errSchedulerPod := testK8sClient.Get(ctx, schedulerKey, schedulerPod)
   170  				errServerSvc := testK8sClient.Get(ctx, serverKey, serverSvc)
   171  				errWorkerSvc := testK8sClient.Get(ctx, worker0Key, workerSvc)
   172  				errSchedulerSvc := testK8sClient.Get(ctx, schedulerKey, schedulerSvc)
   173  				return errors.IsNotFound(errServerPod) && errors.IsNotFound(errWorkerPod) && errors.IsNotFound(errSchedulerPod) &&
   174  					errors.IsNotFound(errServerSvc) && errors.IsNotFound(errWorkerSvc) && errors.IsNotFound(errSchedulerSvc)
   175  			}, testutil.ConsistentDuration, testutil.Interval).Should(BeTrue())
   176  
   177  			By("Checking if the MXJob has suspended condition")
   178  			Eventually(func() []kubeflowv1.JobCondition {
   179  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   180  				return created.Status.Conditions
   181  			}, testutil.ConsistentDuration, testutil.Interval).Should(BeComparableTo([]kubeflowv1.JobCondition{
   182  				{
   183  					Type:    kubeflowv1.JobCreated,
   184  					Status:  corev1.ConditionTrue,
   185  					Reason:  commonutil.NewReason(kubeflowv1.MXJobKind, commonutil.JobCreatedReason),
   186  					Message: fmt.Sprintf("MXJob %s is created.", name),
   187  				},
   188  				{
   189  					Type:    kubeflowv1.JobSuspended,
   190  					Status:  corev1.ConditionTrue,
   191  					Reason:  commonutil.NewReason(kubeflowv1.MXJobKind, commonutil.JobSuspendedReason),
   192  					Message: fmt.Sprintf("MXJob %s is suspended.", name),
   193  				},
   194  			}, testutil.IgnoreJobConditionsTimes))
   195  		})
   196  
   197  		It("Should delete resources after MXJob is suspended; Should resume MXJob after MXJob is unsuspended", func() {
   198  			By("By creating a new MXJob")
   199  			job.Spec.MXReplicaSpecs[kubeflowv1.MXJobReplicaTypeWorker].Replicas = pointer.Int32(1)
   200  			Expect(testK8sClient.Create(ctx, job)).Should(Succeed())
   201  
   202  			created := &kubeflowv1.MXJob{}
   203  			serverPod := &corev1.Pod{}
   204  			workerPod := &corev1.Pod{}
   205  			schedulerPod := &corev1.Pod{}
   206  			serverSvc := &corev1.Service{}
   207  			workerSvc := &corev1.Service{}
   208  			schedulerSvc := &corev1.Service{}
   209  
   210  			// We'll need to retry getting this newly created MXJob, given that creation may not immediately happen.
   211  			By("Checking created MXJob")
   212  			Eventually(func() bool {
   213  				err := testK8sClient.Get(ctx, jobKey, created)
   214  				return err == nil
   215  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   216  
   217  			var startTimeBeforeSuspended *metav1.Time
   218  			Eventually(func() *metav1.Time {
   219  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   220  				startTimeBeforeSuspended = created.Status.StartTime
   221  				return startTimeBeforeSuspended
   222  			}, testutil.Timeout, testutil.Interval).ShouldNot(BeNil())
   223  
   224  			By("Checking the created pods and services")
   225  			Eventually(func() bool {
   226  				errServerPod := testK8sClient.Get(ctx, serverKey, serverPod)
   227  				errWorkerPod := testK8sClient.Get(ctx, worker0Key, workerPod)
   228  				errSchedulerPod := testK8sClient.Get(ctx, schedulerKey, schedulerPod)
   229  				errServerSvc := testK8sClient.Get(ctx, serverKey, serverSvc)
   230  				errWorkerSvc := testK8sClient.Get(ctx, worker0Key, workerSvc)
   231  				errSchedulerSvc := testK8sClient.Get(ctx, schedulerKey, schedulerSvc)
   232  				return errServerPod == nil && errWorkerPod == nil && errSchedulerPod == nil &&
   233  					errServerSvc == nil && errWorkerSvc == nil && errSchedulerSvc == nil
   234  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   235  
   236  			By("Updating the pod's phase with Running")
   237  			Eventually(func() error {
   238  				Expect(testK8sClient.Get(ctx, serverKey, serverPod)).Should(Succeed())
   239  				serverPod.Status.Phase = corev1.PodRunning
   240  				return testK8sClient.Status().Update(ctx, serverPod)
   241  			}, testutil.Timeout, testutil.Interval).Should(Succeed())
   242  			Eventually(func() error {
   243  				Expect(testK8sClient.Get(ctx, worker0Key, workerPod)).Should(Succeed())
   244  				workerPod.Status.Phase = corev1.PodRunning
   245  				return testK8sClient.Status().Update(ctx, workerPod)
   246  			}, testutil.Timeout, testutil.Interval).Should(Succeed())
   247  			Eventually(func() error {
   248  				Expect(testK8sClient.Get(ctx, schedulerKey, schedulerPod)).Should(Succeed())
   249  				schedulerPod.Status.Phase = corev1.PodRunning
   250  				return testK8sClient.Status().Update(ctx, schedulerPod)
   251  			}, testutil.Timeout, testutil.Interval).Should(Succeed())
   252  
   253  			By("Checking the MXJob's condition")
   254  			Eventually(func() []kubeflowv1.JobCondition {
   255  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   256  				return created.Status.Conditions
   257  			}, testutil.Timeout, testutil.Interval).Should(BeComparableTo([]kubeflowv1.JobCondition{
   258  				{
   259  					Type:    kubeflowv1.JobCreated,
   260  					Status:  corev1.ConditionTrue,
   261  					Reason:  commonutil.NewReason(kubeflowv1.MXJobKind, commonutil.JobCreatedReason),
   262  					Message: fmt.Sprintf("MXJob %s is created.", name),
   263  				},
   264  				{
   265  					Type:    kubeflowv1.JobRunning,
   266  					Status:  corev1.ConditionTrue,
   267  					Reason:  commonutil.NewReason(kubeflowv1.MXJobKind, commonutil.JobRunningReason),
   268  					Message: fmt.Sprintf("MXJob %s is running.", name),
   269  				},
   270  			}, testutil.IgnoreJobConditionsTimes))
   271  
   272  			By("Updating the MXJob with suspend=true")
   273  			Eventually(func() error {
   274  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   275  				created.Spec.RunPolicy.Suspend = pointer.Bool(true)
   276  				return testK8sClient.Update(ctx, created)
   277  			}, testutil.Timeout, testutil.Interval).Should(Succeed())
   278  
   279  			By("Checking if the pods and services are removed")
   280  			Eventually(func() bool {
   281  				errServer := testK8sClient.Get(ctx, serverKey, serverPod)
   282  				errWorker := testK8sClient.Get(ctx, worker0Key, workerPod)
   283  				errScheduler := testK8sClient.Get(ctx, schedulerKey, schedulerPod)
   284  				return errors.IsNotFound(errServer) && errors.IsNotFound(errWorker) && errors.IsNotFound(errScheduler)
   285  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   286  			Eventually(func() bool {
   287  				errServer := testK8sClient.Get(ctx, serverKey, serverSvc)
   288  				errWorker := testK8sClient.Get(ctx, worker0Key, workerSvc)
   289  				errScheduler := testK8sClient.Get(ctx, schedulerKey, schedulerSvc)
   290  				return errors.IsNotFound(errServer) && errors.IsNotFound(errWorker) && errors.IsNotFound(errScheduler)
   291  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   292  			Consistently(func() bool {
   293  				errServerPod := testK8sClient.Get(ctx, serverKey, serverPod)
   294  				errWorkerPod := testK8sClient.Get(ctx, worker0Key, workerPod)
   295  				errSchedulerPod := testK8sClient.Get(ctx, schedulerKey, schedulerPod)
   296  				errServerSvc := testK8sClient.Get(ctx, serverKey, serverSvc)
   297  				errWorkerSvc := testK8sClient.Get(ctx, worker0Key, workerSvc)
   298  				errSchedulerSvc := testK8sClient.Get(ctx, schedulerKey, schedulerSvc)
   299  				return errors.IsNotFound(errServerPod) && errors.IsNotFound(errWorkerPod) && errors.IsNotFound(errSchedulerPod) &&
   300  					errors.IsNotFound(errServerSvc) && errors.IsNotFound(errWorkerSvc) && errors.IsNotFound(errSchedulerSvc)
   301  			}, testutil.ConsistentDuration, testutil.Interval).Should(BeTrue())
   302  
   303  			By("Checking if the MXJob has a suspended condition")
   304  			Eventually(func() bool {
   305  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   306  				return created.Status.ReplicaStatuses[kubeflowv1.MXJobReplicaTypeServer].Active == 0 &&
   307  					created.Status.ReplicaStatuses[kubeflowv1.MXJobReplicaTypeWorker].Active == 0 &&
   308  					created.Status.ReplicaStatuses[kubeflowv1.MXJobReplicaTypeScheduler].Active == 0 &&
   309  					created.Status.StartTime.Equal(startTimeBeforeSuspended)
   310  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   311  			Consistently(func() bool {
   312  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   313  				return created.Status.ReplicaStatuses[kubeflowv1.MXJobReplicaTypeServer].Active == 0 &&
   314  					created.Status.ReplicaStatuses[kubeflowv1.MXJobReplicaTypeWorker].Active == 0 &&
   315  					created.Status.ReplicaStatuses[kubeflowv1.MXJobReplicaTypeScheduler].Active == 0 &&
   316  					created.Status.StartTime.Equal(startTimeBeforeSuspended)
   317  			}, testutil.ConsistentDuration, testutil.Interval).Should(BeTrue())
   318  			Expect(created.Status.Conditions).Should(BeComparableTo([]kubeflowv1.JobCondition{
   319  				{
   320  					Type:    kubeflowv1.JobCreated,
   321  					Status:  corev1.ConditionTrue,
   322  					Reason:  commonutil.NewReason(kubeflowv1.MXJobKind, commonutil.JobCreatedReason),
   323  					Message: fmt.Sprintf("MXJob %s is created.", name),
   324  				},
   325  				{
   326  					Type:    kubeflowv1.JobRunning,
   327  					Status:  corev1.ConditionFalse,
   328  					Reason:  commonutil.NewReason(kubeflowv1.MXJobKind, commonutil.JobSuspendedReason),
   329  					Message: fmt.Sprintf("MXJob %s is suspended.", name),
   330  				},
   331  				{
   332  					Type:    kubeflowv1.JobSuspended,
   333  					Reason:  commonutil.NewReason(kubeflowv1.MXJobKind, commonutil.JobSuspendedReason),
   334  					Message: fmt.Sprintf("MXJob %s is suspended.", name),
   335  					Status:  corev1.ConditionTrue,
   336  				},
   337  			}, testutil.IgnoreJobConditionsTimes))
   338  
   339  			By("Unsuspending the MXJob")
   340  			Eventually(func() error {
   341  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   342  				created.Spec.RunPolicy.Suspend = pointer.Bool(false)
   343  				return testK8sClient.Update(ctx, created)
   344  			}, testutil.Timeout, testutil.Interval).Should(Succeed())
   345  			Eventually(func() *metav1.Time {
   346  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   347  				return created.Status.StartTime
   348  			}, testutil.Timeout, testutil.Interval).ShouldNot(BeNil())
   349  
   350  			By("Check if the pods and services are created")
   351  			Eventually(func() error {
   352  				return testK8sClient.Get(ctx, serverKey, serverPod)
   353  			}, testutil.Timeout, testutil.Interval).Should(BeNil())
   354  			Eventually(func() error {
   355  				return testK8sClient.Get(ctx, worker0Key, workerPod)
   356  			}, testutil.Timeout, testutil.Interval).Should(BeNil())
   357  			Eventually(func() error {
   358  				return testK8sClient.Get(ctx, schedulerKey, schedulerPod)
   359  			}, testutil.Timeout, testutil.Interval).Should(BeNil())
   360  			Eventually(func() error {
   361  				return testK8sClient.Get(ctx, serverKey, serverSvc)
   362  			}, testutil.Timeout, testutil.Interval).Should(BeNil())
   363  			Eventually(func() error {
   364  				return testK8sClient.Get(ctx, worker0Key, workerSvc)
   365  			}, testutil.Timeout, testutil.Interval).Should(BeNil())
   366  			Eventually(func() error {
   367  				return testK8sClient.Get(ctx, schedulerKey, schedulerSvc)
   368  			}, testutil.Timeout, testutil.Interval).Should(BeNil())
   369  
   370  			By("Updating Pod's condition with running")
   371  			Eventually(func() error {
   372  				Expect(testK8sClient.Get(ctx, serverKey, serverPod)).Should(Succeed())
   373  				serverPod.Status.Phase = corev1.PodRunning
   374  				return testK8sClient.Status().Update(ctx, serverPod)
   375  			}, testutil.Timeout, testutil.Interval).Should(Succeed())
   376  			Eventually(func() error {
   377  				Expect(testK8sClient.Get(ctx, worker0Key, workerPod)).Should(Succeed())
   378  				workerPod.Status.Phase = corev1.PodRunning
   379  				return testK8sClient.Status().Update(ctx, workerPod)
   380  			}, testutil.Timeout, testutil.Interval).Should(Succeed())
   381  			Eventually(func() error {
   382  				Expect(testK8sClient.Get(ctx, schedulerKey, schedulerPod)).Should(Succeed())
   383  				schedulerPod.Status.Phase = corev1.PodRunning
   384  				return testK8sClient.Status().Update(ctx, schedulerPod)
   385  			}, testutil.Timeout, testutil.Interval).Should(Succeed())
   386  
   387  			By("Checking if the MXJob has resumed conditions")
   388  			Eventually(func() []kubeflowv1.JobCondition {
   389  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   390  				return created.Status.Conditions
   391  			}, testutil.Timeout, testutil.Interval).Should(BeComparableTo([]kubeflowv1.JobCondition{
   392  				{
   393  					Type:    kubeflowv1.JobCreated,
   394  					Status:  corev1.ConditionTrue,
   395  					Reason:  commonutil.NewReason(kubeflowv1.MXJobKind, commonutil.JobCreatedReason),
   396  					Message: fmt.Sprintf("MXJob %s is created.", name),
   397  				},
   398  				{
   399  					Type:    kubeflowv1.JobSuspended,
   400  					Reason:  commonutil.NewReason(kubeflowv1.MXJobKind, commonutil.JobResumedReason),
   401  					Message: fmt.Sprintf("MXJob %s is resumed.", name),
   402  					Status:  corev1.ConditionFalse,
   403  				},
   404  				{
   405  					Type:    kubeflowv1.JobRunning,
   406  					Status:  corev1.ConditionTrue,
   407  					Reason:  commonutil.NewReason(kubeflowv1.MXJobKind, commonutil.JobRunningReason),
   408  					Message: fmt.Sprintf("MXJob %s is running.", name),
   409  				},
   410  			}, testutil.IgnoreJobConditionsTimes))
   411  
   412  			By("Checking if the startTime is updated")
   413  			Expect(created.Status.StartTime).ShouldNot(Equal(startTimeBeforeSuspended))
   414  		})
   415  	})
   416  })
   417  
   418  func newMXJobForTest(name, namespace string) *kubeflowv1.MXJob {
   419  	return &kubeflowv1.MXJob{
   420  		ObjectMeta: metav1.ObjectMeta{
   421  			Name:      name,
   422  			Namespace: namespace,
   423  		},
   424  	}
   425  }