github.com/kubeflow/training-operator@v1.7.0/pkg/controller.v1/xgboost/xgboost.go (about) 1 // Copyright 2021 The Kubeflow Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License 14 15 package xgboost 16 17 import ( 18 "fmt" 19 "strconv" 20 "strings" 21 22 corev1 "k8s.io/api/core/v1" 23 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" 24 25 kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" 26 ) 27 28 // SetPodEnv sets the pod env set for: 29 // - XGBoost Rabit Tracker and worker 30 // - LightGBM master and workers 31 func SetPodEnv(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error { 32 xgboostjob, ok := job.(*kubeflowv1.XGBoostJob) 33 if !ok { 34 return fmt.Errorf("%+v is not a type of XGBoostJob", xgboostjob) 35 } 36 37 rank, err := strconv.Atoi(index) 38 if err != nil { 39 return err 40 } 41 42 // Add master offset for worker pods 43 if strings.EqualFold(strings.ToLower(rtype), strings.ToLower(string(kubeflowv1.XGBoostJobReplicaTypeWorker))) { 44 masterSpec := xgboostjob.Spec.XGBReplicaSpecs[kubeflowv1.XGBoostJobReplicaTypeMaster] 45 masterReplicas := int(*masterSpec.Replicas) 46 rank += masterReplicas 47 } 48 49 masterAddr := replicaName(xgboostjob.Name, kubeflowv1.XGBoostJobReplicaTypeMaster, 0) 50 51 masterPort, err := getPortFromXGBoostJob(xgboostjob, kubeflowv1.XGBoostJobReplicaTypeMaster) 52 if err != nil { 53 return err 54 } 55 56 totalReplicas := computeTotalReplicas(xgboostjob) 57 58 var workerPort int32 59 var workerAddrs []string 60 61 if totalReplicas > 1 { 62 workerPortTemp, err := getPortFromXGBoostJob(xgboostjob, kubeflowv1.XGBoostJobReplicaTypeWorker) 63 if err != nil { 64 return err 65 } 66 workerPort = workerPortTemp 67 workerAddrs = make([]string, totalReplicas-1) 68 for i := range workerAddrs { 69 workerAddrs[i] = replicaName(xgboostjob.Name, kubeflowv1.XGBoostJobReplicaTypeWorker, i) 70 } 71 } 72 73 for i := range podTemplate.Spec.Containers { 74 if len(podTemplate.Spec.Containers[i].Env) == 0 { 75 podTemplate.Spec.Containers[i].Env = make([]corev1.EnvVar, 0) 76 } 77 podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{ 78 Name: "MASTER_PORT", 79 Value: strconv.Itoa(int(masterPort)), 80 }) 81 podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{ 82 Name: "MASTER_ADDR", 83 Value: masterAddr, 84 }) 85 podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{ 86 Name: "WORLD_SIZE", 87 Value: strconv.Itoa(int(totalReplicas)), 88 }) 89 podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{ 90 Name: "RANK", 91 Value: strconv.Itoa(rank), 92 }) 93 podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{ 94 Name: "PYTHONUNBUFFERED", 95 Value: "1", 96 }) 97 // This variables are used if it is a LightGBM job 98 if totalReplicas > 1 { 99 podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{ 100 Name: "WORKER_PORT", 101 Value: strconv.Itoa(int(workerPort)), 102 }) 103 podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{ 104 Name: "WORKER_ADDRS", 105 Value: strings.Join(workerAddrs, ","), 106 }) 107 } 108 } 109 110 return nil 111 } 112 113 func replicaName(jobName string, rtype kubeflowv1.ReplicaType, index int) string { 114 n := jobName + "-" + strings.ToLower(string(rtype)) + "-" + strconv.Itoa(index) 115 return strings.Replace(n, "/", "-", -1) 116 } 117 118 // getPortFromXGBoostJob gets the port of xgboost container. 119 func getPortFromXGBoostJob(job *kubeflowv1.XGBoostJob, rtype kubeflowv1.ReplicaType) (int32, error) { 120 containers := job.Spec.XGBReplicaSpecs[rtype].Template.Spec.Containers 121 for _, container := range containers { 122 if container.Name == kubeflowv1.XGBoostJobDefaultContainerName { 123 ports := container.Ports 124 for _, port := range ports { 125 if port.Name == kubeflowv1.XGBoostJobDefaultPortName { 126 return port.ContainerPort, nil 127 } 128 } 129 } 130 } 131 return -1, fmt.Errorf("failed to found the port") 132 } 133 134 func computeTotalReplicas(obj metav1.Object) int32 { 135 job := obj.(*kubeflowv1.XGBoostJob) 136 jobReplicas := int32(0) 137 138 if job.Spec.XGBReplicaSpecs == nil || len(job.Spec.XGBReplicaSpecs) == 0 { 139 return jobReplicas 140 } 141 for _, r := range job.Spec.XGBReplicaSpecs { 142 if r.Replicas == nil { 143 continue 144 } else { 145 jobReplicas += *r.Replicas 146 } 147 } 148 return jobReplicas 149 }