github.com/kubeflow/training-operator@v1.7.0/examples/xgboost/smoke-dist/xgboost_smoke_test.py (about)

     1  
     2  # Copyright 2018 Google Inc. All Rights Reserved.
     3  #
     4  # Licensed under the Apache License, Version 2.0 (the "License");
     5  # you may not use this file except in compliance with the License.
     6  # You may obtain a copy of the License at
     7  #
     8  #     http://www.apache.org/licenses/LICENSE-2.0
     9  #
    10  # Unless required by applicable law or agreed to in writing, software
    11  # distributed under the License is distributed on an "AS IS" BASIS,
    12  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  # See the License for the specific language governing permissions and
    14  # limitations under the License.
    15  
    16  import logging
    17  import os
    18  import xgboost as xgb
    19  import traceback
    20  
    21  from tracker import RabitTracker
    22  
    23  logger = logging.getLogger(__name__)
    24  
    25  def extract_xgbooost_cluster_env():
    26  
    27      logger.info("start to extract system env")
    28  
    29      master_addr = os.environ.get("MASTER_ADDR", "{}")
    30      master_port = int(os.environ.get("MASTER_PORT", "{}"))
    31      rank = int(os.environ.get("RANK", "{}"))
    32      world_size = int(os.environ.get("WORLD_SIZE", "{}"))
    33  
    34      logger.info("extract the rabit env from cluster : %s, port: %d, rank: %d, word_size: %d ",
    35                  master_addr, master_port, rank, world_size)
    36  
    37      return master_addr, master_port, rank, world_size
    38  
    39  def setup_rabit_cluster():
    40      addr, port, rank, world_size = extract_xgbooost_cluster_env()
    41  
    42      rabit_tracker = None
    43      try:
    44          """start to build the network"""
    45          if world_size > 1:
    46              if rank == 0:
    47                  logger.info("start the master node")
    48  
    49                  rabit = RabitTracker(hostIP="0.0.0.0", nslave=world_size,
    50                                       port=port, port_end=port + 1)
    51                  rabit.start(world_size)
    52                  rabit_tracker = rabit
    53                  logger.info('########### RabitTracker Setup Finished #########')
    54  
    55              envs = [
    56                  'DMLC_NUM_WORKER=%d' % world_size,
    57                  'DMLC_TRACKER_URI=%s' % addr,
    58                  'DMLC_TRACKER_PORT=%d' % port,
    59                  'DMLC_TASK_ID=%d' % rank
    60              ]
    61              logger.info('##### Rabit rank setup with below envs #####')
    62              for i, env in enumerate(envs):
    63                  logger.info(env)
    64                  envs[i] = str.encode(env)
    65  
    66              xgb.rabit.init(envs)
    67              logger.info('##### Rabit rank = %d' % xgb.rabit.get_rank())
    68  
    69              rank = xgb.rabit.get_rank()
    70              s = None
    71              if rank == 0:
    72                  s = {'hello world': 100, 2: 3}
    73  
    74              logger.info('@node[%d] before-broadcast: s=\"%s\"' % (rank, str(s)))
    75              s = xgb.rabit.broadcast(s, 0)
    76  
    77              logger.info('@node[%d] after-broadcast: s=\"%s\"' % (rank, str(s)))
    78  
    79      except Exception as e:
    80          logger.error("something wrong happen: %s", traceback.format_exc())
    81          raise e
    82      finally:
    83          if world_size > 1:
    84              xgb.rabit.finalize()
    85          if rabit_tracker:
    86              rabit_tracker.join()
    87  
    88          logger.info("the rabit network testing finished!")
    89  
    90  def main():
    91  
    92      port = os.environ.get("MASTER_PORT", "{}")
    93      logging.info("MASTER_PORT: %s", port)
    94  
    95      addr = os.environ.get("MASTER_ADDR", "{}")
    96      logging.info("MASTER_ADDR: %s", addr)
    97  
    98      world_size = os.environ.get("WORLD_SIZE", "{}")
    99      logging.info("WORLD_SIZE: %s", world_size)
   100  
   101      rank = os.environ.get("RANK", "{}")
   102      logging.info("RANK: %s", rank)
   103  
   104      setup_rabit_cluster()
   105  
   106  if __name__ == "__main__":
   107      logging.getLogger().setLevel(logging.INFO)
   108      main()