github.com/kubeflow/training-operator@v1.7.0/pkg/controller.v1/common/scheduling.go (about)

     1  /*
     2  Copyright 2023 The Kubeflow Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package common
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  
    24  	apiv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
    25  
    26  	"github.com/google/go-cmp/cmp"
    27  	log "github.com/sirupsen/logrus"
    28  	policyapi "k8s.io/api/policy/v1beta1"
    29  	k8serrors "k8s.io/apimachinery/pkg/api/errors"
    30  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    31  	"k8s.io/apimachinery/pkg/util/intstr"
    32  	"k8s.io/klog/v2"
    33  	"sigs.k8s.io/controller-runtime/pkg/client"
    34  )
    35  
    36  type FillPodGroupSpecFunc func(object metav1.Object) error
    37  
    38  func (jc *JobController) SyncPodGroup(job metav1.Object, specFunc FillPodGroupSpecFunc) (metav1.Object, error) {
    39  	pgctl := jc.PodGroupControl
    40  
    41  	// Check whether podGroup exists or not
    42  	podGroup, err := pgctl.GetPodGroup(job.GetNamespace(), job.GetName())
    43  	if err == nil {
    44  		// update podGroup for gang scheduling
    45  		oldPodGroup := &podGroup
    46  		if err = specFunc(podGroup); err != nil {
    47  			return nil, fmt.Errorf("unable to fill the spec of PodGroup, '%v': %v", klog.KObj(podGroup), err)
    48  		}
    49  		if diff := cmp.Diff(oldPodGroup, podGroup); len(diff) != 0 {
    50  			return podGroup, pgctl.UpdatePodGroup(podGroup.(client.Object))
    51  		}
    52  		return podGroup, nil
    53  	} else if client.IgnoreNotFound(err) != nil {
    54  		return nil, fmt.Errorf("unable to get a PodGroup: %v", err)
    55  	} else {
    56  		// create podGroup for gang scheduling
    57  		newPodGroup := pgctl.NewEmptyPodGroup()
    58  		newPodGroup.SetName(job.GetName())
    59  		newPodGroup.SetNamespace(job.GetNamespace())
    60  		newPodGroup.SetAnnotations(job.GetAnnotations())
    61  		newPodGroup.SetOwnerReferences([]metav1.OwnerReference{*jc.GenOwnerReference(job)})
    62  		if err = specFunc(newPodGroup); err != nil {
    63  			return nil, fmt.Errorf("unable to fill the spec of PodGroup, '%v': %v", klog.KObj(newPodGroup), err)
    64  		}
    65  
    66  		err = pgctl.CreatePodGroup(newPodGroup)
    67  		if err != nil {
    68  			return podGroup, fmt.Errorf("unable to create PodGroup: %v", err)
    69  		}
    70  		createdPodGroupsCount.Inc()
    71  	}
    72  
    73  	createdPodGroup, err := pgctl.GetPodGroup(job.GetNamespace(), job.GetName())
    74  	if err != nil {
    75  		return nil, fmt.Errorf("unable to get PodGroup after success creation: %v", err)
    76  	}
    77  
    78  	return createdPodGroup, nil
    79  }
    80  
    81  // SyncPdb will create a PDB for gang scheduling.
    82  func (jc *JobController) SyncPdb(job metav1.Object, minAvailableReplicas int32) (*policyapi.PodDisruptionBudget, error) {
    83  	// Check the pdb exist or not
    84  	pdb, err := jc.KubeClientSet.PolicyV1beta1().PodDisruptionBudgets(job.GetNamespace()).Get(context.TODO(), job.GetName(), metav1.GetOptions{})
    85  	if err == nil || !k8serrors.IsNotFound(err) {
    86  		if err == nil {
    87  			err = errors.New(string(metav1.StatusReasonAlreadyExists))
    88  		}
    89  		return pdb, err
    90  	}
    91  
    92  	// Create pdb for gang scheduling
    93  	minAvailable := intstr.FromInt(int(minAvailableReplicas))
    94  	createPdb := &policyapi.PodDisruptionBudget{
    95  		ObjectMeta: metav1.ObjectMeta{
    96  			Name: job.GetName(),
    97  			OwnerReferences: []metav1.OwnerReference{
    98  				*jc.GenOwnerReference(job),
    99  			},
   100  		},
   101  		Spec: policyapi.PodDisruptionBudgetSpec{
   102  			MinAvailable: &minAvailable,
   103  			Selector: &metav1.LabelSelector{
   104  				MatchLabels: map[string]string{
   105  					apiv1.JobNameLabel: job.GetName(),
   106  				},
   107  			},
   108  		},
   109  	}
   110  	createdPdb, err := jc.KubeClientSet.PolicyV1beta1().PodDisruptionBudgets(job.GetNamespace()).Create(context.TODO(), createPdb, metav1.CreateOptions{})
   111  	if err != nil {
   112  		return createdPdb, fmt.Errorf("unable to create pdb: %v", err)
   113  	}
   114  	createdPDBCount.Inc()
   115  	return createdPdb, nil
   116  }
   117  
   118  func (jc *JobController) DeletePodGroup(job metav1.Object) error {
   119  	pgctl := jc.PodGroupControl
   120  
   121  	// Check whether podGroup exists or not
   122  	_, err := pgctl.GetPodGroup(job.GetNamespace(), job.GetName())
   123  	if err != nil && k8serrors.IsNotFound(err) {
   124  		return nil
   125  	}
   126  
   127  	log.Infof("Deleting PodGroup %s", job.GetName())
   128  
   129  	// Delete podGroup
   130  	err = pgctl.DeletePodGroup(job.GetNamespace(), job.GetName())
   131  	if err != nil {
   132  		return fmt.Errorf("unable to delete PodGroup: %v", err)
   133  	}
   134  	deletedPodGroupsCount.Inc()
   135  	return nil
   136  }
   137  
   138  func (jc *JobController) DeletePdb(job metav1.Object) error {
   139  	// Check whether pdb exists or not
   140  	_, err := jc.KubeClientSet.PolicyV1beta1().PodDisruptionBudgets(job.GetNamespace()).Get(context.TODO(), job.GetName(), metav1.GetOptions{})
   141  	if err != nil && k8serrors.IsNotFound(err) {
   142  		return nil
   143  	}
   144  
   145  	msg := fmt.Sprintf("Deleting pdb %s", job.GetName())
   146  	log.Info(msg)
   147  
   148  	if err := jc.KubeClientSet.PolicyV1beta1().PodDisruptionBudgets(job.GetNamespace()).Delete(context.TODO(), job.GetName(), metav1.DeleteOptions{}); err != nil {
   149  		return fmt.Errorf("unable to delete pdb: %v", err)
   150  	}
   151  	deletedPDBCount.Inc()
   152  	return nil
   153  }