github.com/kubeflow/training-operator@v1.7.0/pkg/controller.v1/tensorflow/tensorflow.go (about)

     1  // Copyright 2018 The Kubeflow Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package controller provides a Kubernetes controller for a TFJob resource.
    16  package tensorflow
    17  
    18  import (
    19  	"encoding/json"
    20  	"fmt"
    21  	"os"
    22  	"strconv"
    23  	"strings"
    24  
    25  	kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
    26  	"github.com/kubeflow/training-operator/pkg/controller.v1/common"
    27  )
    28  
    29  const (
    30  	// EnvCustomClusterDomain is the custom defined cluster domain, such as "svc.cluster.local".
    31  	// Ref: https://kubernetes.io/docs/concepts/services-networking/dns-pod-service/#a-records
    32  	EnvCustomClusterDomain = "CUSTOM_CLUSTER_DOMAIN"
    33  )
    34  
    35  // TaskSpec is the specification for a task (PS or worker) of the TFJob.
    36  type TaskSpec struct {
    37  	Type  string `json:"type"`
    38  	Index int    `json:"index"`
    39  }
    40  
    41  // ClusterSpec represents a cluster TensorFlow specification.
    42  // https://www.tensorflow.org/deploy/distributed#create_a_tftrainclusterspec_to_describe_the_cluster
    43  // It is a map from job names to network addresses.
    44  type ClusterSpec map[string][]string
    45  
    46  // TFConfig is a struct representing the distributed TensorFlow config.
    47  // This struct is turned into an environment variable TF_CONFIG
    48  // which is used by TensorFlow processes to configure themselves.
    49  // https://www.tensorflow.org/api_docs/python/tf/estimator/RunConfig#methods
    50  // https://cloud.google.com/ml-engine/docs/tensorflow/distributed-training-details
    51  type TFConfig struct {
    52  	// Cluster represents a TensorFlow ClusterSpec.
    53  	// See: https://www.tensorflow.org/api_docs/python/tf/train/ClusterSpec
    54  	Cluster ClusterSpec `json:"cluster"`
    55  	Task    TaskSpec    `json:"task"`
    56  	// Environment is used by tensorflow.contrib.learn.python.learn in versions <= 1.3
    57  	// TODO(jlewi): I don't think it is used in versions TF >- 1.4. So we can eventually get rid of it.
    58  	Environment string `json:"environment"`
    59  }
    60  
    61  // SparseClusterSpec enables a server to be configured without needing to know
    62  // the identity of (for example) all other worker tasks.
    63  // https://www.tensorflow.org/api_docs/python/tf/train/ClusterSpec
    64  type SparseClusterSpec struct {
    65  	Worker map[int32]string `json:"worker"`
    66  	PS     []string         `json:"ps"`
    67  }
    68  
    69  type SparseTFConfig struct {
    70  	Cluster SparseClusterSpec `json:"cluster"`
    71  	Task    TaskSpec          `json:"task"`
    72  }
    73  
    74  func convertClusterSpecToSparseClusterSpec(clusterSpec ClusterSpec, rtype string, index int32) SparseClusterSpec {
    75  	sparseClusterSpec := SparseClusterSpec{Worker: map[int32]string{}, PS: []string{}}
    76  	if rtype == strings.ToLower(string(kubeflowv1.TFJobReplicaTypePS)) {
    77  		sparseClusterSpec.PS = append(sparseClusterSpec.PS, clusterSpec[rtype][index])
    78  	} else if rtype == strings.ToLower(string(kubeflowv1.TFJobReplicaTypeWorker)) {
    79  		sparseClusterSpec.PS = clusterSpec[strings.ToLower(string(kubeflowv1.TFJobReplicaTypePS))]
    80  		sparseClusterSpec.Worker[index] = clusterSpec[rtype][index]
    81  	}
    82  	return sparseClusterSpec
    83  }
    84  
    85  // genTFConfig will generate the environment variable TF_CONFIG
    86  //
    87  //	{
    88  //	    "cluster": {
    89  //	        "ps": ["ps1:2222", "ps2:2222"],
    90  //	        "worker": ["worker1:2222", "worker2:2222", "worker3:2222"]
    91  //	    },
    92  //	    "task": {
    93  //	        "type": "ps",
    94  //	        "index": 1
    95  //	        },
    96  //	    }
    97  //	}
    98  //
    99  // if EnableDynamicWorker set true
   100  //
   101  //	{
   102  //	    "cluster": {
   103  //	        "ps": ["ps1:2222", "ps2:2222"],
   104  //	        "worker": {"1":"worker1:2222"}
   105  //	    },
   106  //	    "task": {
   107  //	        "type": "worker",
   108  //	        "index": 1
   109  //	        },
   110  //	    }
   111  //	}
   112  func genTFConfigJSONStr(tfjob *kubeflowv1.TFJob, rtype, index string) (string, error) {
   113  	// Configure the TFCONFIG environment variable.
   114  	i, err := strconv.ParseInt(index, 0, 32)
   115  	if err != nil {
   116  		return "", err
   117  	}
   118  
   119  	cluster, err := genClusterSpec(tfjob)
   120  	if err != nil {
   121  		return "", err
   122  	}
   123  
   124  	var tfConfigJSONByteSlice []byte
   125  	if tfjob.Spec.EnableDynamicWorker {
   126  		sparseCluster := convertClusterSpecToSparseClusterSpec(cluster, strings.ToLower(rtype), int32(i))
   127  		sparseTFConfig := SparseTFConfig{
   128  			Cluster: sparseCluster,
   129  			Task: TaskSpec{
   130  				Type:  strings.ToLower(rtype),
   131  				Index: int(i),
   132  			},
   133  		}
   134  		tfConfigJSONByteSlice, err = json.Marshal(sparseTFConfig)
   135  	} else {
   136  		tfConfig := TFConfig{
   137  			Cluster: cluster,
   138  			Task: TaskSpec{
   139  				Type:  strings.ToLower(rtype),
   140  				Index: int(i),
   141  			},
   142  			// We need to set environment to cloud  otherwise it will default to local which isn't what we want.
   143  			// Environment is used by tensorflow.contrib.learn.python.learn in versions <= 1.3
   144  			// TODO(jlewi): I don't think it is used in versions TF >- 1.4. So we can eventually get rid of it.
   145  			Environment: "cloud",
   146  		}
   147  		tfConfigJSONByteSlice, err = json.Marshal(tfConfig)
   148  	}
   149  	if err != nil {
   150  		return "", err
   151  	}
   152  
   153  	return string(tfConfigJSONByteSlice), nil
   154  }
   155  
   156  // genClusterSpec will generate ClusterSpec.
   157  func genClusterSpec(tfjob *kubeflowv1.TFJob) (ClusterSpec, error) {
   158  	clusterSpec := make(ClusterSpec)
   159  
   160  	for rtype, spec := range tfjob.Spec.TFReplicaSpecs {
   161  		rt := strings.ToLower(string(rtype))
   162  		replicaNames := make([]string, 0, *spec.Replicas)
   163  
   164  		port, err := GetPortFromTFJob(tfjob, rtype)
   165  		if err != nil {
   166  			return nil, err
   167  		}
   168  		for i := int32(0); i < *spec.Replicas; i++ {
   169  			// As described here: https://kubernetes.io/docs/concepts/services-networking/dns-pod-service/#a-records.
   170  			// Headless service assigned a DNS A record for a name of the form "my-svc.my-namespace.svc.cluster.local".
   171  			// And the last part "svc.cluster.local" is called cluster domain
   172  			// which maybe different between kubernetes clusters.
   173  			hostName := common.GenGeneralName(tfjob.Name, rt, fmt.Sprintf("%d", i))
   174  			svcName := hostName + "." + tfjob.Namespace + "." + "svc"
   175  			clusterDomain := os.Getenv(EnvCustomClusterDomain)
   176  			if len(clusterDomain) > 0 {
   177  				svcName += "." + clusterDomain
   178  			}
   179  
   180  			endpoint := fmt.Sprintf("%s:%d", svcName, port)
   181  			replicaNames = append(replicaNames, endpoint)
   182  		}
   183  
   184  		clusterSpec[rt] = replicaNames
   185  	}
   186  
   187  	return clusterSpec, nil
   188  }