github.com/alwaysproblem/mlserving-tutorial@v0.0.0-20221124033215-121cfddbfbf4/TFserving/ClientAPI/cpp/cmake/src/predict-log/main.cc (about)

     1  #include <grpcpp/grpcpp.h>
     2  
     3  #include <boost/program_options.hpp>
     4  #include <iostream>
     5  #include <memory>
     6  #include <string>
     7  
     8  #include "google/protobuf/map.h"
     9  #include "grpcpp/create_channel.h"
    10  #include "grpcpp/security/credentials.h"
    11  #include "tensorflow/core/framework/tensor.grpc.pb.h"
    12  #include "tensorflow/core/framework/tensor_shape.grpc.pb.h"
    13  #include "tensorflow/core/framework/types.grpc.pb.h"
    14  #include "tensorflow_serving/apis/predict.grpc.pb.h"
    15  #include "tensorflow_serving/apis/prediction_log.grpc.pb.h"
    16  #include "tensorflow_serving/apis/prediction_service.grpc.pb.h"
    17  
    18  using grpc::Channel;
    19  using grpc::ClientContext;
    20  using grpc::Status;
    21  
    22  using tensorflow::TensorProto;
    23  using tensorflow::TensorShapeProto;
    24  using tensorflow::serving::PredictionService;
    25  using tensorflow::serving::PredictLog;
    26  using tensorflow::serving::PredictRequest;
    27  using tensorflow::serving::PredictResponse;
    28  
    29  using namespace boost::program_options;
    30  
    31  typedef google::protobuf::Map<std::string, tensorflow::TensorProto> OutMap;
    32  /*
    33  Application entry point
    34  */
    35  int main(int argc, char** argv) {
    36    std::string server_addr = "172.17.0.2:8500";
    37    std::string model_name = "Toy";
    38    int model_version = -1;
    39    std::string model_version_label = "";
    40    const std::string model_signature_name = "serving_default";
    41  
    42    // for parse argument
    43    variables_map vm;
    44  
    45    // grpc context
    46    ClientContext context;
    47    unsigned int timout_in_sec = 5;
    48  
    49    // predict request
    50    PredictRequest request;
    51    PredictResponse response;
    52  
    53    // predict log
    54    PredictLog logs;
    55  
    56    // input tensor
    57    tensorflow::TensorProto proto;
    58  
    59    // parse arguments
    60    options_description desc("Allowed options");
    61    desc.add_options()
    62        // First parameter describes option name/short name
    63        // The second is parameter to option
    64        // The third is description
    65        ("help,h", "print usage message")(
    66            "server_addr,s", value(&server_addr)->default_value(server_addr),
    67            "the destination address host:port")(
    68            "model_name,m", value(&model_name)->default_value(model_name),
    69            "the mode name for prediction")(
    70            "model_version,v",
    71            value<int>(&model_version)->default_value(model_version),
    72            "the model version for prediction")(
    73            "model_version_label,l",
    74            value(&model_version_label)->default_value(model_version_label),
    75            "the annotation name of model version for prediction");
    76  
    77    store(parse_command_line(argc, argv, desc), vm);
    78  
    79    if (vm.count("help")) {
    80      std::cout << desc << "\n";
    81      return 0;
    82    }
    83  
    84    // set grpc timeout
    85    std::chrono::system_clock::time_point deadline =
    86        std::chrono::system_clock::now() + std::chrono::seconds(timout_in_sec);
    87    context.set_deadline(deadline);
    88  
    89    server_addr = vm["server_addr"].as<std::string>();
    90    model_name = vm["model_name"].as<std::string>();
    91    model_version = vm["model_version"].as<int>();
    92    model_version_label = vm["model_version_label"].as<std::string>();
    93  
    94    // start a
    95    std::shared_ptr<Channel> channel =
    96        grpc::CreateChannel(server_addr, grpc::InsecureChannelCredentials());
    97    std::unique_ptr<PredictionService::Stub> stub =
    98        PredictionService::NewStub(channel);
    99  
   100    request.mutable_model_spec()->set_name(model_name);
   101    request.mutable_model_spec()->set_signature_name(model_signature_name);
   102  
   103    if (model_version > -1) {
   104      request.mutable_model_spec()->mutable_version()->set_value(model_version);
   105    }
   106  
   107    if (model_version_label != "") {
   108      request.mutable_model_spec()->set_version_label(model_version_label);
   109    }
   110  
   111    OutMap& inputs = *request.mutable_inputs();
   112  
   113    std::vector<float> data{
   114        1., 2., 1., 3., 1., 4.,
   115    };
   116  
   117    proto.set_dtype(tensorflow::DataType::DT_FLOAT);
   118  
   119    for (const float& e : data) {
   120      proto.add_float_val(e);
   121    }
   122  
   123    proto.mutable_tensor_shape()->add_dim()->set_size(3);
   124    proto.mutable_tensor_shape()->add_dim()->set_size(2);
   125  
   126    inputs["input_1"].CopyFrom(proto);
   127    // inputs["input_1"] = proto;
   128  
   129    std::cout << "calling prediction service on " << server_addr << std::endl;
   130    Status status = stub->Predict(&context, request, &response);
   131  
   132    // Act upon its status.
   133    if (status.ok()) {
   134      const std::string output_label = "output_1";
   135  
   136      std::cout << "call predict ok" << std::endl;
   137      std::cout << "outputs size is " << response.outputs_size() << std::endl;
   138  
   139      OutMap& map_outputs = *response.mutable_outputs();
   140  
   141      tensorflow::TensorProto& result_tensor_proto = map_outputs[output_label];
   142  
   143      std::cout << std::endl << output_label << ":" << std::endl;
   144  
   145      for (int titer = 0; titer != result_tensor_proto.float_val_size();
   146           ++titer) {
   147        std::cout << result_tensor_proto.float_val(titer) << "\n";
   148      }
   149  
   150      logs.mutable_request()->CopyFrom(request);
   151      logs.mutable_response()->CopyFrom(response);
   152  
   153      std::cout << "********************Predict Log*********************\n"
   154                << logs.DebugString()
   155                << "****************************************************"
   156                << std::endl;
   157  
   158      std::cout << "Done." << std::endl;
   159  
   160    } else {
   161      std::cout << "gRPC call return code: " << status.error_code() << ": "
   162                << status.error_message() << std::endl;
   163      std::cout << "RPC failed" << std::endl;
   164    }
   165  
   166    return 0;
   167  }