github.com/alwaysproblem/mlserving-tutorial@v0.0.0-20221124033215-121cfddbfbf4/TFserving/ClientAPI/cpp/make/src/model-metadata/main.cc (about)

     1  #include <grpcpp/grpcpp.h>
     2  
     3  #include <boost/program_options.hpp>
     4  #include <iostream>
     5  #include <memory>
     6  #include <string>
     7  
     8  #include "google/protobuf/map.h"
     9  #include "grpcpp/create_channel.h"
    10  #include "grpcpp/security/credentials.h"
    11  #include "tensorflow/core/framework/tensor.grpc.pb.h"
    12  #include "tensorflow/core/framework/tensor_shape.grpc.pb.h"
    13  #include "tensorflow/core/framework/types.grpc.pb.h"
    14  #include "tensorflow_serving/apis/get_model_metadata.grpc.pb.h"
    15  #include "tensorflow_serving/apis/predict.grpc.pb.h"
    16  #include "tensorflow_serving/apis/prediction_service.grpc.pb.h"
    17  
    18  using grpc::Channel;
    19  using grpc::ClientContext;
    20  using grpc::Status;
    21  
    22  using tensorflow::TensorProto;
    23  using tensorflow::TensorShapeProto;
    24  using tensorflow::serving::GetModelMetadataRequest;
    25  using tensorflow::serving::GetModelMetadataResponse;
    26  using tensorflow::serving::PredictionService;
    27  
    28  using namespace boost::program_options;
    29  typedef google::protobuf::Map<std::string, google::protobuf::Any> MetadataMap;
    30  
    31  /*
    32  Application entry point
    33  */
    34  int main(int argc, char** argv) {
    35    std::string server_addr = "172.17.0.3:8500";
    36    std::string model_name = "Toy";
    37    int model_version = -1;
    38    std::string model_version_label = "";
    39    const std::string model_signature_name = "serving_default";
    40  
    41    // for parse argument
    42    variables_map vm;
    43  
    44    // grpc context
    45    ClientContext context;
    46    unsigned int timout_in_sec = 5;
    47  
    48    // get model metadata request & response
    49    GetModelMetadataRequest request;
    50    GetModelMetadataResponse response;
    51  
    52    // parse arguments
    53    options_description desc("Allowed options");
    54    desc.add_options()
    55        // First parameter describes option name/short name
    56        // The second is parameter to option
    57        // The third is description
    58        ("help,h", "print usage message")(
    59            "server_addr,s", value(&server_addr)->default_value(server_addr),
    60            "the destination address host:port")(
    61            "model_name,m", value(&model_name)->default_value(model_name),
    62            "the mode name for prediction")(
    63            "model_version,v",
    64            value<int>(&model_version)->default_value(model_version),
    65            "the model version for prediction")(
    66            "model_version_label,l",
    67            value(&model_version_label)->default_value(model_version_label),
    68            "the annotation name of model version for prediction");
    69  
    70    store(parse_command_line(argc, argv, desc), vm);
    71  
    72    if (vm.count("help")) {
    73      std::cout << desc << "\n";
    74      return 0;
    75    }
    76  
    77    // set grpc timeout
    78    std::chrono::system_clock::time_point deadline =
    79        std::chrono::system_clock::now() + std::chrono::seconds(timout_in_sec);
    80    context.set_deadline(deadline);
    81  
    82    server_addr = vm["server_addr"].as<std::string>();
    83    model_name = vm["model_name"].as<std::string>();
    84    model_version = vm["model_version"].as<int>();
    85    model_version_label = vm["model_version_label"].as<std::string>();
    86  
    87    // crate a channel
    88    std::shared_ptr<Channel> channel =
    89        grpc::CreateChannel(server_addr, grpc::InsecureChannelCredentials());
    90    std::unique_ptr<PredictionService::Stub> stub =
    91        PredictionService::NewStub(channel);
    92  
    93    request.mutable_model_spec()->set_name(model_name);
    94    request.mutable_metadata_field()->Add("signature_def");
    95    request.mutable_model_spec()->set_signature_name(model_signature_name);
    96  
    97    if (model_version > -1) {
    98      request.mutable_model_spec()->mutable_version()->set_value(model_version);
    99    }
   100  
   101    if (model_version_label != "") {
   102      request.mutable_model_spec()->set_version_label(model_version_label);
   103    }
   104  
   105    std::cout << "calling prediction service on " << server_addr << std::endl;
   106    // Status status = stub->Predict(&context, request, &response);
   107    Status status = stub->GetModelMetadata(&context, request, &response);
   108  
   109    // Act upon its status.
   110    if (status.ok()) {
   111      const std::string output_label = "signature_def";
   112  
   113      std::cout << "call predict ok" << std::endl;
   114      std::cout << "metadata size is " << response.metadata_size() << std::endl;
   115      std::cout << "metadata DebugString is \n"
   116                << response.DebugString() << std::endl;
   117  
   118    } else {
   119      std::cout << "gRPC call return code: " << status.error_code() << ": "
   120                << status.error_message() << std::endl;
   121      std::cout << "RPC failed" << std::endl;
   122    }
   123  
   124    return 0;
   125  }