sigs.k8s.io/kueue@v0.6.2/pkg/controller/jobs/raycluster/raycluster_webhook_test.go (about) 1 /* 2 Copyright 2024 The Kubernetes Authors. 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 http://www.apache.org/licenses/LICENSE-2.0 7 Unless required by applicable law or agreed to in writing, software 8 distributed under the License is distributed on an "AS IS" BASIS, 9 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 See the License for the specific language governing permissions and 11 limitations under the License. 12 */ 13 14 package raycluster 15 16 import ( 17 "context" 18 "fmt" 19 "testing" 20 21 "github.com/google/go-cmp/cmp" 22 rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" 23 apivalidation "k8s.io/apimachinery/pkg/api/validation" 24 "k8s.io/apimachinery/pkg/util/validation/field" 25 "k8s.io/utils/ptr" 26 27 "sigs.k8s.io/kueue/pkg/controller/constants" 28 testingrayutil "sigs.k8s.io/kueue/pkg/util/testingjobs/raycluster" 29 ) 30 31 var ( 32 labelsPath = field.NewPath("metadata", "labels") 33 workloadPriorityClassNamePath = labelsPath.Key(constants.WorkloadPriorityClassLabel) 34 ) 35 36 func TestValidateDefault(t *testing.T) { 37 testcases := map[string]struct { 38 oldJob *rayv1.RayCluster 39 newJob *rayv1.RayCluster 40 manageAll bool 41 }{ 42 "unmanaged": { 43 oldJob: testingrayutil.MakeCluster("job", "ns"). 44 Suspend(false). 45 Obj(), 46 newJob: testingrayutil.MakeCluster("job", "ns"). 47 Suspend(false). 48 Obj(), 49 }, 50 "managed - by config": { 51 oldJob: testingrayutil.MakeCluster("job", "ns"). 52 Suspend(false). 53 Obj(), 54 newJob: testingrayutil.MakeCluster("job", "ns"). 55 Suspend(true). 56 Obj(), 57 manageAll: true, 58 }, 59 "managed - by queue": { 60 oldJob: testingrayutil.MakeCluster("job", "ns"). 61 Queue("queue"). 62 Suspend(false). 63 Obj(), 64 newJob: testingrayutil.MakeCluster("job", "ns"). 65 Queue("queue"). 66 Suspend(true). 67 Obj(), 68 }, 69 } 70 71 for name, tc := range testcases { 72 t.Run(name, func(t *testing.T) { 73 wh := &RayClusterWebhook{ 74 manageJobsWithoutQueueName: tc.manageAll, 75 } 76 result := tc.oldJob.DeepCopy() 77 if err := wh.Default(context.Background(), result); err != nil { 78 t.Errorf("unexpected Default() error: %s", err) 79 } 80 if diff := cmp.Diff(tc.newJob, result); diff != "" { 81 t.Errorf("Default() mismatch (-want +got):\n%s", diff) 82 } 83 }) 84 } 85 } 86 87 func TestValidateCreate(t *testing.T) { 88 worker := rayv1.WorkerGroupSpec{} 89 bigWorkerGroup := []rayv1.WorkerGroupSpec{worker, worker, worker, worker, worker, worker, worker, worker} 90 91 testcases := map[string]struct { 92 job *rayv1.RayCluster 93 manageAll bool 94 wantErr error 95 }{ 96 "invalid unmanaged": { 97 job: testingrayutil.MakeCluster("job", "ns"). 98 Obj(), 99 wantErr: nil, 100 }, 101 "invalid managed - has auto scaler": { 102 job: testingrayutil.MakeCluster("job", "ns").Queue("queue"). 103 WithEnableAutoscaling(ptr.To(true)). 104 Obj(), 105 wantErr: field.ErrorList{ 106 field.Invalid(field.NewPath("spec", "enableInTreeAutoscaling"), ptr.To(true), "a kueue managed job should not use autoscaling"), 107 }.ToAggregate(), 108 }, 109 "invalid managed - too many worker groups": { 110 job: testingrayutil.MakeCluster("job", "ns").Queue("queue"). 111 WithWorkerGroups(bigWorkerGroup...). 112 Obj(), 113 wantErr: field.ErrorList{ 114 field.TooMany(field.NewPath("spec", "workerGroupSpecs"), 8, 7), 115 }.ToAggregate(), 116 }, 117 "worker group uses head name": { 118 job: testingrayutil.MakeCluster("job", "ns").Queue("queue"). 119 WithWorkerGroups(rayv1.WorkerGroupSpec{ 120 GroupName: headGroupPodSetName, 121 }). 122 Obj(), 123 wantErr: field.ErrorList{ 124 field.Forbidden(field.NewPath("spec", "workerGroupSpecs").Index(0).Child("groupName"), fmt.Sprintf("%q is reserved for the head group", headGroupPodSetName)), 125 }.ToAggregate(), 126 }, 127 } 128 129 for name, tc := range testcases { 130 t.Run(name, func(t *testing.T) { 131 wh := &RayClusterWebhook{ 132 manageJobsWithoutQueueName: tc.manageAll, 133 } 134 _, result := wh.ValidateCreate(context.Background(), tc.job) 135 if diff := cmp.Diff(tc.wantErr, result); diff != "" { 136 t.Errorf("ValidateCreate() mismatch (-want +got):\n%s", diff) 137 } 138 }) 139 } 140 } 141 142 func TestValidateUpdate(t *testing.T) { 143 testcases := map[string]struct { 144 oldJob *rayv1.RayCluster 145 newJob *rayv1.RayCluster 146 manageAll bool 147 wantErr error 148 }{ 149 "invalid unmanaged": { 150 oldJob: testingrayutil.MakeCluster("job", "ns"). 151 Obj(), 152 newJob: testingrayutil.MakeCluster("job", "ns"). 153 Obj(), 154 wantErr: nil, 155 }, 156 "invalid managed - queue name should not change while unsuspended": { 157 oldJob: testingrayutil.MakeCluster("job", "ns"). 158 Queue("queue"). 159 Suspend(false). 160 Obj(), 161 newJob: testingrayutil.MakeCluster("job", "ns"). 162 Queue("queue2"). 163 Suspend(false). 164 Obj(), 165 wantErr: field.ErrorList{ 166 field.Invalid(field.NewPath("metadata", "labels").Key(constants.QueueLabel), "queue", apivalidation.FieldImmutableErrorMsg), 167 }.ToAggregate(), 168 }, 169 "managed - queue name can change while suspended": { 170 oldJob: testingrayutil.MakeCluster("job", "ns"). 171 Queue("queue"). 172 Suspend(true). 173 Obj(), 174 newJob: testingrayutil.MakeCluster("job", "ns"). 175 Queue("queue2"). 176 Suspend(true). 177 Obj(), 178 wantErr: nil, 179 }, 180 "priorityClassName is immutable": { 181 oldJob: testingrayutil.MakeCluster("job", "ns"). 182 Queue("queue"). 183 WorkloadPriorityClass("test-1"). 184 Obj(), 185 newJob: testingrayutil.MakeCluster("job", "ns"). 186 Queue("queue"). 187 WorkloadPriorityClass("test-2"). 188 Obj(), 189 wantErr: field.ErrorList{ 190 field.Invalid(workloadPriorityClassNamePath, "test-1", apivalidation.FieldImmutableErrorMsg), 191 }.ToAggregate(), 192 }, 193 } 194 195 for name, tc := range testcases { 196 t.Run(name, func(t *testing.T) { 197 wh := &RayClusterWebhook{ 198 manageJobsWithoutQueueName: tc.manageAll, 199 } 200 _, result := wh.ValidateUpdate(context.Background(), tc.oldJob, tc.newJob) 201 if diff := cmp.Diff(tc.wantErr, result); diff != "" { 202 t.Errorf("ValidateUpdate() mismatch (-want +got):\n%s", diff) 203 } 204 }) 205 } 206 }