github.com/kubeflow/training-operator@v1.7.0/pkg/controller.v1/mxnet/mxnet.go (about) 1 // Copyright 2021 The Kubeflow Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License 14 15 package mxnet 16 17 import ( 18 "encoding/json" 19 "fmt" 20 "strconv" 21 "strings" 22 23 kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" 24 "github.com/kubeflow/training-operator/pkg/controller.v1/common" 25 26 corev1 "k8s.io/api/core/v1" 27 ) 28 29 const ( 30 // Label is used as tunerServerKey, it's designed for tvm auto-tuning. 31 mxJobTunerServerKey = "tuner-server-key" 32 // mxConfig is the environment variable name of MXNet cluster spec. 33 mxConfig = "MX_CONFIG" 34 ) 35 36 var ( 37 errPortNotFound = fmt.Errorf("failed to found the port") 38 ) 39 40 // MXConfig is a struct representing the distributed Mxnet config. 41 // This struct is turned into an environment variable MX_CONFIG 42 // which is used by Mxnet processes to configure themselves. 43 type MXConfig struct { 44 // Cluster represents a Mxnet ClusterSpec. 45 Cluster ClusterSpec `json:"cluster"` 46 // Labels include all label of task. 47 Labels LabelsSpec `json:"labels"` 48 // Task include information of current node. 49 Task TaskSpec `json:"task"` 50 } 51 52 // ClusterSpec represents a cluster Mxnet specification. 53 type ClusterSpec map[string][]UrlPort 54 55 type UrlPort struct { 56 Url string `json:"url"` 57 Port int `json:"port"` 58 } 59 60 // LabelsSpec represents a label specification. 61 type LabelsSpec map[string]string 62 63 // TaskSpec is the specification for a task (server or worker ...) of the MXJob. 64 type TaskSpec struct { 65 Type string `json:"type"` 66 Index int `json:"index"` 67 } 68 69 func SetPodEnv(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error { 70 mxJob, ok := job.(*kubeflowv1.MXJob) 71 if !ok { 72 return fmt.Errorf("%v is not a type of MXJob", mxJob) 73 } 74 75 // Generate MX_CONFIG JSON. 76 mxConfigData, err := genMXConfig(mxJob, rtype, index) 77 if err != nil { 78 return err 79 } 80 81 // Generate MX_CONFIG JSON Str. 82 mxConfigJson, err := json.Marshal(mxConfigData) 83 if err != nil { 84 return err 85 } 86 87 // Add MX_CONFIG environment variable. 88 for i := range podTemplate.Spec.Containers { 89 90 c := &podTemplate.Spec.Containers[i] 91 92 // Set environment variable MX_CONFIG 93 c.Env = append(c.Env, corev1.EnvVar{ 94 Name: mxConfig, 95 Value: string(mxConfigJson), 96 }) 97 98 // Set Mxnet Distributed Training environment variable 99 // We get these envs from MX_COFING to make them stay identical 100 c.Env = append(c.Env, corev1.EnvVar{ 101 Name: "DMLC_PS_ROOT_PORT", 102 Value: strconv.Itoa(getConfigAddr(&mxConfigData, kubeflowv1.MXJobReplicaTypeScheduler, 0).Port), 103 }) 104 105 c.Env = append(c.Env, corev1.EnvVar{ 106 Name: "DMLC_PS_ROOT_URI", 107 Value: getConfigAddr(&mxConfigData, kubeflowv1.MXJobReplicaTypeScheduler, 0).Url, 108 }) 109 110 c.Env = append(c.Env, corev1.EnvVar{ 111 Name: "DMLC_NUM_SERVER", 112 Value: strconv.Itoa(getConfigReplica(&mxConfigData, kubeflowv1.MXJobReplicaTypeServer)), 113 }) 114 115 c.Env = append(c.Env, corev1.EnvVar{ 116 Name: "DMLC_NUM_WORKER", 117 Value: strconv.Itoa(getConfigReplica(&mxConfigData, kubeflowv1.MXJobReplicaTypeWorker)), 118 }) 119 120 c.Env = append(c.Env, corev1.EnvVar{ 121 Name: "DMLC_ROLE", 122 Value: mxConfigData.Task.Type, 123 }) 124 125 c.Env = append(c.Env, corev1.EnvVar{ 126 Name: "DMLC_USE_KUBERNETES", 127 Value: strconv.Itoa(1), 128 }) 129 130 // BytePS needs env DMLC_WORKER_ID for each worker 131 addBytePSEnv(c, rtype, index) 132 } 133 return nil 134 } 135 136 func genMXConfig(mxjob *kubeflowv1.MXJob, rtype, index string) (MXConfig, error) { 137 // Configure the MXCONFIG environment variable. 138 i, err := strconv.ParseInt(index, 0, 32) 139 if err != nil { 140 return MXConfig{}, err 141 } 142 143 cluster, err := genClusterSpec(mxjob) 144 if err != nil { 145 return MXConfig{}, err 146 } 147 148 labels, err := genLabelsSpec(mxjob) 149 if err != nil { 150 return MXConfig{}, err 151 } 152 153 mxConfig := MXConfig{ 154 Cluster: cluster, 155 Labels: labels, 156 Task: TaskSpec{ 157 Type: rtype, 158 Index: int(i), 159 }, 160 } 161 162 return mxConfig, nil 163 } 164 165 // genClusterSpec will generate ClusterSpec. 166 func genClusterSpec(mxjob *kubeflowv1.MXJob) (ClusterSpec, error) { 167 clusterSpec := make(ClusterSpec) 168 169 for rtype, spec := range mxjob.Spec.MXReplicaSpecs { 170 rt := strings.ToLower(string(rtype)) 171 replicaNames := make([]UrlPort, 0, *spec.Replicas) 172 173 port, err := getPortFromMXJob(mxjob, rtype) 174 if err != nil { 175 return nil, err 176 } 177 for i := int32(0); i < *spec.Replicas; i++ { 178 host := UrlPort{ 179 Url: common.GenGeneralName(mxjob.Name, rt, fmt.Sprintf("%d", i)), 180 Port: int(port), 181 } 182 replicaNames = append(replicaNames, host) 183 } 184 185 clusterSpec[rt] = replicaNames 186 } 187 188 return clusterSpec, nil 189 } 190 191 // genLabelsSpec will generate LabelsSpec. 192 func genLabelsSpec(mxjob *kubeflowv1.MXJob) (LabelsSpec, error) { 193 labelsSpec := make(LabelsSpec) 194 195 for rtype, spec := range mxjob.Spec.MXReplicaSpecs { 196 rt := strings.ToLower(string(rtype)) 197 198 labelsSpec[rt] = spec.Template.Annotations[mxJobTunerServerKey] 199 } 200 201 return labelsSpec, nil 202 } 203 204 func getConfigAddr(mxConfigData *MXConfig, rtype kubeflowv1.ReplicaType, index int) UrlPort { 205 rt := strings.ToLower(string(rtype)) 206 var urlPort UrlPort 207 if len(mxConfigData.Cluster[rt]) <= index { 208 // index out of range, maybe this url doen't exist 209 urlPort = UrlPort{ 210 Url: "", 211 Port: 0, 212 } 213 } else { 214 urlPort = mxConfigData.Cluster[rt][index] 215 } 216 return urlPort 217 } 218 219 func getConfigReplica(mxConfigData *MXConfig, rtype kubeflowv1.ReplicaType) int { 220 rt := strings.ToLower(string(rtype)) 221 return len(mxConfigData.Cluster[rt]) 222 } 223 224 // getPortFromMXJob gets the port of mxnet container. 225 func getPortFromMXJob(mxJob *kubeflowv1.MXJob, rtype kubeflowv1.ReplicaType) (int32, error) { 226 containers := mxJob.Spec.MXReplicaSpecs[rtype].Template.Spec.Containers 227 for _, container := range containers { 228 if container.Name == kubeflowv1.MXJobDefaultContainerName { 229 ports := container.Ports 230 for _, port := range ports { 231 if port.Name == kubeflowv1.MXJobDefaultPortName { 232 return port.ContainerPort, nil 233 } 234 } 235 } 236 } 237 return -1, errPortNotFound 238 } 239 240 func addBytePSEnv(c *corev1.Container, rtype, index string) { 241 if rtype == strings.ToLower(string(kubeflowv1.MXJobReplicaTypeWorker)) { 242 c.Env = append(c.Env, corev1.EnvVar{ 243 Name: "DMLC_WORKER_ID", 244 Value: index, 245 }) 246 } 247 }