sigs.k8s.io/kueue@v0.6.2/pkg/util/testingjobs/pytorchjob/wrappers_pytorchjob.go (about) 1 /* 2 Copyright 2023 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package testing 18 19 import ( 20 kftraining "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" 21 corev1 "k8s.io/api/core/v1" 22 "k8s.io/apimachinery/pkg/api/resource" 23 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" 24 "k8s.io/apimachinery/pkg/types" 25 "k8s.io/utils/ptr" 26 27 "sigs.k8s.io/kueue/pkg/controller/constants" 28 ) 29 30 // PyTorchJobWrapper wraps a Job. 31 type PyTorchJobWrapper struct{ kftraining.PyTorchJob } 32 33 // MakePyTorchJob creates a wrapper for a suspended job with a single container and parallelism=1. 34 func MakePyTorchJob(name, ns string) *PyTorchJobWrapper { 35 return &PyTorchJobWrapper{kftraining.PyTorchJob{ 36 ObjectMeta: metav1.ObjectMeta{ 37 Name: name, 38 Namespace: ns, 39 Annotations: make(map[string]string, 1), 40 }, 41 Spec: kftraining.PyTorchJobSpec{ 42 RunPolicy: kftraining.RunPolicy{ 43 Suspend: ptr.To(true), 44 }, 45 PyTorchReplicaSpecs: map[kftraining.ReplicaType]*kftraining.ReplicaSpec{ 46 kftraining.PyTorchJobReplicaTypeMaster: { 47 Replicas: ptr.To[int32](1), 48 Template: corev1.PodTemplateSpec{ 49 Spec: corev1.PodSpec{ 50 RestartPolicy: "Never", 51 Containers: []corev1.Container{ 52 { 53 Name: "c", 54 Image: "pause", 55 Command: []string{}, 56 Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{}}, 57 }, 58 }, 59 NodeSelector: map[string]string{}, 60 }, 61 }, 62 }, 63 kftraining.PyTorchJobReplicaTypeWorker: { 64 Replicas: ptr.To[int32](1), 65 Template: corev1.PodTemplateSpec{ 66 Spec: corev1.PodSpec{ 67 RestartPolicy: "Never", 68 Containers: []corev1.Container{ 69 { 70 Name: "c", 71 Image: "pause", 72 Command: []string{}, 73 Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{}}, 74 }, 75 }, 76 NodeSelector: map[string]string{}, 77 }, 78 }, 79 }, 80 }, 81 }, 82 }} 83 } 84 85 // PriorityClass updates job priorityclass. 86 func (j *PyTorchJobWrapper) PriorityClass(pc string) *PyTorchJobWrapper { 87 if j.Spec.RunPolicy.SchedulingPolicy == nil { 88 j.Spec.RunPolicy.SchedulingPolicy = &kftraining.SchedulingPolicy{} 89 } 90 j.Spec.RunPolicy.SchedulingPolicy.PriorityClass = pc 91 return j 92 } 93 94 // WorkloadPriorityClass updates job workloadpriorityclass. 95 func (j *PyTorchJobWrapper) WorkloadPriorityClass(wpc string) *PyTorchJobWrapper { 96 if j.Labels == nil { 97 j.Labels = make(map[string]string) 98 } 99 j.Labels[constants.WorkloadPriorityClassLabel] = wpc 100 return j 101 } 102 103 // Obj returns the inner Job. 104 func (j *PyTorchJobWrapper) Obj() *kftraining.PyTorchJob { 105 return &j.PyTorchJob 106 } 107 108 // Queue updates the queue name of the job. 109 func (j *PyTorchJobWrapper) Queue(queue string) *PyTorchJobWrapper { 110 if j.Labels == nil { 111 j.Labels = make(map[string]string) 112 } 113 j.Labels[constants.QueueLabel] = queue 114 return j 115 } 116 117 // Request adds a resource request to the default container. 118 func (j *PyTorchJobWrapper) Request(replicaType kftraining.ReplicaType, r corev1.ResourceName, v string) *PyTorchJobWrapper { 119 j.Spec.PyTorchReplicaSpecs[replicaType].Template.Spec.Containers[0].Resources.Requests[r] = resource.MustParse(v) 120 return j 121 } 122 123 // Parallelism updates job parallelism. 124 func (j *PyTorchJobWrapper) Parallelism(p int32) *PyTorchJobWrapper { 125 j.Spec.PyTorchReplicaSpecs[kftraining.PyTorchJobReplicaTypeWorker].Replicas = ptr.To(p) 126 return j 127 } 128 129 // Suspend updates the suspend status of the job. 130 func (j *PyTorchJobWrapper) Suspend(s bool) *PyTorchJobWrapper { 131 j.Spec.RunPolicy.Suspend = &s 132 return j 133 } 134 135 // UID updates the uid of the job. 136 func (j *PyTorchJobWrapper) UID(uid string) *PyTorchJobWrapper { 137 j.ObjectMeta.UID = types.UID(uid) 138 return j 139 } 140 141 // PodAnnotation sets annotation at the pod template level 142 func (j *PyTorchJobWrapper) PodAnnotation(replicaType kftraining.ReplicaType, k, v string) *PyTorchJobWrapper { 143 if j.Spec.PyTorchReplicaSpecs[replicaType].Template.Annotations == nil { 144 j.Spec.PyTorchReplicaSpecs[replicaType].Template.Annotations = make(map[string]string) 145 } 146 j.Spec.PyTorchReplicaSpecs[replicaType].Template.Annotations[k] = v 147 return j 148 } 149 150 // PodLabel sets label at the pod template level 151 func (j *PyTorchJobWrapper) PodLabel(replicaType kftraining.ReplicaType, k, v string) *PyTorchJobWrapper { 152 if j.Spec.PyTorchReplicaSpecs[replicaType].Template.Labels == nil { 153 j.Spec.PyTorchReplicaSpecs[replicaType].Template.Labels = make(map[string]string) 154 } 155 j.Spec.PyTorchReplicaSpecs[replicaType].Template.Labels[k] = v 156 return j 157 }