github.com/kubeflow/training-operator@v1.7.0/examples/xgboost/xgboost-dist/train.py (about) 1 # Licensed under the Apache License, Version 2.0 (the "License"); 2 # you may not use this file except in compliance with the License. 3 # You may obtain a copy of the License at 4 # 5 # http://www.apache.org/licenses/LICENSE-2.0 6 # 7 # Unless required by applicable law or agreed to in writing, software 8 # distributed under the License is distributed on an "AS IS" BASIS, 9 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 # See the License for the specific language governing permissions and 11 # limitations under the License. 12 13 14 import logging 15 import xgboost as xgb 16 import traceback 17 18 from tracker import RabitTracker 19 from utils import read_train_data, extract_xgbooost_cluster_env 20 21 logger = logging.getLogger(__name__) 22 23 24 def train(args): 25 """ 26 :param args: configuration for train job 27 :return: XGBoost model 28 """ 29 addr, port, rank, world_size = extract_xgbooost_cluster_env() 30 rabit_tracker = None 31 32 try: 33 """start to build the network""" 34 if world_size > 1: 35 if rank == 0: 36 logger.info("start the master node") 37 38 rabit = RabitTracker(hostIP="0.0.0.0", nslave=world_size, 39 port=port, port_end=port + 1) 40 rabit.start(world_size) 41 rabit_tracker = rabit 42 logger.info('###### RabitTracker Setup Finished ######') 43 44 envs = [ 45 'DMLC_NUM_WORKER=%d' % world_size, 46 'DMLC_TRACKER_URI=%s' % addr, 47 'DMLC_TRACKER_PORT=%d' % port, 48 'DMLC_TASK_ID=%d' % rank 49 ] 50 logger.info('##### Rabit rank setup with below envs #####') 51 for i, env in enumerate(envs): 52 logger.info(env) 53 envs[i] = str.encode(env) 54 55 xgb.rabit.init(envs) 56 logger.info('##### Rabit rank = %d' % xgb.rabit.get_rank()) 57 rank = xgb.rabit.get_rank() 58 59 else: 60 world_size = 1 61 logging.info("Start the train in a single node") 62 63 df = read_train_data(rank=rank, num_workers=world_size, path=None) 64 kwargs = {} 65 kwargs["dtrain"] = df 66 kwargs["num_boost_round"] = int(args.n_estimators) 67 param_xgboost_default = {'max_depth': 2, 'eta': 1, 'silent': 1, 68 'objective': 'multi:softprob', 'num_class': 3} 69 kwargs["params"] = param_xgboost_default 70 71 logging.info("starting to train xgboost at node with rank %d", rank) 72 bst = xgb.train(**kwargs) 73 74 if rank == 0: 75 model = bst 76 else: 77 model = None 78 79 logging.info("finish xgboost training at node with rank %d", rank) 80 81 except Exception as e: 82 logger.error("something wrong happen: %s", traceback.format_exc()) 83 raise e 84 finally: 85 logger.info("xgboost training job finished!") 86 if world_size > 1: 87 xgb.rabit.finalize() 88 if rabit_tracker: 89 rabit_tracker.join() 90 91 return model