github.com/alwaysproblem/mlserving-tutorial@v0.0.0-20221124033215-121cfddbfbf4/TFserving/ClientAPI/cpp/cmake-static-lib/src/model-reload/main.cc (about)

     1  #include <grpcpp/grpcpp.h>
     2  
     3  #include <boost/program_options.hpp>
     4  #include <iostream>
     5  #include <memory>
     6  #include <string>
     7  
     8  #include "google/protobuf/map.h"
     9  #include "grpcpp/create_channel.h"
    10  #include "grpcpp/security/credentials.h"
    11  #include "tensorflow/core/framework/tensor.grpc.pb.h"
    12  #include "tensorflow/core/framework/tensor_shape.grpc.pb.h"
    13  #include "tensorflow/core/framework/types.grpc.pb.h"
    14  #include "tensorflow_serving/apis/model_service.grpc.pb.h"
    15  
    16  using grpc::Channel;
    17  using grpc::ClientContext;
    18  using grpc::Status;
    19  
    20  using tensorflow::serving::ModelConfig;
    21  using tensorflow::serving::ModelConfigList;
    22  using tensorflow::serving::ModelServerConfig;
    23  using tensorflow::serving::ModelService;
    24  using tensorflow::serving::ReloadConfigRequest;
    25  using tensorflow::serving::ReloadConfigResponse;
    26  
    27  using namespace boost::program_options;
    28  
    29  typedef google::protobuf::RepeatedPtrField<tensorflow::serving::ModelConfig>
    30      RepeatModelConfig;
    31  /*
    32  Application entry point
    33  */
    34  int main(int argc, char** argv) {
    35    std::string server_addr = "172.17.0.2:8500";
    36    std::string model_name = "Toy";
    37    int model_version = -1;
    38    std::string model_version_label = "";
    39    const std::string model_signature_name = "serving_default";
    40  
    41    // for parse argument
    42    variables_map vm;
    43  
    44    // grpc context
    45    ClientContext context;
    46    unsigned int timout_in_sec = 5;
    47  
    48    // ReloadConfig request & response
    49    ReloadConfigRequest request;
    50    ReloadConfigResponse response;
    51    ModelServerConfig model_server_config;
    52  
    53    // string stream for formatting
    54    std::ostringstream formatter;
    55  
    56    // parse arguments
    57    options_description desc("Allowed options");
    58    desc.add_options()
    59        // First parameter describes option name/short name
    60        // The second is parameter to option
    61        // The third is description
    62        ("help,h", "print usage message")(
    63            "server_addr,s", value(&server_addr)->default_value(server_addr),
    64            "the destination address host:port")(
    65            "model_name,m", value(&model_name)->default_value(model_name),
    66            "the mode name for prediction")(
    67            "model_version,v",
    68            value<int>(&model_version)->default_value(model_version),
    69            "the model version for prediction")(
    70            "model_version_label,l",
    71            value(&model_version_label)->default_value(model_version_label),
    72            "the annotation name of model version for prediction");
    73  
    74    store(parse_command_line(argc, argv, desc), vm);
    75  
    76    if (vm.count("help")) {
    77      std::cout << desc << "\n";
    78      return 0;
    79    }
    80  
    81    // set grpc timeout
    82    std::chrono::system_clock::time_point deadline =
    83        std::chrono::system_clock::now() + std::chrono::seconds(timout_in_sec);
    84    context.set_deadline(deadline);
    85  
    86    server_addr = vm["server_addr"].as<std::string>();
    87    model_name = vm["model_name"].as<std::string>();
    88    model_version = vm["model_version"].as<int>();
    89    model_version_label = vm["model_version_label"].as<std::string>();
    90  
    91    // create a new channel and stub
    92    std::shared_ptr<Channel> channel =
    93        grpc::CreateChannel(server_addr, grpc::InsecureChannelCredentials());
    94    std::unique_ptr<ModelService::Stub> stub = ModelService::NewStub(channel);
    95  
    96    ModelConfigList& config_list =
    97        *model_server_config.mutable_model_config_list();
    98    ModelConfig& one_config = *config_list.add_config();
    99  
   100    formatter << "/models/save/" << model_name << "/";
   101  
   102    one_config.set_name(model_name);
   103    one_config.set_base_path(formatter.str());
   104    one_config.set_model_platform("tensorflow");
   105    // one_config.set_model_type();
   106  
   107    model_server_config.mutable_model_config_list()->CopyFrom(config_list);
   108    request.mutable_config()->CopyFrom(model_server_config);
   109  
   110    std::cout << "calling model service on " << server_addr << std::endl;
   111    Status status = stub->HandleReloadConfigRequest(&context, request, &response);
   112  
   113    // Act upon its status.
   114    if (status.ok()) {
   115      std::cout << "call model service ok" << std::endl;
   116      if (0 == response.status().error_code()) {
   117        std::cout << "model " << model_name << " reloaded successfully."
   118                  << std::endl;
   119      } else {
   120        std::cout << "model " << model_name << " reloaded failed!\n"
   121                  << "error code is " << response.status().error_code() << "\n"
   122                  << response.status().error_message() << std::endl;
   123      }
   124  
   125    } else {
   126      std::cout << "gRPC call return code: " << status.error_code() << ": "
   127                << status.error_message() << std::endl;
   128      std::cout << "RPC failed" << std::endl;
   129    }
   130  
   131    return 0;
   132  }