github.com/kubeflow/training-operator@v1.7.0/examples/xgboost/lightgbm-dist/utils.py (about)

     1  # Licensed under the Apache License, Version 2.0 (the "License");
     2  # you may not use this file except in compliance with the License.
     3  # You may obtain a copy of the License at
     4  #
     5  #     http://www.apache.org/licenses/LICENSE-2.0
     6  #
     7  # Unless required by applicable law or agreed to in writing, software
     8  # distributed under the License is distributed on an "AS IS" BASIS,
     9  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    10  # See the License for the specific language governing permissions and
    11  # limitations under the License.
    12  
    13  import re
    14  import socket
    15  import logging
    16  import tempfile
    17  from time import sleep
    18  from typing import List, Union
    19  
    20  logger = logging.getLogger(__name__)
    21  
    22  
    23  def generate_machine_list_file(
    24      master_addr: str, master_port: str, worker_addrs: str, worker_port: str
    25  ) -> str:
    26      logger.info("starting to extract system env")
    27  
    28      filename = tempfile.NamedTemporaryFile(delete=False).name
    29  
    30      def _get_ips(
    31          master_addr_name,
    32          worker_addr_names,
    33          max_retries=10,
    34          sleep_secs=10,
    35          current_retry=0,
    36      ):
    37          try:
    38              worker_addr_ips = []
    39              master_addr_ip = socket.gethostbyname(master_addr_name)
    40  
    41              for addr in worker_addr_names.split(","):
    42                  worker_addr_ips.append(socket.gethostbyname(addr))
    43  
    44          except socket.gaierror as ex:
    45              if "Name or service not known" in str(ex) and current_retry < max_retries:
    46                  sleep(sleep_secs)
    47                  master_addr_ip, worker_addr_ips = _get_ips(
    48                      master_addr_name,
    49                      worker_addr_names,
    50                      max_retries=max_retries,
    51                      sleep_secs=sleep_secs,
    52                      current_retry=current_retry + 1,
    53                  )
    54              else:
    55                  raise ValueError("Couldn't get address names")
    56  
    57          return master_addr_ip, worker_addr_ips
    58  
    59      master_ip, worker_ips = _get_ips(master_addr, worker_addrs)
    60  
    61      with open(filename, "w") as file:
    62          print(f"{master_ip} {master_port}", file=file)
    63          for addr in worker_ips:
    64              print(f"{addr} {worker_port}", file=file)
    65  
    66      return filename
    67  
    68  
    69  def generate_train_conf_file(
    70      machine_list_file: str,
    71      world_size: int,
    72      output_model: str,
    73      local_port: Union[int, str],
    74      extra_args: List[str],
    75  ) -> str:
    76  
    77      filename = tempfile.NamedTemporaryFile(delete=False).name
    78  
    79      with open(filename, "w") as file:
    80          print("task = train", file=file)
    81          print(f"output_model = {output_model}", file=file)
    82          print(f"num_machines = {world_size}", file=file)
    83          print(f"local_listen_port = {local_port}", file=file)
    84          print(f"machine_list_file = {machine_list_file}", file=file)
    85          for arg in extra_args:
    86              m = re.match(r"--(.+)=([^\s]+)", arg)
    87              if m is not None:
    88                  k, v = m.groups()
    89                  print(f"{k} = {v}", file=file)
    90  
    91      return filename