sigs.k8s.io/kueue@v0.6.2/pkg/controller/jobs/mpijob/mpijob_controller.go (about)

     1  /*
     2  Copyright 2023 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package mpijob
    18  
    19  import (
    20  	"context"
    21  	"strings"
    22  
    23  	kubeflow "github.com/kubeflow/mpi-operator/pkg/apis/kubeflow/v2beta1"
    24  	corev1 "k8s.io/api/core/v1"
    25  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    26  	"k8s.io/apimachinery/pkg/runtime"
    27  	"k8s.io/apimachinery/pkg/runtime/schema"
    28  	utilruntime "k8s.io/apimachinery/pkg/util/runtime"
    29  	"k8s.io/utils/ptr"
    30  	"sigs.k8s.io/controller-runtime/pkg/client"
    31  
    32  	kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
    33  	"sigs.k8s.io/kueue/pkg/controller/jobframework"
    34  	"sigs.k8s.io/kueue/pkg/podset"
    35  )
    36  
    37  var (
    38  	gvk = kubeflow.SchemeGroupVersionKind
    39  
    40  	FrameworkName = "kubeflow.org/mpijob"
    41  )
    42  
    43  func init() {
    44  	utilruntime.Must(jobframework.RegisterIntegration(FrameworkName, jobframework.IntegrationCallbacks{
    45  		SetupIndexes:           SetupIndexes,
    46  		NewReconciler:          NewReconciler,
    47  		SetupWebhook:           SetupMPIJobWebhook,
    48  		JobType:                &kubeflow.MPIJob{},
    49  		AddToScheme:            kubeflow.AddToScheme,
    50  		IsManagingObjectsOwner: isMPIJob,
    51  	}))
    52  }
    53  
    54  // +kubebuilder:rbac:groups=scheduling.k8s.io,resources=priorityclasses,verbs=list;get;watch
    55  // +kubebuilder:rbac:groups="",resources=events,verbs=create;watch;update;patch
    56  // +kubebuilder:rbac:groups=kubeflow.org,resources=mpijobs,verbs=get;list;watch;update;patch
    57  // +kubebuilder:rbac:groups=kubeflow.org,resources=mpijobs/status,verbs=get;update
    58  // +kubebuilder:rbac:groups=kubeflow.org,resources=mpijobs/finalizers,verbs=get;update
    59  // +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads,verbs=get;list;watch;create;update;patch;delete
    60  // +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads/status,verbs=get;update;patch
    61  // +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads/finalizers,verbs=update
    62  // +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=resourceflavors,verbs=get;list;watch
    63  // +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloadpriorityclasses,verbs=get;list;watch
    64  
    65  var NewReconciler = jobframework.NewGenericReconcilerFactory(func() jobframework.GenericJob { return &MPIJob{} })
    66  
    67  func isMPIJob(owner *metav1.OwnerReference) bool {
    68  	return owner.Kind == "MPIJob" && strings.HasPrefix(owner.APIVersion, "kubeflow.org/v2")
    69  }
    70  
    71  type MPIJob kubeflow.MPIJob
    72  
    73  var _ jobframework.GenericJob = (*MPIJob)(nil)
    74  var _ jobframework.JobWithPriorityClass = (*MPIJob)(nil)
    75  
    76  func (j *MPIJob) Object() client.Object {
    77  	return (*kubeflow.MPIJob)(j)
    78  }
    79  
    80  func fromObject(o runtime.Object) *MPIJob {
    81  	return (*MPIJob)(o.(*kubeflow.MPIJob))
    82  }
    83  
    84  func (j *MPIJob) IsSuspended() bool {
    85  	return j.Spec.RunPolicy.Suspend != nil && *j.Spec.RunPolicy.Suspend
    86  }
    87  
    88  func (j *MPIJob) IsActive() bool {
    89  	for _, replicaStatus := range j.Status.ReplicaStatuses {
    90  		if replicaStatus.Active != 0 {
    91  			return true
    92  		}
    93  	}
    94  	return false
    95  }
    96  
    97  func (j *MPIJob) Suspend() {
    98  	j.Spec.RunPolicy.Suspend = ptr.To(true)
    99  }
   100  
   101  func (j *MPIJob) GVK() schema.GroupVersionKind {
   102  	return gvk
   103  }
   104  
   105  func (j *MPIJob) PodSets() []kueue.PodSet {
   106  	replicaTypes := orderedReplicaTypes(&j.Spec)
   107  	podSets := make([]kueue.PodSet, len(replicaTypes))
   108  	for index, mpiReplicaType := range replicaTypes {
   109  		podSets[index] = kueue.PodSet{
   110  			Name:     strings.ToLower(string(mpiReplicaType)),
   111  			Template: *j.Spec.MPIReplicaSpecs[mpiReplicaType].Template.DeepCopy(),
   112  			Count:    podsCount(&j.Spec, mpiReplicaType),
   113  		}
   114  	}
   115  	return podSets
   116  }
   117  
   118  func (j *MPIJob) RunWithPodSetsInfo(podSetsInfo []podset.PodSetInfo) error {
   119  	j.Spec.RunPolicy.Suspend = ptr.To(false)
   120  	orderedReplicaTypes := orderedReplicaTypes(&j.Spec)
   121  
   122  	if len(podSetsInfo) != len(orderedReplicaTypes) {
   123  		return podset.BadPodSetsInfoLenError(len(orderedReplicaTypes), len(podSetsInfo))
   124  	}
   125  
   126  	// The node selectors are provided in the same order as the generated list of
   127  	// podSets, use the same ordering logic to restore them.
   128  	for index := range podSetsInfo {
   129  		replicaType := orderedReplicaTypes[index]
   130  		info := podSetsInfo[index]
   131  		replica := &j.Spec.MPIReplicaSpecs[replicaType].Template
   132  		if err := podset.Merge(&replica.ObjectMeta, &replica.Spec, info); err != nil {
   133  			return err
   134  		}
   135  	}
   136  	return nil
   137  }
   138  
   139  func (j *MPIJob) RestorePodSetsInfo(podSetsInfo []podset.PodSetInfo) bool {
   140  	orderedReplicaTypes := orderedReplicaTypes(&j.Spec)
   141  	changed := false
   142  	for index, info := range podSetsInfo {
   143  		replicaType := orderedReplicaTypes[index]
   144  		replica := &j.Spec.MPIReplicaSpecs[replicaType].Template
   145  		changed = podset.RestorePodSpec(&replica.ObjectMeta, &replica.Spec, info) || changed
   146  	}
   147  	return changed
   148  }
   149  
   150  func (j *MPIJob) Finished() (metav1.Condition, bool) {
   151  	var conditionType kubeflow.JobConditionType
   152  	var finished bool
   153  	for _, c := range j.Status.Conditions {
   154  		if (c.Type == kubeflow.JobSucceeded || c.Type == kubeflow.JobFailed) && c.Status == corev1.ConditionTrue {
   155  			conditionType = c.Type
   156  			finished = true
   157  			break
   158  		}
   159  	}
   160  
   161  	message := "Job finished successfully"
   162  	if conditionType == kubeflow.JobFailed {
   163  		message = "Job failed"
   164  	}
   165  	condition := metav1.Condition{
   166  		Type:    kueue.WorkloadFinished,
   167  		Status:  metav1.ConditionTrue,
   168  		Reason:  "JobFinished",
   169  		Message: message,
   170  	}
   171  	return condition, finished
   172  }
   173  
   174  // PriorityClass calculates the priorityClass name needed for workload according to the following priorities:
   175  //  1. .spec.runPolicy.schedulingPolicy.priorityClass
   176  //  2. .spec.mpiReplicaSpecs[Launcher].template.spec.priorityClassName
   177  //  3. .spec.mpiReplicaSpecs[Worker].template.spec.priorityClassName
   178  //
   179  // This function is inspired by an analogous one in mpi-controller:
   180  // https://github.com/kubeflow/mpi-operator/blob/5946ef4157599a474ab82ff80e780d5c2546c9ee/pkg/controller/podgroup.go#L69-L72
   181  func (j *MPIJob) PriorityClass() string {
   182  	if j.Spec.RunPolicy.SchedulingPolicy != nil && len(j.Spec.RunPolicy.SchedulingPolicy.PriorityClass) != 0 {
   183  		return j.Spec.RunPolicy.SchedulingPolicy.PriorityClass
   184  	} else if l := j.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeLauncher]; l != nil && len(l.Template.Spec.PriorityClassName) != 0 {
   185  		return l.Template.Spec.PriorityClassName
   186  	} else if w := j.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeWorker]; w != nil && len(w.Template.Spec.PriorityClassName) != 0 {
   187  		return w.Template.Spec.PriorityClassName
   188  	}
   189  	return ""
   190  }
   191  
   192  func (j *MPIJob) PodsReady() bool {
   193  	for _, c := range j.Status.Conditions {
   194  		if c.Type == kubeflow.JobRunning && c.Status == corev1.ConditionTrue {
   195  			return true
   196  		}
   197  	}
   198  	return false
   199  }
   200  
   201  func SetupIndexes(ctx context.Context, indexer client.FieldIndexer) error {
   202  	return jobframework.SetupWorkloadOwnerIndex(ctx, indexer, gvk)
   203  }
   204  
   205  func orderedReplicaTypes(jobSpec *kubeflow.MPIJobSpec) []kubeflow.MPIReplicaType {
   206  	var result []kubeflow.MPIReplicaType
   207  	if _, ok := jobSpec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeLauncher]; ok {
   208  		result = append(result, kubeflow.MPIReplicaTypeLauncher)
   209  	}
   210  	if _, ok := jobSpec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeWorker]; ok {
   211  		result = append(result, kubeflow.MPIReplicaTypeWorker)
   212  	}
   213  	return result
   214  }
   215  
   216  func podsCount(jobSpec *kubeflow.MPIJobSpec, mpiReplicaType kubeflow.MPIReplicaType) int32 {
   217  	return ptr.Deref(jobSpec.MPIReplicaSpecs[mpiReplicaType].Replicas, 1)
   218  }
   219  
   220  func GetWorkloadNameForMPIJob(jobName string) string {
   221  	return jobframework.GetWorkloadNameForOwnerWithGVK(jobName, gvk)
   222  }