github.com/alwaysproblem/mlserving-tutorial@v0.0.0-20221124033215-121cfddbfbf4/TFserving/ClientAPI/cpp/cmake-static-lib/src/predict-service/main.cc (about)

     1  #include <grpcpp/grpcpp.h>
     2  
     3  #include <boost/program_options.hpp>
     4  #include <iostream>
     5  #include <memory>
     6  #include <string>
     7  
     8  #include "google/protobuf/map.h"
     9  #include "grpcpp/create_channel.h"
    10  #include "grpcpp/security/credentials.h"
    11  #include "tensorflow/core/example/example.pb.h"
    12  #include "tensorflow/core/framework/tensor.grpc.pb.h"
    13  #include "tensorflow/core/framework/tensor_shape.grpc.pb.h"
    14  #include "tensorflow/core/framework/types.grpc.pb.h"
    15  #include "tensorflow_serving/apis/predict.grpc.pb.h"
    16  #include "tensorflow_serving/apis/prediction_service.grpc.pb.h"
    17  
    18  using grpc::Channel;
    19  using grpc::ClientContext;
    20  using grpc::Status;
    21  
    22  using tensorflow::TensorProto;
    23  using tensorflow::TensorShapeProto;
    24  using tensorflow::serving::PredictionService;
    25  using tensorflow::serving::PredictRequest;
    26  using tensorflow::serving::PredictResponse;
    27  
    28  using namespace boost::program_options;
    29  
    30  typedef google::protobuf::Map<std::string, tensorflow::TensorProto> OutMap;
    31  /*
    32  Application entry point
    33  */
    34  int main(int argc, char** argv) {
    35    std::string server_addr = "172.17.0.2:8500";
    36    std::string model_name = "Toy";
    37    int model_version = -1;
    38    std::string model_version_label = "";
    39    const std::string model_signature_name = "serving_default";
    40  
    41    // for parse argument
    42    variables_map vm;
    43  
    44    // grpc context
    45    ClientContext context;
    46    unsigned int timout_in_sec = 5;
    47  
    48    // predict request
    49    PredictRequest request;
    50    PredictResponse response;
    51  
    52    // input tensor
    53    tensorflow::TensorProto proto;
    54  
    55    // parse arguments
    56    options_description desc("Allowed options");
    57    desc.add_options()
    58        // First parameter describes option name/short name
    59        // The second is parameter to option
    60        // The third is description
    61        ("help,h", "print usage message")(
    62            "server_addr,s", value(&server_addr)->default_value(server_addr),
    63            "the destination address host:port")(
    64            "model_name,m", value(&model_name)->default_value(model_name),
    65            "the mode name for prediction")(
    66            "model_version,v",
    67            value<int>(&model_version)->default_value(model_version),
    68            "the model version for prediction")(
    69            "model_version_label,l",
    70            value(&model_version_label)->default_value(model_version_label),
    71            "the annotation name of model version for prediction");
    72  
    73    store(parse_command_line(argc, argv, desc), vm);
    74  
    75    if (vm.count("help")) {
    76      std::cout << desc << "\n";
    77      return 0;
    78    }
    79  
    80    // set grpc timeout
    81    std::chrono::system_clock::time_point deadline =
    82        std::chrono::system_clock::now() + std::chrono::seconds(timout_in_sec);
    83    context.set_deadline(deadline);
    84  
    85    server_addr = vm["server_addr"].as<std::string>();
    86    model_name = vm["model_name"].as<std::string>();
    87    model_version = vm["model_version"].as<int>();
    88    model_version_label = vm["model_version_label"].as<std::string>();
    89  
    90    // start a
    91    std::shared_ptr<Channel> channel =
    92        grpc::CreateChannel(server_addr, grpc::InsecureChannelCredentials());
    93    std::unique_ptr<PredictionService::Stub> stub =
    94        PredictionService::NewStub(channel);
    95  
    96    request.mutable_model_spec()->set_name(model_name);
    97    request.mutable_model_spec()->set_signature_name(model_signature_name);
    98  
    99    if (model_version > -1) {
   100      request.mutable_model_spec()->mutable_version()->set_value(model_version);
   101    }
   102  
   103    if (model_version_label != "") {
   104      request.mutable_model_spec()->set_version_label(model_version_label);
   105    }
   106  
   107    OutMap& inputs = *request.mutable_inputs();
   108  
   109    std::vector<float> data{
   110        1., 2., 1., 3., 1., 4.,
   111    };
   112  
   113    proto.set_dtype(tensorflow::DataType::DT_FLOAT);
   114  
   115    for (const float& e : data) {
   116      proto.add_float_val(e);
   117    }
   118  
   119    proto.mutable_tensor_shape()->add_dim()->set_size(3);
   120    proto.mutable_tensor_shape()->add_dim()->set_size(2);
   121  
   122    inputs["input_1"].CopyFrom(proto);
   123    // inputs["input_1"] = proto;
   124  
   125    std::cout << "calling prediction service on " << server_addr << std::endl;
   126    Status status = stub->Predict(&context, request, &response);
   127  
   128    // Act upon its status.
   129    if (status.ok()) {
   130      const std::string output_label = "output_1";
   131  
   132      std::cout << "call predict ok" << std::endl;
   133      std::cout << "outputs size is " << response.outputs_size() << std::endl;
   134  
   135      OutMap& map_outputs = *response.mutable_outputs();
   136  
   137      tensorflow::TensorProto& result_tensor_proto = map_outputs[output_label];
   138  
   139      std::cout << std::endl << output_label << ":" << std::endl;
   140  
   141      for (int titer = 0; titer != result_tensor_proto.float_val_size();
   142           ++titer) {
   143        std::cout << result_tensor_proto.float_val(titer) << "\n";
   144      }
   145      std::cout << "Done." << std::endl;
   146    } else {
   147      std::cout << "gRPC call return code: " << status.error_code() << ": "
   148                << status.error_message() << std::endl;
   149      std::cout << "RPC failed" << std::endl;
   150    }
   151  
   152    return 0;
   153  }