github.com/kubeflow/training-operator@v1.7.0/pkg/apis/kubeflow.org/v1/pytorch_defaults.go (about) 1 // Copyright 2018 The Kubeflow Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package v1 16 17 import ( 18 corev1 "k8s.io/api/core/v1" 19 "k8s.io/apimachinery/pkg/runtime" 20 ) 21 22 var ( 23 DefaultNprocPerNode = "auto" 24 ) 25 26 func addPyTorchDefaultingFuncs(scheme *runtime.Scheme) error { 27 return RegisterDefaults(scheme) 28 } 29 30 // setPyTorchDefaultPort sets the default ports for pytorch container. 31 func setPyTorchDefaultPort(spec *corev1.PodSpec) { 32 index := getDefaultContainerIndex(spec, PyTorchJobDefaultContainerName) 33 if ok := hasDefaultPort(spec, index, PyTorchJobDefaultPortName); !ok { 34 setDefaultPort(spec, PyTorchJobDefaultPortName, PyTorchJobDefaultPort, index) 35 } 36 } 37 38 func setElasticPolicy(pytorchJob *PyTorchJob) { 39 if pytorchJob.Spec.ElasticPolicy != nil { 40 if pytorchJob.Spec.ElasticPolicy.MaxReplicas != nil && 41 pytorchJob.Spec.ElasticPolicy.MinReplicas != nil { 42 return 43 } else if pytorchJob.Spec.ElasticPolicy.MaxReplicas != nil { 44 // Set MinRepliacs to elasticPolicy.MaxReplicas. 45 pytorchJob.Spec.ElasticPolicy.MinReplicas = pytorchJob.Spec.ElasticPolicy.MaxReplicas 46 } else if pytorchJob.Spec.ElasticPolicy.MinReplicas != nil { 47 pytorchJob.Spec.ElasticPolicy.MaxReplicas = pytorchJob.Spec.ElasticPolicy.MinReplicas 48 } else { 49 workerReplicas := pytorchJob.Spec.PyTorchReplicaSpecs[PyTorchJobReplicaTypeWorker].Replicas 50 // Set Min and Max to worker.spec.Replicas. 51 pytorchJob.Spec.ElasticPolicy.MaxReplicas = workerReplicas 52 pytorchJob.Spec.ElasticPolicy.MinReplicas = workerReplicas 53 } 54 } 55 } 56 57 // setPyTorchTypeNamesToCamelCase sets the name of all replica types from any case to correct case. 58 func setPyTorchTypeNamesToCamelCase(pytorchJob *PyTorchJob) { 59 replicaTypes := []ReplicaType{ 60 PyTorchJobReplicaTypeMaster, 61 PyTorchJobReplicaTypeWorker, 62 } 63 for _, replicaType := range replicaTypes { 64 setTypeNameToCamelCase(pytorchJob.Spec.PyTorchReplicaSpecs, replicaType) 65 } 66 } 67 68 func setDefaultNprocPerNode(job *PyTorchJob) { 69 if (job.Spec.ElasticPolicy != nil && job.Spec.ElasticPolicy.NProcPerNode == nil) || (job.Spec.ElasticPolicy == nil) { 70 if job.Spec.NprocPerNode == nil { 71 job.Spec.NprocPerNode = &DefaultNprocPerNode 72 } 73 } 74 } 75 76 // SetDefaults_PyTorchJob sets any unspecified values to defaults. 77 func SetDefaults_PyTorchJob(job *PyTorchJob) { 78 // Set default cleanpod policy to None. 79 if job.Spec.RunPolicy.CleanPodPolicy == nil { 80 job.Spec.RunPolicy.CleanPodPolicy = CleanPodPolicyPointer(CleanPodPolicyNone) 81 } 82 83 // Update the key of PyTorchReplicaSpecs to camel case. 84 setPyTorchTypeNamesToCamelCase(job) 85 86 for _, spec := range job.Spec.PyTorchReplicaSpecs { 87 setDefaultReplicas(spec, 1) 88 setDefaultRestartPolicy(spec, PyTorchJobDefaultRestartPolicy) 89 setPyTorchDefaultPort(&spec.Template.Spec) 90 } 91 // Set default elastic policy. 92 setElasticPolicy(job) 93 94 // Set default nproc_per_node. 95 setDefaultNprocPerNode(job) 96 }