github.com/alwaysproblem/mlserving-tutorial@v0.0.0-20221124033215-121cfddbfbf4/TFserving/ClientAPI/cpp/cmake/src/model-status/main.cc (about)

     1  #include <grpcpp/grpcpp.h>
     2  
     3  #include <boost/program_options.hpp>
     4  #include <iostream>
     5  #include <memory>
     6  #include <string>
     7  
     8  #include "google/protobuf/map.h"
     9  #include "grpcpp/create_channel.h"
    10  #include "grpcpp/security/credentials.h"
    11  #include "tensorflow/core/framework/tensor.grpc.pb.h"
    12  #include "tensorflow/core/framework/tensor_shape.grpc.pb.h"
    13  #include "tensorflow/core/framework/types.grpc.pb.h"
    14  #include "tensorflow_serving/apis/model_service.grpc.pb.h"
    15  
    16  using grpc::Channel;
    17  using grpc::ClientContext;
    18  using grpc::Status;
    19  
    20  using tensorflow::TensorProto;
    21  using tensorflow::TensorShapeProto;
    22  using tensorflow::serving::GetModelStatusRequest;
    23  using tensorflow::serving::GetModelStatusResponse;
    24  using tensorflow::serving::ModelService;
    25  
    26  using namespace boost::program_options;
    27  
    28  /*
    29  Application entry point
    30  */
    31  int main(int argc, char** argv) {
    32    std::string server_addr = "172.17.0.2:8500";
    33    std::string model_name = "Toy";
    34    int model_version = -1;
    35    std::string model_version_label = "";
    36    const std::string model_signature_name = "serving_default";
    37  
    38    // for parse argument
    39    variables_map vm;
    40  
    41    // grpc context
    42    ClientContext context;
    43    unsigned int timout_in_sec = 5;
    44  
    45    // GetModelStatus request & response
    46    GetModelStatusRequest request;
    47    GetModelStatusResponse response;
    48  
    49    // parse arguments
    50    options_description desc("Allowed options");
    51    desc.add_options()
    52        // First parameter describes option name/short name
    53        // The second is parameter to option
    54        // The third is description
    55        ("help,h", "print usage message")(
    56            "server_addr,s", value(&server_addr)->default_value(server_addr),
    57            "the destination address host:port")(
    58            "model_name,m", value(&model_name)->default_value(model_name),
    59            "the mode name for prediction")(
    60            "model_version,v",
    61            value<int>(&model_version)->default_value(model_version),
    62            "the model version for prediction")(
    63            "model_version_label,l",
    64            value(&model_version_label)->default_value(model_version_label),
    65            "the annotation name of model version for prediction");
    66  
    67    store(parse_command_line(argc, argv, desc), vm);
    68  
    69    if (vm.count("help")) {
    70      std::cout << desc << "\n";
    71      return 0;
    72    }
    73  
    74    // set grpc timeout
    75    std::chrono::system_clock::time_point deadline =
    76        std::chrono::system_clock::now() + std::chrono::seconds(timout_in_sec);
    77    context.set_deadline(deadline);
    78  
    79    server_addr = vm["server_addr"].as<std::string>();
    80    model_name = vm["model_name"].as<std::string>();
    81    model_version = vm["model_version"].as<int>();
    82    model_version_label = vm["model_version_label"].as<std::string>();
    83  
    84    // create a new channel and stub
    85    std::shared_ptr<Channel> channel =
    86        grpc::CreateChannel(server_addr, grpc::InsecureChannelCredentials());
    87    std::unique_ptr<ModelService::Stub> stub = ModelService::NewStub(channel);
    88  
    89    request.mutable_model_spec()->set_name(model_name);
    90    request.mutable_model_spec()->set_signature_name(model_signature_name);
    91  
    92    if (model_version > -1) {
    93      request.mutable_model_spec()->mutable_version()->set_value(model_version);
    94    }
    95  
    96    if (model_version_label != "") {
    97      request.mutable_model_spec()->set_version_label(model_version_label);
    98    }
    99  
   100    std::cout << "calling model service on " << server_addr << std::endl;
   101    Status status = stub->GetModelStatus(&context, request, &response);
   102  
   103    // std::cout << request.DebugString() << std::endl;
   104  
   105    // Act upon its status.
   106    if (status.ok()) {
   107      const std::string output_label = "signature_def";
   108  
   109      std::cout << "call predict ok" << std::endl;
   110      std::cout << "metadata size is " << response.GetCachedSize() << std::endl;
   111      std::cout << "metadata DebugString is \n"
   112                << response.DebugString() << std::endl;
   113  
   114    } else {
   115      std::cout << "gRPC call return code: " << status.error_code() << ": "
   116                << status.error_message() << std::endl;
   117      std::cout << "RPC failed" << std::endl;
   118    }
   119  
   120    return 0;
   121  }