sigs.k8s.io/kueue@v0.6.2/pkg/controller/jobs/rayjob/rayjob_webhook_test.go (about)

     1  /*
     2  Copyright 2023 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package rayjob
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"testing"
    23  
    24  	"github.com/google/go-cmp/cmp"
    25  	rayjobapi "github.com/ray-project/kuberay/ray-operator/apis/ray/v1alpha1"
    26  	apivalidation "k8s.io/apimachinery/pkg/api/validation"
    27  	"k8s.io/apimachinery/pkg/util/validation/field"
    28  	"k8s.io/utils/ptr"
    29  
    30  	"sigs.k8s.io/kueue/pkg/controller/constants"
    31  	testingrayutil "sigs.k8s.io/kueue/pkg/util/testingjobs/rayjob"
    32  )
    33  
    34  var (
    35  	labelsPath                    = field.NewPath("metadata", "labels")
    36  	workloadPriorityClassNamePath = labelsPath.Key(constants.WorkloadPriorityClassLabel)
    37  )
    38  
    39  func TestValidateDefault(t *testing.T) {
    40  	testcases := map[string]struct {
    41  		oldJob    *rayjobapi.RayJob
    42  		newJob    *rayjobapi.RayJob
    43  		manageAll bool
    44  	}{
    45  		"unmanaged": {
    46  			oldJob: testingrayutil.MakeJob("job", "ns").
    47  				Suspend(false).
    48  				Obj(),
    49  			newJob: testingrayutil.MakeJob("job", "ns").
    50  				Suspend(false).
    51  				Obj(),
    52  		},
    53  		"managed - by config": {
    54  			oldJob: testingrayutil.MakeJob("job", "ns").
    55  				Suspend(false).
    56  				Obj(),
    57  			newJob: testingrayutil.MakeJob("job", "ns").
    58  				Suspend(true).
    59  				Obj(),
    60  			manageAll: true,
    61  		},
    62  		"managed - by queue": {
    63  			oldJob: testingrayutil.MakeJob("job", "ns").
    64  				Queue("queue").
    65  				Suspend(false).
    66  				Obj(),
    67  			newJob: testingrayutil.MakeJob("job", "ns").
    68  				Queue("queue").
    69  				Suspend(true).
    70  				Obj(),
    71  		},
    72  	}
    73  
    74  	for name, tc := range testcases {
    75  		t.Run(name, func(t *testing.T) {
    76  			wh := &RayJobWebhook{
    77  				manageJobsWithoutQueueName: tc.manageAll,
    78  			}
    79  			result := tc.oldJob.DeepCopy()
    80  			if err := wh.Default(context.Background(), result); err != nil {
    81  				t.Errorf("unexpected Default() error: %s", err)
    82  			}
    83  			if diff := cmp.Diff(tc.newJob, result); diff != "" {
    84  				t.Errorf("Default() mismatch (-want +got):\n%s", diff)
    85  			}
    86  		})
    87  	}
    88  }
    89  
    90  func TestValidateCreate(t *testing.T) {
    91  	worker := rayjobapi.WorkerGroupSpec{}
    92  	bigWorkerGroup := []rayjobapi.WorkerGroupSpec{worker, worker, worker, worker, worker, worker, worker, worker}
    93  
    94  	testcases := map[string]struct {
    95  		job       *rayjobapi.RayJob
    96  		manageAll bool
    97  		wantErr   error
    98  	}{
    99  		"invalid unmanaged": {
   100  			job: testingrayutil.MakeJob("job", "ns").
   101  				ShutdownAfterJobFinishes(false).
   102  				Obj(),
   103  			wantErr: nil,
   104  		},
   105  		"invalid managed - by config": {
   106  			job: testingrayutil.MakeJob("job", "ns").
   107  				ShutdownAfterJobFinishes(false).
   108  				Obj(),
   109  			manageAll: true,
   110  			wantErr: field.ErrorList{
   111  				field.Invalid(field.NewPath("spec", "shutdownAfterJobFinishes"), false, "a kueue managed job should delete the cluster after finishing"),
   112  			}.ToAggregate(),
   113  		},
   114  		"invalid managed - by queue": {
   115  			job: testingrayutil.MakeJob("job", "ns").Queue("queue").
   116  				ShutdownAfterJobFinishes(false).
   117  				Obj(),
   118  			wantErr: field.ErrorList{
   119  				field.Invalid(field.NewPath("spec", "shutdownAfterJobFinishes"), false, "a kueue managed job should delete the cluster after finishing"),
   120  			}.ToAggregate(),
   121  		},
   122  		"invalid managed - has cluster selector": {
   123  			job: testingrayutil.MakeJob("job", "ns").Queue("queue").
   124  				ClusterSelector(map[string]string{
   125  					"k1": "v1",
   126  				}).
   127  				Obj(),
   128  			wantErr: field.ErrorList{
   129  				field.Invalid(field.NewPath("spec", "clusterSelector"), map[string]string{"k1": "v1"}, "a kueue managed job should not use an existing cluster"),
   130  			}.ToAggregate(),
   131  		},
   132  		"invalid managed - has auto scaler": {
   133  			job: testingrayutil.MakeJob("job", "ns").Queue("queue").
   134  				WithEnableAutoscaling(ptr.To(true)).
   135  				Obj(),
   136  			wantErr: field.ErrorList{
   137  				field.Invalid(field.NewPath("spec", "rayClusterSpec", "enableInTreeAutoscaling"), ptr.To(true), "a kueue managed job should not use autoscaling"),
   138  			}.ToAggregate(),
   139  		},
   140  		"invalid managed - too many worker groups": {
   141  			job: testingrayutil.MakeJob("job", "ns").Queue("queue").
   142  				WithWorkerGroups(bigWorkerGroup...).
   143  				Obj(),
   144  			wantErr: field.ErrorList{
   145  				field.TooMany(field.NewPath("spec", "rayClusterSpec", "workerGroupSpecs"), 8, 7),
   146  			}.ToAggregate(),
   147  		},
   148  		"worker group uses head name": {
   149  			job: testingrayutil.MakeJob("job", "ns").Queue("queue").
   150  				WithWorkerGroups(rayjobapi.WorkerGroupSpec{
   151  					GroupName: headGroupPodSetName,
   152  				}).
   153  				Obj(),
   154  			wantErr: field.ErrorList{
   155  				field.Forbidden(field.NewPath("spec", "rayClusterSpec", "workerGroupSpecs").Index(0).Child("groupName"), fmt.Sprintf("%q is reserved for the head group", headGroupPodSetName)),
   156  			}.ToAggregate(),
   157  		},
   158  	}
   159  
   160  	for name, tc := range testcases {
   161  		t.Run(name, func(t *testing.T) {
   162  			wh := &RayJobWebhook{
   163  				manageJobsWithoutQueueName: tc.manageAll,
   164  			}
   165  			_, result := wh.ValidateCreate(context.Background(), tc.job)
   166  			if diff := cmp.Diff(tc.wantErr, result); diff != "" {
   167  				t.Errorf("ValidateCreate() mismatch (-want +got):\n%s", diff)
   168  			}
   169  		})
   170  	}
   171  }
   172  
   173  func TestValidateUpdate(t *testing.T) {
   174  	testcases := map[string]struct {
   175  		oldJob    *rayjobapi.RayJob
   176  		newJob    *rayjobapi.RayJob
   177  		manageAll bool
   178  		wantErr   error
   179  	}{
   180  		"invalid unmanaged": {
   181  			oldJob: testingrayutil.MakeJob("job", "ns").
   182  				ShutdownAfterJobFinishes(true).
   183  				Obj(),
   184  			newJob: testingrayutil.MakeJob("job", "ns").
   185  				ShutdownAfterJobFinishes(false).
   186  				Obj(),
   187  			wantErr: nil,
   188  		},
   189  		"invalid new managed - by config": {
   190  			oldJob: testingrayutil.MakeJob("job", "ns").
   191  				ShutdownAfterJobFinishes(true).
   192  				Obj(),
   193  			newJob: testingrayutil.MakeJob("job", "ns").
   194  				ShutdownAfterJobFinishes(false).
   195  				Obj(),
   196  			manageAll: true,
   197  			wantErr: field.ErrorList{
   198  				field.Invalid(field.NewPath("spec", "shutdownAfterJobFinishes"), false, "a kueue managed job should delete the cluster after finishing"),
   199  			}.ToAggregate(),
   200  		},
   201  		"invalid new managed - by queue": {
   202  			oldJob: testingrayutil.MakeJob("job", "ns").
   203  				Queue("queue").
   204  				ShutdownAfterJobFinishes(true).
   205  				Obj(),
   206  			newJob: testingrayutil.MakeJob("job", "ns").
   207  				Queue("queue").
   208  				ShutdownAfterJobFinishes(false).
   209  				Obj(),
   210  			wantErr: field.ErrorList{
   211  				field.Invalid(field.NewPath("spec", "shutdownAfterJobFinishes"), false, "a kueue managed job should delete the cluster after finishing"),
   212  			}.ToAggregate(),
   213  		},
   214  		"invalid managed - queue name should not change while unsuspended": {
   215  			oldJob: testingrayutil.MakeJob("job", "ns").
   216  				Queue("queue").
   217  				Suspend(false).
   218  				ShutdownAfterJobFinishes(true).
   219  				Obj(),
   220  			newJob: testingrayutil.MakeJob("job", "ns").
   221  				Queue("queue2").
   222  				Suspend(false).
   223  				ShutdownAfterJobFinishes(true).
   224  				Obj(),
   225  			wantErr: field.ErrorList{
   226  				field.Invalid(field.NewPath("metadata", "labels").Key(constants.QueueLabel), "queue", apivalidation.FieldImmutableErrorMsg),
   227  			}.ToAggregate(),
   228  		},
   229  		"managed - queue name can change while suspended": {
   230  			oldJob: testingrayutil.MakeJob("job", "ns").
   231  				Queue("queue").
   232  				Suspend(true).
   233  				ShutdownAfterJobFinishes(true).
   234  				Obj(),
   235  			newJob: testingrayutil.MakeJob("job", "ns").
   236  				Queue("queue2").
   237  				Suspend(true).
   238  				ShutdownAfterJobFinishes(true).
   239  				Obj(),
   240  			wantErr: nil,
   241  		},
   242  		"priorityClassName is immutable": {
   243  			oldJob: testingrayutil.MakeJob("job", "ns").
   244  				Queue("queue").
   245  				WorkloadPriorityClass("test-1").
   246  				Obj(),
   247  			newJob: testingrayutil.MakeJob("job", "ns").
   248  				Queue("queue").
   249  				WorkloadPriorityClass("test-2").
   250  				Obj(),
   251  			wantErr: field.ErrorList{
   252  				field.Invalid(workloadPriorityClassNamePath, "test-1", apivalidation.FieldImmutableErrorMsg),
   253  			}.ToAggregate(),
   254  		},
   255  	}
   256  
   257  	for name, tc := range testcases {
   258  		t.Run(name, func(t *testing.T) {
   259  			wh := &RayJobWebhook{
   260  				manageJobsWithoutQueueName: tc.manageAll,
   261  			}
   262  			_, result := wh.ValidateUpdate(context.Background(), tc.oldJob, tc.newJob)
   263  			if diff := cmp.Diff(tc.wantErr, result); diff != "" {
   264  				t.Errorf("ValidateUpdate() mismatch (-want +got):\n%s", diff)
   265  			}
   266  		})
   267  	}
   268  }