github.com/kubeflow/training-operator@v1.7.0/pkg/controller.v1/pytorch/hpa.go (about)

     1  // Copyright 2021 The Kubeflow Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License
    14  
    15  package pytorch
    16  
    17  import (
    18  	"context"
    19  
    20  	autoscalingv2 "k8s.io/api/autoscaling/v2"
    21  	"k8s.io/apimachinery/pkg/api/equality"
    22  	"k8s.io/apimachinery/pkg/api/errors"
    23  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    24  	"k8s.io/apimachinery/pkg/runtime"
    25  	"k8s.io/klog/v2"
    26  	controllerruntime "sigs.k8s.io/controller-runtime"
    27  	"sigs.k8s.io/controller-runtime/pkg/client"
    28  
    29  	kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
    30  	trainutil "github.com/kubeflow/training-operator/pkg/util/train"
    31  )
    32  
    33  func (r *PyTorchJobReconciler) ReconcileHPA(pytorchJob *kubeflowv1.PyTorchJob) error {
    34  	logger := r.Log.WithValues(kubeflowv1.PyTorchJobSingular, pytorchJob.Name)
    35  
    36  	if pytorchJob.Spec.ElasticPolicy == nil || pytorchJob.Spec.ElasticPolicy.Metrics == nil {
    37  		logger.V(1).Info(
    38  			"No ElasicPolicy or Metric is specified, skipping HPA reconciling process")
    39  		return nil
    40  	}
    41  
    42  	current := &autoscalingv2.HorizontalPodAutoscaler{}
    43  
    44  	// Get the expected HPA.
    45  	expected, err := desiredHPA(pytorchJob, r.Scheme)
    46  	if err != nil {
    47  		return err
    48  	}
    49  
    50  	err = r.Get(context.TODO(), client.ObjectKeyFromObject(expected), current)
    51  	if err != nil {
    52  		if errors.IsNotFound(err) {
    53  			if trainutil.IsJobSuspended(&pytorchJob.Spec.RunPolicy) {
    54  				// If the job is suspended, it's correct behavior that HPA doesn't exist.
    55  				return nil
    56  			}
    57  			// Create the new HPA.
    58  			logger.V(1).Info("Creating HPA", "namespace", expected.Namespace, "name", expected.Name)
    59  			return r.Create(context.TODO(), expected)
    60  		}
    61  		return err
    62  	}
    63  	if trainutil.IsJobSuspended(&pytorchJob.Spec.RunPolicy) {
    64  		// Delete the current HPA
    65  		logger.V(1).Info("Deleting HPA", "HorizontalPodAutoscaler", klog.KObj(current))
    66  		return r.Delete(context.TODO(), current)
    67  	}
    68  
    69  	if !equality.Semantic.DeepEqual(expected.Spec, current.Spec) {
    70  		logger.V(1).Info("Updating HPA", "namespace", current.Namespace, "name", current.Name)
    71  		expected.ResourceVersion = current.ResourceVersion
    72  		err = r.Update(context.TODO(), expected)
    73  		if err != nil {
    74  			return err
    75  		}
    76  	}
    77  	return nil
    78  }
    79  
    80  func desiredHPA(pytorchJob *kubeflowv1.PyTorchJob, scheme *runtime.Scheme) (
    81  	*autoscalingv2.HorizontalPodAutoscaler, error) {
    82  	hpa := &autoscalingv2.HorizontalPodAutoscaler{
    83  		ObjectMeta: metav1.ObjectMeta{
    84  			Name:      pytorchJob.Name,
    85  			Namespace: pytorchJob.Namespace,
    86  		},
    87  		Spec: autoscalingv2.HorizontalPodAutoscalerSpec{
    88  			ScaleTargetRef: autoscalingv2.CrossVersionObjectReference{
    89  				Kind:       pytorchJob.Kind,
    90  				Name:       pytorchJob.Name,
    91  				APIVersion: pytorchJob.APIVersion,
    92  			},
    93  			MinReplicas: pytorchJob.Spec.ElasticPolicy.MinReplicas,
    94  			MaxReplicas: *pytorchJob.Spec.ElasticPolicy.MaxReplicas,
    95  			Metrics:     pytorchJob.Spec.ElasticPolicy.Metrics,
    96  		},
    97  	}
    98  	if err := controllerruntime.SetControllerReference(pytorchJob, hpa, scheme); err != nil {
    99  		return nil, err
   100  	}
   101  	return hpa, nil
   102  }