github.com/kubeflow/training-operator@v1.7.0/pkg/controller.v1/pytorch/envvar.go (about)

     1  // Copyright 2021 The Kubeflow Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License
    14  
    15  package pytorch
    16  
    17  import (
    18  	"fmt"
    19  	"strconv"
    20  	"strings"
    21  
    22  	corev1 "k8s.io/api/core/v1"
    23  
    24  	kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
    25  )
    26  
    27  const (
    28  	// Worker/node size related arguments.
    29  
    30  	// EnvNprocPerNode is the environment variable name for the number of processes per node.
    31  	EnvNprocPerNode = "PET_NPROC_PER_NODE"
    32  	// EnvNnodes is the environment variable name for the number of nodes.
    33  	EnvNnodes = "PET_NNODES"
    34  	// EnvNodeRank is the environment variable name for the rank of nodes.
    35  	EnvNodeRank = "PET_NODE_RANK"
    36  )
    37  
    38  // EnvVarGenerator is the environment variable generator interface.
    39  type EnvVarGenerator interface {
    40  	Generate(job *kubeflowv1.PyTorchJob) ([]corev1.EnvVar, error)
    41  }
    42  
    43  func setPodEnv(obj interface{}, podTemplateSpec *corev1.PodTemplateSpec, rtype, index string) error {
    44  	pytorchjob, ok := obj.(*kubeflowv1.PyTorchJob)
    45  	if !ok {
    46  		return fmt.Errorf("%+v is not a type of PyTorchJob", obj)
    47  	}
    48  
    49  	for i := range podTemplateSpec.Spec.Containers {
    50  		// Initialize the environment variables.
    51  		if len(podTemplateSpec.Spec.Containers[i].Env) == 0 {
    52  			podTemplateSpec.Spec.Containers[i].Env = make([]corev1.EnvVar, 0)
    53  		}
    54  		// Set PYTHONUNBUFFERED to true, to disable output buffering.
    55  		// Ref https://stackoverflow.com/questions/59812009/what-is-the-use-of-pythonunbuffered-in-docker-file.
    56  		podTemplateSpec.Spec.Containers[i].Env = append(
    57  			podTemplateSpec.Spec.Containers[i].Env, corev1.EnvVar{
    58  				Name:  "PYTHONUNBUFFERED",
    59  				Value: "1",
    60  			})
    61  
    62  		totalReplicas := getTotalReplicas(pytorchjob)
    63  		nprocPerNode := getNprocPerNodeInt(pytorchjob)
    64  		worldSize := int(totalReplicas) * nprocPerNode
    65  
    66  		// If the master is not null, then we need to set the MASTER_ADDR and RANK.
    67  		if pytorchjob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster] != nil {
    68  			envVars, err := GetMasterEnvVarGenerator().Generate(pytorchjob)
    69  			if err != nil {
    70  				return err
    71  			}
    72  			// Set master related environment variables.
    73  			podTemplateSpec.Spec.Containers[i].Env = append(
    74  				podTemplateSpec.Spec.Containers[i].Env, envVars...)
    75  
    76  			// Set world size and rank.
    77  			rank, err := strconv.Atoi(index)
    78  			if err != nil {
    79  				return err
    80  			}
    81  			if rtype == strings.ToLower(string(kubeflowv1.PyTorchJobReplicaTypeWorker)) {
    82  				rank = rank + 1
    83  			}
    84  
    85  			podTemplateSpec.Spec.Containers[i].Env = append(podTemplateSpec.Spec.Containers[i].Env, corev1.EnvVar{
    86  				Name:  "WORLD_SIZE",
    87  				Value: strconv.Itoa(worldSize),
    88  			})
    89  			podTemplateSpec.Spec.Containers[i].Env = append(podTemplateSpec.Spec.Containers[i].Env, corev1.EnvVar{
    90  				Name:  "RANK",
    91  				Value: strconv.Itoa(rank),
    92  			})
    93  			podTemplateSpec.Spec.Containers[i].Env = append(podTemplateSpec.Spec.Containers[i].Env, corev1.EnvVar{
    94  				Name:  EnvNprocPerNode,
    95  				Value: *pytorchjob.Spec.NprocPerNode,
    96  			})
    97  			podTemplateSpec.Spec.Containers[i].Env = append(podTemplateSpec.Spec.Containers[i].Env, corev1.EnvVar{
    98  				Name:  EnvNodeRank,
    99  				Value: strconv.Itoa(rank),
   100  			})
   101  		}
   102  
   103  		// Set the elastic environment variables if the elasticPolicy is not null.
   104  		// nnodes is set in range format in elastic mode, e.g. nnodes=1:4
   105  		// otherwise, nnodes is set by int, e.g. nnodes=2
   106  		if pytorchjob.Spec.ElasticPolicy != nil {
   107  			envVars, err := GetElasticEnvVarGenerator().Generate(pytorchjob)
   108  			if err != nil {
   109  				return err
   110  			}
   111  			// Set elastic related environment variables.
   112  			podTemplateSpec.Spec.Containers[i].Env = append(
   113  				podTemplateSpec.Spec.Containers[i].Env, envVars...)
   114  		} else {
   115  			podTemplateSpec.Spec.Containers[i].Env = append(
   116  				podTemplateSpec.Spec.Containers[i].Env, corev1.EnvVar{
   117  					Name:  EnvNnodes,
   118  					Value: strconv.Itoa(int(totalReplicas)),
   119  				})
   120  		}
   121  	}
   122  
   123  	return nil
   124  }
   125  
   126  // getNprocPerNodeInt return the int value of NprocPerNode, return 1 if not int
   127  // When nproc_per_node set to auto, it means the number of process will be determinated
   128  // in the user process phase, in this case, world size env will not be used.
   129  func getNprocPerNodeInt(job *kubeflowv1.PyTorchJob) int {
   130  	if job.Spec.NprocPerNode == nil {
   131  		return 1
   132  	}
   133  	if np, err := strconv.Atoi(*job.Spec.NprocPerNode); err == nil {
   134  		return np
   135  	}
   136  	return 1
   137  }
   138  
   139  func getTotalReplicas(job *kubeflowv1.PyTorchJob) int32 {
   140  	jobReplicas := int32(0)
   141  	for _, r := range job.Spec.PyTorchReplicaSpecs {
   142  		jobReplicas += *r.Replicas
   143  	}
   144  	return jobReplicas
   145  }
   146  
   147  func replicaName(jobName string, rtype kubeflowv1.ReplicaType, index int) string {
   148  	n := jobName + "-" + strings.ToLower(string(rtype)) + "-" + strconv.Itoa(index)
   149  	return strings.Replace(n, "/", "-", -1)
   150  }
   151  
   152  func getPortFromPyTorchJob(job *kubeflowv1.PyTorchJob, rtype kubeflowv1.ReplicaType) (int32, error) {
   153  	containers := job.Spec.PyTorchReplicaSpecs[rtype].Template.Spec.Containers
   154  	for _, container := range containers {
   155  		if container.Name == kubeflowv1.PyTorchJobDefaultContainerName {
   156  			ports := container.Ports
   157  			for _, port := range ports {
   158  				if port.Name == kubeflowv1.PyTorchJobDefaultPortName {
   159  					return port.ContainerPort, nil
   160  				}
   161  			}
   162  		}
   163  	}
   164  	return -1, fmt.Errorf("port not found")
   165  }