volcano.sh/volcano@v1.9.0/pkg/controllers/job/plugins/distributed-framework/tensorflow/tensorflow.go (about)

     1  /*
     2  Copyright 2021 The Volcano Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package tensorflow
    18  
    19  import (
    20  	"encoding/json"
    21  	"flag"
    22  	"fmt"
    23  	"strconv"
    24  
    25  	v1 "k8s.io/api/core/v1"
    26  	"k8s.io/klog/v2"
    27  
    28  	batch "volcano.sh/apis/pkg/apis/batch/v1alpha1"
    29  	jobhelpers "volcano.sh/volcano/pkg/controllers/job/helpers"
    30  	pluginsinterface "volcano.sh/volcano/pkg/controllers/job/plugins/interface"
    31  )
    32  
    33  const (
    34  	// TFPluginName is the name of the plugin
    35  	TFPluginName = "tensorflow"
    36  	// DefaultPort defines default port for service
    37  	DefaultPort = 2222
    38  	// TFConfig defines environment variables for TF
    39  	TFConfig = "TF_CONFIG"
    40  )
    41  
    42  type tensorflowPlugin struct {
    43  	tfArguments   []string
    44  	Clientset     pluginsinterface.PluginClientset
    45  	psName        string
    46  	workerName    string
    47  	chiefName     string
    48  	evaluatorName string
    49  	port          int
    50  }
    51  
    52  // New creates tensorflow plugin.
    53  func New(client pluginsinterface.PluginClientset, arguments []string) pluginsinterface.PluginInterface {
    54  	tp := tensorflowPlugin{tfArguments: arguments, Clientset: client}
    55  	tp.addFlags()
    56  	return &tp
    57  }
    58  
    59  func (tp *tensorflowPlugin) addFlags() {
    60  	flagSet := flag.NewFlagSet(tp.Name(), flag.ContinueOnError)
    61  	flagSet.StringVar(&tp.psName, "ps", "ps", "name of ps role task")
    62  	flagSet.StringVar(&tp.workerName, "worker", "worker", "name of ps role task")
    63  	flagSet.StringVar(&tp.chiefName, "chief", "chief", "name of chief role task")
    64  	flagSet.StringVar(&tp.evaluatorName, "evaluator", "evaluator", "name of evaluator role task")
    65  	flagSet.IntVar(&tp.port, "port", DefaultPort, "service port")
    66  	if err := flagSet.Parse(tp.tfArguments); err != nil {
    67  		klog.Errorf("plugin %s flagset parse failed, err: %v", tp.Name(), err)
    68  	}
    69  }
    70  
    71  func (tp *tensorflowPlugin) Name() string {
    72  	return TFPluginName
    73  }
    74  
    75  func (tp *tensorflowPlugin) OnPodCreate(pod *v1.Pod, job *batch.Job) error {
    76  	// No need to generate TF_CONFIG for stand-alone tensorflow job
    77  	if len(job.Spec.Tasks) == 1 && job.Spec.Tasks[0].Replicas == 1 {
    78  		return nil
    79  	}
    80  	// Generate TF_CONFIG value
    81  	spec, err := tp.generateTFClusterSpec(pod, job)
    82  	if err != nil {
    83  		return err
    84  	}
    85  	raw, err := json.Marshal(spec)
    86  	if err != nil {
    87  		return err
    88  	}
    89  
    90  	// Add TF_CONFIG enviroment variables
    91  	for i := range pod.Spec.Containers {
    92  		pod.Spec.Containers[i].Env = append(pod.Spec.Containers[i].Env, v1.EnvVar{
    93  			Name:  TFConfig,
    94  			Value: string(raw),
    95  		})
    96  	}
    97  	return nil
    98  }
    99  
   100  func (tp *tensorflowPlugin) OnJobAdd(job *batch.Job) error {
   101  	if job.Status.ControlledResources["plugin-"+tp.Name()] == tp.Name() {
   102  		return nil
   103  	}
   104  
   105  	job.Status.ControlledResources["plugin-"+tp.Name()] = tp.Name()
   106  
   107  	return nil
   108  }
   109  
   110  func (tp *tensorflowPlugin) OnJobDelete(job *batch.Job) error {
   111  	if job.Status.ControlledResources["plugin-"+tp.Name()] != tp.Name() {
   112  		return nil
   113  	}
   114  	delete(job.Status.ControlledResources, "plugin-"+tp.Name())
   115  	return nil
   116  }
   117  
   118  func (tp *tensorflowPlugin) OnJobUpdate(job *batch.Job) error {
   119  	return nil
   120  }
   121  
   122  func (tp *tensorflowPlugin) generateTFClusterSpec(pod *v1.Pod, job *batch.Job) (tfClusterSpec, error) {
   123  	index, err := strconv.Atoi(jobhelpers.GetPodIndexUnderTask(pod))
   124  	if err != nil {
   125  		return tfClusterSpec{}, err
   126  	}
   127  
   128  	// Generate tensorflow task info
   129  	c := tfClusterSpec{
   130  		Task: taskInfo{
   131  			Type:  tp.getTaskType(jobhelpers.GetTaskKey(pod)),
   132  			Index: index,
   133  		},
   134  	}
   135  
   136  	// Generate tensorflow cluster info
   137  	for _, ts := range job.Spec.Tasks {
   138  		hosts := []string{}
   139  		for i := 0; i < int(ts.Replicas); i++ {
   140  			hosts = append(hosts, fmt.Sprintf("%s:%d", jobhelpers.MakeDomainName(ts, job, i), tp.port))
   141  		}
   142  		switch ts.Name {
   143  		case tp.psName:
   144  			c.Cluster.PS = hosts
   145  		case tp.workerName:
   146  			c.Cluster.Worker = hosts
   147  		case tp.chiefName:
   148  			c.Cluster.Chief = hosts
   149  		case tp.evaluatorName:
   150  			c.Cluster.Evaluator = hosts
   151  		}
   152  	}
   153  	return c, nil
   154  }
   155  
   156  func (tp *tensorflowPlugin) getTaskType(taskKey string) tfTaskType {
   157  	switch taskKey {
   158  	case tp.chiefName:
   159  		return tfChief
   160  	case tp.workerName:
   161  		return tfWorker
   162  	case tp.psName:
   163  		return tfPS
   164  	case tp.evaluatorName:
   165  		return tfEvaluator
   166  	}
   167  	return tfTaskType(taskKey)
   168  }
   169  
   170  // TfClusterSpec is the spec of a tensorflow cluster
   171  // It will be injected into container's environment variables, and be used by tensorflow framework.
   172  // e.g.
   173  //
   174  //	{
   175  //	   "cluster": {
   176  //	     "worker": ["worker-0:2222", "worker-1:2222"],
   177  //	     "ps": ["ps-0:2222"]
   178  //	   },
   179  //	   "task": {
   180  //	     "type": "worker",
   181  //	     "index": 0
   182  //	   }
   183  //	}
   184  type tfClusterSpec struct {
   185  	Cluster clusterInfo `json:"cluster"`
   186  	Task    taskInfo    `json:"task"`
   187  }
   188  
   189  type clusterInfo struct {
   190  	PS        []string `json:"ps,omitempty"`
   191  	Worker    []string `json:"worker,omitempty"`
   192  	Chief     []string `json:"chief,omitempty"`
   193  	Evaluator []string `json:"evaluator,omitempty"`
   194  }
   195  
   196  type tfTaskType string
   197  
   198  const (
   199  	tfWorker    tfTaskType = "worker"
   200  	tfChief     tfTaskType = "chief"
   201  	tfPS        tfTaskType = "ps"
   202  	tfEvaluator tfTaskType = "evaluator"
   203  )
   204  
   205  type taskInfo struct {
   206  	Type  tfTaskType `json:"type"`
   207  	Index int        `json:"index"`
   208  }