sigs.k8s.io/kueue@v0.6.2/pkg/webhooks/workload_webhook.go (about)

     1  /*
     2  Copyright 2022 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package webhooks
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"strings"
    23  
    24  	corev1 "k8s.io/api/core/v1"
    25  	apivalidation "k8s.io/apimachinery/pkg/api/validation"
    26  	metav1validation "k8s.io/apimachinery/pkg/apis/meta/v1/validation"
    27  	"k8s.io/apimachinery/pkg/runtime"
    28  	"k8s.io/apimachinery/pkg/util/sets"
    29  	"k8s.io/apimachinery/pkg/util/validation"
    30  	"k8s.io/apimachinery/pkg/util/validation/field"
    31  	"k8s.io/klog/v2"
    32  	"k8s.io/utils/ptr"
    33  	ctrl "sigs.k8s.io/controller-runtime"
    34  	"sigs.k8s.io/controller-runtime/pkg/webhook"
    35  	"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
    36  
    37  	kueue "sigs.k8s.io/kueue/apis/kueue/v1beta1"
    38  	"sigs.k8s.io/kueue/pkg/features"
    39  	"sigs.k8s.io/kueue/pkg/util/slices"
    40  	"sigs.k8s.io/kueue/pkg/workload"
    41  )
    42  
    43  type WorkloadWebhook struct{}
    44  
    45  func setupWebhookForWorkload(mgr ctrl.Manager) error {
    46  	return ctrl.NewWebhookManagedBy(mgr).
    47  		For(&kueue.Workload{}).
    48  		WithDefaulter(&WorkloadWebhook{}).
    49  		WithValidator(&WorkloadWebhook{}).
    50  		Complete()
    51  }
    52  
    53  // +kubebuilder:webhook:path=/mutate-kueue-x-k8s-io-v1beta1-workload,mutating=true,failurePolicy=fail,sideEffects=None,groups=kueue.x-k8s.io,resources=workloads,verbs=create,versions=v1beta1,name=mworkload.kb.io,admissionReviewVersions=v1
    54  
    55  var _ webhook.CustomDefaulter = &WorkloadWebhook{}
    56  
    57  // Default implements webhook.CustomDefaulter so a webhook will be registered for the type
    58  func (w *WorkloadWebhook) Default(ctx context.Context, obj runtime.Object) error {
    59  	wl := obj.(*kueue.Workload)
    60  	log := ctrl.LoggerFrom(ctx).WithName("workload-webhook")
    61  	log.V(5).Info("Applying defaults", "workload", klog.KObj(wl))
    62  
    63  	// Only when we have one podSet and its name is empty,
    64  	// we'll set it to the default name `main`.
    65  	if len(wl.Spec.PodSets) == 1 {
    66  		podSet := &wl.Spec.PodSets[0]
    67  		if len(podSet.Name) == 0 {
    68  			podSet.Name = kueue.DefaultPodSetName
    69  		}
    70  	}
    71  
    72  	// drop minCounts if PartialAdmission is not enabled
    73  	if !features.Enabled(features.PartialAdmission) {
    74  		for i := range wl.Spec.PodSets {
    75  			wl.Spec.PodSets[i].MinCount = nil
    76  		}
    77  	}
    78  
    79  	return nil
    80  }
    81  
    82  // +kubebuilder:webhook:path=/validate-kueue-x-k8s-io-v1beta1-workload,mutating=false,failurePolicy=fail,sideEffects=None,groups=kueue.x-k8s.io,resources=workloads;workloads/status,verbs=create;update,versions=v1beta1,name=vworkload.kb.io,admissionReviewVersions=v1
    83  
    84  var _ webhook.CustomValidator = &WorkloadWebhook{}
    85  
    86  // ValidateCreate implements webhook.CustomValidator so a webhook will be registered for the type
    87  func (w *WorkloadWebhook) ValidateCreate(ctx context.Context, obj runtime.Object) (admission.Warnings, error) {
    88  	wl := obj.(*kueue.Workload)
    89  	log := ctrl.LoggerFrom(ctx).WithName("workload-webhook")
    90  	log.V(5).Info("Validating create", "workload", klog.KObj(wl))
    91  	return nil, ValidateWorkload(wl).ToAggregate()
    92  }
    93  
    94  // ValidateUpdate implements webhook.CustomValidator so a webhook will be registered for the type
    95  func (w *WorkloadWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Object) (admission.Warnings, error) {
    96  	newWL := newObj.(*kueue.Workload)
    97  	oldWL := oldObj.(*kueue.Workload)
    98  	log := ctrl.LoggerFrom(ctx).WithName("workload-webhook")
    99  	log.V(5).Info("Validating update", "workload", klog.KObj(newWL))
   100  	return nil, ValidateWorkloadUpdate(newWL, oldWL).ToAggregate()
   101  }
   102  
   103  // ValidateDelete implements webhook.CustomValidator so a webhook will be registered for the type
   104  func (w *WorkloadWebhook) ValidateDelete(ctx context.Context, obj runtime.Object) (admission.Warnings, error) {
   105  	return nil, nil
   106  }
   107  
   108  func ValidateWorkload(obj *kueue.Workload) field.ErrorList {
   109  	var allErrs field.ErrorList
   110  	specPath := field.NewPath("spec")
   111  
   112  	variableCountPosets := 0
   113  	for i := range obj.Spec.PodSets {
   114  		ps := &obj.Spec.PodSets[i]
   115  		allErrs = append(allErrs, validatePodSet(ps, specPath.Child("podSets").Index(i))...)
   116  		if ps.MinCount != nil {
   117  			variableCountPosets++
   118  		}
   119  	}
   120  
   121  	if variableCountPosets > 1 {
   122  		allErrs = append(allErrs, field.Invalid(specPath.Child("podSets"), variableCountPosets, "at most one podSet can use minCount"))
   123  	}
   124  
   125  	if len(obj.Spec.PriorityClassName) > 0 {
   126  		msgs := validation.IsDNS1123Subdomain(obj.Spec.PriorityClassName)
   127  		if len(msgs) > 0 {
   128  			for _, msg := range msgs {
   129  				allErrs = append(allErrs, field.Invalid(specPath.Child("priorityClassName"), obj.Spec.PriorityClassName, msg))
   130  			}
   131  		}
   132  		if obj.Spec.Priority == nil {
   133  			allErrs = append(allErrs, field.Invalid(specPath.Child("priority"), obj.Spec.Priority, "priority should not be nil when priorityClassName is set"))
   134  		}
   135  	}
   136  
   137  	if len(obj.Spec.QueueName) > 0 {
   138  		allErrs = append(allErrs, validateNameReference(obj.Spec.QueueName, specPath.Child("queueName"))...)
   139  	}
   140  
   141  	statusPath := field.NewPath("status")
   142  	if workload.HasQuotaReservation(obj) {
   143  		allErrs = append(allErrs, validateAdmission(obj, statusPath.Child("admission"))...)
   144  	}
   145  
   146  	allErrs = append(allErrs, metav1validation.ValidateConditions(obj.Status.Conditions, statusPath.Child("conditions"))...)
   147  	allErrs = append(allErrs, validateReclaimablePods(obj, statusPath.Child("reclaimablePods"))...)
   148  	allErrs = append(allErrs, validateAdmissionChecks(obj, statusPath.Child("admissionChecks"))...)
   149  
   150  	return allErrs
   151  }
   152  
   153  func validatePodSet(ps *kueue.PodSet, path *field.Path) field.ErrorList {
   154  	var allErrs field.ErrorList
   155  	// Apply the same validation as container names.
   156  	for _, msg := range validation.IsDNS1123Label(ps.Name) {
   157  		allErrs = append(allErrs, field.Invalid(path.Child("name"), ps.Name, msg))
   158  	}
   159  
   160  	// validate initContainers
   161  	icPath := path.Child("template", "spec", "initContainers")
   162  	for ci := range ps.Template.Spec.InitContainers {
   163  		allErrs = append(allErrs, validateContainer(&ps.Template.Spec.InitContainers[ci], icPath.Index(ci))...)
   164  	}
   165  	// validate containers
   166  	cPath := path.Child("template", "spec", "containers")
   167  	for ci := range ps.Template.Spec.Containers {
   168  		allErrs = append(allErrs, validateContainer(&ps.Template.Spec.Containers[ci], cPath.Index(ci))...)
   169  	}
   170  
   171  	if min := ptr.Deref(ps.MinCount, ps.Count); min > ps.Count || min < 0 {
   172  		allErrs = append(allErrs, field.Forbidden(path.Child("minCount"), fmt.Sprintf("%d should be positive and less or equal to %d", min, ps.Count)))
   173  	}
   174  
   175  	return allErrs
   176  }
   177  
   178  func validateContainer(c *corev1.Container, path *field.Path) field.ErrorList {
   179  	var allErrs field.ErrorList
   180  	rPath := path.Child("resources", "requests")
   181  	for name := range c.Resources.Requests {
   182  		if name == corev1.ResourcePods {
   183  			allErrs = append(allErrs, field.Invalid(rPath.Key(string(name)), corev1.ResourcePods, "the key is reserved for internal kueue use"))
   184  		}
   185  	}
   186  	return allErrs
   187  }
   188  
   189  func validateAdmissionChecks(obj *kueue.Workload, basePath *field.Path) field.ErrorList {
   190  	var allErrs field.ErrorList
   191  	for i := range obj.Status.AdmissionChecks {
   192  		admissionChecksPath := basePath.Index(i)
   193  		ac := &obj.Status.AdmissionChecks[i]
   194  		if len(ac.PodSetUpdates) > 0 && len(ac.PodSetUpdates) != len(obj.Spec.PodSets) {
   195  			allErrs = append(allErrs, field.Invalid(admissionChecksPath.Child("podSetUpdates"), field.OmitValueType{}, "must have the same number of podSetUpdates as the podSets"))
   196  		}
   197  		allErrs = append(allErrs, validatePodSetUpdates(ac, obj, admissionChecksPath.Child("podSetUpdates"))...)
   198  	}
   199  	return allErrs
   200  }
   201  
   202  func validatePodSetUpdates(acs *kueue.AdmissionCheckState, obj *kueue.Workload, basePath *field.Path) field.ErrorList {
   203  	var allErrs field.ErrorList
   204  
   205  	knowPodSets := sets.New(slices.Map(obj.Spec.PodSets, func(ps *kueue.PodSet) string {
   206  		return ps.Name
   207  	})...)
   208  
   209  	for i := range acs.PodSetUpdates {
   210  		psu := &acs.PodSetUpdates[i]
   211  		psuPath := basePath.Index(i)
   212  		if !knowPodSets.Has(psu.Name) {
   213  			allErrs = append(allErrs, field.NotSupported(psuPath.Child("name"), psu.Name, sets.List(knowPodSets)))
   214  		}
   215  		allErrs = append(allErrs, validateTolerations(psu.Tolerations, psuPath.Child("tolerations"))...)
   216  		allErrs = append(allErrs, apivalidation.ValidateAnnotations(psu.Annotations, psuPath.Child("annotations"))...)
   217  		allErrs = append(allErrs, metav1validation.ValidateLabels(psu.NodeSelector, psuPath.Child("nodeSelector"))...)
   218  		allErrs = append(allErrs, metav1validation.ValidateLabels(psu.Labels, psuPath.Child("labels"))...)
   219  	}
   220  	return allErrs
   221  }
   222  
   223  func validateImmutablePodSetUpdates(newObj, oldObj *kueue.Workload, basePath *field.Path) field.ErrorList {
   224  	var allErrs field.ErrorList
   225  	newAcs := slices.ToRefMap(newObj.Status.AdmissionChecks, func(f *kueue.AdmissionCheckState) string { return f.Name })
   226  	for i := range oldObj.Status.AdmissionChecks {
   227  		oldAc := &oldObj.Status.AdmissionChecks[i]
   228  		newAc, found := newAcs[oldAc.Name]
   229  		if !found {
   230  			continue
   231  		}
   232  		if oldAc.State == kueue.CheckStateReady && newAc.State == kueue.CheckStateReady {
   233  			allErrs = append(allErrs, apivalidation.ValidateImmutableField(newAc.PodSetUpdates, oldAc.PodSetUpdates, basePath.Index(i).Child("podSetUpdates"))...)
   234  		}
   235  	}
   236  	return allErrs
   237  }
   238  
   239  // validateTolerations is extracted from git.k8s.io/kubernetes/pkg/apis/core/validation/validation.go
   240  // we do not import it as dependency, see the comment:
   241  // https://github.com/kubernetes/kubernetes/issues/79384#issuecomment-505627280
   242  func validateTolerations(tolerations []corev1.Toleration, fldPath *field.Path) field.ErrorList {
   243  	allErrors := field.ErrorList{}
   244  	for i, toleration := range tolerations {
   245  		idxPath := fldPath.Index(i)
   246  		// validate the toleration key
   247  		if len(toleration.Key) > 0 {
   248  			allErrors = append(allErrors, metav1validation.ValidateLabelName(toleration.Key, idxPath.Child("key"))...)
   249  		}
   250  
   251  		// empty toleration key with Exists operator and empty value means match all taints
   252  		if len(toleration.Key) == 0 && toleration.Operator != corev1.TolerationOpExists {
   253  			allErrors = append(allErrors, field.Invalid(idxPath.Child("operator"), toleration.Operator,
   254  				"operator must be Exists when `key` is empty, which means \"match all values and all keys\""))
   255  		}
   256  
   257  		if toleration.TolerationSeconds != nil && toleration.Effect != corev1.TaintEffectNoExecute {
   258  			allErrors = append(allErrors, field.Invalid(idxPath.Child("effect"), toleration.Effect,
   259  				"effect must be 'NoExecute' when `tolerationSeconds` is set"))
   260  		}
   261  
   262  		// validate toleration operator and value
   263  		switch toleration.Operator {
   264  		// empty operator means Equal
   265  		case corev1.TolerationOpEqual, "":
   266  			if errs := validation.IsValidLabelValue(toleration.Value); len(errs) != 0 {
   267  				allErrors = append(allErrors, field.Invalid(idxPath.Child("operator"), toleration.Value, strings.Join(errs, ";")))
   268  			}
   269  		case corev1.TolerationOpExists:
   270  			if len(toleration.Value) > 0 {
   271  				allErrors = append(allErrors, field.Invalid(idxPath.Child("operator"), toleration, "value must be empty when `operator` is 'Exists'"))
   272  			}
   273  		default:
   274  			validValues := []string{string(corev1.TolerationOpEqual), string(corev1.TolerationOpExists)}
   275  			allErrors = append(allErrors, field.NotSupported(idxPath.Child("operator"), toleration.Operator, validValues))
   276  		}
   277  
   278  		// validate toleration effect, empty toleration effect means match all taint effects
   279  		if len(toleration.Effect) > 0 {
   280  			allErrors = append(allErrors, validateTaintEffect(&toleration.Effect, true, idxPath.Child("effect"))...)
   281  		}
   282  	}
   283  	return allErrors
   284  }
   285  
   286  func validateAdmission(obj *kueue.Workload, path *field.Path) field.ErrorList {
   287  	admission := obj.Status.Admission
   288  	var allErrs field.ErrorList
   289  	allErrs = append(allErrs, validateNameReference(string(admission.ClusterQueue), path.Child("clusterQueue"))...)
   290  
   291  	names := sets.New[string]()
   292  	for _, ps := range obj.Spec.PodSets {
   293  		names.Insert(ps.Name)
   294  	}
   295  	assigmentsPath := path.Child("podSetAssignments")
   296  	if names.Len() != len(admission.PodSetAssignments) {
   297  		allErrs = append(allErrs, field.Invalid(assigmentsPath, field.OmitValueType{}, "must have the same number of podSets as the spec"))
   298  	}
   299  
   300  	for i, ps := range admission.PodSetAssignments {
   301  		psaPath := assigmentsPath.Index(i)
   302  		if !names.Has(ps.Name) {
   303  			allErrs = append(allErrs, field.NotFound(psaPath.Child("name"), ps.Name))
   304  		}
   305  		if count := ptr.Deref(ps.Count, 0); count > 0 {
   306  			for k, v := range ps.ResourceUsage {
   307  				if (workload.ResourceValue(k, v) % int64(count)) != 0 {
   308  					allErrs = append(allErrs, field.Invalid(psaPath.Child("resourceUsage").Key(string(k)), v, fmt.Sprintf("is not a multiple of %d", ps.Count)))
   309  				}
   310  			}
   311  		}
   312  	}
   313  
   314  	return allErrs
   315  }
   316  
   317  func validateReclaimablePods(obj *kueue.Workload, basePath *field.Path) field.ErrorList {
   318  	if len(obj.Status.ReclaimablePods) == 0 {
   319  		return nil
   320  	}
   321  	knowPodSets := make(map[string]*kueue.PodSet, len(obj.Spec.PodSets))
   322  	knowPodSetNames := make([]string, len(obj.Spec.PodSets))
   323  	for i := range obj.Spec.PodSets {
   324  		name := obj.Spec.PodSets[i].Name
   325  		knowPodSets[name] = &obj.Spec.PodSets[i]
   326  		knowPodSetNames = append(knowPodSetNames, name)
   327  	}
   328  
   329  	var ret field.ErrorList
   330  	for i := range obj.Status.ReclaimablePods {
   331  		rps := &obj.Status.ReclaimablePods[i]
   332  		ps, found := knowPodSets[rps.Name]
   333  		rpsPath := basePath.Key(rps.Name)
   334  		if !found {
   335  			ret = append(ret, field.NotSupported(rpsPath.Child("name"), rps.Name, knowPodSetNames))
   336  		} else if rps.Count > ps.Count {
   337  			ret = append(ret, field.Invalid(rpsPath.Child("count"), rps.Count, fmt.Sprintf("should be less or equal to %d", ps.Count)))
   338  		}
   339  	}
   340  	return ret
   341  }
   342  
   343  func ValidateWorkloadUpdate(newObj, oldObj *kueue.Workload) field.ErrorList {
   344  	var allErrs field.ErrorList
   345  	specPath := field.NewPath("spec")
   346  	statusPath := field.NewPath("status")
   347  	allErrs = append(allErrs, ValidateWorkload(newObj)...)
   348  
   349  	if workload.HasQuotaReservation(oldObj) {
   350  		allErrs = append(allErrs, apivalidation.ValidateImmutableField(newObj.Spec.PodSets, oldObj.Spec.PodSets, specPath.Child("podSets"))...)
   351  		allErrs = append(allErrs, apivalidation.ValidateImmutableField(newObj.Spec.PriorityClassSource, oldObj.Spec.PriorityClassSource, specPath.Child("priorityClassSource"))...)
   352  		allErrs = append(allErrs, apivalidation.ValidateImmutableField(newObj.Spec.PriorityClassName, oldObj.Spec.PriorityClassName, specPath.Child("priorityClassName"))...)
   353  	}
   354  	if workload.HasQuotaReservation(newObj) && workload.HasQuotaReservation(oldObj) {
   355  		allErrs = append(allErrs, apivalidation.ValidateImmutableField(newObj.Spec.QueueName, oldObj.Spec.QueueName, specPath.Child("queueName"))...)
   356  		allErrs = append(allErrs, validateReclaimablePodsUpdate(newObj, oldObj, field.NewPath("status", "reclaimablePods"))...)
   357  	}
   358  	allErrs = append(allErrs, validateAdmissionUpdate(newObj.Status.Admission, oldObj.Status.Admission, field.NewPath("status", "admission"))...)
   359  	allErrs = append(allErrs, validateImmutablePodSetUpdates(newObj, oldObj, statusPath.Child("admissionChecks"))...)
   360  
   361  	return allErrs
   362  }
   363  
   364  // validateAdmissionUpdate validates that admission can be set or unset, but the
   365  // fields within can't change.
   366  func validateAdmissionUpdate(new, old *kueue.Admission, path *field.Path) field.ErrorList {
   367  	if old == nil || new == nil {
   368  		return nil
   369  	}
   370  	return apivalidation.ValidateImmutableField(new, old, path)
   371  }
   372  
   373  // validateReclaimablePodsUpdate validates that the reclaimable counts do not decrease, this should be checked
   374  // while the workload is admitted.
   375  func validateReclaimablePodsUpdate(newObj, oldObj *kueue.Workload, basePath *field.Path) field.ErrorList {
   376  	if workload.ReclaimablePodsAreEqual(newObj.Status.ReclaimablePods, oldObj.Status.ReclaimablePods) {
   377  		return nil
   378  	}
   379  
   380  	if len(oldObj.Status.ReclaimablePods) == 0 {
   381  		return nil
   382  	}
   383  
   384  	knowPodSets := make(map[string]*kueue.ReclaimablePod, len(oldObj.Status.ReclaimablePods))
   385  	for i := range oldObj.Status.ReclaimablePods {
   386  		name := oldObj.Status.ReclaimablePods[i].Name
   387  		knowPodSets[name] = &oldObj.Status.ReclaimablePods[i]
   388  	}
   389  
   390  	var ret field.ErrorList
   391  	newNames := sets.New[string]()
   392  	for i := range newObj.Status.ReclaimablePods {
   393  		newCount := &newObj.Status.ReclaimablePods[i]
   394  		newNames.Insert(newCount.Name)
   395  		if !workload.HasQuotaReservation(newObj) && newCount.Count == 0 {
   396  			continue
   397  		}
   398  		oldCount, found := knowPodSets[newCount.Name]
   399  		if found && newCount.Count < oldCount.Count {
   400  			ret = append(ret, field.Invalid(basePath.Key(newCount.Name).Child("count"), newCount.Count, fmt.Sprintf("cannot be less then %d", oldCount.Count)))
   401  		}
   402  	}
   403  
   404  	for name := range knowPodSets {
   405  		if workload.HasQuotaReservation(newObj) && !newNames.Has(name) {
   406  			ret = append(ret, field.Required(basePath.Key(name), "cannot be removed"))
   407  		}
   408  	}
   409  	return ret
   410  }