github.com/kubeflow/training-operator@v1.7.0/pkg/core/pod.go (about)

     1  /*
     2  Copyright 2023 The Kubeflow Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package core
    18  
    19  import (
    20  	utillabels "github.com/kubeflow/training-operator/pkg/util/labels"
    21  
    22  	apiv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
    23  	log "github.com/sirupsen/logrus"
    24  	v1 "k8s.io/api/core/v1"
    25  	"k8s.io/apimachinery/pkg/labels"
    26  )
    27  
    28  // FilterPodsForReplicaType returns pods belong to a replicaType.
    29  func FilterPodsForReplicaType(pods []*v1.Pod, replicaType string) ([]*v1.Pod, error) {
    30  	var result []*v1.Pod
    31  
    32  	selector := labels.SelectorFromValidatedSet(labels.Set{
    33  		apiv1.ReplicaTypeLabel: replicaType,
    34  	})
    35  
    36  	for _, pod := range pods {
    37  		set := labels.Set(pod.Labels)
    38  		if !selector.Matches(set) {
    39  			continue
    40  		}
    41  		result = append(result, pod)
    42  	}
    43  	return result, nil
    44  }
    45  
    46  // GetPodSlices returns a slice, which element is the slice of pod.
    47  // It gives enough information to caller to make decision to up/down scale resources.
    48  func GetPodSlices(pods []*v1.Pod, replicas int, logger *log.Entry) [][]*v1.Pod {
    49  	podSlices := make([][]*v1.Pod, CalculatePodSliceSize(pods, replicas))
    50  	for _, pod := range pods {
    51  		index, err := utillabels.ReplicaIndex(pod.Labels)
    52  		if err != nil {
    53  			logger.Warningf("Error obtaining replica index from Pod %s/%s: %v", pod.Namespace, pod.Name, err)
    54  			continue
    55  		}
    56  		if index < 0 || index >= replicas {
    57  			logger.Warningf("The label index is not expected: %d, pod: %s/%s", index, pod.Namespace, pod.Name)
    58  		}
    59  
    60  		podSlices[index] = append(podSlices[index], pod)
    61  	}
    62  	return podSlices
    63  }
    64  
    65  // CalculatePodSliceSize compare max pod index with desired replicas and return larger size
    66  func CalculatePodSliceSize(pods []*v1.Pod, replicas int) int {
    67  	size := 0
    68  	for _, pod := range pods {
    69  		index, err := utillabels.ReplicaIndex(pod.Labels)
    70  		if err != nil {
    71  			continue
    72  		}
    73  		size = MaxInt(size, index)
    74  	}
    75  
    76  	// size comes from index, need to +1 to indicate real size
    77  	return MaxInt(size+1, replicas)
    78  }
    79  
    80  // SetRestartPolicy check the RestartPolicy defined in job spec and overwrite RestartPolicy in podTemplate if necessary
    81  func SetRestartPolicy(podTemplateSpec *v1.PodTemplateSpec, spec *apiv1.ReplicaSpec) {
    82  	// This is necessary since restartPolicyExitCode is not supported in v1.PodTemplateSpec
    83  	if spec.RestartPolicy == apiv1.RestartPolicyExitCode {
    84  		podTemplateSpec.Spec.RestartPolicy = v1.RestartPolicyNever
    85  	} else {
    86  		podTemplateSpec.Spec.RestartPolicy = v1.RestartPolicy(spec.RestartPolicy)
    87  	}
    88  }