github.com/kubeflow/training-operator@v1.7.0/pkg/controller.v1/mxnet/mxnet.go (about)

     1  // Copyright 2021 The Kubeflow Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License
    14  
    15  package mxnet
    16  
    17  import (
    18  	"encoding/json"
    19  	"fmt"
    20  	"strconv"
    21  	"strings"
    22  
    23  	kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1"
    24  	"github.com/kubeflow/training-operator/pkg/controller.v1/common"
    25  
    26  	corev1 "k8s.io/api/core/v1"
    27  )
    28  
    29  const (
    30  	// Label is used as tunerServerKey, it's designed for tvm auto-tuning.
    31  	mxJobTunerServerKey = "tuner-server-key"
    32  	// mxConfig is the environment variable name of MXNet cluster spec.
    33  	mxConfig = "MX_CONFIG"
    34  )
    35  
    36  var (
    37  	errPortNotFound = fmt.Errorf("failed to found the port")
    38  )
    39  
    40  // MXConfig is a struct representing the distributed Mxnet config.
    41  // This struct is turned into an environment variable MX_CONFIG
    42  // which is used by Mxnet processes to configure themselves.
    43  type MXConfig struct {
    44  	// Cluster represents a Mxnet ClusterSpec.
    45  	Cluster ClusterSpec `json:"cluster"`
    46  	// Labels include all label of task.
    47  	Labels LabelsSpec `json:"labels"`
    48  	// Task include information of current node.
    49  	Task TaskSpec `json:"task"`
    50  }
    51  
    52  // ClusterSpec represents a cluster Mxnet specification.
    53  type ClusterSpec map[string][]UrlPort
    54  
    55  type UrlPort struct {
    56  	Url  string `json:"url"`
    57  	Port int    `json:"port"`
    58  }
    59  
    60  // LabelsSpec represents a label specification.
    61  type LabelsSpec map[string]string
    62  
    63  // TaskSpec is the specification for a task (server or worker ...) of the MXJob.
    64  type TaskSpec struct {
    65  	Type  string `json:"type"`
    66  	Index int    `json:"index"`
    67  }
    68  
    69  func SetPodEnv(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error {
    70  	mxJob, ok := job.(*kubeflowv1.MXJob)
    71  	if !ok {
    72  		return fmt.Errorf("%v is not a type of MXJob", mxJob)
    73  	}
    74  
    75  	// Generate MX_CONFIG JSON.
    76  	mxConfigData, err := genMXConfig(mxJob, rtype, index)
    77  	if err != nil {
    78  		return err
    79  	}
    80  
    81  	// Generate MX_CONFIG JSON Str.
    82  	mxConfigJson, err := json.Marshal(mxConfigData)
    83  	if err != nil {
    84  		return err
    85  	}
    86  
    87  	// Add MX_CONFIG environment variable.
    88  	for i := range podTemplate.Spec.Containers {
    89  
    90  		c := &podTemplate.Spec.Containers[i]
    91  
    92  		// Set environment variable MX_CONFIG
    93  		c.Env = append(c.Env, corev1.EnvVar{
    94  			Name:  mxConfig,
    95  			Value: string(mxConfigJson),
    96  		})
    97  
    98  		// Set Mxnet Distributed Training environment variable
    99  		// We get these envs from MX_COFING to make them stay identical
   100  		c.Env = append(c.Env, corev1.EnvVar{
   101  			Name:  "DMLC_PS_ROOT_PORT",
   102  			Value: strconv.Itoa(getConfigAddr(&mxConfigData, kubeflowv1.MXJobReplicaTypeScheduler, 0).Port),
   103  		})
   104  
   105  		c.Env = append(c.Env, corev1.EnvVar{
   106  			Name:  "DMLC_PS_ROOT_URI",
   107  			Value: getConfigAddr(&mxConfigData, kubeflowv1.MXJobReplicaTypeScheduler, 0).Url,
   108  		})
   109  
   110  		c.Env = append(c.Env, corev1.EnvVar{
   111  			Name:  "DMLC_NUM_SERVER",
   112  			Value: strconv.Itoa(getConfigReplica(&mxConfigData, kubeflowv1.MXJobReplicaTypeServer)),
   113  		})
   114  
   115  		c.Env = append(c.Env, corev1.EnvVar{
   116  			Name:  "DMLC_NUM_WORKER",
   117  			Value: strconv.Itoa(getConfigReplica(&mxConfigData, kubeflowv1.MXJobReplicaTypeWorker)),
   118  		})
   119  
   120  		c.Env = append(c.Env, corev1.EnvVar{
   121  			Name:  "DMLC_ROLE",
   122  			Value: mxConfigData.Task.Type,
   123  		})
   124  
   125  		c.Env = append(c.Env, corev1.EnvVar{
   126  			Name:  "DMLC_USE_KUBERNETES",
   127  			Value: strconv.Itoa(1),
   128  		})
   129  
   130  		// BytePS needs env DMLC_WORKER_ID for each worker
   131  		addBytePSEnv(c, rtype, index)
   132  	}
   133  	return nil
   134  }
   135  
   136  func genMXConfig(mxjob *kubeflowv1.MXJob, rtype, index string) (MXConfig, error) {
   137  	// Configure the MXCONFIG environment variable.
   138  	i, err := strconv.ParseInt(index, 0, 32)
   139  	if err != nil {
   140  		return MXConfig{}, err
   141  	}
   142  
   143  	cluster, err := genClusterSpec(mxjob)
   144  	if err != nil {
   145  		return MXConfig{}, err
   146  	}
   147  
   148  	labels, err := genLabelsSpec(mxjob)
   149  	if err != nil {
   150  		return MXConfig{}, err
   151  	}
   152  
   153  	mxConfig := MXConfig{
   154  		Cluster: cluster,
   155  		Labels:  labels,
   156  		Task: TaskSpec{
   157  			Type:  rtype,
   158  			Index: int(i),
   159  		},
   160  	}
   161  
   162  	return mxConfig, nil
   163  }
   164  
   165  // genClusterSpec will generate ClusterSpec.
   166  func genClusterSpec(mxjob *kubeflowv1.MXJob) (ClusterSpec, error) {
   167  	clusterSpec := make(ClusterSpec)
   168  
   169  	for rtype, spec := range mxjob.Spec.MXReplicaSpecs {
   170  		rt := strings.ToLower(string(rtype))
   171  		replicaNames := make([]UrlPort, 0, *spec.Replicas)
   172  
   173  		port, err := getPortFromMXJob(mxjob, rtype)
   174  		if err != nil {
   175  			return nil, err
   176  		}
   177  		for i := int32(0); i < *spec.Replicas; i++ {
   178  			host := UrlPort{
   179  				Url:  common.GenGeneralName(mxjob.Name, rt, fmt.Sprintf("%d", i)),
   180  				Port: int(port),
   181  			}
   182  			replicaNames = append(replicaNames, host)
   183  		}
   184  
   185  		clusterSpec[rt] = replicaNames
   186  	}
   187  
   188  	return clusterSpec, nil
   189  }
   190  
   191  // genLabelsSpec will generate LabelsSpec.
   192  func genLabelsSpec(mxjob *kubeflowv1.MXJob) (LabelsSpec, error) {
   193  	labelsSpec := make(LabelsSpec)
   194  
   195  	for rtype, spec := range mxjob.Spec.MXReplicaSpecs {
   196  		rt := strings.ToLower(string(rtype))
   197  
   198  		labelsSpec[rt] = spec.Template.Annotations[mxJobTunerServerKey]
   199  	}
   200  
   201  	return labelsSpec, nil
   202  }
   203  
   204  func getConfigAddr(mxConfigData *MXConfig, rtype kubeflowv1.ReplicaType, index int) UrlPort {
   205  	rt := strings.ToLower(string(rtype))
   206  	var urlPort UrlPort
   207  	if len(mxConfigData.Cluster[rt]) <= index {
   208  		// index out of range, maybe this url doen't exist
   209  		urlPort = UrlPort{
   210  			Url:  "",
   211  			Port: 0,
   212  		}
   213  	} else {
   214  		urlPort = mxConfigData.Cluster[rt][index]
   215  	}
   216  	return urlPort
   217  }
   218  
   219  func getConfigReplica(mxConfigData *MXConfig, rtype kubeflowv1.ReplicaType) int {
   220  	rt := strings.ToLower(string(rtype))
   221  	return len(mxConfigData.Cluster[rt])
   222  }
   223  
   224  // getPortFromMXJob gets the port of mxnet container.
   225  func getPortFromMXJob(mxJob *kubeflowv1.MXJob, rtype kubeflowv1.ReplicaType) (int32, error) {
   226  	containers := mxJob.Spec.MXReplicaSpecs[rtype].Template.Spec.Containers
   227  	for _, container := range containers {
   228  		if container.Name == kubeflowv1.MXJobDefaultContainerName {
   229  			ports := container.Ports
   230  			for _, port := range ports {
   231  				if port.Name == kubeflowv1.MXJobDefaultPortName {
   232  					return port.ContainerPort, nil
   233  				}
   234  			}
   235  		}
   236  	}
   237  	return -1, errPortNotFound
   238  }
   239  
   240  func addBytePSEnv(c *corev1.Container, rtype, index string) {
   241  	if rtype == strings.ToLower(string(kubeflowv1.MXJobReplicaTypeWorker)) {
   242  		c.Env = append(c.Env, corev1.EnvVar{
   243  			Name:  "DMLC_WORKER_ID",
   244  			Value: index,
   245  		})
   246  	}
   247  }