sigs.k8s.io/kueue@v0.6.2/pkg/controller/jobs/raycluster/raycluster_controller.go (about)

     1  /*
     2  Copyright 2024 The Kubernetes Authors.
     3  Licensed under the Apache License, Version 2.0 (the "License");
     4  you may not use this file except in compliance with the License.
     5  You may obtain a copy of the License at
     6      http://www.apache.org/licenses/LICENSE-2.0
     7  Unless required by applicable law or agreed to in writing, software
     8  distributed under the License is distributed on an "AS IS" BASIS,
     9  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    10  See the License for the specific language governing permissions and
    11  limitations under the License.
    12  */
    13  
    14  package raycluster
    15  
    16  import (
    17  	"context"
    18  	"strings"
    19  
    20  	rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
    21  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    22  	"k8s.io/apimachinery/pkg/runtime"
    23  	"k8s.io/apimachinery/pkg/runtime/schema"
    24  	utilruntime "k8s.io/apimachinery/pkg/util/runtime"
    25  	"k8s.io/utils/ptr"
    26  	"sigs.k8s.io/controller-runtime/pkg/client"
    27  
    28  	kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
    29  	"sigs.k8s.io/kueue/pkg/controller/jobframework"
    30  	"sigs.k8s.io/kueue/pkg/podset"
    31  )
    32  
    33  var (
    34  	gvk = rayv1.GroupVersion.WithKind("RayCluster")
    35  )
    36  
    37  const (
    38  	headGroupPodSetName = "head"
    39  	FrameworkName       = "ray.io/raycluster"
    40  )
    41  
    42  func init() {
    43  	utilruntime.Must(jobframework.RegisterIntegration(FrameworkName, jobframework.IntegrationCallbacks{
    44  		SetupIndexes:           SetupIndexes,
    45  		NewReconciler:          NewReconciler,
    46  		SetupWebhook:           SetupRayClusterWebhook,
    47  		JobType:                &rayv1.RayCluster{},
    48  		AddToScheme:            rayv1.AddToScheme,
    49  		IsManagingObjectsOwner: isRayCluster,
    50  	}))
    51  }
    52  
    53  // +kubebuilder:rbac:groups="",resources=events,verbs=create;watch;update
    54  // +kubebuilder:rbac:groups=ray.io,resources=rayclusters,verbs=get;list;watch;update;patch
    55  // +kubebuilder:rbac:groups=ray.io,resources=rayclusters/status,verbs=get;update
    56  // +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads,verbs=get;list;watch;create;update;patch;delete
    57  // +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads/status,verbs=get;update;patch
    58  // +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads/finalizers,verbs=update
    59  // +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=resourceflavors,verbs=get;list;watch
    60  // +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloadpriorityclasses,verbs=get;list;watch
    61  // +kubebuilder:rbac:groups=ray.io,resources=rayclusters/finalizers,verbs=get;update
    62  
    63  var NewReconciler = jobframework.NewGenericReconcilerFactory(func() jobframework.GenericJob { return &RayCluster{} })
    64  
    65  type RayCluster rayv1.RayCluster
    66  
    67  var _ jobframework.GenericJob = (*RayCluster)(nil)
    68  
    69  func (j *RayCluster) Object() client.Object {
    70  	return (*rayv1.RayCluster)(j)
    71  }
    72  
    73  func (j *RayCluster) IsSuspended() bool {
    74  	return j.Spec.Suspend != nil && *j.Spec.Suspend
    75  }
    76  
    77  func (j *RayCluster) IsActive() bool {
    78  	return j.Status.State == rayv1.Ready
    79  }
    80  
    81  func (j *RayCluster) Suspend() {
    82  	j.Spec.Suspend = ptr.To(true)
    83  }
    84  
    85  func (j *RayCluster) GVK() schema.GroupVersionKind {
    86  	return gvk
    87  }
    88  
    89  func (j *RayCluster) PodSets() []kueue.PodSet {
    90  	// len = workerGroups + head
    91  	podSets := make([]kueue.PodSet, len(j.Spec.WorkerGroupSpecs)+1)
    92  
    93  	// head
    94  	podSets[0] = kueue.PodSet{
    95  		Name:     headGroupPodSetName,
    96  		Template: *j.Spec.HeadGroupSpec.Template.DeepCopy(),
    97  		Count:    1,
    98  	}
    99  
   100  	// workers
   101  	for index := range j.Spec.WorkerGroupSpecs {
   102  		wgs := &j.Spec.WorkerGroupSpecs[index]
   103  		replicas := int32(1)
   104  		if wgs.Replicas != nil {
   105  			replicas = *wgs.Replicas
   106  		}
   107  		podSets[index+1] = kueue.PodSet{
   108  			Name:     strings.ToLower(wgs.GroupName),
   109  			Template: *wgs.Template.DeepCopy(),
   110  			Count:    replicas,
   111  		}
   112  	}
   113  	return podSets
   114  }
   115  
   116  func (j *RayCluster) RunWithPodSetsInfo(podSetsInfo []podset.PodSetInfo) error {
   117  	expectedLen := len(j.Spec.WorkerGroupSpecs) + 1
   118  	if len(podSetsInfo) != expectedLen {
   119  		return podset.BadPodSetsInfoLenError(expectedLen, len(podSetsInfo))
   120  	}
   121  
   122  	j.Spec.Suspend = ptr.To(false)
   123  
   124  	// head
   125  	headPod := &j.Spec.HeadGroupSpec.Template
   126  	info := podSetsInfo[0]
   127  	if err := podset.Merge(&headPod.ObjectMeta, &headPod.Spec, info); err != nil {
   128  		return err
   129  	}
   130  
   131  	// workers
   132  	for index := range j.Spec.WorkerGroupSpecs {
   133  		workerPod := &j.Spec.WorkerGroupSpecs[index].Template
   134  
   135  		info := podSetsInfo[index+1]
   136  		if err := podset.Merge(&workerPod.ObjectMeta, &workerPod.Spec, info); err != nil {
   137  			return err
   138  		}
   139  	}
   140  	return nil
   141  }
   142  
   143  func (j *RayCluster) RestorePodSetsInfo(podSetsInfo []podset.PodSetInfo) bool {
   144  	if len(podSetsInfo) != len(j.Spec.WorkerGroupSpecs)+1 {
   145  		return false
   146  	}
   147  
   148  	changed := false
   149  	// head
   150  	headPod := &j.Spec.HeadGroupSpec.Template
   151  	changed = podset.RestorePodSpec(&headPod.ObjectMeta, &headPod.Spec, podSetsInfo[0]) || changed
   152  
   153  	// workers
   154  	for index := range j.Spec.WorkerGroupSpecs {
   155  		workerPod := &j.Spec.WorkerGroupSpecs[index].Template
   156  		info := podSetsInfo[index+1]
   157  		changed = podset.RestorePodSpec(&workerPod.ObjectMeta, &workerPod.Spec, info) || changed
   158  	}
   159  	return changed
   160  }
   161  
   162  func (j *RayCluster) Finished() (metav1.Condition, bool) {
   163  	condition := metav1.Condition{
   164  		Type:    kueue.WorkloadFinished,
   165  		Status:  metav1.ConditionFalse,
   166  		Reason:  string(j.Status.State),
   167  		Message: string(j.Status.Reason),
   168  	}
   169  	// Technically a RayCluster is never "finished"
   170  	return condition, false
   171  }
   172  
   173  func (j *RayCluster) PodsReady() bool {
   174  	return j.Status.State == rayv1.Ready
   175  }
   176  
   177  func SetupIndexes(ctx context.Context, indexer client.FieldIndexer) error {
   178  	return jobframework.SetupWorkloadOwnerIndex(ctx, indexer, gvk)
   179  }
   180  
   181  func GetWorkloadNameForRayCluster(jobName string) string {
   182  	return jobframework.GetWorkloadNameForOwnerWithGVK(jobName, gvk)
   183  }
   184  
   185  func isRayCluster(owner *metav1.OwnerReference) bool {
   186  	return owner.Kind == "RayCluster" && strings.HasPrefix(owner.APIVersion, "ray.io/v1")
   187  }
   188  
   189  func fromObject(o runtime.Object) *RayCluster {
   190  	return (*RayCluster)(o.(*rayv1.RayCluster))
   191  }