sigs.k8s.io/kueue@v0.6.2/pkg/controller/jobs/mpijob/mpijob_controller.go (about) 1 /* 2 Copyright 2023 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package mpijob 18 19 import ( 20 "context" 21 "strings" 22 23 kubeflow "github.com/kubeflow/mpi-operator/pkg/apis/kubeflow/v2beta1" 24 corev1 "k8s.io/api/core/v1" 25 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" 26 "k8s.io/apimachinery/pkg/runtime" 27 "k8s.io/apimachinery/pkg/runtime/schema" 28 utilruntime "k8s.io/apimachinery/pkg/util/runtime" 29 "k8s.io/utils/ptr" 30 "sigs.k8s.io/controller-runtime/pkg/client" 31 32 kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" 33 "sigs.k8s.io/kueue/pkg/controller/jobframework" 34 "sigs.k8s.io/kueue/pkg/podset" 35 ) 36 37 var ( 38 gvk = kubeflow.SchemeGroupVersionKind 39 40 FrameworkName = "kubeflow.org/mpijob" 41 ) 42 43 func init() { 44 utilruntime.Must(jobframework.RegisterIntegration(FrameworkName, jobframework.IntegrationCallbacks{ 45 SetupIndexes: SetupIndexes, 46 NewReconciler: NewReconciler, 47 SetupWebhook: SetupMPIJobWebhook, 48 JobType: &kubeflow.MPIJob{}, 49 AddToScheme: kubeflow.AddToScheme, 50 IsManagingObjectsOwner: isMPIJob, 51 })) 52 } 53 54 // +kubebuilder:rbac:groups=scheduling.k8s.io,resources=priorityclasses,verbs=list;get;watch 55 // +kubebuilder:rbac:groups="",resources=events,verbs=create;watch;update;patch 56 // +kubebuilder:rbac:groups=kubeflow.org,resources=mpijobs,verbs=get;list;watch;update;patch 57 // +kubebuilder:rbac:groups=kubeflow.org,resources=mpijobs/status,verbs=get;update 58 // +kubebuilder:rbac:groups=kubeflow.org,resources=mpijobs/finalizers,verbs=get;update 59 // +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads,verbs=get;list;watch;create;update;patch;delete 60 // +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads/status,verbs=get;update;patch 61 // +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads/finalizers,verbs=update 62 // +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=resourceflavors,verbs=get;list;watch 63 // +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloadpriorityclasses,verbs=get;list;watch 64 65 var NewReconciler = jobframework.NewGenericReconcilerFactory(func() jobframework.GenericJob { return &MPIJob{} }) 66 67 func isMPIJob(owner *metav1.OwnerReference) bool { 68 return owner.Kind == "MPIJob" && strings.HasPrefix(owner.APIVersion, "kubeflow.org/v2") 69 } 70 71 type MPIJob kubeflow.MPIJob 72 73 var _ jobframework.GenericJob = (*MPIJob)(nil) 74 var _ jobframework.JobWithPriorityClass = (*MPIJob)(nil) 75 76 func (j *MPIJob) Object() client.Object { 77 return (*kubeflow.MPIJob)(j) 78 } 79 80 func fromObject(o runtime.Object) *MPIJob { 81 return (*MPIJob)(o.(*kubeflow.MPIJob)) 82 } 83 84 func (j *MPIJob) IsSuspended() bool { 85 return j.Spec.RunPolicy.Suspend != nil && *j.Spec.RunPolicy.Suspend 86 } 87 88 func (j *MPIJob) IsActive() bool { 89 for _, replicaStatus := range j.Status.ReplicaStatuses { 90 if replicaStatus.Active != 0 { 91 return true 92 } 93 } 94 return false 95 } 96 97 func (j *MPIJob) Suspend() { 98 j.Spec.RunPolicy.Suspend = ptr.To(true) 99 } 100 101 func (j *MPIJob) GVK() schema.GroupVersionKind { 102 return gvk 103 } 104 105 func (j *MPIJob) PodSets() []kueue.PodSet { 106 replicaTypes := orderedReplicaTypes(&j.Spec) 107 podSets := make([]kueue.PodSet, len(replicaTypes)) 108 for index, mpiReplicaType := range replicaTypes { 109 podSets[index] = kueue.PodSet{ 110 Name: strings.ToLower(string(mpiReplicaType)), 111 Template: *j.Spec.MPIReplicaSpecs[mpiReplicaType].Template.DeepCopy(), 112 Count: podsCount(&j.Spec, mpiReplicaType), 113 } 114 } 115 return podSets 116 } 117 118 func (j *MPIJob) RunWithPodSetsInfo(podSetsInfo []podset.PodSetInfo) error { 119 j.Spec.RunPolicy.Suspend = ptr.To(false) 120 orderedReplicaTypes := orderedReplicaTypes(&j.Spec) 121 122 if len(podSetsInfo) != len(orderedReplicaTypes) { 123 return podset.BadPodSetsInfoLenError(len(orderedReplicaTypes), len(podSetsInfo)) 124 } 125 126 // The node selectors are provided in the same order as the generated list of 127 // podSets, use the same ordering logic to restore them. 128 for index := range podSetsInfo { 129 replicaType := orderedReplicaTypes[index] 130 info := podSetsInfo[index] 131 replica := &j.Spec.MPIReplicaSpecs[replicaType].Template 132 if err := podset.Merge(&replica.ObjectMeta, &replica.Spec, info); err != nil { 133 return err 134 } 135 } 136 return nil 137 } 138 139 func (j *MPIJob) RestorePodSetsInfo(podSetsInfo []podset.PodSetInfo) bool { 140 orderedReplicaTypes := orderedReplicaTypes(&j.Spec) 141 changed := false 142 for index, info := range podSetsInfo { 143 replicaType := orderedReplicaTypes[index] 144 replica := &j.Spec.MPIReplicaSpecs[replicaType].Template 145 changed = podset.RestorePodSpec(&replica.ObjectMeta, &replica.Spec, info) || changed 146 } 147 return changed 148 } 149 150 func (j *MPIJob) Finished() (metav1.Condition, bool) { 151 var conditionType kubeflow.JobConditionType 152 var finished bool 153 for _, c := range j.Status.Conditions { 154 if (c.Type == kubeflow.JobSucceeded || c.Type == kubeflow.JobFailed) && c.Status == corev1.ConditionTrue { 155 conditionType = c.Type 156 finished = true 157 break 158 } 159 } 160 161 message := "Job finished successfully" 162 if conditionType == kubeflow.JobFailed { 163 message = "Job failed" 164 } 165 condition := metav1.Condition{ 166 Type: kueue.WorkloadFinished, 167 Status: metav1.ConditionTrue, 168 Reason: "JobFinished", 169 Message: message, 170 } 171 return condition, finished 172 } 173 174 // PriorityClass calculates the priorityClass name needed for workload according to the following priorities: 175 // 1. .spec.runPolicy.schedulingPolicy.priorityClass 176 // 2. .spec.mpiReplicaSpecs[Launcher].template.spec.priorityClassName 177 // 3. .spec.mpiReplicaSpecs[Worker].template.spec.priorityClassName 178 // 179 // This function is inspired by an analogous one in mpi-controller: 180 // https://github.com/kubeflow/mpi-operator/blob/5946ef4157599a474ab82ff80e780d5c2546c9ee/pkg/controller/podgroup.go#L69-L72 181 func (j *MPIJob) PriorityClass() string { 182 if j.Spec.RunPolicy.SchedulingPolicy != nil && len(j.Spec.RunPolicy.SchedulingPolicy.PriorityClass) != 0 { 183 return j.Spec.RunPolicy.SchedulingPolicy.PriorityClass 184 } else if l := j.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeLauncher]; l != nil && len(l.Template.Spec.PriorityClassName) != 0 { 185 return l.Template.Spec.PriorityClassName 186 } else if w := j.Spec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeWorker]; w != nil && len(w.Template.Spec.PriorityClassName) != 0 { 187 return w.Template.Spec.PriorityClassName 188 } 189 return "" 190 } 191 192 func (j *MPIJob) PodsReady() bool { 193 for _, c := range j.Status.Conditions { 194 if c.Type == kubeflow.JobRunning && c.Status == corev1.ConditionTrue { 195 return true 196 } 197 } 198 return false 199 } 200 201 func SetupIndexes(ctx context.Context, indexer client.FieldIndexer) error { 202 return jobframework.SetupWorkloadOwnerIndex(ctx, indexer, gvk) 203 } 204 205 func orderedReplicaTypes(jobSpec *kubeflow.MPIJobSpec) []kubeflow.MPIReplicaType { 206 var result []kubeflow.MPIReplicaType 207 if _, ok := jobSpec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeLauncher]; ok { 208 result = append(result, kubeflow.MPIReplicaTypeLauncher) 209 } 210 if _, ok := jobSpec.MPIReplicaSpecs[kubeflow.MPIReplicaTypeWorker]; ok { 211 result = append(result, kubeflow.MPIReplicaTypeWorker) 212 } 213 return result 214 } 215 216 func podsCount(jobSpec *kubeflow.MPIJobSpec, mpiReplicaType kubeflow.MPIReplicaType) int32 { 217 return ptr.Deref(jobSpec.MPIReplicaSpecs[mpiReplicaType].Replicas, 1) 218 } 219 220 func GetWorkloadNameForMPIJob(jobName string) string { 221 return jobframework.GetWorkloadNameForOwnerWithGVK(jobName, gvk) 222 }