volcano.sh/volcano@v1.9.0/pkg/controllers/job/plugins/distributed-framework/pytorch/pytorch.go (about)

     1  package pytorch
     2  
     3  import (
     4  	"flag"
     5  	"fmt"
     6  	"strconv"
     7  
     8  	v1 "k8s.io/api/core/v1"
     9  	"k8s.io/klog/v2"
    10  
    11  	batch "volcano.sh/apis/pkg/apis/batch/v1alpha1"
    12  	"volcano.sh/volcano/pkg/controllers/job/helpers"
    13  	pluginsinterface "volcano.sh/volcano/pkg/controllers/job/plugins/interface"
    14  )
    15  
    16  const (
    17  	// PytorchPluginName is the name of the plugin
    18  	PytorchPluginName = "pytorch"
    19  	// DefaultPort is the default port for pytorch
    20  	DefaultPort = 23456
    21  	// DefaultMaster is the default task name of master host
    22  	DefaultMaster = "master"
    23  	// DefaultWorker is the default task name of worker host
    24  	DefaultWorker = "worker"
    25  
    26  	// EnvMasterPort is the env name of master port
    27  	EnvMasterPort = "MASTER_PORT"
    28  	// EnvMasterAddr is the env name of master addr
    29  	EnvMasterAddr = "MASTER_ADDR"
    30  	// EnvWorldSize is the env name of world size
    31  	EnvWorldSize = "WORLD_SIZE"
    32  	// EnvRank is the env name of rank
    33  	EnvRank = "RANK"
    34  )
    35  
    36  type pytorchPlugin struct {
    37  	pytorchArguments []string
    38  	clientset        pluginsinterface.PluginClientset
    39  	masterName       string
    40  	workerName       string
    41  	port             int
    42  }
    43  
    44  // New creates pytorch plugin.
    45  func New(client pluginsinterface.PluginClientset, arguments []string) pluginsinterface.PluginInterface {
    46  	pp := pytorchPlugin{pytorchArguments: arguments, clientset: client}
    47  	pp.addFlags()
    48  	return &pp
    49  }
    50  
    51  func (pp *pytorchPlugin) addFlags() {
    52  	flagSet := flag.NewFlagSet(pp.Name(), flag.ContinueOnError)
    53  	flagSet.StringVar(&pp.masterName, "master", DefaultMaster, "name of master role task")
    54  	flagSet.StringVar(&pp.workerName, "worker", DefaultWorker, "name of worker role task")
    55  	flagSet.IntVar(&pp.port, "port", DefaultPort, "open port for containers")
    56  	if err := flagSet.Parse(pp.pytorchArguments); err != nil {
    57  		klog.Errorf("plugin %s flagset parse failed, err: %v", pp.Name(), err)
    58  	}
    59  }
    60  
    61  func (pp *pytorchPlugin) Name() string {
    62  	return PytorchPluginName
    63  }
    64  
    65  func (pp *pytorchPlugin) OnPodCreate(pod *v1.Pod, job *batch.Job) error {
    66  	taskType := helpers.GetTaskKey(pod)
    67  	masterIndex := helpers.GetTaskIndexUnderJob(pp.masterName, job)
    68  	if masterIndex == -1 {
    69  		klog.Errorf("job %v doesn't have task %v", job.Name, pp.masterName)
    70  		return nil
    71  	}
    72  
    73  	masterEnvVars := []v1.EnvVar{}
    74  	masterAddr := pp.generateMasterAddr(job.Spec.Tasks[masterIndex], job.Name)
    75  	masterEnvVars = append(masterEnvVars, v1.EnvVar{
    76  		Name:  EnvMasterAddr,
    77  		Value: masterAddr,
    78  	}, v1.EnvVar{
    79  		Name:  EnvMasterPort,
    80  		Value: fmt.Sprintf("%v", pp.port),
    81  	})
    82  
    83  	masterRank := 0
    84  	workerRank := 0
    85  	if taskType == pp.workerName {
    86  		index, err := strconv.Atoi(helpers.GetPodIndexUnderTask(pod))
    87  		if err != nil {
    88  			return err
    89  		}
    90  
    91  		workerRank = index + 1
    92  	}
    93  
    94  	totalReplicas := pp.getTotalReplicas(job)
    95  	for i, c := range pod.Spec.Containers {
    96  		pp.openContainerPort(&c, i, pod)
    97  
    98  		pod.Spec.Containers[i].Env = append(pod.Spec.Containers[i].Env, masterEnvVars...)
    99  		pod.Spec.Containers[i].Env = append(pod.Spec.Containers[i].Env, v1.EnvVar{
   100  			Name:  EnvWorldSize,
   101  			Value: strconv.Itoa(int(totalReplicas)),
   102  		})
   103  
   104  		if taskType == pp.workerName {
   105  			pod.Spec.Containers[i].Env = append(pod.Spec.Containers[i].Env, v1.EnvVar{
   106  				Name:  EnvRank,
   107  				Value: strconv.Itoa(workerRank),
   108  			})
   109  		} else if taskType == pp.masterName {
   110  			pod.Spec.Containers[i].Env = append(pod.Spec.Containers[i].Env, v1.EnvVar{
   111  				Name:  EnvRank,
   112  				Value: strconv.Itoa(masterRank),
   113  			})
   114  		}
   115  	}
   116  
   117  	return nil
   118  }
   119  
   120  func (pp *pytorchPlugin) getTotalReplicas(job *batch.Job) int32 {
   121  	jobReplicas := int32(0)
   122  	for _, task := range job.Spec.Tasks {
   123  		jobReplicas += task.Replicas
   124  	}
   125  
   126  	return jobReplicas
   127  }
   128  
   129  func (pp *pytorchPlugin) generateMasterAddr(task batch.TaskSpec, jobName string) string {
   130  	hostName := task.Template.Spec.Hostname
   131  	subdomain := task.Template.Spec.Subdomain
   132  	if len(hostName) == 0 {
   133  		hostName = helpers.MakePodName(jobName, task.Name, 0)
   134  	}
   135  	if len(subdomain) == 0 {
   136  		subdomain = jobName
   137  	}
   138  
   139  	host := hostName + "." + subdomain
   140  	return host
   141  }
   142  
   143  func (pp *pytorchPlugin) openContainerPort(c *v1.Container, index int, pod *v1.Pod) {
   144  	hasPort := false
   145  	for _, p := range c.Ports {
   146  		if p.ContainerPort == int32(pp.port) {
   147  			hasPort = true
   148  			break
   149  		}
   150  	}
   151  
   152  	if !hasPort {
   153  		port := v1.ContainerPort{
   154  			Name:          "pytorchjob-port",
   155  			ContainerPort: int32(pp.port),
   156  		}
   157  
   158  		pod.Spec.Containers[index].Ports = append(pod.Spec.Containers[index].Ports, port)
   159  	}
   160  }
   161  
   162  func (pp *pytorchPlugin) OnJobAdd(job *batch.Job) error {
   163  	if job.Status.ControlledResources["plugin-"+pp.Name()] == pp.Name() {
   164  		return nil
   165  	}
   166  	job.Status.ControlledResources["plugin-"+pp.Name()] = pp.Name()
   167  	return nil
   168  }
   169  
   170  func (pp *pytorchPlugin) OnJobDelete(job *batch.Job) error {
   171  	if job.Status.ControlledResources["plugin-"+pp.Name()] != pp.Name() {
   172  		return nil
   173  	}
   174  	delete(job.Status.ControlledResources, "plugin-"+pp.Name())
   175  	return nil
   176  }
   177  
   178  func (pp *pytorchPlugin) OnJobUpdate(job *batch.Job) error {
   179  	return nil
   180  }