github.com/kubeflow/training-operator@v1.7.0/pkg/controller.v1/xgboost/xgboostjob_controller_test.go (about)

     1  // Copyright 2023 The Kubeflow Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package xgboost
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  
    21  	. "github.com/onsi/ginkgo/v2"
    22  	. "github.com/onsi/gomega"
    23  	corev1 "k8s.io/api/core/v1"
    24  	"k8s.io/apimachinery/pkg/api/errors"
    25  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    26  	"k8s.io/apimachinery/pkg/types"
    27  	"k8s.io/utils/pointer"
    28  	"sigs.k8s.io/controller-runtime/pkg/client"
    29  
    30  	kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
    31  	commonutil "github.com/kubeflow/training-operator/pkg/util"
    32  	"github.com/kubeflow/training-operator/pkg/util/testutil"
    33  )
    34  
    35  var _ = Describe("XGBoost controller", func() {
    36  	// Define utility constants for object names.
    37  	const (
    38  		expectedPort = int32(9999)
    39  	)
    40  	Context("When creating the XGBoostJob", func() {
    41  		const name = "test-job"
    42  		var (
    43  			ns         *corev1.Namespace
    44  			job        *kubeflowv1.XGBoostJob
    45  			jobKey     types.NamespacedName
    46  			masterKey  types.NamespacedName
    47  			worker0Key types.NamespacedName
    48  			ctx        = context.Background()
    49  		)
    50  		BeforeEach(func() {
    51  			ns = &corev1.Namespace{
    52  				ObjectMeta: metav1.ObjectMeta{
    53  					GenerateName: "xgboost-test-",
    54  				},
    55  			}
    56  			Expect(testK8sClient.Create(ctx, ns)).Should(Succeed())
    57  
    58  			job = newXGBoostForTest(name, ns.Name)
    59  			jobKey = client.ObjectKeyFromObject(job)
    60  			masterKey = types.NamespacedName{
    61  				Name:      fmt.Sprintf("%s-master-0", name),
    62  				Namespace: ns.Name,
    63  			}
    64  			worker0Key = types.NamespacedName{
    65  				Name:      fmt.Sprintf("%s-worker-0", name),
    66  				Namespace: ns.Name,
    67  			}
    68  			job.Spec.XGBReplicaSpecs = map[kubeflowv1.ReplicaType]*kubeflowv1.ReplicaSpec{
    69  				kubeflowv1.XGBoostJobReplicaTypeMaster: {
    70  					Replicas: pointer.Int32(1),
    71  					Template: corev1.PodTemplateSpec{
    72  						Spec: corev1.PodSpec{
    73  							Containers: []corev1.Container{
    74  								{
    75  									Image: "test-image",
    76  									Name:  kubeflowv1.XGBoostJobDefaultContainerName,
    77  									Ports: []corev1.ContainerPort{
    78  										{
    79  											Name:          kubeflowv1.XGBoostJobDefaultPortName,
    80  											ContainerPort: expectedPort,
    81  											Protocol:      corev1.ProtocolTCP,
    82  										},
    83  									},
    84  								},
    85  							},
    86  						},
    87  					},
    88  				},
    89  				kubeflowv1.XGBoostJobReplicaTypeWorker: {
    90  					Replicas: pointer.Int32(2),
    91  					Template: corev1.PodTemplateSpec{
    92  						Spec: corev1.PodSpec{
    93  							Containers: []corev1.Container{
    94  								{
    95  									Image: "test-image",
    96  									Name:  kubeflowv1.XGBoostJobDefaultContainerName,
    97  									Ports: []corev1.ContainerPort{
    98  										{
    99  											Name:          kubeflowv1.XGBoostJobDefaultPortName,
   100  											ContainerPort: expectedPort,
   101  											Protocol:      corev1.ProtocolTCP,
   102  										},
   103  									},
   104  								},
   105  							},
   106  						},
   107  					},
   108  				},
   109  			}
   110  		})
   111  		AfterEach(func() {
   112  			Expect(testK8sClient.Delete(ctx, job)).Should(Succeed())
   113  			Expect(testK8sClient.Delete(ctx, ns)).Should(Succeed())
   114  		})
   115  		It("Shouldn't create resources if XGBoostJob is suspended", func() {
   116  			By("By creating a new XGBoostJob with suspend=true")
   117  			job.Spec.RunPolicy.Suspend = pointer.Bool(true)
   118  			job.Spec.XGBReplicaSpecs[kubeflowv1.XGBoostJobReplicaTypeWorker].Replicas = pointer.Int32(1)
   119  			Expect(testK8sClient.Create(ctx, job)).Should(Succeed())
   120  
   121  			created := &kubeflowv1.XGBoostJob{}
   122  			masterPod := &corev1.Pod{}
   123  			workerPod := &corev1.Pod{}
   124  			masterSvc := &corev1.Service{}
   125  			workerSvc := &corev1.Service{}
   126  
   127  			By("Checking created XGBoostJob")
   128  			Eventually(func() bool {
   129  				err := testK8sClient.Get(ctx, jobKey, created)
   130  				return err == nil
   131  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   132  			By("Checking created XGBoostJob has a nil startTime")
   133  			Consistently(func() *metav1.Time {
   134  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   135  				return created.Status.StartTime
   136  			}, testutil.ConsistentDuration, testutil.Interval).Should(BeNil())
   137  
   138  			By("Checking if the pods and services aren't created")
   139  			Consistently(func() bool {
   140  				errMasterPod := testK8sClient.Get(ctx, masterKey, masterPod)
   141  				errWorkerPod := testK8sClient.Get(ctx, worker0Key, workerPod)
   142  				errMasterSvc := testK8sClient.Get(ctx, masterKey, masterSvc)
   143  				errWorkerSvc := testK8sClient.Get(ctx, worker0Key, workerSvc)
   144  				return errors.IsNotFound(errMasterPod) && errors.IsNotFound(errWorkerPod) &&
   145  					errors.IsNotFound(errMasterSvc) && errors.IsNotFound(errWorkerSvc)
   146  			}, testutil.ConsistentDuration, testutil.Interval).Should(BeTrue())
   147  
   148  			By("Checking if the XGBoostJob has suspended condition")
   149  			Eventually(func() []kubeflowv1.JobCondition {
   150  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   151  				return created.Status.Conditions
   152  			}, testutil.ConsistentDuration, testutil.Interval).Should(BeComparableTo([]kubeflowv1.JobCondition{
   153  				{
   154  					Type:    kubeflowv1.JobCreated,
   155  					Status:  corev1.ConditionTrue,
   156  					Reason:  commonutil.NewReason(kubeflowv1.XGBoostJobKind, commonutil.JobCreatedReason),
   157  					Message: fmt.Sprintf("XGBoostJob %s is created.", name),
   158  				},
   159  				{
   160  					Type:    kubeflowv1.JobSuspended,
   161  					Status:  corev1.ConditionTrue,
   162  					Reason:  commonutil.NewReason(kubeflowv1.XGBoostJobKind, commonutil.JobSuspendedReason),
   163  					Message: fmt.Sprintf("XGBoostJob %s is suspended.", name),
   164  				},
   165  			}, testutil.IgnoreJobConditionsTimes))
   166  		})
   167  
   168  		It("Should delete resources after XGBoostJob is suspended; Should resume XGBoostJob after XGBoostJob is unsuspended", func() {
   169  			By("By creating a new XGBoostJob")
   170  			job.Spec.XGBReplicaSpecs[kubeflowv1.XGBoostJobReplicaTypeWorker].Replicas = pointer.Int32(1)
   171  			Expect(testK8sClient.Create(ctx, job)).Should(Succeed())
   172  
   173  			created := &kubeflowv1.XGBoostJob{}
   174  			masterPod := &corev1.Pod{}
   175  			workerPod := &corev1.Pod{}
   176  			masterSvc := &corev1.Service{}
   177  			workerSvc := &corev1.Service{}
   178  
   179  			// We'll need to retry getting this newly created XGBoostJob, given that creation may not immediately happen.
   180  			By("Checking created XGBoostJob")
   181  			Eventually(func() bool {
   182  				err := testK8sClient.Get(ctx, jobKey, created)
   183  				return err == nil
   184  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   185  
   186  			var startTimeBeforeSuspended *metav1.Time
   187  			Eventually(func() *metav1.Time {
   188  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   189  				startTimeBeforeSuspended = created.Status.StartTime
   190  				return startTimeBeforeSuspended
   191  			}, testutil.Timeout, testutil.Interval).ShouldNot(BeNil())
   192  
   193  			By("Checking the created pods and services")
   194  			Eventually(func() bool {
   195  				errMaster := testK8sClient.Get(ctx, masterKey, masterPod)
   196  				errWorker := testK8sClient.Get(ctx, worker0Key, workerPod)
   197  				return errMaster == nil && errWorker == nil
   198  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   199  			Eventually(func() bool {
   200  				errMaster := testK8sClient.Get(ctx, masterKey, masterSvc)
   201  				errWorker := testK8sClient.Get(ctx, worker0Key, workerSvc)
   202  				return errMaster == nil && errWorker == nil
   203  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   204  
   205  			By("Updating the Pod's phase with Running")
   206  			Eventually(func() error {
   207  				Expect(testK8sClient.Get(ctx, masterKey, masterPod)).Should(Succeed())
   208  				masterPod.Status.Phase = corev1.PodRunning
   209  				return testK8sClient.Status().Update(ctx, masterPod)
   210  			}, testutil.Timeout, testutil.Interval).Should(Succeed())
   211  			Eventually(func() error {
   212  				Expect(testK8sClient.Get(ctx, worker0Key, workerPod)).Should(Succeed())
   213  				workerPod.Status.Phase = corev1.PodRunning
   214  				return testK8sClient.Status().Update(ctx, workerPod)
   215  			}, testutil.Timeout, testutil.Interval).Should(Succeed())
   216  
   217  			By("Checking the XGBoostJob's condition")
   218  			Eventually(func() []kubeflowv1.JobCondition {
   219  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   220  				return created.Status.Conditions
   221  			}, testutil.Timeout, testutil.Interval).Should(BeComparableTo([]kubeflowv1.JobCondition{
   222  				{
   223  					Type:    kubeflowv1.JobCreated,
   224  					Status:  corev1.ConditionTrue,
   225  					Reason:  commonutil.NewReason(kubeflowv1.XGBoostJobKind, commonutil.JobCreatedReason),
   226  					Message: fmt.Sprintf("XGBoostJob %s is created.", name),
   227  				},
   228  				{
   229  					Type:    kubeflowv1.JobRunning,
   230  					Status:  corev1.ConditionTrue,
   231  					Reason:  commonutil.NewReason(kubeflowv1.XGBoostJobKind, commonutil.JobRunningReason),
   232  					Message: fmt.Sprintf("XGBoostJob %s is running.", name),
   233  				},
   234  			}, testutil.IgnoreJobConditionsTimes))
   235  
   236  			By("Updating the XGBoostJob with suspend=true")
   237  			Eventually(func() error {
   238  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   239  				created.Spec.RunPolicy.Suspend = pointer.Bool(true)
   240  				return testK8sClient.Update(ctx, created)
   241  			}, testutil.Timeout, testutil.Interval).Should(Succeed())
   242  
   243  			By("Checking if the pods and services are removed")
   244  			Eventually(func() bool {
   245  				errMaster := testK8sClient.Get(ctx, masterKey, masterPod)
   246  				errWorker := testK8sClient.Get(ctx, worker0Key, workerPod)
   247  				return errors.IsNotFound(errMaster) && errors.IsNotFound(errWorker)
   248  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   249  			Eventually(func() bool {
   250  				errMaster := testK8sClient.Get(ctx, masterKey, masterSvc)
   251  				errWorker := testK8sClient.Get(ctx, worker0Key, workerSvc)
   252  				return errors.IsNotFound(errMaster) && errors.IsNotFound(errWorker)
   253  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   254  			Consistently(func() bool {
   255  				errMasterPod := testK8sClient.Get(ctx, masterKey, masterPod)
   256  				errWorkerPod := testK8sClient.Get(ctx, worker0Key, workerPod)
   257  				errMasterSvc := testK8sClient.Get(ctx, masterKey, masterSvc)
   258  				errWorkerSvc := testK8sClient.Get(ctx, worker0Key, workerSvc)
   259  				return errors.IsNotFound(errMasterPod) && errors.IsNotFound(errWorkerPod) &&
   260  					errors.IsNotFound(errMasterSvc) && errors.IsNotFound(errWorkerSvc)
   261  			}, testutil.ConsistentDuration, testutil.Interval).Should(BeTrue())
   262  
   263  			By("Checking if the XGBoostJob has a suspended condition")
   264  			Eventually(func() bool {
   265  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   266  				return created.Status.ReplicaStatuses[kubeflowv1.XGBoostJobReplicaTypeMaster].Active == 0 &&
   267  					created.Status.ReplicaStatuses[kubeflowv1.XGBoostJobReplicaTypeWorker].Active == 0 &&
   268  					created.Status.StartTime.Equal(startTimeBeforeSuspended)
   269  			}, testutil.Timeout, testutil.Interval).Should(BeTrue())
   270  			Consistently(func() bool {
   271  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   272  				return created.Status.ReplicaStatuses[kubeflowv1.XGBoostJobReplicaTypeMaster].Active == 0 &&
   273  					created.Status.ReplicaStatuses[kubeflowv1.XGBoostJobReplicaTypeWorker].Active == 0 &&
   274  					created.Status.StartTime.Equal(startTimeBeforeSuspended)
   275  			}, testutil.ConsistentDuration, testutil.Interval).Should(BeTrue())
   276  			Expect(created.Status.Conditions).Should(BeComparableTo([]kubeflowv1.JobCondition{
   277  				{
   278  					Type:    kubeflowv1.JobCreated,
   279  					Status:  corev1.ConditionTrue,
   280  					Reason:  commonutil.NewReason(kubeflowv1.XGBoostJobKind, commonutil.JobCreatedReason),
   281  					Message: fmt.Sprintf("XGBoostJob %s is created.", name),
   282  				},
   283  				{
   284  					Type:    kubeflowv1.JobRunning,
   285  					Status:  corev1.ConditionFalse,
   286  					Reason:  commonutil.NewReason(kubeflowv1.XGBoostJobKind, commonutil.JobSuspendedReason),
   287  					Message: fmt.Sprintf("XGBoostJob %s is suspended.", name),
   288  				},
   289  				{
   290  					Type:    kubeflowv1.JobSuspended,
   291  					Reason:  commonutil.NewReason(kubeflowv1.XGBoostJobKind, commonutil.JobSuspendedReason),
   292  					Message: fmt.Sprintf("XGBoostJob %s is suspended.", name),
   293  					Status:  corev1.ConditionTrue,
   294  				},
   295  			}, testutil.IgnoreJobConditionsTimes))
   296  
   297  			By("Unsuspending the XGBoostJob")
   298  			Eventually(func() error {
   299  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   300  				created.Spec.RunPolicy.Suspend = pointer.Bool(false)
   301  				return testK8sClient.Update(ctx, created)
   302  			}, testutil.Timeout, testutil.Interval).Should(Succeed())
   303  			Eventually(func() *metav1.Time {
   304  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   305  				return created.Status.StartTime
   306  			}, testutil.Timeout, testutil.Interval).ShouldNot(BeNil())
   307  
   308  			By("Check if the pods and services are created")
   309  			Eventually(func() error {
   310  				return testK8sClient.Get(ctx, masterKey, masterPod)
   311  			}, testutil.Timeout, testutil.Interval).Should(BeNil())
   312  			Eventually(func() error {
   313  				return testK8sClient.Get(ctx, worker0Key, workerPod)
   314  			}, testutil.Timeout, testutil.Interval).Should(BeNil())
   315  			Eventually(func() error {
   316  				return testK8sClient.Get(ctx, masterKey, masterSvc)
   317  			}, testutil.Timeout, testutil.Interval).Should(BeNil())
   318  			Eventually(func() error {
   319  				return testK8sClient.Get(ctx, worker0Key, workerSvc)
   320  			}, testutil.Timeout, testutil.Interval).Should(BeNil())
   321  
   322  			By("Updating Pod's condition with Running")
   323  			Eventually(func() error {
   324  				Expect(testK8sClient.Get(ctx, masterKey, masterPod)).Should(Succeed())
   325  				masterPod.Status.Phase = corev1.PodRunning
   326  				return testK8sClient.Status().Update(ctx, masterPod)
   327  			}, testutil.Timeout, testutil.Interval).Should(Succeed())
   328  			Eventually(func() error {
   329  				Expect(testK8sClient.Get(ctx, worker0Key, workerPod)).Should(Succeed())
   330  				workerPod.Status.Phase = corev1.PodRunning
   331  				return testK8sClient.Status().Update(ctx, workerPod)
   332  			}, testutil.Timeout, testutil.Interval).Should(Succeed())
   333  
   334  			By("Checking if the XGBoostJob has resumed conditions")
   335  			Eventually(func() []kubeflowv1.JobCondition {
   336  				Expect(testK8sClient.Get(ctx, jobKey, created)).Should(Succeed())
   337  				return created.Status.Conditions
   338  			}, testutil.Timeout, testutil.Interval).Should(BeComparableTo([]kubeflowv1.JobCondition{
   339  				{
   340  					Type:    kubeflowv1.JobCreated,
   341  					Status:  corev1.ConditionTrue,
   342  					Reason:  commonutil.NewReason(kubeflowv1.XGBoostJobKind, commonutil.JobCreatedReason),
   343  					Message: fmt.Sprintf("XGBoostJob %s is created.", name),
   344  				},
   345  				{
   346  					Type:    kubeflowv1.JobSuspended,
   347  					Reason:  commonutil.NewReason(kubeflowv1.XGBoostJobKind, commonutil.JobResumedReason),
   348  					Message: fmt.Sprintf("XGBoostJob %s is resumed.", name),
   349  					Status:  corev1.ConditionFalse,
   350  				},
   351  				{
   352  					Type:    kubeflowv1.JobRunning,
   353  					Status:  corev1.ConditionTrue,
   354  					Reason:  commonutil.NewReason(kubeflowv1.XGBoostJobKind, commonutil.JobRunningReason),
   355  					Message: fmt.Sprintf("XGBoostJob %s is running.", name),
   356  				},
   357  			}, testutil.IgnoreJobConditionsTimes))
   358  
   359  			By("Checking if the startTime is updated")
   360  			Expect(created.Status.StartTime).ShouldNot(Equal(startTimeBeforeSuspended))
   361  		})
   362  	})
   363  })
   364  
   365  func newXGBoostForTest(name, namespace string) *kubeflowv1.XGBoostJob {
   366  	return &kubeflowv1.XGBoostJob{
   367  		ObjectMeta: metav1.ObjectMeta{
   368  			Name:      name,
   369  			Namespace: namespace,
   370  		},
   371  	}
   372  }