github.com/kubeflow/training-operator@v1.7.0/pkg/controller.v1/common/job.go (about)

     1  /*
     2  Copyright 2023 The Kubeflow Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package common
    18  
    19  import (
    20  	"fmt"
    21  	"reflect"
    22  	"time"
    23  
    24  	apiv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
    25  	"github.com/kubeflow/training-operator/pkg/controller.v1/expectation"
    26  	"github.com/kubeflow/training-operator/pkg/core"
    27  	commonutil "github.com/kubeflow/training-operator/pkg/util"
    28  	"github.com/kubeflow/training-operator/pkg/util/k8sutil"
    29  	trainutil "github.com/kubeflow/training-operator/pkg/util/train"
    30  
    31  	log "github.com/sirupsen/logrus"
    32  	corev1 "k8s.io/api/core/v1"
    33  	v1 "k8s.io/api/core/v1"
    34  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    35  	"k8s.io/apimachinery/pkg/runtime"
    36  	utilruntime "k8s.io/apimachinery/pkg/util/runtime"
    37  	"k8s.io/klog/v2"
    38  	schedulerpluginsv1alpha1 "sigs.k8s.io/scheduler-plugins/apis/scheduling/v1alpha1"
    39  	volcanov1beta1 "volcano.sh/apis/pkg/apis/scheduling/v1beta1"
    40  )
    41  
    42  // DeletePodsAndServices deletes pods and services considering cleanPodPolicy.
    43  // However, if the job doesn't have Succeeded or Failed condition, it ignores cleanPodPolicy.
    44  func (jc *JobController) DeletePodsAndServices(runtimeObject runtime.Object, runPolicy *apiv1.RunPolicy, jobStatus apiv1.JobStatus, pods []*corev1.Pod) error {
    45  	if len(pods) == 0 {
    46  		return nil
    47  	}
    48  
    49  	// Delete nothing when the cleanPodPolicy is None and the job has Succeeded or Failed condition.
    50  	if commonutil.IsFinished(jobStatus) && *runPolicy.CleanPodPolicy == apiv1.CleanPodPolicyNone {
    51  		return nil
    52  	}
    53  
    54  	for _, pod := range pods {
    55  		// Note that pending pod will turn into running once schedulable,
    56  		// not cleaning it may leave orphan running pod in the future,
    57  		// we should treat it equivalent to running phase here.
    58  		if commonutil.IsFinished(jobStatus) && *runPolicy.CleanPodPolicy == apiv1.CleanPodPolicyRunning && pod.Status.Phase != corev1.PodRunning && pod.Status.Phase != corev1.PodPending {
    59  			continue
    60  		}
    61  		if err := jc.PodControl.DeletePod(pod.Namespace, pod.Name, runtimeObject); err != nil {
    62  			return err
    63  		}
    64  		// Pod and service have the same name, thus the service could be deleted using pod's name.
    65  		if err := jc.ServiceControl.DeleteService(pod.Namespace, pod.Name, runtimeObject); err != nil {
    66  			return err
    67  		}
    68  	}
    69  	return nil
    70  }
    71  
    72  // recordAbnormalPods records the active pod whose latest condition is not in True status.
    73  func (jc *JobController) recordAbnormalPods(activePods []*corev1.Pod, object runtime.Object) {
    74  	core.RecordAbnormalPods(activePods, object, jc.Recorder)
    75  }
    76  
    77  // ReconcileJobs checks and updates replicas for each given ReplicaSpec.
    78  // It will requeue the job in case of an error while creating/deleting pods/services.
    79  func (jc *JobController) ReconcileJobs(
    80  	job interface{},
    81  	replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec,
    82  	jobStatus apiv1.JobStatus,
    83  	runPolicy *apiv1.RunPolicy) error {
    84  
    85  	metaObject, ok := job.(metav1.Object)
    86  	jobName := metaObject.GetName()
    87  	if !ok {
    88  		return fmt.Errorf("job is not of type metav1.Object")
    89  	}
    90  	runtimeObject, ok := job.(runtime.Object)
    91  	if !ok {
    92  		return fmt.Errorf("job is not of type runtime.Object")
    93  	}
    94  	jobKey, err := KeyFunc(job)
    95  	if err != nil {
    96  		utilruntime.HandleError(fmt.Errorf("couldn't get key for job object %#v: %v", job, err))
    97  		return err
    98  	}
    99  	jobKind := jc.Controller.GetAPIGroupVersionKind().Kind
   100  	// Reset expectations
   101  	// 1. Since `ReconcileJobs` is called, we expect that previous expectations are all satisfied,
   102  	//    and it's safe to reset the expectations
   103  	// 2. Reset expectations can avoid dirty data such as `expectedDeletion = -1`
   104  	//    (pod or service was deleted unexpectedly)
   105  	if err = jc.ResetExpectations(jobKey, replicas); err != nil {
   106  		log.Warnf("Failed to reset expectations: %v", err)
   107  	}
   108  
   109  	log.Infof("Reconciling for job %s", metaObject.GetName())
   110  	pods, err := jc.Controller.GetPodsForJob(job)
   111  	if err != nil {
   112  		log.Warnf("GetPodsForJob error %v", err)
   113  		return err
   114  	}
   115  
   116  	services, err := jc.Controller.GetServicesForJob(job)
   117  	if err != nil {
   118  		log.Warnf("GetServicesForJob error %v", err)
   119  		return err
   120  	}
   121  
   122  	oldStatus := jobStatus.DeepCopy()
   123  	if commonutil.IsFinished(jobStatus) {
   124  		// If the Job is succeed or failed, delete all pods, services, and podGroup.
   125  		if err = jc.CleanUpResources(runPolicy, runtimeObject, metaObject, jobStatus, pods); err != nil {
   126  			return err
   127  		}
   128  
   129  		// At this point the pods may have been deleted.
   130  		// 1) If the job succeeded, we manually set the replica status.
   131  		// 2) If any replicas are still active, set their status to succeeded.
   132  		if commonutil.IsSucceeded(jobStatus) {
   133  			for rtype := range jobStatus.ReplicaStatuses {
   134  				jobStatus.ReplicaStatuses[rtype].Succeeded += jobStatus.ReplicaStatuses[rtype].Active
   135  				jobStatus.ReplicaStatuses[rtype].Active = 0
   136  			}
   137  		}
   138  
   139  		// No need to update the job status if the status hasn't changed since last time.
   140  		if !reflect.DeepEqual(*oldStatus, jobStatus) {
   141  			return jc.Controller.UpdateJobStatusInApiServer(job, &jobStatus)
   142  		}
   143  
   144  		return nil
   145  	}
   146  
   147  	if trainutil.IsJobSuspended(runPolicy) {
   148  		if err = jc.CleanUpResources(runPolicy, runtimeObject, metaObject, jobStatus, pods); err != nil {
   149  			return err
   150  		}
   151  		for rType := range jobStatus.ReplicaStatuses {
   152  			jobStatus.ReplicaStatuses[rType].Active = 0
   153  		}
   154  		msg := fmt.Sprintf("%s %s is suspended.", jobKind, jobName)
   155  		if commonutil.IsRunning(jobStatus) {
   156  			commonutil.UpdateJobConditions(&jobStatus, apiv1.JobRunning, corev1.ConditionFalse, commonutil.NewReason(jobKind, commonutil.JobSuspendedReason), msg)
   157  		}
   158  		// We add the suspended condition to the job only when the job doesn't have a suspended condition.
   159  		if !commonutil.IsSuspended(jobStatus) {
   160  			commonutil.UpdateJobConditions(&jobStatus, apiv1.JobSuspended, corev1.ConditionTrue, commonutil.NewReason(jobKind, commonutil.JobSuspendedReason), msg)
   161  		}
   162  		jc.Recorder.Event(runtimeObject, corev1.EventTypeNormal, commonutil.NewReason(jobKind, commonutil.JobSuspendedReason), msg)
   163  		if !reflect.DeepEqual(*oldStatus, jobStatus) {
   164  			return jc.Controller.UpdateJobStatusInApiServer(job, &jobStatus)
   165  		}
   166  		return nil
   167  	}
   168  	if commonutil.IsSuspended(jobStatus) {
   169  		msg := fmt.Sprintf("%s %s is resumed.", jobKind, jobName)
   170  		commonutil.UpdateJobConditions(&jobStatus, apiv1.JobSuspended, corev1.ConditionFalse, commonutil.NewReason(jobKind, commonutil.JobResumedReason), msg)
   171  		now := metav1.Now()
   172  		jobStatus.StartTime = &now
   173  		jc.Recorder.Eventf(runtimeObject, corev1.EventTypeNormal, commonutil.NewReason(jobKind, commonutil.JobResumedReason), msg)
   174  	}
   175  
   176  	// retrieve the previous number of retry
   177  	previousRetry := jc.WorkQueue.NumRequeues(jobKey)
   178  
   179  	activePods := k8sutil.FilterActivePods(pods)
   180  
   181  	jc.recordAbnormalPods(activePods, runtimeObject)
   182  
   183  	active := int32(len(activePods))
   184  	failed := k8sutil.FilterPodCount(pods, corev1.PodFailed)
   185  	totalReplicas := k8sutil.GetTotalReplicas(replicas)
   186  	prevReplicasFailedNum := k8sutil.GetTotalFailedReplicas(jobStatus.ReplicaStatuses)
   187  
   188  	var failureMessage string
   189  	jobExceedsLimit := false
   190  	exceedsBackoffLimit := false
   191  	pastBackoffLimit := false
   192  
   193  	if runPolicy.BackoffLimit != nil {
   194  		jobHasNewFailure := failed > prevReplicasFailedNum
   195  		// new failures happen when status does not reflect the failures and active
   196  		// is different than parallelism, otherwise the previous controller loop
   197  		// failed updating status so even if we pick up failure it is not a new one
   198  		exceedsBackoffLimit = jobHasNewFailure && (active != totalReplicas) &&
   199  			(int32(previousRetry)+1 > *runPolicy.BackoffLimit)
   200  
   201  		pastBackoffLimit, err = jc.PastBackoffLimit(jobName, runPolicy, replicas, pods)
   202  		if err != nil {
   203  			return err
   204  		}
   205  	}
   206  
   207  	if exceedsBackoffLimit || pastBackoffLimit {
   208  		// check if the number of pod restart exceeds backoff (for restart OnFailure only)
   209  		// OR if the number of failed jobs increased since the last syncJob
   210  		jobExceedsLimit = true
   211  		failureMessage = fmt.Sprintf("Job %s has failed because it has reached the specified backoff limit", jobName)
   212  	} else if jc.PastActiveDeadline(runPolicy, jobStatus) {
   213  		failureMessage = fmt.Sprintf("Job %s has failed because it was active longer than specified deadline", jobName)
   214  		jobExceedsLimit = true
   215  	}
   216  
   217  	if jobExceedsLimit {
   218  		// Set job completion time before resource cleanup
   219  		if jobStatus.CompletionTime == nil {
   220  			now := metav1.Now()
   221  			jobStatus.CompletionTime = &now
   222  		}
   223  
   224  		// If the Job exceeds backoff limit or is past active deadline
   225  		// delete all pods and services, then set the status to failed
   226  		if err := jc.DeletePodsAndServices(runtimeObject, runPolicy, jobStatus, pods); err != nil {
   227  			return err
   228  		}
   229  
   230  		if err := jc.CleanupJob(runPolicy, jobStatus, job); err != nil {
   231  			return err
   232  		}
   233  
   234  		if jc.Config.EnableGangScheduling() {
   235  			jc.Recorder.Event(runtimeObject, corev1.EventTypeNormal, "JobTerminated", "Job has been terminated. Deleting PodGroup")
   236  			if err := jc.DeletePodGroup(metaObject); err != nil {
   237  				jc.Recorder.Eventf(runtimeObject, corev1.EventTypeWarning, "FailedDeletePodGroup", "Error deleting: %v", err)
   238  				return err
   239  			} else {
   240  				jc.Recorder.Eventf(runtimeObject, corev1.EventTypeNormal, "SuccessfulDeletePodGroup", "Deleted PodGroup: %v", jobName)
   241  			}
   242  		}
   243  
   244  		jc.Recorder.Event(runtimeObject, corev1.EventTypeNormal, commonutil.NewReason(jobKind, commonutil.JobFailedReason), failureMessage)
   245  
   246  		commonutil.UpdateJobConditions(&jobStatus, apiv1.JobFailed, corev1.ConditionTrue, commonutil.NewReason(jobKind, commonutil.JobFailedReason), failureMessage)
   247  
   248  		return jc.Controller.UpdateJobStatusInApiServer(job, &jobStatus)
   249  	} else {
   250  		// General cases which need to reconcile
   251  		if jc.Config.EnableGangScheduling() {
   252  			minMember := totalReplicas
   253  			queue := ""
   254  			priorityClass := ""
   255  			var schedulerTimeout *int32
   256  			var minResources *v1.ResourceList
   257  
   258  			if runPolicy.SchedulingPolicy != nil {
   259  				if minAvailable := runPolicy.SchedulingPolicy.MinAvailable; minAvailable != nil {
   260  					minMember = *minAvailable
   261  				}
   262  				if q := runPolicy.SchedulingPolicy.Queue; len(q) != 0 {
   263  					queue = q
   264  				}
   265  				if pc := runPolicy.SchedulingPolicy.PriorityClass; len(pc) != 0 {
   266  					priorityClass = pc
   267  				}
   268  				if mr := runPolicy.SchedulingPolicy.MinResources; mr != nil {
   269  					minResources = (*v1.ResourceList)(mr)
   270  				}
   271  				if timeout := runPolicy.SchedulingPolicy.ScheduleTimeoutSeconds; timeout != nil {
   272  					schedulerTimeout = timeout
   273  				}
   274  			}
   275  
   276  			if minResources == nil {
   277  				minResources = jc.calcPGMinResources(minMember, replicas)
   278  			}
   279  
   280  			var pgSpecFill FillPodGroupSpecFunc
   281  			switch jc.Config.GangScheduling {
   282  			case GangSchedulerVolcano:
   283  				pgSpecFill = func(pg metav1.Object) error {
   284  					volcanoPodGroup, match := pg.(*volcanov1beta1.PodGroup)
   285  					if !match {
   286  						return fmt.Errorf("unable to recognize PodGroup: %v", klog.KObj(pg))
   287  					}
   288  					volcanoPodGroup.Spec = volcanov1beta1.PodGroupSpec{
   289  						MinMember:         minMember,
   290  						Queue:             queue,
   291  						PriorityClassName: priorityClass,
   292  						MinResources:      minResources,
   293  					}
   294  					return nil
   295  				}
   296  			default:
   297  				pgSpecFill = func(pg metav1.Object) error {
   298  					schedulerPluginsPodGroup, match := pg.(*schedulerpluginsv1alpha1.PodGroup)
   299  					if !match {
   300  						return fmt.Errorf("unable to recognize PodGroup: %v", klog.KObj(pg))
   301  					}
   302  					schedulerPluginsPodGroup.Spec = schedulerpluginsv1alpha1.PodGroupSpec{
   303  						MinMember:              minMember,
   304  						MinResources:           *minResources,
   305  						ScheduleTimeoutSeconds: schedulerTimeout,
   306  					}
   307  					return nil
   308  				}
   309  			}
   310  
   311  			syncReplicas := true
   312  			pg, err := jc.SyncPodGroup(metaObject, pgSpecFill)
   313  			if err != nil {
   314  				log.Warnf("Sync PodGroup %v: %v", jobKey, err)
   315  				syncReplicas = false
   316  			}
   317  
   318  			// Delay pods creation until PodGroup status is Inqueue
   319  			if jc.PodGroupControl.DelayPodCreationDueToPodGroup(pg) {
   320  				log.Warnf("PodGroup %v unschedulable", jobKey)
   321  				syncReplicas = false
   322  			}
   323  
   324  			if !syncReplicas {
   325  				now := metav1.Now()
   326  				jobStatus.LastReconcileTime = &now
   327  
   328  				// Update job status here to trigger a new reconciliation
   329  				return jc.Controller.UpdateJobStatusInApiServer(job, &jobStatus)
   330  			}
   331  		}
   332  
   333  		// Diff current active pods/services with replicas.
   334  		for rtype, spec := range replicas {
   335  			err := jc.Controller.ReconcilePods(metaObject, &jobStatus, pods, rtype, spec, replicas)
   336  			if err != nil {
   337  				log.Warnf("ReconcilePods error %v", err)
   338  				return err
   339  			}
   340  
   341  			err = jc.Controller.ReconcileServices(metaObject, services, rtype, spec)
   342  
   343  			if err != nil {
   344  				log.Warnf("ReconcileServices error %v", err)
   345  				return err
   346  			}
   347  		}
   348  	}
   349  
   350  	err = jc.Controller.UpdateJobStatus(job, replicas, &jobStatus)
   351  	if err != nil {
   352  		log.Warnf("UpdateJobStatus error %v", err)
   353  		return err
   354  	}
   355  	// No need to update the job status if the status hasn't changed since last time.
   356  	if !reflect.DeepEqual(*oldStatus, jobStatus) {
   357  		return jc.Controller.UpdateJobStatusInApiServer(job, &jobStatus)
   358  	}
   359  	return nil
   360  }
   361  
   362  func (jc *JobController) CleanUpResources(
   363  	runPolicy *apiv1.RunPolicy,
   364  	runtimeObject runtime.Object,
   365  	metaObject metav1.Object,
   366  	jobStatus apiv1.JobStatus,
   367  	pods []*v1.Pod,
   368  ) error {
   369  	if err := jc.DeletePodsAndServices(runtimeObject, runPolicy, jobStatus, pods); err != nil {
   370  		return err
   371  	}
   372  	if jc.Config.EnableGangScheduling() {
   373  
   374  		jc.Recorder.Event(runtimeObject, corev1.EventTypeNormal, "JobTerminated", "Job has been terminated. Deleting PodGroup")
   375  		if err := jc.DeletePodGroup(metaObject); err != nil {
   376  			jc.Recorder.Eventf(runtimeObject, corev1.EventTypeWarning, "FailedDeletePodGroup", "Error deleting: %v", err)
   377  			return err
   378  		} else {
   379  			jc.Recorder.Eventf(runtimeObject, corev1.EventTypeNormal, "SuccessfulDeletePodGroup", "Deleted PodGroup: %v", metaObject.GetName())
   380  		}
   381  	}
   382  	if err := jc.CleanupJob(runPolicy, jobStatus, runtimeObject); err != nil {
   383  		return err
   384  	}
   385  	return nil
   386  }
   387  
   388  // ResetExpectations reset the expectation for creates and deletes of pod/service to zero.
   389  func (jc *JobController) ResetExpectations(jobKey string, replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec) error {
   390  	var allErrs error
   391  	for rtype := range replicas {
   392  		expectationPodsKey := expectation.GenExpectationPodsKey(jobKey, string(rtype))
   393  		if err := jc.Expectations.SetExpectations(expectationPodsKey, 0, 0); err != nil {
   394  			allErrs = err
   395  		}
   396  		expectationServicesKey := expectation.GenExpectationServicesKey(jobKey, string(rtype))
   397  		if err := jc.Expectations.SetExpectations(expectationServicesKey, 0, 0); err != nil {
   398  			allErrs = fmt.Errorf("%s: %w", allErrs.Error(), err)
   399  		}
   400  	}
   401  	return allErrs
   402  }
   403  
   404  // PastActiveDeadline checks if job has ActiveDeadlineSeconds field set and if it is exceeded.
   405  func (jc *JobController) PastActiveDeadline(runPolicy *apiv1.RunPolicy, jobStatus apiv1.JobStatus) bool {
   406  	return core.PastActiveDeadline(runPolicy, jobStatus)
   407  }
   408  
   409  // PastBackoffLimit checks if container restartCounts sum exceeds BackoffLimit
   410  // this method applies only to pods when restartPolicy is one of OnFailure, Always or ExitCode
   411  func (jc *JobController) PastBackoffLimit(jobName string, runPolicy *apiv1.RunPolicy,
   412  	replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec, pods []*corev1.Pod) (bool, error) {
   413  	return core.PastBackoffLimit(jobName, runPolicy, replicas, pods, jc.FilterPodsForReplicaType)
   414  }
   415  
   416  func (jc *JobController) CleanupJob(runPolicy *apiv1.RunPolicy, jobStatus apiv1.JobStatus, job interface{}) error {
   417  	currentTime := time.Now()
   418  	metaObject, _ := job.(metav1.Object)
   419  	ttl := runPolicy.TTLSecondsAfterFinished
   420  	if ttl == nil {
   421  		return nil
   422  	}
   423  	duration := time.Second * time.Duration(*ttl)
   424  	if jobStatus.CompletionTime == nil {
   425  		return fmt.Errorf("job completion time is nil, cannot cleanup")
   426  	}
   427  	finishTime := jobStatus.CompletionTime
   428  	expireTime := finishTime.Add(duration)
   429  	if currentTime.After(expireTime) {
   430  		err := jc.Controller.DeleteJob(job)
   431  		if err != nil {
   432  			commonutil.LoggerForJob(metaObject).Warnf("Cleanup Job error: %v.", err)
   433  			return err
   434  		}
   435  		return nil
   436  	} else {
   437  		if finishTime.After(currentTime) {
   438  			commonutil.LoggerForJob(metaObject).Warnf("Found Job finished in the future. This is likely due to time skew in the cluster. Job cleanup will be deferred.")
   439  		}
   440  		remaining := expireTime.Sub(currentTime)
   441  		key, err := KeyFunc(job)
   442  		if err != nil {
   443  			commonutil.LoggerForJob(metaObject).Warnf("Couldn't get key for job object: %v", err)
   444  			return err
   445  		}
   446  		jc.WorkQueue.AddAfter(key, remaining)
   447  		return nil
   448  	}
   449  }
   450  
   451  func (jc *JobController) calcPGMinResources(minMember int32, replicas map[apiv1.ReplicaType]*apiv1.ReplicaSpec) *v1.ResourceList {
   452  	return CalcPGMinResources(minMember, replicas, jc.PriorityClassLister.Get)
   453  }