github.com/kubeflow/training-operator@v1.7.0/examples/xgboost/xgboost-dist/train.py (about)

     1  # Licensed under the Apache License, Version 2.0 (the "License");
     2  # you may not use this file except in compliance with the License.
     3  # You may obtain a copy of the License at
     4  #
     5  #     http://www.apache.org/licenses/LICENSE-2.0
     6  #
     7  # Unless required by applicable law or agreed to in writing, software
     8  # distributed under the License is distributed on an "AS IS" BASIS,
     9  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    10  # See the License for the specific language governing permissions and
    11  # limitations under the License.
    12  
    13  
    14  import logging
    15  import xgboost as xgb
    16  import traceback
    17  
    18  from tracker import RabitTracker
    19  from utils import read_train_data, extract_xgbooost_cluster_env
    20  
    21  logger = logging.getLogger(__name__)
    22  
    23  
    24  def train(args):
    25      """
    26      :param args: configuration for train job
    27      :return: XGBoost model
    28      """
    29      addr, port, rank, world_size = extract_xgbooost_cluster_env()
    30      rabit_tracker = None
    31  
    32      try:
    33          """start to build the network"""
    34          if world_size > 1:
    35              if rank == 0:
    36                  logger.info("start the master node")
    37  
    38                  rabit = RabitTracker(hostIP="0.0.0.0", nslave=world_size,
    39                                       port=port, port_end=port + 1)
    40                  rabit.start(world_size)
    41                  rabit_tracker = rabit
    42                  logger.info('###### RabitTracker Setup Finished ######')
    43  
    44              envs = [
    45                  'DMLC_NUM_WORKER=%d' % world_size,
    46                  'DMLC_TRACKER_URI=%s' % addr,
    47                  'DMLC_TRACKER_PORT=%d' % port,
    48                  'DMLC_TASK_ID=%d' % rank
    49              ]
    50              logger.info('##### Rabit rank setup with below envs #####')
    51              for i, env in enumerate(envs):
    52                  logger.info(env)
    53                  envs[i] = str.encode(env)
    54  
    55              xgb.rabit.init(envs)
    56              logger.info('##### Rabit rank = %d' % xgb.rabit.get_rank())
    57              rank = xgb.rabit.get_rank()
    58  
    59          else:
    60              world_size = 1
    61              logging.info("Start the train in a single node")
    62  
    63          df = read_train_data(rank=rank, num_workers=world_size, path=None)
    64          kwargs = {}
    65          kwargs["dtrain"] = df
    66          kwargs["num_boost_round"] = int(args.n_estimators)
    67          param_xgboost_default = {'max_depth': 2, 'eta': 1, 'silent': 1,
    68                                   'objective': 'multi:softprob', 'num_class': 3}
    69          kwargs["params"] = param_xgboost_default
    70  
    71          logging.info("starting to train xgboost at node with rank %d", rank)
    72          bst = xgb.train(**kwargs)
    73  
    74          if rank == 0:
    75              model = bst
    76          else:
    77              model = None
    78  
    79          logging.info("finish xgboost training at node with rank %d", rank)
    80  
    81      except Exception as e:
    82          logger.error("something wrong happen: %s", traceback.format_exc())
    83          raise e
    84      finally:
    85          logger.info("xgboost training job finished!")
    86          if world_size > 1:
    87              xgb.rabit.finalize()
    88          if rabit_tracker:
    89              rabit_tracker.join()
    90  
    91      return model