github.com/kubeflow/training-operator@v1.7.0/examples/xgboost/lightgbm-dist/main.py (about)

     1  # Licensed under the Apache License, Version 2.0 (the "License");
     2  # you may not use this file except in compliance with the License.
     3  # You may obtain a copy of the License at
     4  #
     5  #     http://www.apache.org/licenses/LICENSE-2.0
     6  #
     7  # Unless required by applicable law or agreed to in writing, software
     8  # distributed under the License is distributed on an "AS IS" BASIS,
     9  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    10  # See the License for the specific language governing permissions and
    11  # limitations under the License.
    12  
    13  import os
    14  import logging
    15  import argparse
    16  
    17  from train import train
    18  
    19  from utils import generate_machine_list_file, generate_train_conf_file
    20  
    21  
    22  logger = logging.getLogger(__name__)
    23  
    24  
    25  def main(args, extra_args):
    26  
    27      master_addr = os.environ["MASTER_ADDR"]
    28      master_port = os.environ["MASTER_PORT"]
    29      worker_addrs = os.environ["WORKER_ADDRS"]
    30      worker_port = os.environ["WORKER_PORT"]
    31      world_size = int(os.environ["WORLD_SIZE"])
    32      rank = int(os.environ["RANK"])
    33  
    34      logger.info(
    35          "extract cluster info from env variables \n"
    36          f"master_addr: {master_addr} \n"
    37          f"master_port: {master_port} \n"
    38          f"worker_addrs: {worker_addrs} \n"
    39          f"worker_port: {worker_port} \n"
    40          f"world_size: {world_size} \n"
    41          f"rank: {rank} \n"
    42      )
    43  
    44      if args.job_type == "Predict":
    45          logging.info("starting the predict job")
    46  
    47      elif args.job_type == "Train":
    48          logging.info("starting the train job")
    49          logging.info(f"extra args:\n {extra_args}")
    50          machine_list_filepath = generate_machine_list_file(
    51              master_addr, master_port, worker_addrs, worker_port
    52          )
    53          logging.info(f"machine list generated in: {machine_list_filepath}")
    54          local_port = worker_port if rank else master_port
    55          config_file = generate_train_conf_file(
    56              machine_list_file=machine_list_filepath,
    57              world_size=world_size,
    58              output_model="model.txt",
    59              local_port=local_port,
    60              extra_args=extra_args,
    61          )
    62          logging.info(f"config generated in: {config_file}")
    63          train(config_file)
    64      logging.info("Finish distributed job")
    65  
    66  
    67  if __name__ == "__main__":
    68      parser = argparse.ArgumentParser()
    69  
    70      parser.add_argument(
    71          "--job_type",
    72          help="Job type to execute",
    73          choices=["Train", "Predict"],
    74          required=True,
    75      )
    76  
    77      logging.basicConfig(format="%(message)s")
    78      logging.getLogger().setLevel(logging.INFO)
    79      args, extra_args = parser.parse_known_args()
    80      main(args, extra_args)