volcano.sh/volcano@v1.9.0/pkg/controllers/job/plugins/distributed-framework/mpi/mpi.go (about)

     1  /*
     2  Copyright 2022 The Volcano Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package mpi
    18  
    19  import (
    20  	"flag"
    21  
    22  	v1 "k8s.io/api/core/v1"
    23  	"k8s.io/klog/v2"
    24  
    25  	batch "volcano.sh/apis/pkg/apis/batch/v1alpha1"
    26  	"volcano.sh/volcano/pkg/controllers/job/helpers"
    27  	pluginsinterface "volcano.sh/volcano/pkg/controllers/job/plugins/interface"
    28  )
    29  
    30  const (
    31  	// MPIPluginName is the name of the plugin
    32  	MPIPluginName = "mpi"
    33  	// DefaultPort is the default port for ssh
    34  	DefaultPort = 22
    35  	// DefaultMaster is the default task name of master host
    36  	DefaultMaster = "master"
    37  	// DefaultWorker is the default task name of worker host
    38  	DefaultWorker = "worker"
    39  	// MPIHost is the environment variable key of MPI host
    40  	MPIHost = "MPI_HOST"
    41  )
    42  
    43  type Plugin struct {
    44  	mpiArguments []string
    45  	clientset    pluginsinterface.PluginClientset
    46  	masterName   string
    47  	workerName   string
    48  	port         int
    49  }
    50  
    51  // New creates mpi plugin.
    52  func New(client pluginsinterface.PluginClientset, arguments []string) pluginsinterface.PluginInterface {
    53  	mp := Plugin{mpiArguments: arguments, clientset: client}
    54  	mp.addFlags()
    55  	return &mp
    56  }
    57  
    58  func NewInstance(arguments []string) Plugin {
    59  	mp := Plugin{mpiArguments: arguments}
    60  	mp.addFlags()
    61  	return mp
    62  }
    63  
    64  func (mp *Plugin) addFlags() {
    65  	flagSet := flag.NewFlagSet(mp.Name(), flag.ContinueOnError)
    66  	flagSet.StringVar(&mp.masterName, "master", DefaultMaster, "name of master role task")
    67  	flagSet.StringVar(&mp.workerName, "worker", DefaultWorker, "name of worker role task")
    68  	flagSet.IntVar(&mp.port, "port", DefaultPort, "open port for containers")
    69  	if err := flagSet.Parse(mp.mpiArguments); err != nil {
    70  		klog.Errorf("plugin %s flagset parse failed, err: %v", mp.Name(), err)
    71  	}
    72  }
    73  
    74  func (mp *Plugin) Name() string {
    75  	return MPIPluginName
    76  }
    77  
    78  func (mp *Plugin) OnPodCreate(pod *v1.Pod, job *batch.Job) error {
    79  	isMaster := false
    80  	workerHosts := ""
    81  	env := v1.EnvVar{}
    82  	if helpers.GetTaskKey(pod) == mp.masterName {
    83  		workerHosts = mp.generateTaskHosts(job.Spec.Tasks[helpers.GetTaskIndexUnderJob(mp.workerName, job)], job.Name)
    84  		env = v1.EnvVar{
    85  			Name:  MPIHost,
    86  			Value: workerHosts,
    87  		}
    88  
    89  		isMaster = true
    90  	}
    91  
    92  	// open port for ssh and add MPI_HOST env for master task
    93  	for index, ic := range pod.Spec.InitContainers {
    94  		mp.openContainerPort(&ic, index, pod, true)
    95  		if isMaster {
    96  			pod.Spec.InitContainers[index].Env = append(pod.Spec.InitContainers[index].Env, env)
    97  		}
    98  	}
    99  
   100  	for index, c := range pod.Spec.Containers {
   101  		mp.openContainerPort(&c, index, pod, false)
   102  		if isMaster {
   103  			pod.Spec.Containers[index].Env = append(pod.Spec.Containers[index].Env, env)
   104  		}
   105  	}
   106  
   107  	return nil
   108  }
   109  
   110  func (mp *Plugin) generateTaskHosts(task batch.TaskSpec, jobName string) string {
   111  	hosts := ""
   112  	for i := 0; i < int(task.Replicas); i++ {
   113  		hostName := task.Template.Spec.Hostname
   114  		subdomain := task.Template.Spec.Subdomain
   115  		if len(hostName) == 0 {
   116  			hostName = helpers.MakePodName(jobName, task.Name, i)
   117  		}
   118  		if len(subdomain) == 0 {
   119  			subdomain = jobName
   120  		}
   121  		hosts = hosts + hostName + "." + subdomain + ","
   122  		if len(task.Template.Spec.Hostname) != 0 {
   123  			break
   124  		}
   125  	}
   126  	return hosts[:len(hosts)-1]
   127  }
   128  
   129  func (mp *Plugin) openContainerPort(c *v1.Container, index int, pod *v1.Pod, isInitContainer bool) {
   130  	SSHPortRight := false
   131  	for _, p := range c.Ports {
   132  		if p.ContainerPort == int32(mp.port) {
   133  			SSHPortRight = true
   134  			break
   135  		}
   136  	}
   137  	if !SSHPortRight {
   138  		sshPort := v1.ContainerPort{
   139  			Name:          "mpijob-port",
   140  			ContainerPort: int32(mp.port),
   141  		}
   142  		if isInitContainer {
   143  			pod.Spec.InitContainers[index].Ports = append(pod.Spec.InitContainers[index].Ports, sshPort)
   144  		} else {
   145  			pod.Spec.Containers[index].Ports = append(pod.Spec.Containers[index].Ports, sshPort)
   146  		}
   147  	}
   148  }
   149  
   150  func (mp *Plugin) OnJobAdd(job *batch.Job) error {
   151  	if job.Status.ControlledResources["plugin-"+mp.Name()] == mp.Name() {
   152  		return nil
   153  	}
   154  	job.Status.ControlledResources["plugin-"+mp.Name()] = mp.Name()
   155  	return nil
   156  }
   157  
   158  func (mp *Plugin) OnJobDelete(job *batch.Job) error {
   159  	if job.Status.ControlledResources["plugin-"+mp.Name()] != mp.Name() {
   160  		return nil
   161  	}
   162  	delete(job.Status.ControlledResources, "plugin-"+mp.Name())
   163  	return nil
   164  }
   165  
   166  func (mp *Plugin) OnJobUpdate(job *batch.Job) error {
   167  	return nil
   168  }
   169  
   170  func (mp *Plugin) GetMasterName() string {
   171  	return mp.masterName
   172  }
   173  
   174  func (mp *Plugin) GetWorkerName() string {
   175  	return mp.workerName
   176  }
   177  
   178  func (mp *Plugin) GetMpiArguments() []string {
   179  	return mp.mpiArguments
   180  }