volcano.sh/volcano@v1.9.0/pkg/controllers/job/plugins/distributed-framework/tensorflow/tensorflow.go (about) 1 /* 2 Copyright 2021 The Volcano Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package tensorflow 18 19 import ( 20 "encoding/json" 21 "flag" 22 "fmt" 23 "strconv" 24 25 v1 "k8s.io/api/core/v1" 26 "k8s.io/klog/v2" 27 28 batch "volcano.sh/apis/pkg/apis/batch/v1alpha1" 29 jobhelpers "volcano.sh/volcano/pkg/controllers/job/helpers" 30 pluginsinterface "volcano.sh/volcano/pkg/controllers/job/plugins/interface" 31 ) 32 33 const ( 34 // TFPluginName is the name of the plugin 35 TFPluginName = "tensorflow" 36 // DefaultPort defines default port for service 37 DefaultPort = 2222 38 // TFConfig defines environment variables for TF 39 TFConfig = "TF_CONFIG" 40 ) 41 42 type tensorflowPlugin struct { 43 tfArguments []string 44 Clientset pluginsinterface.PluginClientset 45 psName string 46 workerName string 47 chiefName string 48 evaluatorName string 49 port int 50 } 51 52 // New creates tensorflow plugin. 53 func New(client pluginsinterface.PluginClientset, arguments []string) pluginsinterface.PluginInterface { 54 tp := tensorflowPlugin{tfArguments: arguments, Clientset: client} 55 tp.addFlags() 56 return &tp 57 } 58 59 func (tp *tensorflowPlugin) addFlags() { 60 flagSet := flag.NewFlagSet(tp.Name(), flag.ContinueOnError) 61 flagSet.StringVar(&tp.psName, "ps", "ps", "name of ps role task") 62 flagSet.StringVar(&tp.workerName, "worker", "worker", "name of ps role task") 63 flagSet.StringVar(&tp.chiefName, "chief", "chief", "name of chief role task") 64 flagSet.StringVar(&tp.evaluatorName, "evaluator", "evaluator", "name of evaluator role task") 65 flagSet.IntVar(&tp.port, "port", DefaultPort, "service port") 66 if err := flagSet.Parse(tp.tfArguments); err != nil { 67 klog.Errorf("plugin %s flagset parse failed, err: %v", tp.Name(), err) 68 } 69 } 70 71 func (tp *tensorflowPlugin) Name() string { 72 return TFPluginName 73 } 74 75 func (tp *tensorflowPlugin) OnPodCreate(pod *v1.Pod, job *batch.Job) error { 76 // No need to generate TF_CONFIG for stand-alone tensorflow job 77 if len(job.Spec.Tasks) == 1 && job.Spec.Tasks[0].Replicas == 1 { 78 return nil 79 } 80 // Generate TF_CONFIG value 81 spec, err := tp.generateTFClusterSpec(pod, job) 82 if err != nil { 83 return err 84 } 85 raw, err := json.Marshal(spec) 86 if err != nil { 87 return err 88 } 89 90 // Add TF_CONFIG enviroment variables 91 for i := range pod.Spec.Containers { 92 pod.Spec.Containers[i].Env = append(pod.Spec.Containers[i].Env, v1.EnvVar{ 93 Name: TFConfig, 94 Value: string(raw), 95 }) 96 } 97 return nil 98 } 99 100 func (tp *tensorflowPlugin) OnJobAdd(job *batch.Job) error { 101 if job.Status.ControlledResources["plugin-"+tp.Name()] == tp.Name() { 102 return nil 103 } 104 105 job.Status.ControlledResources["plugin-"+tp.Name()] = tp.Name() 106 107 return nil 108 } 109 110 func (tp *tensorflowPlugin) OnJobDelete(job *batch.Job) error { 111 if job.Status.ControlledResources["plugin-"+tp.Name()] != tp.Name() { 112 return nil 113 } 114 delete(job.Status.ControlledResources, "plugin-"+tp.Name()) 115 return nil 116 } 117 118 func (tp *tensorflowPlugin) OnJobUpdate(job *batch.Job) error { 119 return nil 120 } 121 122 func (tp *tensorflowPlugin) generateTFClusterSpec(pod *v1.Pod, job *batch.Job) (tfClusterSpec, error) { 123 index, err := strconv.Atoi(jobhelpers.GetPodIndexUnderTask(pod)) 124 if err != nil { 125 return tfClusterSpec{}, err 126 } 127 128 // Generate tensorflow task info 129 c := tfClusterSpec{ 130 Task: taskInfo{ 131 Type: tp.getTaskType(jobhelpers.GetTaskKey(pod)), 132 Index: index, 133 }, 134 } 135 136 // Generate tensorflow cluster info 137 for _, ts := range job.Spec.Tasks { 138 hosts := []string{} 139 for i := 0; i < int(ts.Replicas); i++ { 140 hosts = append(hosts, fmt.Sprintf("%s:%d", jobhelpers.MakeDomainName(ts, job, i), tp.port)) 141 } 142 switch ts.Name { 143 case tp.psName: 144 c.Cluster.PS = hosts 145 case tp.workerName: 146 c.Cluster.Worker = hosts 147 case tp.chiefName: 148 c.Cluster.Chief = hosts 149 case tp.evaluatorName: 150 c.Cluster.Evaluator = hosts 151 } 152 } 153 return c, nil 154 } 155 156 func (tp *tensorflowPlugin) getTaskType(taskKey string) tfTaskType { 157 switch taskKey { 158 case tp.chiefName: 159 return tfChief 160 case tp.workerName: 161 return tfWorker 162 case tp.psName: 163 return tfPS 164 case tp.evaluatorName: 165 return tfEvaluator 166 } 167 return tfTaskType(taskKey) 168 } 169 170 // TfClusterSpec is the spec of a tensorflow cluster 171 // It will be injected into container's environment variables, and be used by tensorflow framework. 172 // e.g. 173 // 174 // { 175 // "cluster": { 176 // "worker": ["worker-0:2222", "worker-1:2222"], 177 // "ps": ["ps-0:2222"] 178 // }, 179 // "task": { 180 // "type": "worker", 181 // "index": 0 182 // } 183 // } 184 type tfClusterSpec struct { 185 Cluster clusterInfo `json:"cluster"` 186 Task taskInfo `json:"task"` 187 } 188 189 type clusterInfo struct { 190 PS []string `json:"ps,omitempty"` 191 Worker []string `json:"worker,omitempty"` 192 Chief []string `json:"chief,omitempty"` 193 Evaluator []string `json:"evaluator,omitempty"` 194 } 195 196 type tfTaskType string 197 198 const ( 199 tfWorker tfTaskType = "worker" 200 tfChief tfTaskType = "chief" 201 tfPS tfTaskType = "ps" 202 tfEvaluator tfTaskType = "evaluator" 203 ) 204 205 type taskInfo struct { 206 Type tfTaskType `json:"type"` 207 Index int `json:"index"` 208 }