sigs.k8s.io/kueue@v0.6.2/pkg/controller/jobs/raycluster/raycluster_webhook_test.go (about)

     1  /*
     2  Copyright 2024 The Kubernetes Authors.
     3  Licensed under the Apache License, Version 2.0 (the "License");
     4  you may not use this file except in compliance with the License.
     5  You may obtain a copy of the License at
     6      http://www.apache.org/licenses/LICENSE-2.0
     7  Unless required by applicable law or agreed to in writing, software
     8  distributed under the License is distributed on an "AS IS" BASIS,
     9  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    10  See the License for the specific language governing permissions and
    11  limitations under the License.
    12  */
    13  
    14  package raycluster
    15  
    16  import (
    17  	"context"
    18  	"fmt"
    19  	"testing"
    20  
    21  	"github.com/google/go-cmp/cmp"
    22  	rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1"
    23  	apivalidation "k8s.io/apimachinery/pkg/api/validation"
    24  	"k8s.io/apimachinery/pkg/util/validation/field"
    25  	"k8s.io/utils/ptr"
    26  
    27  	"sigs.k8s.io/kueue/pkg/controller/constants"
    28  	testingrayutil "sigs.k8s.io/kueue/pkg/util/testingjobs/raycluster"
    29  )
    30  
    31  var (
    32  	labelsPath                    = field.NewPath("metadata", "labels")
    33  	workloadPriorityClassNamePath = labelsPath.Key(constants.WorkloadPriorityClassLabel)
    34  )
    35  
    36  func TestValidateDefault(t *testing.T) {
    37  	testcases := map[string]struct {
    38  		oldJob    *rayv1.RayCluster
    39  		newJob    *rayv1.RayCluster
    40  		manageAll bool
    41  	}{
    42  		"unmanaged": {
    43  			oldJob: testingrayutil.MakeCluster("job", "ns").
    44  				Suspend(false).
    45  				Obj(),
    46  			newJob: testingrayutil.MakeCluster("job", "ns").
    47  				Suspend(false).
    48  				Obj(),
    49  		},
    50  		"managed - by config": {
    51  			oldJob: testingrayutil.MakeCluster("job", "ns").
    52  				Suspend(false).
    53  				Obj(),
    54  			newJob: testingrayutil.MakeCluster("job", "ns").
    55  				Suspend(true).
    56  				Obj(),
    57  			manageAll: true,
    58  		},
    59  		"managed - by queue": {
    60  			oldJob: testingrayutil.MakeCluster("job", "ns").
    61  				Queue("queue").
    62  				Suspend(false).
    63  				Obj(),
    64  			newJob: testingrayutil.MakeCluster("job", "ns").
    65  				Queue("queue").
    66  				Suspend(true).
    67  				Obj(),
    68  		},
    69  	}
    70  
    71  	for name, tc := range testcases {
    72  		t.Run(name, func(t *testing.T) {
    73  			wh := &RayClusterWebhook{
    74  				manageJobsWithoutQueueName: tc.manageAll,
    75  			}
    76  			result := tc.oldJob.DeepCopy()
    77  			if err := wh.Default(context.Background(), result); err != nil {
    78  				t.Errorf("unexpected Default() error: %s", err)
    79  			}
    80  			if diff := cmp.Diff(tc.newJob, result); diff != "" {
    81  				t.Errorf("Default() mismatch (-want +got):\n%s", diff)
    82  			}
    83  		})
    84  	}
    85  }
    86  
    87  func TestValidateCreate(t *testing.T) {
    88  	worker := rayv1.WorkerGroupSpec{}
    89  	bigWorkerGroup := []rayv1.WorkerGroupSpec{worker, worker, worker, worker, worker, worker, worker, worker}
    90  
    91  	testcases := map[string]struct {
    92  		job       *rayv1.RayCluster
    93  		manageAll bool
    94  		wantErr   error
    95  	}{
    96  		"invalid unmanaged": {
    97  			job: testingrayutil.MakeCluster("job", "ns").
    98  				Obj(),
    99  			wantErr: nil,
   100  		},
   101  		"invalid managed - has auto scaler": {
   102  			job: testingrayutil.MakeCluster("job", "ns").Queue("queue").
   103  				WithEnableAutoscaling(ptr.To(true)).
   104  				Obj(),
   105  			wantErr: field.ErrorList{
   106  				field.Invalid(field.NewPath("spec", "enableInTreeAutoscaling"), ptr.To(true), "a kueue managed job should not use autoscaling"),
   107  			}.ToAggregate(),
   108  		},
   109  		"invalid managed - too many worker groups": {
   110  			job: testingrayutil.MakeCluster("job", "ns").Queue("queue").
   111  				WithWorkerGroups(bigWorkerGroup...).
   112  				Obj(),
   113  			wantErr: field.ErrorList{
   114  				field.TooMany(field.NewPath("spec", "workerGroupSpecs"), 8, 7),
   115  			}.ToAggregate(),
   116  		},
   117  		"worker group uses head name": {
   118  			job: testingrayutil.MakeCluster("job", "ns").Queue("queue").
   119  				WithWorkerGroups(rayv1.WorkerGroupSpec{
   120  					GroupName: headGroupPodSetName,
   121  				}).
   122  				Obj(),
   123  			wantErr: field.ErrorList{
   124  				field.Forbidden(field.NewPath("spec", "workerGroupSpecs").Index(0).Child("groupName"), fmt.Sprintf("%q is reserved for the head group", headGroupPodSetName)),
   125  			}.ToAggregate(),
   126  		},
   127  	}
   128  
   129  	for name, tc := range testcases {
   130  		t.Run(name, func(t *testing.T) {
   131  			wh := &RayClusterWebhook{
   132  				manageJobsWithoutQueueName: tc.manageAll,
   133  			}
   134  			_, result := wh.ValidateCreate(context.Background(), tc.job)
   135  			if diff := cmp.Diff(tc.wantErr, result); diff != "" {
   136  				t.Errorf("ValidateCreate() mismatch (-want +got):\n%s", diff)
   137  			}
   138  		})
   139  	}
   140  }
   141  
   142  func TestValidateUpdate(t *testing.T) {
   143  	testcases := map[string]struct {
   144  		oldJob    *rayv1.RayCluster
   145  		newJob    *rayv1.RayCluster
   146  		manageAll bool
   147  		wantErr   error
   148  	}{
   149  		"invalid unmanaged": {
   150  			oldJob: testingrayutil.MakeCluster("job", "ns").
   151  				Obj(),
   152  			newJob: testingrayutil.MakeCluster("job", "ns").
   153  				Obj(),
   154  			wantErr: nil,
   155  		},
   156  		"invalid managed - queue name should not change while unsuspended": {
   157  			oldJob: testingrayutil.MakeCluster("job", "ns").
   158  				Queue("queue").
   159  				Suspend(false).
   160  				Obj(),
   161  			newJob: testingrayutil.MakeCluster("job", "ns").
   162  				Queue("queue2").
   163  				Suspend(false).
   164  				Obj(),
   165  			wantErr: field.ErrorList{
   166  				field.Invalid(field.NewPath("metadata", "labels").Key(constants.QueueLabel), "queue", apivalidation.FieldImmutableErrorMsg),
   167  			}.ToAggregate(),
   168  		},
   169  		"managed - queue name can change while suspended": {
   170  			oldJob: testingrayutil.MakeCluster("job", "ns").
   171  				Queue("queue").
   172  				Suspend(true).
   173  				Obj(),
   174  			newJob: testingrayutil.MakeCluster("job", "ns").
   175  				Queue("queue2").
   176  				Suspend(true).
   177  				Obj(),
   178  			wantErr: nil,
   179  		},
   180  		"priorityClassName is immutable": {
   181  			oldJob: testingrayutil.MakeCluster("job", "ns").
   182  				Queue("queue").
   183  				WorkloadPriorityClass("test-1").
   184  				Obj(),
   185  			newJob: testingrayutil.MakeCluster("job", "ns").
   186  				Queue("queue").
   187  				WorkloadPriorityClass("test-2").
   188  				Obj(),
   189  			wantErr: field.ErrorList{
   190  				field.Invalid(workloadPriorityClassNamePath, "test-1", apivalidation.FieldImmutableErrorMsg),
   191  			}.ToAggregate(),
   192  		},
   193  	}
   194  
   195  	for name, tc := range testcases {
   196  		t.Run(name, func(t *testing.T) {
   197  			wh := &RayClusterWebhook{
   198  				manageJobsWithoutQueueName: tc.manageAll,
   199  			}
   200  			_, result := wh.ValidateUpdate(context.Background(), tc.oldJob, tc.newJob)
   201  			if diff := cmp.Diff(tc.wantErr, result); diff != "" {
   202  				t.Errorf("ValidateUpdate() mismatch (-want +got):\n%s", diff)
   203  			}
   204  		})
   205  	}
   206  }