sigs.k8s.io/kueue@v0.6.2/pkg/controller/jobs/rayjob/rayjob_webhook_test.go (about) 1 /* 2 Copyright 2023 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package rayjob 18 19 import ( 20 "context" 21 "fmt" 22 "testing" 23 24 "github.com/google/go-cmp/cmp" 25 rayjobapi "github.com/ray-project/kuberay/ray-operator/apis/ray/v1alpha1" 26 apivalidation "k8s.io/apimachinery/pkg/api/validation" 27 "k8s.io/apimachinery/pkg/util/validation/field" 28 "k8s.io/utils/ptr" 29 30 "sigs.k8s.io/kueue/pkg/controller/constants" 31 testingrayutil "sigs.k8s.io/kueue/pkg/util/testingjobs/rayjob" 32 ) 33 34 var ( 35 labelsPath = field.NewPath("metadata", "labels") 36 workloadPriorityClassNamePath = labelsPath.Key(constants.WorkloadPriorityClassLabel) 37 ) 38 39 func TestValidateDefault(t *testing.T) { 40 testcases := map[string]struct { 41 oldJob *rayjobapi.RayJob 42 newJob *rayjobapi.RayJob 43 manageAll bool 44 }{ 45 "unmanaged": { 46 oldJob: testingrayutil.MakeJob("job", "ns"). 47 Suspend(false). 48 Obj(), 49 newJob: testingrayutil.MakeJob("job", "ns"). 50 Suspend(false). 51 Obj(), 52 }, 53 "managed - by config": { 54 oldJob: testingrayutil.MakeJob("job", "ns"). 55 Suspend(false). 56 Obj(), 57 newJob: testingrayutil.MakeJob("job", "ns"). 58 Suspend(true). 59 Obj(), 60 manageAll: true, 61 }, 62 "managed - by queue": { 63 oldJob: testingrayutil.MakeJob("job", "ns"). 64 Queue("queue"). 65 Suspend(false). 66 Obj(), 67 newJob: testingrayutil.MakeJob("job", "ns"). 68 Queue("queue"). 69 Suspend(true). 70 Obj(), 71 }, 72 } 73 74 for name, tc := range testcases { 75 t.Run(name, func(t *testing.T) { 76 wh := &RayJobWebhook{ 77 manageJobsWithoutQueueName: tc.manageAll, 78 } 79 result := tc.oldJob.DeepCopy() 80 if err := wh.Default(context.Background(), result); err != nil { 81 t.Errorf("unexpected Default() error: %s", err) 82 } 83 if diff := cmp.Diff(tc.newJob, result); diff != "" { 84 t.Errorf("Default() mismatch (-want +got):\n%s", diff) 85 } 86 }) 87 } 88 } 89 90 func TestValidateCreate(t *testing.T) { 91 worker := rayjobapi.WorkerGroupSpec{} 92 bigWorkerGroup := []rayjobapi.WorkerGroupSpec{worker, worker, worker, worker, worker, worker, worker, worker} 93 94 testcases := map[string]struct { 95 job *rayjobapi.RayJob 96 manageAll bool 97 wantErr error 98 }{ 99 "invalid unmanaged": { 100 job: testingrayutil.MakeJob("job", "ns"). 101 ShutdownAfterJobFinishes(false). 102 Obj(), 103 wantErr: nil, 104 }, 105 "invalid managed - by config": { 106 job: testingrayutil.MakeJob("job", "ns"). 107 ShutdownAfterJobFinishes(false). 108 Obj(), 109 manageAll: true, 110 wantErr: field.ErrorList{ 111 field.Invalid(field.NewPath("spec", "shutdownAfterJobFinishes"), false, "a kueue managed job should delete the cluster after finishing"), 112 }.ToAggregate(), 113 }, 114 "invalid managed - by queue": { 115 job: testingrayutil.MakeJob("job", "ns").Queue("queue"). 116 ShutdownAfterJobFinishes(false). 117 Obj(), 118 wantErr: field.ErrorList{ 119 field.Invalid(field.NewPath("spec", "shutdownAfterJobFinishes"), false, "a kueue managed job should delete the cluster after finishing"), 120 }.ToAggregate(), 121 }, 122 "invalid managed - has cluster selector": { 123 job: testingrayutil.MakeJob("job", "ns").Queue("queue"). 124 ClusterSelector(map[string]string{ 125 "k1": "v1", 126 }). 127 Obj(), 128 wantErr: field.ErrorList{ 129 field.Invalid(field.NewPath("spec", "clusterSelector"), map[string]string{"k1": "v1"}, "a kueue managed job should not use an existing cluster"), 130 }.ToAggregate(), 131 }, 132 "invalid managed - has auto scaler": { 133 job: testingrayutil.MakeJob("job", "ns").Queue("queue"). 134 WithEnableAutoscaling(ptr.To(true)). 135 Obj(), 136 wantErr: field.ErrorList{ 137 field.Invalid(field.NewPath("spec", "rayClusterSpec", "enableInTreeAutoscaling"), ptr.To(true), "a kueue managed job should not use autoscaling"), 138 }.ToAggregate(), 139 }, 140 "invalid managed - too many worker groups": { 141 job: testingrayutil.MakeJob("job", "ns").Queue("queue"). 142 WithWorkerGroups(bigWorkerGroup...). 143 Obj(), 144 wantErr: field.ErrorList{ 145 field.TooMany(field.NewPath("spec", "rayClusterSpec", "workerGroupSpecs"), 8, 7), 146 }.ToAggregate(), 147 }, 148 "worker group uses head name": { 149 job: testingrayutil.MakeJob("job", "ns").Queue("queue"). 150 WithWorkerGroups(rayjobapi.WorkerGroupSpec{ 151 GroupName: headGroupPodSetName, 152 }). 153 Obj(), 154 wantErr: field.ErrorList{ 155 field.Forbidden(field.NewPath("spec", "rayClusterSpec", "workerGroupSpecs").Index(0).Child("groupName"), fmt.Sprintf("%q is reserved for the head group", headGroupPodSetName)), 156 }.ToAggregate(), 157 }, 158 } 159 160 for name, tc := range testcases { 161 t.Run(name, func(t *testing.T) { 162 wh := &RayJobWebhook{ 163 manageJobsWithoutQueueName: tc.manageAll, 164 } 165 _, result := wh.ValidateCreate(context.Background(), tc.job) 166 if diff := cmp.Diff(tc.wantErr, result); diff != "" { 167 t.Errorf("ValidateCreate() mismatch (-want +got):\n%s", diff) 168 } 169 }) 170 } 171 } 172 173 func TestValidateUpdate(t *testing.T) { 174 testcases := map[string]struct { 175 oldJob *rayjobapi.RayJob 176 newJob *rayjobapi.RayJob 177 manageAll bool 178 wantErr error 179 }{ 180 "invalid unmanaged": { 181 oldJob: testingrayutil.MakeJob("job", "ns"). 182 ShutdownAfterJobFinishes(true). 183 Obj(), 184 newJob: testingrayutil.MakeJob("job", "ns"). 185 ShutdownAfterJobFinishes(false). 186 Obj(), 187 wantErr: nil, 188 }, 189 "invalid new managed - by config": { 190 oldJob: testingrayutil.MakeJob("job", "ns"). 191 ShutdownAfterJobFinishes(true). 192 Obj(), 193 newJob: testingrayutil.MakeJob("job", "ns"). 194 ShutdownAfterJobFinishes(false). 195 Obj(), 196 manageAll: true, 197 wantErr: field.ErrorList{ 198 field.Invalid(field.NewPath("spec", "shutdownAfterJobFinishes"), false, "a kueue managed job should delete the cluster after finishing"), 199 }.ToAggregate(), 200 }, 201 "invalid new managed - by queue": { 202 oldJob: testingrayutil.MakeJob("job", "ns"). 203 Queue("queue"). 204 ShutdownAfterJobFinishes(true). 205 Obj(), 206 newJob: testingrayutil.MakeJob("job", "ns"). 207 Queue("queue"). 208 ShutdownAfterJobFinishes(false). 209 Obj(), 210 wantErr: field.ErrorList{ 211 field.Invalid(field.NewPath("spec", "shutdownAfterJobFinishes"), false, "a kueue managed job should delete the cluster after finishing"), 212 }.ToAggregate(), 213 }, 214 "invalid managed - queue name should not change while unsuspended": { 215 oldJob: testingrayutil.MakeJob("job", "ns"). 216 Queue("queue"). 217 Suspend(false). 218 ShutdownAfterJobFinishes(true). 219 Obj(), 220 newJob: testingrayutil.MakeJob("job", "ns"). 221 Queue("queue2"). 222 Suspend(false). 223 ShutdownAfterJobFinishes(true). 224 Obj(), 225 wantErr: field.ErrorList{ 226 field.Invalid(field.NewPath("metadata", "labels").Key(constants.QueueLabel), "queue", apivalidation.FieldImmutableErrorMsg), 227 }.ToAggregate(), 228 }, 229 "managed - queue name can change while suspended": { 230 oldJob: testingrayutil.MakeJob("job", "ns"). 231 Queue("queue"). 232 Suspend(true). 233 ShutdownAfterJobFinishes(true). 234 Obj(), 235 newJob: testingrayutil.MakeJob("job", "ns"). 236 Queue("queue2"). 237 Suspend(true). 238 ShutdownAfterJobFinishes(true). 239 Obj(), 240 wantErr: nil, 241 }, 242 "priorityClassName is immutable": { 243 oldJob: testingrayutil.MakeJob("job", "ns"). 244 Queue("queue"). 245 WorkloadPriorityClass("test-1"). 246 Obj(), 247 newJob: testingrayutil.MakeJob("job", "ns"). 248 Queue("queue"). 249 WorkloadPriorityClass("test-2"). 250 Obj(), 251 wantErr: field.ErrorList{ 252 field.Invalid(workloadPriorityClassNamePath, "test-1", apivalidation.FieldImmutableErrorMsg), 253 }.ToAggregate(), 254 }, 255 } 256 257 for name, tc := range testcases { 258 t.Run(name, func(t *testing.T) { 259 wh := &RayJobWebhook{ 260 manageJobsWithoutQueueName: tc.manageAll, 261 } 262 _, result := wh.ValidateUpdate(context.Background(), tc.oldJob, tc.newJob) 263 if diff := cmp.Diff(tc.wantErr, result); diff != "" { 264 t.Errorf("ValidateUpdate() mismatch (-want +got):\n%s", diff) 265 } 266 }) 267 } 268 }