volcano.sh/volcano@v1.9.0/pkg/controllers/job/plugins/distributed-framework/mpi/mpi.go (about) 1 /* 2 Copyright 2022 The Volcano Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package mpi 18 19 import ( 20 "flag" 21 22 v1 "k8s.io/api/core/v1" 23 "k8s.io/klog/v2" 24 25 batch "volcano.sh/apis/pkg/apis/batch/v1alpha1" 26 "volcano.sh/volcano/pkg/controllers/job/helpers" 27 pluginsinterface "volcano.sh/volcano/pkg/controllers/job/plugins/interface" 28 ) 29 30 const ( 31 // MPIPluginName is the name of the plugin 32 MPIPluginName = "mpi" 33 // DefaultPort is the default port for ssh 34 DefaultPort = 22 35 // DefaultMaster is the default task name of master host 36 DefaultMaster = "master" 37 // DefaultWorker is the default task name of worker host 38 DefaultWorker = "worker" 39 // MPIHost is the environment variable key of MPI host 40 MPIHost = "MPI_HOST" 41 ) 42 43 type Plugin struct { 44 mpiArguments []string 45 clientset pluginsinterface.PluginClientset 46 masterName string 47 workerName string 48 port int 49 } 50 51 // New creates mpi plugin. 52 func New(client pluginsinterface.PluginClientset, arguments []string) pluginsinterface.PluginInterface { 53 mp := Plugin{mpiArguments: arguments, clientset: client} 54 mp.addFlags() 55 return &mp 56 } 57 58 func NewInstance(arguments []string) Plugin { 59 mp := Plugin{mpiArguments: arguments} 60 mp.addFlags() 61 return mp 62 } 63 64 func (mp *Plugin) addFlags() { 65 flagSet := flag.NewFlagSet(mp.Name(), flag.ContinueOnError) 66 flagSet.StringVar(&mp.masterName, "master", DefaultMaster, "name of master role task") 67 flagSet.StringVar(&mp.workerName, "worker", DefaultWorker, "name of worker role task") 68 flagSet.IntVar(&mp.port, "port", DefaultPort, "open port for containers") 69 if err := flagSet.Parse(mp.mpiArguments); err != nil { 70 klog.Errorf("plugin %s flagset parse failed, err: %v", mp.Name(), err) 71 } 72 } 73 74 func (mp *Plugin) Name() string { 75 return MPIPluginName 76 } 77 78 func (mp *Plugin) OnPodCreate(pod *v1.Pod, job *batch.Job) error { 79 isMaster := false 80 workerHosts := "" 81 env := v1.EnvVar{} 82 if helpers.GetTaskKey(pod) == mp.masterName { 83 workerHosts = mp.generateTaskHosts(job.Spec.Tasks[helpers.GetTaskIndexUnderJob(mp.workerName, job)], job.Name) 84 env = v1.EnvVar{ 85 Name: MPIHost, 86 Value: workerHosts, 87 } 88 89 isMaster = true 90 } 91 92 // open port for ssh and add MPI_HOST env for master task 93 for index, ic := range pod.Spec.InitContainers { 94 mp.openContainerPort(&ic, index, pod, true) 95 if isMaster { 96 pod.Spec.InitContainers[index].Env = append(pod.Spec.InitContainers[index].Env, env) 97 } 98 } 99 100 for index, c := range pod.Spec.Containers { 101 mp.openContainerPort(&c, index, pod, false) 102 if isMaster { 103 pod.Spec.Containers[index].Env = append(pod.Spec.Containers[index].Env, env) 104 } 105 } 106 107 return nil 108 } 109 110 func (mp *Plugin) generateTaskHosts(task batch.TaskSpec, jobName string) string { 111 hosts := "" 112 for i := 0; i < int(task.Replicas); i++ { 113 hostName := task.Template.Spec.Hostname 114 subdomain := task.Template.Spec.Subdomain 115 if len(hostName) == 0 { 116 hostName = helpers.MakePodName(jobName, task.Name, i) 117 } 118 if len(subdomain) == 0 { 119 subdomain = jobName 120 } 121 hosts = hosts + hostName + "." + subdomain + "," 122 if len(task.Template.Spec.Hostname) != 0 { 123 break 124 } 125 } 126 return hosts[:len(hosts)-1] 127 } 128 129 func (mp *Plugin) openContainerPort(c *v1.Container, index int, pod *v1.Pod, isInitContainer bool) { 130 SSHPortRight := false 131 for _, p := range c.Ports { 132 if p.ContainerPort == int32(mp.port) { 133 SSHPortRight = true 134 break 135 } 136 } 137 if !SSHPortRight { 138 sshPort := v1.ContainerPort{ 139 Name: "mpijob-port", 140 ContainerPort: int32(mp.port), 141 } 142 if isInitContainer { 143 pod.Spec.InitContainers[index].Ports = append(pod.Spec.InitContainers[index].Ports, sshPort) 144 } else { 145 pod.Spec.Containers[index].Ports = append(pod.Spec.Containers[index].Ports, sshPort) 146 } 147 } 148 } 149 150 func (mp *Plugin) OnJobAdd(job *batch.Job) error { 151 if job.Status.ControlledResources["plugin-"+mp.Name()] == mp.Name() { 152 return nil 153 } 154 job.Status.ControlledResources["plugin-"+mp.Name()] = mp.Name() 155 return nil 156 } 157 158 func (mp *Plugin) OnJobDelete(job *batch.Job) error { 159 if job.Status.ControlledResources["plugin-"+mp.Name()] != mp.Name() { 160 return nil 161 } 162 delete(job.Status.ControlledResources, "plugin-"+mp.Name()) 163 return nil 164 } 165 166 func (mp *Plugin) OnJobUpdate(job *batch.Job) error { 167 return nil 168 } 169 170 func (mp *Plugin) GetMasterName() string { 171 return mp.masterName 172 } 173 174 func (mp *Plugin) GetWorkerName() string { 175 return mp.workerName 176 } 177 178 func (mp *Plugin) GetMpiArguments() []string { 179 return mp.mpiArguments 180 }