github.com/kubeflow/training-operator@v1.7.0/examples/xgboost/smoke-dist/xgboost_smoke_test.py (about) 1 2 # Copyright 2018 Google Inc. All Rights Reserved. 3 # 4 # Licensed under the Apache License, Version 2.0 (the "License"); 5 # you may not use this file except in compliance with the License. 6 # You may obtain a copy of the License at 7 # 8 # http://www.apache.org/licenses/LICENSE-2.0 9 # 10 # Unless required by applicable law or agreed to in writing, software 11 # distributed under the License is distributed on an "AS IS" BASIS, 12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 # See the License for the specific language governing permissions and 14 # limitations under the License. 15 16 import logging 17 import os 18 import xgboost as xgb 19 import traceback 20 21 from tracker import RabitTracker 22 23 logger = logging.getLogger(__name__) 24 25 def extract_xgbooost_cluster_env(): 26 27 logger.info("start to extract system env") 28 29 master_addr = os.environ.get("MASTER_ADDR", "{}") 30 master_port = int(os.environ.get("MASTER_PORT", "{}")) 31 rank = int(os.environ.get("RANK", "{}")) 32 world_size = int(os.environ.get("WORLD_SIZE", "{}")) 33 34 logger.info("extract the rabit env from cluster : %s, port: %d, rank: %d, word_size: %d ", 35 master_addr, master_port, rank, world_size) 36 37 return master_addr, master_port, rank, world_size 38 39 def setup_rabit_cluster(): 40 addr, port, rank, world_size = extract_xgbooost_cluster_env() 41 42 rabit_tracker = None 43 try: 44 """start to build the network""" 45 if world_size > 1: 46 if rank == 0: 47 logger.info("start the master node") 48 49 rabit = RabitTracker(hostIP="0.0.0.0", nslave=world_size, 50 port=port, port_end=port + 1) 51 rabit.start(world_size) 52 rabit_tracker = rabit 53 logger.info('########### RabitTracker Setup Finished #########') 54 55 envs = [ 56 'DMLC_NUM_WORKER=%d' % world_size, 57 'DMLC_TRACKER_URI=%s' % addr, 58 'DMLC_TRACKER_PORT=%d' % port, 59 'DMLC_TASK_ID=%d' % rank 60 ] 61 logger.info('##### Rabit rank setup with below envs #####') 62 for i, env in enumerate(envs): 63 logger.info(env) 64 envs[i] = str.encode(env) 65 66 xgb.rabit.init(envs) 67 logger.info('##### Rabit rank = %d' % xgb.rabit.get_rank()) 68 69 rank = xgb.rabit.get_rank() 70 s = None 71 if rank == 0: 72 s = {'hello world': 100, 2: 3} 73 74 logger.info('@node[%d] before-broadcast: s=\"%s\"' % (rank, str(s))) 75 s = xgb.rabit.broadcast(s, 0) 76 77 logger.info('@node[%d] after-broadcast: s=\"%s\"' % (rank, str(s))) 78 79 except Exception as e: 80 logger.error("something wrong happen: %s", traceback.format_exc()) 81 raise e 82 finally: 83 if world_size > 1: 84 xgb.rabit.finalize() 85 if rabit_tracker: 86 rabit_tracker.join() 87 88 logger.info("the rabit network testing finished!") 89 90 def main(): 91 92 port = os.environ.get("MASTER_PORT", "{}") 93 logging.info("MASTER_PORT: %s", port) 94 95 addr = os.environ.get("MASTER_ADDR", "{}") 96 logging.info("MASTER_ADDR: %s", addr) 97 98 world_size = os.environ.get("WORLD_SIZE", "{}") 99 logging.info("WORLD_SIZE: %s", world_size) 100 101 rank = os.environ.get("RANK", "{}") 102 logging.info("RANK: %s", rank) 103 104 setup_rabit_cluster() 105 106 if __name__ == "__main__": 107 logging.getLogger().setLevel(logging.INFO) 108 main()