sigs.k8s.io/kueue@v0.6.2/pkg/controller/jobs/rayjob/rayjob_controller.go (about)

     1  /*
     2  Copyright 2023 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package rayjob
    18  
    19  import (
    20  	"context"
    21  	"strings"
    22  
    23  	rayjobapi "github.com/ray-project/kuberay/ray-operator/apis/ray/v1alpha1"
    24  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    25  	"k8s.io/apimachinery/pkg/runtime/schema"
    26  	utilruntime "k8s.io/apimachinery/pkg/util/runtime"
    27  	"sigs.k8s.io/controller-runtime/pkg/client"
    28  
    29  	kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
    30  	"sigs.k8s.io/kueue/pkg/controller/jobframework"
    31  	"sigs.k8s.io/kueue/pkg/podset"
    32  )
    33  
    34  var (
    35  	gvk = rayjobapi.GroupVersion.WithKind("RayJob")
    36  )
    37  
    38  const (
    39  	headGroupPodSetName = "head"
    40  	FrameworkName       = "ray.io/rayjob"
    41  )
    42  
    43  func init() {
    44  	utilruntime.Must(jobframework.RegisterIntegration(FrameworkName, jobframework.IntegrationCallbacks{
    45  		SetupIndexes:           SetupIndexes,
    46  		NewReconciler:          NewReconciler,
    47  		SetupWebhook:           SetupRayJobWebhook,
    48  		JobType:                &rayjobapi.RayJob{},
    49  		AddToScheme:            rayjobapi.AddToScheme,
    50  		IsManagingObjectsOwner: isRayJob,
    51  	}))
    52  }
    53  
    54  // +kubebuilder:rbac:groups="",resources=events,verbs=create;watch;update
    55  // +kubebuilder:rbac:groups=ray.io,resources=rayjobs,verbs=get;list;watch;update;patch
    56  // +kubebuilder:rbac:groups=ray.io,resources=rayjobs/status,verbs=get;update
    57  // +kubebuilder:rbac:groups=ray.io,resources=rayjobs/finalizers,verbs=get;update
    58  // +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads,verbs=get;list;watch;create;update;patch;delete
    59  // +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads/status,verbs=get;update;patch
    60  // +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloads/finalizers,verbs=update
    61  // +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=resourceflavors,verbs=get;list;watch
    62  // +kubebuilder:rbac:groups=kueue.x-k8s.io,resources=workloadpriorityclasses,verbs=get;list;watch
    63  
    64  var NewReconciler = jobframework.NewGenericReconcilerFactory(func() jobframework.GenericJob { return &RayJob{} })
    65  
    66  type RayJob rayjobapi.RayJob
    67  
    68  var _ jobframework.GenericJob = (*RayJob)(nil)
    69  
    70  func (j *RayJob) Object() client.Object {
    71  	return (*rayjobapi.RayJob)(j)
    72  }
    73  
    74  func (j *RayJob) IsSuspended() bool {
    75  	return j.Spec.Suspend
    76  }
    77  
    78  func (j *RayJob) IsActive() bool {
    79  	return j.Status.JobDeploymentStatus != rayjobapi.JobDeploymentStatusSuspended
    80  }
    81  
    82  func (j *RayJob) Suspend() {
    83  	j.Spec.Suspend = true
    84  }
    85  
    86  func (j *RayJob) GVK() schema.GroupVersionKind {
    87  	return gvk
    88  }
    89  
    90  func (j *RayJob) PodSets() []kueue.PodSet {
    91  	// len = workerGroups + head
    92  	podSets := make([]kueue.PodSet, len(j.Spec.RayClusterSpec.WorkerGroupSpecs)+1)
    93  
    94  	// head
    95  	podSets[0] = kueue.PodSet{
    96  		Name:     headGroupPodSetName,
    97  		Template: *j.Spec.RayClusterSpec.HeadGroupSpec.Template.DeepCopy(),
    98  		Count:    1,
    99  	}
   100  
   101  	// workers
   102  	for index := range j.Spec.RayClusterSpec.WorkerGroupSpecs {
   103  		wgs := &j.Spec.RayClusterSpec.WorkerGroupSpecs[index]
   104  		replicas := int32(1)
   105  		if wgs.Replicas != nil {
   106  			replicas = *wgs.Replicas
   107  		}
   108  		podSets[index+1] = kueue.PodSet{
   109  			Name:     strings.ToLower(wgs.GroupName),
   110  			Template: *wgs.Template.DeepCopy(),
   111  			Count:    replicas,
   112  		}
   113  	}
   114  	return podSets
   115  }
   116  
   117  func (j *RayJob) RunWithPodSetsInfo(podSetsInfo []podset.PodSetInfo) error {
   118  	expectedLen := len(j.Spec.RayClusterSpec.WorkerGroupSpecs) + 1
   119  	if len(podSetsInfo) != expectedLen {
   120  		return podset.BadPodSetsInfoLenError(expectedLen, len(podSetsInfo))
   121  	}
   122  
   123  	j.Spec.Suspend = false
   124  
   125  	// head
   126  	headPod := &j.Spec.RayClusterSpec.HeadGroupSpec.Template
   127  	info := podSetsInfo[0]
   128  	if err := podset.Merge(&headPod.ObjectMeta, &headPod.Spec, info); err != nil {
   129  		return err
   130  	}
   131  
   132  	// workers
   133  	for index := range j.Spec.RayClusterSpec.WorkerGroupSpecs {
   134  		workerPod := &j.Spec.RayClusterSpec.WorkerGroupSpecs[index].Template
   135  		info := podSetsInfo[index+1]
   136  		if err := podset.Merge(&workerPod.ObjectMeta, &workerPod.Spec, info); err != nil {
   137  			return err
   138  		}
   139  	}
   140  	return nil
   141  }
   142  
   143  func (j *RayJob) RestorePodSetsInfo(podSetsInfo []podset.PodSetInfo) bool {
   144  	if len(podSetsInfo) != len(j.Spec.RayClusterSpec.WorkerGroupSpecs)+1 {
   145  		return false
   146  	}
   147  
   148  	changed := false
   149  	// head
   150  	headPod := &j.Spec.RayClusterSpec.HeadGroupSpec.Template
   151  	changed = podset.RestorePodSpec(&headPod.ObjectMeta, &headPod.Spec, podSetsInfo[0]) || changed
   152  
   153  	// workers
   154  	for index := range j.Spec.RayClusterSpec.WorkerGroupSpecs {
   155  		workerPod := &j.Spec.RayClusterSpec.WorkerGroupSpecs[index].Template
   156  		info := podSetsInfo[index+1]
   157  		changed = podset.RestorePodSpec(&workerPod.ObjectMeta, &workerPod.Spec, info) || changed
   158  	}
   159  	return changed
   160  }
   161  
   162  func (j *RayJob) Finished() (metav1.Condition, bool) {
   163  	condition := metav1.Condition{
   164  		Type:    kueue.WorkloadFinished,
   165  		Status:  metav1.ConditionTrue,
   166  		Reason:  string(j.Status.JobStatus),
   167  		Message: j.Status.Message,
   168  	}
   169  
   170  	return condition, j.Status.JobStatus == rayjobapi.JobStatusFailed || j.Status.JobStatus == rayjobapi.JobStatusSucceeded
   171  }
   172  
   173  func (j *RayJob) PodsReady() bool {
   174  	return j.Status.RayClusterStatus.State == rayjobapi.Ready
   175  }
   176  
   177  func SetupIndexes(ctx context.Context, indexer client.FieldIndexer) error {
   178  	return jobframework.SetupWorkloadOwnerIndex(ctx, indexer, gvk)
   179  }
   180  
   181  func GetWorkloadNameForRayJob(jobName string) string {
   182  	return jobframework.GetWorkloadNameForOwnerWithGVK(jobName, gvk)
   183  }
   184  
   185  func isRayJob(owner *metav1.OwnerReference) bool {
   186  	return owner.Kind == "RayJob" && (strings.HasPrefix(owner.APIVersion, "ray.io/v1alpha1") || strings.HasPrefix(owner.APIVersion, "ray.io/v1"))
   187  }