github.com/kubeflow/training-operator@v1.7.0/pkg/controller.v1/tensorflow/tensorflow.go (about) 1 // Copyright 2018 The Kubeflow Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package controller provides a Kubernetes controller for a TFJob resource. 16 package tensorflow 17 18 import ( 19 "encoding/json" 20 "fmt" 21 "os" 22 "strconv" 23 "strings" 24 25 kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" 26 "github.com/kubeflow/training-operator/pkg/controller.v1/common" 27 ) 28 29 const ( 30 // EnvCustomClusterDomain is the custom defined cluster domain, such as "svc.cluster.local". 31 // Ref: https://kubernetes.io/docs/concepts/services-networking/dns-pod-service/#a-records 32 EnvCustomClusterDomain = "CUSTOM_CLUSTER_DOMAIN" 33 ) 34 35 // TaskSpec is the specification for a task (PS or worker) of the TFJob. 36 type TaskSpec struct { 37 Type string `json:"type"` 38 Index int `json:"index"` 39 } 40 41 // ClusterSpec represents a cluster TensorFlow specification. 42 // https://www.tensorflow.org/deploy/distributed#create_a_tftrainclusterspec_to_describe_the_cluster 43 // It is a map from job names to network addresses. 44 type ClusterSpec map[string][]string 45 46 // TFConfig is a struct representing the distributed TensorFlow config. 47 // This struct is turned into an environment variable TF_CONFIG 48 // which is used by TensorFlow processes to configure themselves. 49 // https://www.tensorflow.org/api_docs/python/tf/estimator/RunConfig#methods 50 // https://cloud.google.com/ml-engine/docs/tensorflow/distributed-training-details 51 type TFConfig struct { 52 // Cluster represents a TensorFlow ClusterSpec. 53 // See: https://www.tensorflow.org/api_docs/python/tf/train/ClusterSpec 54 Cluster ClusterSpec `json:"cluster"` 55 Task TaskSpec `json:"task"` 56 // Environment is used by tensorflow.contrib.learn.python.learn in versions <= 1.3 57 // TODO(jlewi): I don't think it is used in versions TF >- 1.4. So we can eventually get rid of it. 58 Environment string `json:"environment"` 59 } 60 61 // SparseClusterSpec enables a server to be configured without needing to know 62 // the identity of (for example) all other worker tasks. 63 // https://www.tensorflow.org/api_docs/python/tf/train/ClusterSpec 64 type SparseClusterSpec struct { 65 Worker map[int32]string `json:"worker"` 66 PS []string `json:"ps"` 67 } 68 69 type SparseTFConfig struct { 70 Cluster SparseClusterSpec `json:"cluster"` 71 Task TaskSpec `json:"task"` 72 } 73 74 func convertClusterSpecToSparseClusterSpec(clusterSpec ClusterSpec, rtype string, index int32) SparseClusterSpec { 75 sparseClusterSpec := SparseClusterSpec{Worker: map[int32]string{}, PS: []string{}} 76 if rtype == strings.ToLower(string(kubeflowv1.TFJobReplicaTypePS)) { 77 sparseClusterSpec.PS = append(sparseClusterSpec.PS, clusterSpec[rtype][index]) 78 } else if rtype == strings.ToLower(string(kubeflowv1.TFJobReplicaTypeWorker)) { 79 sparseClusterSpec.PS = clusterSpec[strings.ToLower(string(kubeflowv1.TFJobReplicaTypePS))] 80 sparseClusterSpec.Worker[index] = clusterSpec[rtype][index] 81 } 82 return sparseClusterSpec 83 } 84 85 // genTFConfig will generate the environment variable TF_CONFIG 86 // 87 // { 88 // "cluster": { 89 // "ps": ["ps1:2222", "ps2:2222"], 90 // "worker": ["worker1:2222", "worker2:2222", "worker3:2222"] 91 // }, 92 // "task": { 93 // "type": "ps", 94 // "index": 1 95 // }, 96 // } 97 // } 98 // 99 // if EnableDynamicWorker set true 100 // 101 // { 102 // "cluster": { 103 // "ps": ["ps1:2222", "ps2:2222"], 104 // "worker": {"1":"worker1:2222"} 105 // }, 106 // "task": { 107 // "type": "worker", 108 // "index": 1 109 // }, 110 // } 111 // } 112 func genTFConfigJSONStr(tfjob *kubeflowv1.TFJob, rtype, index string) (string, error) { 113 // Configure the TFCONFIG environment variable. 114 i, err := strconv.ParseInt(index, 0, 32) 115 if err != nil { 116 return "", err 117 } 118 119 cluster, err := genClusterSpec(tfjob) 120 if err != nil { 121 return "", err 122 } 123 124 var tfConfigJSONByteSlice []byte 125 if tfjob.Spec.EnableDynamicWorker { 126 sparseCluster := convertClusterSpecToSparseClusterSpec(cluster, strings.ToLower(rtype), int32(i)) 127 sparseTFConfig := SparseTFConfig{ 128 Cluster: sparseCluster, 129 Task: TaskSpec{ 130 Type: strings.ToLower(rtype), 131 Index: int(i), 132 }, 133 } 134 tfConfigJSONByteSlice, err = json.Marshal(sparseTFConfig) 135 } else { 136 tfConfig := TFConfig{ 137 Cluster: cluster, 138 Task: TaskSpec{ 139 Type: strings.ToLower(rtype), 140 Index: int(i), 141 }, 142 // We need to set environment to cloud otherwise it will default to local which isn't what we want. 143 // Environment is used by tensorflow.contrib.learn.python.learn in versions <= 1.3 144 // TODO(jlewi): I don't think it is used in versions TF >- 1.4. So we can eventually get rid of it. 145 Environment: "cloud", 146 } 147 tfConfigJSONByteSlice, err = json.Marshal(tfConfig) 148 } 149 if err != nil { 150 return "", err 151 } 152 153 return string(tfConfigJSONByteSlice), nil 154 } 155 156 // genClusterSpec will generate ClusterSpec. 157 func genClusterSpec(tfjob *kubeflowv1.TFJob) (ClusterSpec, error) { 158 clusterSpec := make(ClusterSpec) 159 160 for rtype, spec := range tfjob.Spec.TFReplicaSpecs { 161 rt := strings.ToLower(string(rtype)) 162 replicaNames := make([]string, 0, *spec.Replicas) 163 164 port, err := GetPortFromTFJob(tfjob, rtype) 165 if err != nil { 166 return nil, err 167 } 168 for i := int32(0); i < *spec.Replicas; i++ { 169 // As described here: https://kubernetes.io/docs/concepts/services-networking/dns-pod-service/#a-records. 170 // Headless service assigned a DNS A record for a name of the form "my-svc.my-namespace.svc.cluster.local". 171 // And the last part "svc.cluster.local" is called cluster domain 172 // which maybe different between kubernetes clusters. 173 hostName := common.GenGeneralName(tfjob.Name, rt, fmt.Sprintf("%d", i)) 174 svcName := hostName + "." + tfjob.Namespace + "." + "svc" 175 clusterDomain := os.Getenv(EnvCustomClusterDomain) 176 if len(clusterDomain) > 0 { 177 svcName += "." + clusterDomain 178 } 179 180 endpoint := fmt.Sprintf("%s:%d", svcName, port) 181 replicaNames = append(replicaNames, endpoint) 182 } 183 184 clusterSpec[rt] = replicaNames 185 } 186 187 return clusterSpec, nil 188 }