github.com/kubeflow/training-operator@v1.7.0/pkg/apis/kubeflow.org/v1/pytorch_defaults.go (about)

     1  // Copyright 2018 The Kubeflow Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package v1
    16  
    17  import (
    18  	corev1 "k8s.io/api/core/v1"
    19  	"k8s.io/apimachinery/pkg/runtime"
    20  )
    21  
    22  var (
    23  	DefaultNprocPerNode = "auto"
    24  )
    25  
    26  func addPyTorchDefaultingFuncs(scheme *runtime.Scheme) error {
    27  	return RegisterDefaults(scheme)
    28  }
    29  
    30  // setPyTorchDefaultPort sets the default ports for pytorch container.
    31  func setPyTorchDefaultPort(spec *corev1.PodSpec) {
    32  	index := getDefaultContainerIndex(spec, PyTorchJobDefaultContainerName)
    33  	if ok := hasDefaultPort(spec, index, PyTorchJobDefaultPortName); !ok {
    34  		setDefaultPort(spec, PyTorchJobDefaultPortName, PyTorchJobDefaultPort, index)
    35  	}
    36  }
    37  
    38  func setElasticPolicy(pytorchJob *PyTorchJob) {
    39  	if pytorchJob.Spec.ElasticPolicy != nil {
    40  		if pytorchJob.Spec.ElasticPolicy.MaxReplicas != nil &&
    41  			pytorchJob.Spec.ElasticPolicy.MinReplicas != nil {
    42  			return
    43  		} else if pytorchJob.Spec.ElasticPolicy.MaxReplicas != nil {
    44  			// Set MinRepliacs to elasticPolicy.MaxReplicas.
    45  			pytorchJob.Spec.ElasticPolicy.MinReplicas = pytorchJob.Spec.ElasticPolicy.MaxReplicas
    46  		} else if pytorchJob.Spec.ElasticPolicy.MinReplicas != nil {
    47  			pytorchJob.Spec.ElasticPolicy.MaxReplicas = pytorchJob.Spec.ElasticPolicy.MinReplicas
    48  		} else {
    49  			workerReplicas := pytorchJob.Spec.PyTorchReplicaSpecs[PyTorchJobReplicaTypeWorker].Replicas
    50  			// Set Min and Max to worker.spec.Replicas.
    51  			pytorchJob.Spec.ElasticPolicy.MaxReplicas = workerReplicas
    52  			pytorchJob.Spec.ElasticPolicy.MinReplicas = workerReplicas
    53  		}
    54  	}
    55  }
    56  
    57  // setPyTorchTypeNamesToCamelCase sets the name of all replica types from any case to correct case.
    58  func setPyTorchTypeNamesToCamelCase(pytorchJob *PyTorchJob) {
    59  	replicaTypes := []ReplicaType{
    60  		PyTorchJobReplicaTypeMaster,
    61  		PyTorchJobReplicaTypeWorker,
    62  	}
    63  	for _, replicaType := range replicaTypes {
    64  		setTypeNameToCamelCase(pytorchJob.Spec.PyTorchReplicaSpecs, replicaType)
    65  	}
    66  }
    67  
    68  func setDefaultNprocPerNode(job *PyTorchJob) {
    69  	if (job.Spec.ElasticPolicy != nil && job.Spec.ElasticPolicy.NProcPerNode == nil) || (job.Spec.ElasticPolicy == nil) {
    70  		if job.Spec.NprocPerNode == nil {
    71  			job.Spec.NprocPerNode = &DefaultNprocPerNode
    72  		}
    73  	}
    74  }
    75  
    76  // SetDefaults_PyTorchJob sets any unspecified values to defaults.
    77  func SetDefaults_PyTorchJob(job *PyTorchJob) {
    78  	// Set default cleanpod policy to None.
    79  	if job.Spec.RunPolicy.CleanPodPolicy == nil {
    80  		job.Spec.RunPolicy.CleanPodPolicy = CleanPodPolicyPointer(CleanPodPolicyNone)
    81  	}
    82  
    83  	// Update the key of PyTorchReplicaSpecs to camel case.
    84  	setPyTorchTypeNamesToCamelCase(job)
    85  
    86  	for _, spec := range job.Spec.PyTorchReplicaSpecs {
    87  		setDefaultReplicas(spec, 1)
    88  		setDefaultRestartPolicy(spec, PyTorchJobDefaultRestartPolicy)
    89  		setPyTorchDefaultPort(&spec.Template.Spec)
    90  	}
    91  	// Set default elastic policy.
    92  	setElasticPolicy(job)
    93  
    94  	// Set default nproc_per_node.
    95  	setDefaultNprocPerNode(job)
    96  }