sigs.k8s.io/kueue@v0.6.2/pkg/util/testingjobs/pytorchjob/wrappers_pytorchjob.go (about)

     1  /*
     2  Copyright 2023 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package testing
    18  
    19  import (
    20  	kftraining "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
    21  	corev1 "k8s.io/api/core/v1"
    22  	"k8s.io/apimachinery/pkg/api/resource"
    23  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    24  	"k8s.io/apimachinery/pkg/types"
    25  	"k8s.io/utils/ptr"
    26  
    27  	"sigs.k8s.io/kueue/pkg/controller/constants"
    28  )
    29  
    30  // PyTorchJobWrapper wraps a Job.
    31  type PyTorchJobWrapper struct{ kftraining.PyTorchJob }
    32  
    33  // MakePyTorchJob creates a wrapper for a suspended job with a single container and parallelism=1.
    34  func MakePyTorchJob(name, ns string) *PyTorchJobWrapper {
    35  	return &PyTorchJobWrapper{kftraining.PyTorchJob{
    36  		ObjectMeta: metav1.ObjectMeta{
    37  			Name:        name,
    38  			Namespace:   ns,
    39  			Annotations: make(map[string]string, 1),
    40  		},
    41  		Spec: kftraining.PyTorchJobSpec{
    42  			RunPolicy: kftraining.RunPolicy{
    43  				Suspend: ptr.To(true),
    44  			},
    45  			PyTorchReplicaSpecs: map[kftraining.ReplicaType]*kftraining.ReplicaSpec{
    46  				kftraining.PyTorchJobReplicaTypeMaster: {
    47  					Replicas: ptr.To[int32](1),
    48  					Template: corev1.PodTemplateSpec{
    49  						Spec: corev1.PodSpec{
    50  							RestartPolicy: "Never",
    51  							Containers: []corev1.Container{
    52  								{
    53  									Name:      "c",
    54  									Image:     "pause",
    55  									Command:   []string{},
    56  									Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{}},
    57  								},
    58  							},
    59  							NodeSelector: map[string]string{},
    60  						},
    61  					},
    62  				},
    63  				kftraining.PyTorchJobReplicaTypeWorker: {
    64  					Replicas: ptr.To[int32](1),
    65  					Template: corev1.PodTemplateSpec{
    66  						Spec: corev1.PodSpec{
    67  							RestartPolicy: "Never",
    68  							Containers: []corev1.Container{
    69  								{
    70  									Name:      "c",
    71  									Image:     "pause",
    72  									Command:   []string{},
    73  									Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{}},
    74  								},
    75  							},
    76  							NodeSelector: map[string]string{},
    77  						},
    78  					},
    79  				},
    80  			},
    81  		},
    82  	}}
    83  }
    84  
    85  // PriorityClass updates job priorityclass.
    86  func (j *PyTorchJobWrapper) PriorityClass(pc string) *PyTorchJobWrapper {
    87  	if j.Spec.RunPolicy.SchedulingPolicy == nil {
    88  		j.Spec.RunPolicy.SchedulingPolicy = &kftraining.SchedulingPolicy{}
    89  	}
    90  	j.Spec.RunPolicy.SchedulingPolicy.PriorityClass = pc
    91  	return j
    92  }
    93  
    94  // WorkloadPriorityClass updates job workloadpriorityclass.
    95  func (j *PyTorchJobWrapper) WorkloadPriorityClass(wpc string) *PyTorchJobWrapper {
    96  	if j.Labels == nil {
    97  		j.Labels = make(map[string]string)
    98  	}
    99  	j.Labels[constants.WorkloadPriorityClassLabel] = wpc
   100  	return j
   101  }
   102  
   103  // Obj returns the inner Job.
   104  func (j *PyTorchJobWrapper) Obj() *kftraining.PyTorchJob {
   105  	return &j.PyTorchJob
   106  }
   107  
   108  // Queue updates the queue name of the job.
   109  func (j *PyTorchJobWrapper) Queue(queue string) *PyTorchJobWrapper {
   110  	if j.Labels == nil {
   111  		j.Labels = make(map[string]string)
   112  	}
   113  	j.Labels[constants.QueueLabel] = queue
   114  	return j
   115  }
   116  
   117  // Request adds a resource request to the default container.
   118  func (j *PyTorchJobWrapper) Request(replicaType kftraining.ReplicaType, r corev1.ResourceName, v string) *PyTorchJobWrapper {
   119  	j.Spec.PyTorchReplicaSpecs[replicaType].Template.Spec.Containers[0].Resources.Requests[r] = resource.MustParse(v)
   120  	return j
   121  }
   122  
   123  // Parallelism updates job parallelism.
   124  func (j *PyTorchJobWrapper) Parallelism(p int32) *PyTorchJobWrapper {
   125  	j.Spec.PyTorchReplicaSpecs[kftraining.PyTorchJobReplicaTypeWorker].Replicas = ptr.To(p)
   126  	return j
   127  }
   128  
   129  // Suspend updates the suspend status of the job.
   130  func (j *PyTorchJobWrapper) Suspend(s bool) *PyTorchJobWrapper {
   131  	j.Spec.RunPolicy.Suspend = &s
   132  	return j
   133  }
   134  
   135  // UID updates the uid of the job.
   136  func (j *PyTorchJobWrapper) UID(uid string) *PyTorchJobWrapper {
   137  	j.ObjectMeta.UID = types.UID(uid)
   138  	return j
   139  }
   140  
   141  // PodAnnotation sets annotation at the pod template level
   142  func (j *PyTorchJobWrapper) PodAnnotation(replicaType kftraining.ReplicaType, k, v string) *PyTorchJobWrapper {
   143  	if j.Spec.PyTorchReplicaSpecs[replicaType].Template.Annotations == nil {
   144  		j.Spec.PyTorchReplicaSpecs[replicaType].Template.Annotations = make(map[string]string)
   145  	}
   146  	j.Spec.PyTorchReplicaSpecs[replicaType].Template.Annotations[k] = v
   147  	return j
   148  }
   149  
   150  // PodLabel sets label at the pod template level
   151  func (j *PyTorchJobWrapper) PodLabel(replicaType kftraining.ReplicaType, k, v string) *PyTorchJobWrapper {
   152  	if j.Spec.PyTorchReplicaSpecs[replicaType].Template.Labels == nil {
   153  		j.Spec.PyTorchReplicaSpecs[replicaType].Template.Labels = make(map[string]string)
   154  	}
   155  	j.Spec.PyTorchReplicaSpecs[replicaType].Template.Labels[k] = v
   156  	return j
   157  }