github.com/kubeflow/training-operator@v1.7.0/pkg/controller.v1/pytorch/elastic.go (about)

     1  // Copyright 2021 The Kubeflow Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License
    14  
    15  package pytorch
    16  
    17  import (
    18  	"fmt"
    19  	"strconv"
    20  	"sync"
    21  
    22  	corev1 "k8s.io/api/core/v1"
    23  
    24  	kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
    25  )
    26  
    27  const (
    28  	// Rendezvous related arguments
    29  
    30  	// EnvRDZVBackend is the environment variable name for the rdzv backend.
    31  	EnvRDZVBackend = "PET_RDZV_BACKEND"
    32  	// EnvRDZVID is the environment variable name for the rdzv id.
    33  	EnvRDZVID = "PET_RDZV_ID"
    34  	// ENVRDZVConf is the environment variable name for the rdzv conf.
    35  	EnvRDZVConf = "PET_RDZV_CONF"
    36  	// EnvRDZVEndpoint is the environment variable name for the rdzv endpoint.
    37  	EnvRDZVEndpoint = "PET_RDZV_ENDPOINT"
    38  	// EnvRDZVStandalone is the environment variable name for the standalone mode.
    39  	EnvStandalone = "PET_STANDALONE"
    40  
    41  	// User-code launch related arguments.
    42  
    43  	// EnvMaxRestarts is the environment variable name for the maximum number of worker group restarts before failing.
    44  	EnvMaxRestarts = "PET_MAX_RESTARTS"
    45  	// EnvMonitorInterval is the environment variable name for the interval, in seconds, to monitor the state of workers.
    46  	EnvMonitorInterval = "PET_MONITOR_INTERVAL"
    47  	// EnvStartMethod is the environment variable name for the multiprocessing start method to use when creating workers, which could be fork, spawn and forkserver.
    48  	EnvStartMethod = "PET_START_METHOD"
    49  
    50  	// EnvNNodes is the common environment variable name from envvar
    51  
    52  	// EnvNProcPerNode is the environment variable name for the number of processes per node.
    53  	EnvNProcPerNode = "PET_NPROC_PER_NODE"
    54  )
    55  
    56  var (
    57  	elasticGenerator EnvVarGenerator
    58  	onceElastic      sync.Once
    59  )
    60  
    61  // ElasticEnvVarGenerator is the environment variable generator for Elastic related arguments.
    62  type ElasticEnvVarGenerator struct{}
    63  
    64  func GetElasticEnvVarGenerator() EnvVarGenerator {
    65  	onceElastic.Do(func() {
    66  		elasticGenerator = &ElasticEnvVarGenerator{}
    67  	})
    68  	return elasticGenerator
    69  }
    70  
    71  func (e ElasticEnvVarGenerator) Generate(
    72  	job *kubeflowv1.PyTorchJob) ([]corev1.EnvVar, error) {
    73  	envVars := []corev1.EnvVar{}
    74  
    75  	elasticPolicy := job.Spec.ElasticPolicy
    76  	if elasticPolicy == nil {
    77  		// Return empty env vars.
    78  		return nil, nil
    79  	}
    80  
    81  	// Generate RDZV_ENDPOINT.
    82  	if envVar, err := e.generateEnvRDZVEndpoint(job); err != nil {
    83  		return nil, err
    84  	} else {
    85  		envVars = append(envVars, *envVar)
    86  	}
    87  	// Generate RDZV_BACKEND.
    88  	envVars = append(envVars, e.generateEnvBackend(elasticPolicy))
    89  	// Generate NNODES.
    90  	if envVar, err := e.generateEnvNnodes(job); err != nil {
    91  		return nil, err
    92  	} else {
    93  		envVars = append(envVars, *envVar)
    94  	}
    95  
    96  	if elasticPolicy.MaxRestarts != nil {
    97  		envVars = append(envVars, corev1.EnvVar{
    98  			Name:  EnvMaxRestarts,
    99  			Value: strconv.Itoa(int(*elasticPolicy.MaxRestarts)),
   100  		})
   101  	}
   102  	if elasticPolicy.NProcPerNode != nil {
   103  		envVars = append(envVars, corev1.EnvVar{
   104  			Name:  EnvNProcPerNode,
   105  			Value: strconv.Itoa(int(*elasticPolicy.NProcPerNode)),
   106  		})
   107  	}
   108  	if elasticPolicy.RDZVID != nil {
   109  		envVars = append(envVars, corev1.EnvVar{
   110  			Name:  EnvRDZVID,
   111  			Value: *elasticPolicy.RDZVID,
   112  		})
   113  	}
   114  	if envVar := e.generateEnvRDZVConf(elasticPolicy); envVar != nil {
   115  		envVars = append(envVars, *envVar)
   116  	}
   117  	if elasticPolicy.Standalone != nil && *elasticPolicy.Standalone {
   118  		envVars = append(envVars, corev1.EnvVar{
   119  			Name:  EnvStandalone,
   120  			Value: "",
   121  		})
   122  	}
   123  
   124  	return envVars, nil
   125  }
   126  
   127  func (e ElasticEnvVarGenerator) generateEnvNnodes(job *kubeflowv1.PyTorchJob) (*corev1.EnvVar, error) {
   128  	// Return worker.replicas if there is no max and min replicas specified.
   129  	if job.Spec.ElasticPolicy.MinReplicas == nil &&
   130  		job.Spec.ElasticPolicy.MaxReplicas == nil {
   131  		if job.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker] == nil {
   132  			return nil, fmt.Errorf("cannot find the worker spec")
   133  		}
   134  		return &corev1.EnvVar{
   135  			Name: EnvNnodes,
   136  			Value: strconv.Itoa(
   137  				int(*job.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].
   138  					Replicas)),
   139  		}, nil
   140  	}
   141  
   142  	return &corev1.EnvVar{
   143  		Name: EnvNnodes,
   144  		Value: fmt.Sprintf("%d:%d",
   145  			*job.Spec.ElasticPolicy.MinReplicas, *job.Spec.ElasticPolicy.MaxReplicas),
   146  	}, nil
   147  }
   148  
   149  func (e ElasticEnvVarGenerator) generateEnvRDZVEndpoint(job *kubeflowv1.PyTorchJob) (*corev1.EnvVar, error) {
   150  	var err error
   151  	host := ""
   152  	if job.Spec.ElasticPolicy.RDZVHost == nil {
   153  		host = fmt.Sprintf("%s-worker-0", job.Name)
   154  	} else {
   155  		host = *job.Spec.ElasticPolicy.RDZVHost
   156  	}
   157  
   158  	var port int32
   159  	if job.Spec.ElasticPolicy.RDZVPort == nil {
   160  		// Generate RDZV_Endpoint.
   161  		port, err = getPortFromPyTorchJob(job, kubeflowv1.PyTorchJobReplicaTypeWorker)
   162  		if err != nil {
   163  			return nil, err
   164  		}
   165  	} else {
   166  		port = *job.Spec.ElasticPolicy.RDZVPort
   167  	}
   168  	return &corev1.EnvVar{
   169  		Name:  EnvRDZVEndpoint,
   170  		Value: fmt.Sprintf("%s:%d", host, port),
   171  	}, nil
   172  }
   173  
   174  func (e ElasticEnvVarGenerator) generateEnvRDZVConf(elasticPolicy *kubeflowv1.ElasticPolicy) *corev1.EnvVar {
   175  	if elasticPolicy.RDZVConf == nil {
   176  		return nil
   177  	}
   178  	val := ""
   179  	for _, conf := range elasticPolicy.RDZVConf {
   180  		val += fmt.Sprintf("%s=%s,", conf.Key, conf.Value)
   181  	}
   182  	return &corev1.EnvVar{
   183  		Name: EnvRDZVConf,
   184  		// Remove the last comma.
   185  		Value: val[:len(val)-1],
   186  	}
   187  }
   188  
   189  func (e ElasticEnvVarGenerator) generateEnvBackend(elasticPolicy *kubeflowv1.ElasticPolicy) corev1.EnvVar {
   190  	if elasticPolicy.RDZVBackend != nil {
   191  		return corev1.EnvVar{
   192  			Name:  EnvRDZVBackend,
   193  			Value: string(*elasticPolicy.RDZVBackend),
   194  		}
   195  	}
   196  	return corev1.EnvVar{
   197  		Name:  EnvRDZVBackend,
   198  		Value: string(kubeflowv1.BackendC10D),
   199  	}
   200  }