github.com/kubeflow/training-operator@v1.7.0/pkg/controller.v1/common/scheduling.go (about) 1 /* 2 Copyright 2023 The Kubeflow Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package common 18 19 import ( 20 "context" 21 "errors" 22 "fmt" 23 24 apiv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" 25 26 "github.com/google/go-cmp/cmp" 27 log "github.com/sirupsen/logrus" 28 policyapi "k8s.io/api/policy/v1beta1" 29 k8serrors "k8s.io/apimachinery/pkg/api/errors" 30 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" 31 "k8s.io/apimachinery/pkg/util/intstr" 32 "k8s.io/klog/v2" 33 "sigs.k8s.io/controller-runtime/pkg/client" 34 ) 35 36 type FillPodGroupSpecFunc func(object metav1.Object) error 37 38 func (jc *JobController) SyncPodGroup(job metav1.Object, specFunc FillPodGroupSpecFunc) (metav1.Object, error) { 39 pgctl := jc.PodGroupControl 40 41 // Check whether podGroup exists or not 42 podGroup, err := pgctl.GetPodGroup(job.GetNamespace(), job.GetName()) 43 if err == nil { 44 // update podGroup for gang scheduling 45 oldPodGroup := &podGroup 46 if err = specFunc(podGroup); err != nil { 47 return nil, fmt.Errorf("unable to fill the spec of PodGroup, '%v': %v", klog.KObj(podGroup), err) 48 } 49 if diff := cmp.Diff(oldPodGroup, podGroup); len(diff) != 0 { 50 return podGroup, pgctl.UpdatePodGroup(podGroup.(client.Object)) 51 } 52 return podGroup, nil 53 } else if client.IgnoreNotFound(err) != nil { 54 return nil, fmt.Errorf("unable to get a PodGroup: %v", err) 55 } else { 56 // create podGroup for gang scheduling 57 newPodGroup := pgctl.NewEmptyPodGroup() 58 newPodGroup.SetName(job.GetName()) 59 newPodGroup.SetNamespace(job.GetNamespace()) 60 newPodGroup.SetAnnotations(job.GetAnnotations()) 61 newPodGroup.SetOwnerReferences([]metav1.OwnerReference{*jc.GenOwnerReference(job)}) 62 if err = specFunc(newPodGroup); err != nil { 63 return nil, fmt.Errorf("unable to fill the spec of PodGroup, '%v': %v", klog.KObj(newPodGroup), err) 64 } 65 66 err = pgctl.CreatePodGroup(newPodGroup) 67 if err != nil { 68 return podGroup, fmt.Errorf("unable to create PodGroup: %v", err) 69 } 70 createdPodGroupsCount.Inc() 71 } 72 73 createdPodGroup, err := pgctl.GetPodGroup(job.GetNamespace(), job.GetName()) 74 if err != nil { 75 return nil, fmt.Errorf("unable to get PodGroup after success creation: %v", err) 76 } 77 78 return createdPodGroup, nil 79 } 80 81 // SyncPdb will create a PDB for gang scheduling. 82 func (jc *JobController) SyncPdb(job metav1.Object, minAvailableReplicas int32) (*policyapi.PodDisruptionBudget, error) { 83 // Check the pdb exist or not 84 pdb, err := jc.KubeClientSet.PolicyV1beta1().PodDisruptionBudgets(job.GetNamespace()).Get(context.TODO(), job.GetName(), metav1.GetOptions{}) 85 if err == nil || !k8serrors.IsNotFound(err) { 86 if err == nil { 87 err = errors.New(string(metav1.StatusReasonAlreadyExists)) 88 } 89 return pdb, err 90 } 91 92 // Create pdb for gang scheduling 93 minAvailable := intstr.FromInt(int(minAvailableReplicas)) 94 createPdb := &policyapi.PodDisruptionBudget{ 95 ObjectMeta: metav1.ObjectMeta{ 96 Name: job.GetName(), 97 OwnerReferences: []metav1.OwnerReference{ 98 *jc.GenOwnerReference(job), 99 }, 100 }, 101 Spec: policyapi.PodDisruptionBudgetSpec{ 102 MinAvailable: &minAvailable, 103 Selector: &metav1.LabelSelector{ 104 MatchLabels: map[string]string{ 105 apiv1.JobNameLabel: job.GetName(), 106 }, 107 }, 108 }, 109 } 110 createdPdb, err := jc.KubeClientSet.PolicyV1beta1().PodDisruptionBudgets(job.GetNamespace()).Create(context.TODO(), createPdb, metav1.CreateOptions{}) 111 if err != nil { 112 return createdPdb, fmt.Errorf("unable to create pdb: %v", err) 113 } 114 createdPDBCount.Inc() 115 return createdPdb, nil 116 } 117 118 func (jc *JobController) DeletePodGroup(job metav1.Object) error { 119 pgctl := jc.PodGroupControl 120 121 // Check whether podGroup exists or not 122 _, err := pgctl.GetPodGroup(job.GetNamespace(), job.GetName()) 123 if err != nil && k8serrors.IsNotFound(err) { 124 return nil 125 } 126 127 log.Infof("Deleting PodGroup %s", job.GetName()) 128 129 // Delete podGroup 130 err = pgctl.DeletePodGroup(job.GetNamespace(), job.GetName()) 131 if err != nil { 132 return fmt.Errorf("unable to delete PodGroup: %v", err) 133 } 134 deletedPodGroupsCount.Inc() 135 return nil 136 } 137 138 func (jc *JobController) DeletePdb(job metav1.Object) error { 139 // Check whether pdb exists or not 140 _, err := jc.KubeClientSet.PolicyV1beta1().PodDisruptionBudgets(job.GetNamespace()).Get(context.TODO(), job.GetName(), metav1.GetOptions{}) 141 if err != nil && k8serrors.IsNotFound(err) { 142 return nil 143 } 144 145 msg := fmt.Sprintf("Deleting pdb %s", job.GetName()) 146 log.Info(msg) 147 148 if err := jc.KubeClientSet.PolicyV1beta1().PodDisruptionBudgets(job.GetNamespace()).Delete(context.TODO(), job.GetName(), metav1.DeleteOptions{}); err != nil { 149 return fmt.Errorf("unable to delete pdb: %v", err) 150 } 151 deletedPDBCount.Inc() 152 return nil 153 }