github.com/kubeflow/training-operator@v1.7.0/examples/xgboost/xgboost-dist/main.py (about) 1 # Licensed under the Apache License, Version 2.0 (the "License"); 2 # you may not use this file except in compliance with the License. 3 # You may obtain a copy of the License at 4 # 5 # http://www.apache.org/licenses/LICENSE-2.0 6 # 7 # Unless required by applicable law or agreed to in writing, software 8 # distributed under the License is distributed on an "AS IS" BASIS, 9 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 # See the License for the specific language governing permissions and 11 # limitations under the License. 12 13 import argparse 14 import logging 15 16 from train import train 17 from predict import predict 18 from utils import dump_model 19 20 21 def main(args): 22 23 model_storage_type = args.model_storage_type 24 if (model_storage_type == "local" or model_storage_type == "oss"): 25 print ( "The storage type is " + model_storage_type) 26 else: 27 raise Exception("Only supports storage types like local and OSS") 28 29 if args.job_type == "Predict": 30 logging.info("starting the predict job") 31 predict(args) 32 33 elif args.job_type == "Train": 34 logging.info("starting the train job") 35 model = train(args) 36 37 if model is not None: 38 logging.info("finish the model training, and start to dump model ") 39 model_path = args.model_path 40 dump_model(model, model_storage_type, model_path, args) 41 42 elif args.job_type == "All": 43 logging.info("starting the train and predict job") 44 45 logging.info("Finish distributed XGBoost job") 46 47 48 if __name__ == '__main__': 49 parser = argparse.ArgumentParser() 50 51 parser.add_argument( 52 '--job_type', 53 help="Train, Predict, All", 54 required=True 55 ) 56 parser.add_argument( 57 '--xgboost_parameter', 58 help='XGBoost model parameter like: objective, number_class', 59 ) 60 parser.add_argument( 61 '--n_estimators', 62 help='Number of trees in the model', 63 type=int, 64 default=1000 65 ) 66 parser.add_argument( 67 '--learning_rate', 68 help='Learning rate for the model', 69 default=0.1 70 ) 71 parser.add_argument( 72 '--early_stopping_rounds', 73 help='XGBoost argument for stopping early', 74 default=50 75 ) 76 parser.add_argument( 77 '--model_path', 78 help='place to store model', 79 default="/tmp/xgboost_model" 80 ) 81 parser.add_argument( 82 '--model_storage_type', 83 help='place to store the model', 84 default="oss" 85 ) 86 parser.add_argument( 87 '--oss_param', 88 help='oss parameter if you choose the model storage as OSS type', 89 ) 90 91 logging.basicConfig(format='%(message)s') 92 logging.getLogger().setLevel(logging.INFO) 93 main_args = parser.parse_args() 94 main(main_args)