github.com/kubeflow/training-operator@v1.7.0/examples/xgboost/xgboost-dist/main.py (about)

     1  # Licensed under the Apache License, Version 2.0 (the "License");
     2  # you may not use this file except in compliance with the License.
     3  # You may obtain a copy of the License at
     4  #
     5  #     http://www.apache.org/licenses/LICENSE-2.0
     6  #
     7  # Unless required by applicable law or agreed to in writing, software
     8  # distributed under the License is distributed on an "AS IS" BASIS,
     9  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    10  # See the License for the specific language governing permissions and
    11  # limitations under the License.
    12  
    13  import argparse
    14  import logging
    15  
    16  from train import train
    17  from predict import predict
    18  from utils import dump_model
    19  
    20  
    21  def main(args):
    22  
    23      model_storage_type = args.model_storage_type
    24      if (model_storage_type == "local" or model_storage_type == "oss"):
    25        print ( "The storage type is " + model_storage_type)
    26      else:
    27        raise Exception("Only supports storage types like local and OSS")
    28  
    29      if args.job_type == "Predict":
    30          logging.info("starting the predict job")
    31          predict(args)
    32  
    33      elif args.job_type == "Train":
    34          logging.info("starting the train job")
    35          model = train(args)
    36  
    37          if model is not None:
    38              logging.info("finish the model training, and start to dump model ")
    39              model_path = args.model_path
    40              dump_model(model, model_storage_type, model_path, args)
    41  
    42      elif args.job_type == "All":
    43          logging.info("starting the train and predict job")
    44  
    45      logging.info("Finish distributed XGBoost job")
    46  
    47  
    48  if __name__ == '__main__':
    49      parser = argparse.ArgumentParser()
    50  
    51      parser.add_argument(
    52             '--job_type',
    53             help="Train, Predict, All",
    54             required=True
    55             )
    56      parser.add_argument(
    57             '--xgboost_parameter',
    58             help='XGBoost model parameter like: objective, number_class',
    59            )
    60      parser.add_argument(
    61            '--n_estimators',
    62            help='Number of trees in the model',
    63            type=int,
    64            default=1000
    65            )
    66      parser.add_argument(
    67             '--learning_rate',
    68             help='Learning rate for the model',
    69             default=0.1
    70            )
    71      parser.add_argument(
    72            '--early_stopping_rounds',
    73            help='XGBoost argument for stopping early',
    74            default=50
    75            )
    76      parser.add_argument(
    77            '--model_path',
    78            help='place to store model',
    79            default="/tmp/xgboost_model"
    80            )
    81      parser.add_argument(
    82            '--model_storage_type',
    83            help='place to store the model',
    84            default="oss"
    85            )
    86      parser.add_argument(
    87            '--oss_param',
    88            help='oss parameter if you choose the model storage as OSS type',
    89            )
    90  
    91      logging.basicConfig(format='%(message)s')
    92      logging.getLogger().setLevel(logging.INFO)
    93      main_args = parser.parse_args()
    94      main(main_args)