sigs.k8s.io/kueue@v0.6.2/pkg/controller/jobs/pod/pod_webhook.go (about)

     1  /*
     2  Copyright 2023 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package pod
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  
    24  	corev1 "k8s.io/api/core/v1"
    25  	"k8s.io/apimachinery/pkg/api/validation"
    26  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    27  	"k8s.io/apimachinery/pkg/labels"
    28  	"k8s.io/apimachinery/pkg/runtime"
    29  	"k8s.io/apimachinery/pkg/util/validation/field"
    30  	"k8s.io/klog/v2"
    31  	ctrl "sigs.k8s.io/controller-runtime"
    32  	"sigs.k8s.io/controller-runtime/pkg/client"
    33  	"sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"
    34  	"sigs.k8s.io/controller-runtime/pkg/webhook"
    35  	"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
    36  
    37  	configapi "sigs.k8s.io/kueue/apis/config/v1beta1"
    38  	"sigs.k8s.io/kueue/pkg/controller/jobframework"
    39  )
    40  
    41  const (
    42  	ManagedLabelKey            = "kueue.x-k8s.io/managed"
    43  	ManagedLabelValue          = "true"
    44  	PodFinalizer               = ManagedLabelKey
    45  	GroupNameLabel             = "kueue.x-k8s.io/pod-group-name"
    46  	GroupTotalCountAnnotation  = "kueue.x-k8s.io/pod-group-total-count"
    47  	RoleHashAnnotation         = "kueue.x-k8s.io/role-hash"
    48  	RetriableInGroupAnnotation = "kueue.x-k8s.io/retriable-in-group"
    49  )
    50  
    51  var (
    52  	labelsPath                     = field.NewPath("metadata", "labels")
    53  	annotationsPath                = field.NewPath("metadata", "annotations")
    54  	managedLabelPath               = labelsPath.Key(ManagedLabelKey)
    55  	groupNameLabelPath             = labelsPath.Key(GroupNameLabel)
    56  	groupTotalCountAnnotationPath  = annotationsPath.Key(GroupTotalCountAnnotation)
    57  	retriableInGroupAnnotationPath = annotationsPath.Key(RetriableInGroupAnnotation)
    58  
    59  	errPodOptsTypeAssertion = errors.New("options are not of type PodIntegrationOptions")
    60  	errPodOptsNotFound      = errors.New("podIntegrationOptions not found in options")
    61  )
    62  
    63  type PodWebhook struct {
    64  	client                     client.Client
    65  	manageJobsWithoutQueueName bool
    66  	namespaceSelector          *metav1.LabelSelector
    67  	podSelector                *metav1.LabelSelector
    68  }
    69  
    70  // SetupWebhook configures the webhook for pods.
    71  func SetupWebhook(mgr ctrl.Manager, opts ...jobframework.Option) error {
    72  	options := jobframework.ProcessOptions(opts...)
    73  	podOpts, err := getPodOptions(options.IntegrationOptions)
    74  	if err != nil {
    75  		return err
    76  	}
    77  	wh := &PodWebhook{
    78  		client:                     mgr.GetClient(),
    79  		manageJobsWithoutQueueName: options.ManageJobsWithoutQueueName,
    80  		namespaceSelector:          podOpts.NamespaceSelector,
    81  		podSelector:                podOpts.PodSelector,
    82  	}
    83  	return ctrl.NewWebhookManagedBy(mgr).
    84  		For(&corev1.Pod{}).
    85  		WithDefaulter(wh).
    86  		WithValidator(wh).
    87  		Complete()
    88  }
    89  
    90  func getPodOptions(integrationOpts map[string]any) (configapi.PodIntegrationOptions, error) {
    91  	opts, ok := integrationOpts[corev1.SchemeGroupVersion.WithKind("Pod").String()]
    92  	if !ok {
    93  		return configapi.PodIntegrationOptions{}, errPodOptsNotFound
    94  	}
    95  	podOpts, ok := opts.(*configapi.PodIntegrationOptions)
    96  	if !ok {
    97  		return configapi.PodIntegrationOptions{}, fmt.Errorf("%w, got %T", errPodOptsTypeAssertion, opts)
    98  	}
    99  	return *podOpts, nil
   100  }
   101  
   102  // +kubebuilder:webhook:path=/mutate--v1-pod,mutating=true,failurePolicy=fail,sideEffects=None,groups="",resources=pods,verbs=create,versions=v1,name=mpod.kb.io,admissionReviewVersions=v1
   103  // +kubebuilder:rbac:groups="",resources=namespaces,verbs=get;list;watch
   104  
   105  var _ webhook.CustomDefaulter = &PodWebhook{}
   106  
   107  func containersShape(containers []corev1.Container) (result []map[string]interface{}) {
   108  	for _, c := range containers {
   109  		result = append(result, map[string]interface{}{
   110  			"resources": map[string]interface{}{
   111  				"requests": c.Resources.Requests,
   112  			},
   113  			"ports": c.Ports,
   114  		})
   115  	}
   116  
   117  	return result
   118  }
   119  
   120  // addRoleHash calculates the role hash and adds it to the pod's annotations
   121  func (p *Pod) addRoleHash() error {
   122  	if p.pod.Annotations == nil {
   123  		p.pod.Annotations = make(map[string]string)
   124  	}
   125  
   126  	hash, err := getRoleHash(p.pod)
   127  	if err != nil {
   128  		return err
   129  	}
   130  
   131  	p.pod.Annotations[RoleHashAnnotation] = hash
   132  	return nil
   133  }
   134  
   135  func (w *PodWebhook) Default(ctx context.Context, obj runtime.Object) error {
   136  	pod := fromObject(obj)
   137  	log := ctrl.LoggerFrom(ctx).WithName("pod-webhook").WithValues("pod", klog.KObj(&pod.pod))
   138  	log.V(5).Info("Applying defaults")
   139  
   140  	if IsPodOwnerManagedByKueue(pod) {
   141  		log.V(5).Info("Pod owner is managed by kueue, skipping")
   142  		return nil
   143  	}
   144  
   145  	// Check for pod label selector match
   146  	podSelector, err := metav1.LabelSelectorAsSelector(w.podSelector)
   147  	if err != nil {
   148  		return fmt.Errorf("failed to parse pod selector: %w", err)
   149  	}
   150  	if !podSelector.Matches(labels.Set(pod.pod.GetLabels())) {
   151  		return nil
   152  	}
   153  
   154  	// Get pod namespace and check for namespace label selector match
   155  	ns := corev1.Namespace{}
   156  	err = w.client.Get(ctx, client.ObjectKey{Name: pod.pod.GetNamespace()}, &ns)
   157  	if err != nil {
   158  		return fmt.Errorf("failed to run mutating webhook on pod %s, error while getting namespace: %w",
   159  			pod.pod.GetName(),
   160  			err,
   161  		)
   162  	}
   163  	log.V(5).Info("Found pod namespace", "Namespace.Name", ns.GetName())
   164  	nsSelector, err := metav1.LabelSelectorAsSelector(w.namespaceSelector)
   165  	if err != nil {
   166  		return fmt.Errorf("failed to parse namespace selector: %w", err)
   167  	}
   168  	if !nsSelector.Matches(labels.Set(ns.GetLabels())) {
   169  		return nil
   170  	}
   171  
   172  	if jobframework.QueueName(pod) != "" || w.manageJobsWithoutQueueName {
   173  		controllerutil.AddFinalizer(pod.Object(), PodFinalizer)
   174  
   175  		if pod.pod.Labels == nil {
   176  			pod.pod.Labels = make(map[string]string)
   177  		}
   178  		pod.pod.Labels[ManagedLabelKey] = ManagedLabelValue
   179  
   180  		if gateIndex(&pod.pod) == gateNotFound {
   181  			log.V(5).Info("Adding gate")
   182  			pod.pod.Spec.SchedulingGates = append(pod.pod.Spec.SchedulingGates, corev1.PodSchedulingGate{Name: SchedulingGateName})
   183  		}
   184  
   185  		if podGroupName(pod.pod) != "" {
   186  			if err := pod.addRoleHash(); err != nil {
   187  				return err
   188  			}
   189  		}
   190  	}
   191  
   192  	// copy back to the object
   193  	pod.pod.DeepCopyInto(obj.(*corev1.Pod))
   194  	return nil
   195  }
   196  
   197  // +kubebuilder:webhook:path=/validate--v1-pod,mutating=false,failurePolicy=fail,sideEffects=None,groups="",resources=pods,verbs=create;update,versions=v1,name=vpod.kb.io,admissionReviewVersions=v1
   198  
   199  var _ webhook.CustomValidator = &PodWebhook{}
   200  
   201  func (w *PodWebhook) ValidateCreate(ctx context.Context, obj runtime.Object) (admission.Warnings, error) {
   202  	var warnings admission.Warnings
   203  
   204  	pod := fromObject(obj)
   205  	log := ctrl.LoggerFrom(ctx).WithName("pod-webhook").WithValues("pod", klog.KObj(&pod.pod))
   206  	log.V(5).Info("Validating create")
   207  	allErrs := jobframework.ValidateCreateForQueueName(pod)
   208  
   209  	allErrs = append(allErrs, validateManagedLabel(pod)...)
   210  
   211  	allErrs = append(allErrs, validatePodGroupMetadata(pod)...)
   212  
   213  	if warn := warningForPodManagedLabel(pod); warn != "" {
   214  		warnings = append(warnings, warn)
   215  	}
   216  
   217  	return warnings, allErrs.ToAggregate()
   218  }
   219  
   220  func (w *PodWebhook) ValidateUpdate(ctx context.Context, oldObj, newObj runtime.Object) (admission.Warnings, error) {
   221  	var warnings admission.Warnings
   222  
   223  	oldPod := fromObject(oldObj)
   224  	newPod := fromObject(newObj)
   225  	log := ctrl.LoggerFrom(ctx).WithName("pod-webhook").WithValues("pod", klog.KObj(&newPod.pod))
   226  	log.V(5).Info("Validating update")
   227  	allErrs := jobframework.ValidateUpdateForQueueName(oldPod, newPod)
   228  
   229  	allErrs = append(allErrs, validateManagedLabel(newPod)...)
   230  
   231  	allErrs = append(allErrs, validation.ValidateImmutableField(podGroupName(newPod.pod), podGroupName(oldPod.pod), groupNameLabelPath)...)
   232  
   233  	allErrs = append(allErrs, validatePodGroupMetadata(newPod)...)
   234  
   235  	allErrs = append(allErrs, validateUpdateForRetriableInGroupAnnotation(oldPod, newPod)...)
   236  
   237  	if warn := warningForPodManagedLabel(newPod); warn != "" {
   238  		warnings = append(warnings, warn)
   239  	}
   240  
   241  	return warnings, allErrs.ToAggregate()
   242  }
   243  
   244  func (w *PodWebhook) ValidateDelete(context.Context, runtime.Object) (admission.Warnings, error) {
   245  	return nil, nil
   246  }
   247  
   248  func validateManagedLabel(pod *Pod) field.ErrorList {
   249  	var allErrs field.ErrorList
   250  
   251  	if managedLabel, ok := pod.pod.GetLabels()[ManagedLabelKey]; ok && managedLabel != ManagedLabelValue {
   252  		return append(allErrs, field.Forbidden(managedLabelPath, fmt.Sprintf("managed label value can only be '%s'", ManagedLabelValue)))
   253  	}
   254  
   255  	return allErrs
   256  }
   257  
   258  // warningForPodManagedLabel returns a warning message if the pod has a managed label, and it's parent is managed by kueue
   259  func warningForPodManagedLabel(p *Pod) string {
   260  	if managedLabel := p.pod.GetLabels()[ManagedLabelKey]; managedLabel == ManagedLabelValue && IsPodOwnerManagedByKueue(p) {
   261  		return fmt.Sprintf("pod owner is managed by kueue, label '%s=%s' might lead to unexpected behaviour",
   262  			ManagedLabelKey, ManagedLabelValue)
   263  	}
   264  
   265  	return ""
   266  }
   267  
   268  func validatePodGroupMetadata(p *Pod) field.ErrorList {
   269  	var allErrs field.ErrorList
   270  
   271  	gtc, gtcExists := p.pod.GetAnnotations()[GroupTotalCountAnnotation]
   272  
   273  	if podGroupName(p.pod) == "" {
   274  		if gtcExists {
   275  			return append(allErrs, field.Required(
   276  				groupNameLabelPath,
   277  				fmt.Sprintf("both the '%s' annotation and the '%s' label should be set", GroupTotalCountAnnotation, GroupNameLabel),
   278  			))
   279  		}
   280  	} else {
   281  		allErrs = append(allErrs, jobframework.ValidateLabelAsCRDName(p, GroupNameLabel)...)
   282  
   283  		if !gtcExists {
   284  			return append(allErrs, field.Required(
   285  				groupTotalCountAnnotationPath,
   286  				fmt.Sprintf("both the '%s' annotation and the '%s' label should be set", GroupTotalCountAnnotation, GroupNameLabel),
   287  			))
   288  		}
   289  	}
   290  
   291  	if _, err := p.groupTotalCount(); gtcExists && err != nil {
   292  		return append(allErrs, field.Invalid(
   293  			groupTotalCountAnnotationPath,
   294  			gtc,
   295  			err.Error(),
   296  		))
   297  	}
   298  
   299  	return allErrs
   300  }
   301  
   302  func validateUpdateForRetriableInGroupAnnotation(oldPod, newPod *Pod) field.ErrorList {
   303  	if podGroupName(newPod.pod) != "" && isUnretriablePod(oldPod.pod) && !isUnretriablePod(newPod.pod) {
   304  		return field.ErrorList{
   305  			field.Forbidden(retriableInGroupAnnotationPath, "unretriable pod group can't be converted to retriable"),
   306  		}
   307  	}
   308  
   309  	return field.ErrorList{}
   310  }