sigs.k8s.io/kueue@v0.6.2/pkg/controller/jobs/raycluster/raycluster_controller.go (about) 1 /* 2 Copyright 2024 The Kubernetes Authors. 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 http://www.apache.org/licenses/LICENSE-2.0 7 Unless required by applicable law or agreed to in writing, software 8 distributed under the License is distributed on an "AS IS" BASIS, 9 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 See the License for the specific language governing permissions and 11 limitations under the License. 12 */ 13 14 package raycluster 15 16 import ( 17 "context" 18 "strings" 19 20 rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" 21 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" 22 "k8s.io/apimachinery/pkg/runtime" 23 "k8s.io/apimachinery/pkg/runtime/schema" 24 utilruntime "k8s.io/apimachinery/pkg/util/runtime" 25 "k8s.io/utils/ptr" 26 "sigs.k8s.io/controller-runtime/pkg/client" 27 28 kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1" 29 "sigs.k8s.io/kueue/pkg/controller/jobframework" 30 "sigs.k8s.io/kueue/pkg/podset" 31 ) 32 33 var ( 34 gvk = rayv1.GroupVersion.WithKind("RayCluster") 35 ) 36 37 const ( 38 headGroupPodSetName = "head" 39 FrameworkName = "ray.io/raycluster" 40 ) 41 42 func init() { 43 utilruntime.Must(jobframework.RegisterIntegration(FrameworkName, jobframework.IntegrationCallbacks{ 44 SetupIndexes: SetupIndexes, 45 NewReconciler: NewReconciler, 46 SetupWebhook: SetupRayClusterWebhook, 47 JobType: &rayv1.RayCluster{}, 48 AddToScheme: rayv1.AddToScheme, 49 IsManagingObjectsOwner: isRayCluster, 50 })) 51 } 52 53 // +kubebuilder:rbac:groups="",resources=events,verbs=create;watch;update 54 // +kubebuilder:rbac:groups=ray.io,resources=rayclusters,verbs=get;list;watch;update;patch 55 // +kubebuilder:rbac:groups=ray.io,resources=rayclusters/status,verbs=get;update 56 // +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads,verbs=get;list;watch;create;update;patch;delete 57 // +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads/status,verbs=get;update;patch 58 // +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads/finalizers,verbs=update 59 // +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=resourceflavors,verbs=get;list;watch 60 // +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloadpriorityclasses,verbs=get;list;watch 61 // +kubebuilder:rbac:groups=ray.io,resources=rayclusters/finalizers,verbs=get;update 62 63 var NewReconciler = jobframework.NewGenericReconcilerFactory(func() jobframework.GenericJob { return &RayCluster{} }) 64 65 type RayCluster rayv1.RayCluster 66 67 var _ jobframework.GenericJob = (*RayCluster)(nil) 68 69 func (j *RayCluster) Object() client.Object { 70 return (*rayv1.RayCluster)(j) 71 } 72 73 func (j *RayCluster) IsSuspended() bool { 74 return j.Spec.Suspend != nil && *j.Spec.Suspend 75 } 76 77 func (j *RayCluster) IsActive() bool { 78 return j.Status.State == rayv1.Ready 79 } 80 81 func (j *RayCluster) Suspend() { 82 j.Spec.Suspend = ptr.To(true) 83 } 84 85 func (j *RayCluster) GVK() schema.GroupVersionKind { 86 return gvk 87 } 88 89 func (j *RayCluster) PodSets() []kueue.PodSet { 90 // len = workerGroups + head 91 podSets := make([]kueue.PodSet, len(j.Spec.WorkerGroupSpecs)+1) 92 93 // head 94 podSets[0] = kueue.PodSet{ 95 Name: headGroupPodSetName, 96 Template: *j.Spec.HeadGroupSpec.Template.DeepCopy(), 97 Count: 1, 98 } 99 100 // workers 101 for index := range j.Spec.WorkerGroupSpecs { 102 wgs := &j.Spec.WorkerGroupSpecs[index] 103 replicas := int32(1) 104 if wgs.Replicas != nil { 105 replicas = *wgs.Replicas 106 } 107 podSets[index+1] = kueue.PodSet{ 108 Name: strings.ToLower(wgs.GroupName), 109 Template: *wgs.Template.DeepCopy(), 110 Count: replicas, 111 } 112 } 113 return podSets 114 } 115 116 func (j *RayCluster) RunWithPodSetsInfo(podSetsInfo []podset.PodSetInfo) error { 117 expectedLen := len(j.Spec.WorkerGroupSpecs) + 1 118 if len(podSetsInfo) != expectedLen { 119 return podset.BadPodSetsInfoLenError(expectedLen, len(podSetsInfo)) 120 } 121 122 j.Spec.Suspend = ptr.To(false) 123 124 // head 125 headPod := &j.Spec.HeadGroupSpec.Template 126 info := podSetsInfo[0] 127 if err := podset.Merge(&headPod.ObjectMeta, &headPod.Spec, info); err != nil { 128 return err 129 } 130 131 // workers 132 for index := range j.Spec.WorkerGroupSpecs { 133 workerPod := &j.Spec.WorkerGroupSpecs[index].Template 134 135 info := podSetsInfo[index+1] 136 if err := podset.Merge(&workerPod.ObjectMeta, &workerPod.Spec, info); err != nil { 137 return err 138 } 139 } 140 return nil 141 } 142 143 func (j *RayCluster) RestorePodSetsInfo(podSetsInfo []podset.PodSetInfo) bool { 144 if len(podSetsInfo) != len(j.Spec.WorkerGroupSpecs)+1 { 145 return false 146 } 147 148 changed := false 149 // head 150 headPod := &j.Spec.HeadGroupSpec.Template 151 changed = podset.RestorePodSpec(&headPod.ObjectMeta, &headPod.Spec, podSetsInfo[0]) || changed 152 153 // workers 154 for index := range j.Spec.WorkerGroupSpecs { 155 workerPod := &j.Spec.WorkerGroupSpecs[index].Template 156 info := podSetsInfo[index+1] 157 changed = podset.RestorePodSpec(&workerPod.ObjectMeta, &workerPod.Spec, info) || changed 158 } 159 return changed 160 } 161 162 func (j *RayCluster) Finished() (metav1.Condition, bool) { 163 condition := metav1.Condition{ 164 Type: kueue.WorkloadFinished, 165 Status: metav1.ConditionFalse, 166 Reason: string(j.Status.State), 167 Message: string(j.Status.Reason), 168 } 169 // Technically a RayCluster is never "finished" 170 return condition, false 171 } 172 173 func (j *RayCluster) PodsReady() bool { 174 return j.Status.State == rayv1.Ready 175 } 176 177 func SetupIndexes(ctx context.Context, indexer client.FieldIndexer) error { 178 return jobframework.SetupWorkloadOwnerIndex(ctx, indexer, gvk) 179 } 180 181 func GetWorkloadNameForRayCluster(jobName string) string { 182 return jobframework.GetWorkloadNameForOwnerWithGVK(jobName, gvk) 183 } 184 185 func isRayCluster(owner *metav1.OwnerReference) bool { 186 return owner.Kind == "RayCluster" && strings.HasPrefix(owner.APIVersion, "ray.io/v1") 187 } 188 189 func fromObject(o runtime.Object) *RayCluster { 190 return (*RayCluster)(o.(*rayv1.RayCluster)) 191 }