github.com/kubeflow/training-operator@v1.7.0/examples/xgboost/lightgbm-dist/main.py (about) 1 # Licensed under the Apache License, Version 2.0 (the "License"); 2 # you may not use this file except in compliance with the License. 3 # You may obtain a copy of the License at 4 # 5 # http://www.apache.org/licenses/LICENSE-2.0 6 # 7 # Unless required by applicable law or agreed to in writing, software 8 # distributed under the License is distributed on an "AS IS" BASIS, 9 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 # See the License for the specific language governing permissions and 11 # limitations under the License. 12 13 import os 14 import logging 15 import argparse 16 17 from train import train 18 19 from utils import generate_machine_list_file, generate_train_conf_file 20 21 22 logger = logging.getLogger(__name__) 23 24 25 def main(args, extra_args): 26 27 master_addr = os.environ["MASTER_ADDR"] 28 master_port = os.environ["MASTER_PORT"] 29 worker_addrs = os.environ["WORKER_ADDRS"] 30 worker_port = os.environ["WORKER_PORT"] 31 world_size = int(os.environ["WORLD_SIZE"]) 32 rank = int(os.environ["RANK"]) 33 34 logger.info( 35 "extract cluster info from env variables \n" 36 f"master_addr: {master_addr} \n" 37 f"master_port: {master_port} \n" 38 f"worker_addrs: {worker_addrs} \n" 39 f"worker_port: {worker_port} \n" 40 f"world_size: {world_size} \n" 41 f"rank: {rank} \n" 42 ) 43 44 if args.job_type == "Predict": 45 logging.info("starting the predict job") 46 47 elif args.job_type == "Train": 48 logging.info("starting the train job") 49 logging.info(f"extra args:\n {extra_args}") 50 machine_list_filepath = generate_machine_list_file( 51 master_addr, master_port, worker_addrs, worker_port 52 ) 53 logging.info(f"machine list generated in: {machine_list_filepath}") 54 local_port = worker_port if rank else master_port 55 config_file = generate_train_conf_file( 56 machine_list_file=machine_list_filepath, 57 world_size=world_size, 58 output_model="model.txt", 59 local_port=local_port, 60 extra_args=extra_args, 61 ) 62 logging.info(f"config generated in: {config_file}") 63 train(config_file) 64 logging.info("Finish distributed job") 65 66 67 if __name__ == "__main__": 68 parser = argparse.ArgumentParser() 69 70 parser.add_argument( 71 "--job_type", 72 help="Job type to execute", 73 choices=["Train", "Predict"], 74 required=True, 75 ) 76 77 logging.basicConfig(format="%(message)s") 78 logging.getLogger().setLevel(logging.INFO) 79 args, extra_args = parser.parse_known_args() 80 main(args, extra_args)