sigs.k8s.io/kueue@v0.6.2/pkg/controller/jobs/job/job_webhook_test.go (about)

     1  /*
     2  Copyright 2022 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package job
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"testing"
    23  
    24  	"github.com/google/go-cmp/cmp"
    25  	"github.com/google/go-cmp/cmp/cmpopts"
    26  	kubeflow "github.com/kubeflow/mpi-operator/pkg/apis/kubeflow/v2beta1"
    27  	batchv1 "k8s.io/api/batch/v1"
    28  	apivalidation "k8s.io/apimachinery/pkg/api/validation"
    29  	"k8s.io/apimachinery/pkg/util/validation/field"
    30  	"k8s.io/apimachinery/pkg/version"
    31  	fakediscovery "k8s.io/client-go/discovery/fake"
    32  	fakeclient "k8s.io/client-go/kubernetes/fake"
    33  	"k8s.io/utils/ptr"
    34  
    35  	"sigs.k8s.io/kueue/pkg/controller/constants"
    36  	"sigs.k8s.io/kueue/pkg/controller/jobframework"
    37  	"sigs.k8s.io/kueue/pkg/util/kubeversion"
    38  	testingutil "sigs.k8s.io/kueue/pkg/util/testingjobs/job"
    39  
    40  	// without this only the job framework is registered
    41  	_ "sigs.k8s.io/kueue/pkg/controller/jobs/mpijob"
    42  )
    43  
    44  const (
    45  	invalidRFC1123Message = `a lowercase RFC 1123 subdomain must consist of lower case alphanumeric characters, '-' or '.', and must start and end with an alphanumeric character (e.g. 'example.com', regex used for validation is '[a-z0-9]([-a-z0-9]*[a-z0-9])?(\.[a-z0-9]([-a-z0-9]*[a-z0-9])?)*')`
    46  )
    47  
    48  var (
    49  	annotationsPath               = field.NewPath("metadata", "annotations")
    50  	labelsPath                    = field.NewPath("metadata", "labels")
    51  	parentWorkloadKeyPath         = annotationsPath.Key(constants.ParentWorkloadAnnotation)
    52  	queueNameLabelPath            = labelsPath.Key(constants.QueueLabel)
    53  	prebuiltWlNameLabelPath       = labelsPath.Key(constants.PrebuiltWorkloadLabel)
    54  	queueNameAnnotationsPath      = annotationsPath.Key(constants.QueueAnnotation)
    55  	workloadPriorityClassNamePath = labelsPath.Key(constants.WorkloadPriorityClassLabel)
    56  )
    57  
    58  func TestValidateCreate(t *testing.T) {
    59  	testcases := []struct {
    60  		name          string
    61  		job           *batchv1.Job
    62  		wantErr       field.ErrorList
    63  		serverVersion string
    64  	}{
    65  		{
    66  			name:    "simple",
    67  			job:     testingutil.MakeJob("job", "default").Queue("queue").Obj(),
    68  			wantErr: nil,
    69  		},
    70  		{
    71  			name: "valid parent-workload annotation",
    72  			job: testingutil.MakeJob("job", "default").
    73  				ParentWorkload("parent-workload-name").
    74  				Queue("queue").
    75  				OwnerReference("parent-workload-name", kubeflow.SchemeGroupVersionKind).
    76  				Obj(),
    77  			wantErr: nil,
    78  		},
    79  		{
    80  			name: "invalid parent-workload annotation",
    81  			job: testingutil.MakeJob("job", "default").
    82  				ParentWorkload("parent workload name").
    83  				OwnerReference("parent workload name", kubeflow.SchemeGroupVersionKind).
    84  				Queue("queue").
    85  				Obj(),
    86  			wantErr: field.ErrorList{field.Invalid(parentWorkloadKeyPath, "parent workload name", invalidRFC1123Message)},
    87  		},
    88  		{
    89  			name: "invalid parent-workload annotation (owner is missing)",
    90  			job: testingutil.MakeJob("job", "default").
    91  				ParentWorkload("parent-workload").
    92  				Queue("queue").
    93  				Obj(),
    94  			wantErr: field.ErrorList{
    95  				field.Forbidden(parentWorkloadKeyPath, "must not add a parent workload annotation to job without OwnerReference"),
    96  			},
    97  		},
    98  		{
    99  			name:    "invalid queue-name label",
   100  			job:     testingutil.MakeJob("job", "default").Queue("queue name").Obj(),
   101  			wantErr: field.ErrorList{field.Invalid(queueNameLabelPath, "queue name", invalidRFC1123Message)},
   102  		},
   103  		{
   104  			name:    "invalid queue-name annotation (deprecated)",
   105  			job:     testingutil.MakeJob("job", "default").QueueNameAnnotation("queue name").Obj(),
   106  			wantErr: field.ErrorList{field.Invalid(queueNameAnnotationsPath, "queue name", invalidRFC1123Message)},
   107  		},
   108  		{
   109  			name: "invalid queue-name and parent-workload annotation",
   110  			job: testingutil.MakeJob("job", "default").
   111  				Queue("queue name").
   112  				ParentWorkload("parent workload name").
   113  				OwnerReference("parent workload name", kubeflow.SchemeGroupVersionKind).
   114  				Obj(),
   115  			wantErr: field.ErrorList{
   116  				field.Invalid(parentWorkloadKeyPath, "parent workload name", invalidRFC1123Message),
   117  				field.Invalid(queueNameLabelPath, "queue name", invalidRFC1123Message),
   118  			},
   119  		},
   120  		{
   121  			name: "invalid partial admission annotation (format)",
   122  			job: testingutil.MakeJob("job", "default").
   123  				Parallelism(4).
   124  				Completions(6).
   125  				SetAnnotation(JobMinParallelismAnnotation, "NaN").
   126  				Obj(),
   127  			wantErr: field.ErrorList{
   128  				field.Invalid(minPodsCountAnnotationsPath, "NaN", "strconv.Atoi: parsing \"NaN\": invalid syntax"),
   129  			},
   130  		},
   131  		{
   132  			name: "invalid partial admission annotation (badValue)",
   133  			job: testingutil.MakeJob("job", "default").
   134  				Parallelism(4).
   135  				Completions(6).
   136  				SetAnnotation(JobMinParallelismAnnotation, "5").
   137  				Obj(),
   138  			wantErr: field.ErrorList{
   139  				field.Invalid(minPodsCountAnnotationsPath, 5, "should be between 0 and 3"),
   140  			},
   141  		},
   142  		{
   143  			name: "valid partial admission annotation",
   144  			job: testingutil.MakeJob("job", "default").
   145  				Parallelism(4).
   146  				Completions(6).
   147  				SetAnnotation(JobMinParallelismAnnotation, "3").
   148  				Obj(),
   149  			wantErr: nil,
   150  		},
   151  		{
   152  			name: "invalid sync completions annotation (format)",
   153  			job: testingutil.MakeJob("job", "default").
   154  				Parallelism(4).
   155  				Completions(6).
   156  				SetAnnotation(JobCompletionsEqualParallelismAnnotation, "-").
   157  				Indexed(true).
   158  				Obj(),
   159  			wantErr: field.ErrorList{
   160  				field.Invalid(syncCompletionAnnotationsPath, "-", "strconv.ParseBool: parsing \"-\": invalid syntax"),
   161  			},
   162  		},
   163  		{
   164  			name: "valid sync completions annotation, wrong completions count",
   165  			job: testingutil.MakeJob("job", "default").
   166  				Parallelism(4).
   167  				Completions(6).
   168  				SetAnnotation(JobCompletionsEqualParallelismAnnotation, "true").
   169  				Indexed(true).
   170  				Obj(),
   171  			wantErr: field.ErrorList{
   172  				field.Invalid(field.NewPath("spec", "completions"), ptr.To[int32](6), fmt.Sprintf("should be equal to parallelism when %s is annotation is true", JobCompletionsEqualParallelismAnnotation)),
   173  			},
   174  			serverVersion: "1.27.0",
   175  		},
   176  		{
   177  			name: "valid sync completions annotation, wrong job completions type (default)",
   178  			job: testingutil.MakeJob("job", "default").
   179  				Parallelism(4).
   180  				Completions(4).
   181  				SetAnnotation(JobCompletionsEqualParallelismAnnotation, "true").
   182  				Obj(),
   183  			wantErr: field.ErrorList{
   184  				field.Invalid(syncCompletionAnnotationsPath, "true", "should not be enabled for NonIndexed jobs"),
   185  			},
   186  			serverVersion: "1.27.0",
   187  		},
   188  		{
   189  			name: "valid sync completions annotation, wrong job completions type",
   190  			job: testingutil.MakeJob("job", "default").
   191  				Parallelism(4).
   192  				Completions(4).
   193  				SetAnnotation(JobCompletionsEqualParallelismAnnotation, "true").
   194  				Indexed(false).
   195  				Obj(),
   196  			wantErr: field.ErrorList{
   197  				field.Invalid(syncCompletionAnnotationsPath, "true", "should not be enabled for NonIndexed jobs"),
   198  			},
   199  			serverVersion: "1.27.0",
   200  		},
   201  		{
   202  			name: "valid sync completions annotation, server version less then 1.27",
   203  			job: testingutil.MakeJob("job", "default").
   204  				Parallelism(4).
   205  				Completions(4).
   206  				SetAnnotation(JobCompletionsEqualParallelismAnnotation, "true").
   207  				Indexed(true).
   208  				Obj(),
   209  			wantErr: field.ErrorList{
   210  				field.Invalid(syncCompletionAnnotationsPath, "true", "only supported in Kubernetes 1.27 or newer"),
   211  			},
   212  			serverVersion: "1.26.3",
   213  		},
   214  		{
   215  			name: "valid sync completions annotation, server version wasn't specified",
   216  			job: testingutil.MakeJob("job", "default").
   217  				Parallelism(4).
   218  				Completions(4).
   219  				SetAnnotation(JobCompletionsEqualParallelismAnnotation, "true").
   220  				Indexed(true).
   221  				Obj(),
   222  			wantErr: field.ErrorList{
   223  				field.Invalid(syncCompletionAnnotationsPath, "true", "only supported in Kubernetes 1.27 or newer"),
   224  			},
   225  		},
   226  		{
   227  			name: "valid sync completions annotation",
   228  			job: testingutil.MakeJob("job", "default").
   229  				Parallelism(4).
   230  				Completions(4).
   231  				SetAnnotation(JobCompletionsEqualParallelismAnnotation, "true").
   232  				Indexed(true).
   233  				Obj(),
   234  			wantErr:       nil,
   235  			serverVersion: "1.27.0",
   236  		},
   237  		{
   238  			name: "invalid prebuilt workload",
   239  			job: testingutil.MakeJob("job", "default").
   240  				Parallelism(4).
   241  				Completions(4).
   242  				Label(constants.PrebuiltWorkloadLabel, "workload name").
   243  				Indexed(true).
   244  				Obj(),
   245  			wantErr: field.ErrorList{
   246  				field.Invalid(prebuiltWlNameLabelPath, "workload name", invalidRFC1123Message),
   247  			},
   248  			serverVersion: "1.27.0",
   249  		},
   250  		{
   251  			name: "valid prebuilt workload",
   252  			job: testingutil.MakeJob("job", "default").
   253  				Parallelism(4).
   254  				Completions(4).
   255  				Label(constants.PrebuiltWorkloadLabel, "workload-name").
   256  				Indexed(true).
   257  				Obj(),
   258  			wantErr:       nil,
   259  			serverVersion: "1.27.0",
   260  		},
   261  	}
   262  
   263  	for _, tc := range testcases {
   264  		t.Run(tc.name, func(t *testing.T) {
   265  			jw := &JobWebhook{}
   266  			fakeDiscoveryClient, _ := fakeclient.NewSimpleClientset().Discovery().(*fakediscovery.FakeDiscovery)
   267  			fakeDiscoveryClient.FakedServerVersion = &version.Info{GitVersion: tc.serverVersion}
   268  			jw.kubeServerVersion = kubeversion.NewServerVersionFetcher(fakeDiscoveryClient)
   269  			if err := jw.kubeServerVersion.FetchServerVersion(); err != nil && tc.serverVersion != "" {
   270  				t.Fatalf("Failed fetching server version: %v", err)
   271  			}
   272  
   273  			gotErr := jw.validateCreate((*Job)(tc.job))
   274  
   275  			if diff := cmp.Diff(tc.wantErr, gotErr); diff != "" {
   276  				t.Errorf("validateCreate() mismatch (-want +got):\n%s", diff)
   277  			}
   278  		})
   279  	}
   280  }
   281  
   282  func TestValidateUpdate(t *testing.T) {
   283  	testcases := []struct {
   284  		name    string
   285  		oldJob  *batchv1.Job
   286  		newJob  *batchv1.Job
   287  		wantErr field.ErrorList
   288  	}{
   289  		{
   290  			name:    "normal update",
   291  			oldJob:  testingutil.MakeJob("job", "default").Queue("queue").Obj(),
   292  			newJob:  testingutil.MakeJob("job", "default").Queue("queue").Suspend(false).Obj(),
   293  			wantErr: nil,
   294  		},
   295  		{
   296  			name:   "add queue name with suspend is false",
   297  			oldJob: testingutil.MakeJob("job", "default").Obj(),
   298  			newJob: testingutil.MakeJob("job", "default").Queue("queue").Suspend(false).Obj(),
   299  			wantErr: field.ErrorList{
   300  				field.Invalid(queueNameLabelPath, "", apivalidation.FieldImmutableErrorMsg),
   301  			},
   302  		},
   303  		{
   304  			name:    "add queue name with suspend is true",
   305  			oldJob:  testingutil.MakeJob("job", "default").Obj(),
   306  			newJob:  testingutil.MakeJob("job", "default").Queue("queue").Suspend(true).Obj(),
   307  			wantErr: nil,
   308  		},
   309  		{
   310  			name:   "change queue name with suspend is false",
   311  			oldJob: testingutil.MakeJob("job", "default").Queue("queue").Obj(),
   312  			newJob: testingutil.MakeJob("job", "default").Queue("queue2").Suspend(false).Obj(),
   313  			wantErr: field.ErrorList{
   314  				field.Invalid(queueNameLabelPath, "queue", apivalidation.FieldImmutableErrorMsg),
   315  			},
   316  		},
   317  		{
   318  			name:    "change queue name with suspend is true",
   319  			oldJob:  testingutil.MakeJob("job", "default").Obj(),
   320  			newJob:  testingutil.MakeJob("job", "default").Queue("queue").Suspend(true).Obj(),
   321  			wantErr: nil,
   322  		},
   323  		{
   324  			name:    "change queue name with suspend is true, but invalid value",
   325  			oldJob:  testingutil.MakeJob("job", "default").Obj(),
   326  			newJob:  testingutil.MakeJob("job", "default").Queue("queue name").Suspend(true).Obj(),
   327  			wantErr: field.ErrorList{field.Invalid(queueNameLabelPath, "queue name", invalidRFC1123Message)},
   328  		},
   329  		{
   330  			name:   "update the nil parent workload to non-empty",
   331  			oldJob: testingutil.MakeJob("job", "default").Obj(),
   332  			newJob: testingutil.MakeJob("job", "default").
   333  				ParentWorkload("parent").
   334  				OwnerReference("parent", kubeflow.SchemeGroupVersionKind).
   335  				Obj(),
   336  			wantErr: field.ErrorList{
   337  				field.Invalid(parentWorkloadKeyPath, "parent", apivalidation.FieldImmutableErrorMsg),
   338  			},
   339  		},
   340  		{
   341  			name:   "update the non-empty parent workload to nil",
   342  			oldJob: testingutil.MakeJob("job", "default").ParentWorkload("parent").Obj(),
   343  			newJob: testingutil.MakeJob("job", "default").Obj(),
   344  			wantErr: field.ErrorList{
   345  				field.Invalid(parentWorkloadKeyPath, "", apivalidation.FieldImmutableErrorMsg),
   346  			},
   347  		},
   348  		{
   349  			name:   "invalid queue name and immutable parent",
   350  			oldJob: testingutil.MakeJob("job", "default").Obj(),
   351  			newJob: testingutil.MakeJob("job", "default").
   352  				Queue("queue name").
   353  				ParentWorkload("parent").
   354  				OwnerReference("parent", kubeflow.SchemeGroupVersionKind).
   355  				Obj(),
   356  			wantErr: field.ErrorList{
   357  				field.Invalid(queueNameLabelPath, "queue name", invalidRFC1123Message),
   358  				field.Invalid(parentWorkloadKeyPath, "parent", apivalidation.FieldImmutableErrorMsg),
   359  			},
   360  		},
   361  		{
   362  			name: "immutable parallelism while unsuspended with partial admission enabled",
   363  			oldJob: testingutil.MakeJob("job", "default").
   364  				Suspend(false).
   365  				Parallelism(4).
   366  				Completions(6).
   367  				SetAnnotation(JobMinParallelismAnnotation, "3").
   368  				Obj(),
   369  			newJob: testingutil.MakeJob("job", "default").
   370  				Suspend(false).
   371  				Parallelism(5).
   372  				Completions(6).
   373  				SetAnnotation(JobMinParallelismAnnotation, "3").
   374  				Obj(),
   375  			wantErr: field.ErrorList{
   376  				field.Forbidden(field.NewPath("spec", "parallelism"), "cannot change when partial admission is enabled and the job is not suspended"),
   377  			},
   378  		},
   379  		{
   380  			name: "mutable parallelism while suspended with partial admission enabled",
   381  			oldJob: testingutil.MakeJob("job", "default").
   382  				Parallelism(4).
   383  				Completions(6).
   384  				SetAnnotation(JobMinParallelismAnnotation, "3").
   385  				Obj(),
   386  			newJob: testingutil.MakeJob("job", "default").
   387  				Parallelism(5).
   388  				Completions(6).
   389  				SetAnnotation(JobMinParallelismAnnotation, "3").
   390  				Obj(),
   391  			wantErr: nil,
   392  		},
   393  		{
   394  			name: "immutable sync completion annotation while unsuspended",
   395  			oldJob: testingutil.MakeJob("job", "default").
   396  				Suspend(false).
   397  				Parallelism(4).
   398  				Completions(6).
   399  				SetAnnotation(JobCompletionsEqualParallelismAnnotation, "true").
   400  				Obj(),
   401  			newJob: testingutil.MakeJob("job", "default").
   402  				Suspend(false).
   403  				Parallelism(5).
   404  				Completions(6).
   405  				SetAnnotation(JobCompletionsEqualParallelismAnnotation, "false").
   406  				Obj(),
   407  			wantErr: field.ErrorList{
   408  				field.Forbidden(syncCompletionAnnotationsPath, fmt.Sprintf("%s while the job is not suspended", apivalidation.FieldImmutableErrorMsg)),
   409  			},
   410  		},
   411  		{
   412  			name: "mutable sync completion annotation while suspended",
   413  			oldJob: testingutil.MakeJob("job", "default").
   414  				Suspend(true).
   415  				Parallelism(4).
   416  				Completions(6).
   417  				SetAnnotation(JobCompletionsEqualParallelismAnnotation, "true").
   418  				Obj(),
   419  			newJob: testingutil.MakeJob("job", "default").
   420  				Suspend(false).
   421  				Parallelism(5).
   422  				Completions(6).
   423  				SetAnnotation(JobCompletionsEqualParallelismAnnotation, "false").
   424  				Obj(),
   425  			wantErr: nil,
   426  		},
   427  		{
   428  			name:   "workloadPriorityClassName is immutable",
   429  			oldJob: testingutil.MakeJob("job", "default").WorkloadPriorityClass("test-1").Obj(),
   430  			newJob: testingutil.MakeJob("job", "default").WorkloadPriorityClass("test-2").Obj(),
   431  			wantErr: field.ErrorList{
   432  				field.Invalid(workloadPriorityClassNamePath, "test-1", apivalidation.FieldImmutableErrorMsg),
   433  			},
   434  		},
   435  		{
   436  			name: "immutable prebuilt workload ",
   437  			oldJob: testingutil.MakeJob("job", "default").
   438  				Suspend(true).
   439  				Label(constants.PrebuiltWorkloadLabel, "old-workload").
   440  				Obj(),
   441  			newJob: testingutil.MakeJob("job", "default").
   442  				Suspend(false).
   443  				Label(constants.PrebuiltWorkloadLabel, "new-workload").
   444  				Obj(),
   445  			wantErr: apivalidation.ValidateImmutableField("old-workload", "new-workload", prebuiltWlNameLabelPath),
   446  		},
   447  	}
   448  
   449  	for _, tc := range testcases {
   450  		t.Run(tc.name, func(t *testing.T) {
   451  			gotErr := new(JobWebhook).validateUpdate((*Job)(tc.oldJob), (*Job)(tc.newJob))
   452  			if diff := cmp.Diff(tc.wantErr, gotErr, cmpopts.IgnoreFields(field.Error{})); diff != "" {
   453  				t.Errorf("validateUpdate() mismatch (-want +got):\n%s", diff)
   454  			}
   455  		})
   456  	}
   457  }
   458  
   459  func TestDefault(t *testing.T) {
   460  	testcases := map[string]struct {
   461  		job                        *batchv1.Job
   462  		manageJobsWithoutQueueName bool
   463  		want                       *batchv1.Job
   464  	}{
   465  		"add a parent job name to annotations": {
   466  			job: testingutil.MakeJob("child-job", "default").
   467  				OwnerReference("parent-job", kubeflow.SchemeGroupVersionKind).
   468  				Obj(),
   469  			want: testingutil.MakeJob("child-job", "default").
   470  				OwnerReference("parent-job", kubeflow.SchemeGroupVersionKind).
   471  				ParentWorkload(jobframework.GetWorkloadNameForOwnerWithGVK("parent-job", kubeflow.SchemeGroupVersionKind)).
   472  				Obj(),
   473  		},
   474  		"update the suspend field with 'manageJobsWithoutQueueName=false'": {
   475  			job:  testingutil.MakeJob("job", "default").Queue("queue").Suspend(false).Obj(),
   476  			want: testingutil.MakeJob("job", "default").Queue("queue").Obj(),
   477  		},
   478  		"update the suspend field 'manageJobsWithoutQueueName=true'": {
   479  			job:                        testingutil.MakeJob("job", "default").Suspend(false).Obj(),
   480  			manageJobsWithoutQueueName: true,
   481  			want:                       testingutil.MakeJob("job", "default").Obj(),
   482  		},
   483  		"don't replace parent workload name in annotations": {
   484  			job: testingutil.MakeJob("child-job", "default").
   485  				OwnerReference("parent-job", kubeflow.SchemeGroupVersionKind).
   486  				ParentWorkload("prebuilt-workload").
   487  				Obj(),
   488  			want: testingutil.MakeJob("child-job", "default").
   489  				OwnerReference("parent-job", kubeflow.SchemeGroupVersionKind).
   490  				ParentWorkload("prebuilt-workload").
   491  				Obj(),
   492  		},
   493  	}
   494  	for name, tc := range testcases {
   495  		t.Run(name, func(t *testing.T) {
   496  			w := &JobWebhook{manageJobsWithoutQueueName: tc.manageJobsWithoutQueueName}
   497  			if err := w.Default(context.Background(), tc.job); err != nil {
   498  				t.Errorf("set defaults to a batch/job by a Defaulter")
   499  			}
   500  			if diff := cmp.Diff(tc.want, tc.job); len(diff) != 0 {
   501  				t.Errorf("Default() mismatch (-want,+got):\n%s", diff)
   502  			}
   503  		})
   504  	}
   505  }