volcano.sh/volcano@v1.9.0/pkg/controllers/job/plugins/distributed-framework/pytorch/pytorch.go (about) 1 package pytorch 2 3 import ( 4 "flag" 5 "fmt" 6 "strconv" 7 8 v1 "k8s.io/api/core/v1" 9 "k8s.io/klog/v2" 10 11 batch "volcano.sh/apis/pkg/apis/batch/v1alpha1" 12 "volcano.sh/volcano/pkg/controllers/job/helpers" 13 pluginsinterface "volcano.sh/volcano/pkg/controllers/job/plugins/interface" 14 ) 15 16 const ( 17 // PytorchPluginName is the name of the plugin 18 PytorchPluginName = "pytorch" 19 // DefaultPort is the default port for pytorch 20 DefaultPort = 23456 21 // DefaultMaster is the default task name of master host 22 DefaultMaster = "master" 23 // DefaultWorker is the default task name of worker host 24 DefaultWorker = "worker" 25 26 // EnvMasterPort is the env name of master port 27 EnvMasterPort = "MASTER_PORT" 28 // EnvMasterAddr is the env name of master addr 29 EnvMasterAddr = "MASTER_ADDR" 30 // EnvWorldSize is the env name of world size 31 EnvWorldSize = "WORLD_SIZE" 32 // EnvRank is the env name of rank 33 EnvRank = "RANK" 34 ) 35 36 type pytorchPlugin struct { 37 pytorchArguments []string 38 clientset pluginsinterface.PluginClientset 39 masterName string 40 workerName string 41 port int 42 } 43 44 // New creates pytorch plugin. 45 func New(client pluginsinterface.PluginClientset, arguments []string) pluginsinterface.PluginInterface { 46 pp := pytorchPlugin{pytorchArguments: arguments, clientset: client} 47 pp.addFlags() 48 return &pp 49 } 50 51 func (pp *pytorchPlugin) addFlags() { 52 flagSet := flag.NewFlagSet(pp.Name(), flag.ContinueOnError) 53 flagSet.StringVar(&pp.masterName, "master", DefaultMaster, "name of master role task") 54 flagSet.StringVar(&pp.workerName, "worker", DefaultWorker, "name of worker role task") 55 flagSet.IntVar(&pp.port, "port", DefaultPort, "open port for containers") 56 if err := flagSet.Parse(pp.pytorchArguments); err != nil { 57 klog.Errorf("plugin %s flagset parse failed, err: %v", pp.Name(), err) 58 } 59 } 60 61 func (pp *pytorchPlugin) Name() string { 62 return PytorchPluginName 63 } 64 65 func (pp *pytorchPlugin) OnPodCreate(pod *v1.Pod, job *batch.Job) error { 66 taskType := helpers.GetTaskKey(pod) 67 masterIndex := helpers.GetTaskIndexUnderJob(pp.masterName, job) 68 if masterIndex == -1 { 69 klog.Errorf("job %v doesn't have task %v", job.Name, pp.masterName) 70 return nil 71 } 72 73 masterEnvVars := []v1.EnvVar{} 74 masterAddr := pp.generateMasterAddr(job.Spec.Tasks[masterIndex], job.Name) 75 masterEnvVars = append(masterEnvVars, v1.EnvVar{ 76 Name: EnvMasterAddr, 77 Value: masterAddr, 78 }, v1.EnvVar{ 79 Name: EnvMasterPort, 80 Value: fmt.Sprintf("%v", pp.port), 81 }) 82 83 masterRank := 0 84 workerRank := 0 85 if taskType == pp.workerName { 86 index, err := strconv.Atoi(helpers.GetPodIndexUnderTask(pod)) 87 if err != nil { 88 return err 89 } 90 91 workerRank = index + 1 92 } 93 94 totalReplicas := pp.getTotalReplicas(job) 95 for i, c := range pod.Spec.Containers { 96 pp.openContainerPort(&c, i, pod) 97 98 pod.Spec.Containers[i].Env = append(pod.Spec.Containers[i].Env, masterEnvVars...) 99 pod.Spec.Containers[i].Env = append(pod.Spec.Containers[i].Env, v1.EnvVar{ 100 Name: EnvWorldSize, 101 Value: strconv.Itoa(int(totalReplicas)), 102 }) 103 104 if taskType == pp.workerName { 105 pod.Spec.Containers[i].Env = append(pod.Spec.Containers[i].Env, v1.EnvVar{ 106 Name: EnvRank, 107 Value: strconv.Itoa(workerRank), 108 }) 109 } else if taskType == pp.masterName { 110 pod.Spec.Containers[i].Env = append(pod.Spec.Containers[i].Env, v1.EnvVar{ 111 Name: EnvRank, 112 Value: strconv.Itoa(masterRank), 113 }) 114 } 115 } 116 117 return nil 118 } 119 120 func (pp *pytorchPlugin) getTotalReplicas(job *batch.Job) int32 { 121 jobReplicas := int32(0) 122 for _, task := range job.Spec.Tasks { 123 jobReplicas += task.Replicas 124 } 125 126 return jobReplicas 127 } 128 129 func (pp *pytorchPlugin) generateMasterAddr(task batch.TaskSpec, jobName string) string { 130 hostName := task.Template.Spec.Hostname 131 subdomain := task.Template.Spec.Subdomain 132 if len(hostName) == 0 { 133 hostName = helpers.MakePodName(jobName, task.Name, 0) 134 } 135 if len(subdomain) == 0 { 136 subdomain = jobName 137 } 138 139 host := hostName + "." + subdomain 140 return host 141 } 142 143 func (pp *pytorchPlugin) openContainerPort(c *v1.Container, index int, pod *v1.Pod) { 144 hasPort := false 145 for _, p := range c.Ports { 146 if p.ContainerPort == int32(pp.port) { 147 hasPort = true 148 break 149 } 150 } 151 152 if !hasPort { 153 port := v1.ContainerPort{ 154 Name: "pytorchjob-port", 155 ContainerPort: int32(pp.port), 156 } 157 158 pod.Spec.Containers[index].Ports = append(pod.Spec.Containers[index].Ports, port) 159 } 160 } 161 162 func (pp *pytorchPlugin) OnJobAdd(job *batch.Job) error { 163 if job.Status.ControlledResources["plugin-"+pp.Name()] == pp.Name() { 164 return nil 165 } 166 job.Status.ControlledResources["plugin-"+pp.Name()] = pp.Name() 167 return nil 168 } 169 170 func (pp *pytorchPlugin) OnJobDelete(job *batch.Job) error { 171 if job.Status.ControlledResources["plugin-"+pp.Name()] != pp.Name() { 172 return nil 173 } 174 delete(job.Status.ControlledResources, "plugin-"+pp.Name()) 175 return nil 176 } 177 178 func (pp *pytorchPlugin) OnJobUpdate(job *batch.Job) error { 179 return nil 180 }