github.com/alwaysproblem/mlserving-tutorial@v0.0.0-20221124033215-121cfddbfbf4/TFserving/ClientAPI/cpp/cmake-static-lib/src/predict-service/main.cc (about) 1 #include <grpcpp/grpcpp.h> 2 3 #include <boost/program_options.hpp> 4 #include <iostream> 5 #include <memory> 6 #include <string> 7 8 #include "google/protobuf/map.h" 9 #include "grpcpp/create_channel.h" 10 #include "grpcpp/security/credentials.h" 11 #include "tensorflow/core/example/example.pb.h" 12 #include "tensorflow/core/framework/tensor.grpc.pb.h" 13 #include "tensorflow/core/framework/tensor_shape.grpc.pb.h" 14 #include "tensorflow/core/framework/types.grpc.pb.h" 15 #include "tensorflow_serving/apis/predict.grpc.pb.h" 16 #include "tensorflow_serving/apis/prediction_service.grpc.pb.h" 17 18 using grpc::Channel; 19 using grpc::ClientContext; 20 using grpc::Status; 21 22 using tensorflow::TensorProto; 23 using tensorflow::TensorShapeProto; 24 using tensorflow::serving::PredictionService; 25 using tensorflow::serving::PredictRequest; 26 using tensorflow::serving::PredictResponse; 27 28 using namespace boost::program_options; 29 30 typedef google::protobuf::Map<std::string, tensorflow::TensorProto> OutMap; 31 /* 32 Application entry point 33 */ 34 int main(int argc, char** argv) { 35 std::string server_addr = "172.17.0.2:8500"; 36 std::string model_name = "Toy"; 37 int model_version = -1; 38 std::string model_version_label = ""; 39 const std::string model_signature_name = "serving_default"; 40 41 // for parse argument 42 variables_map vm; 43 44 // grpc context 45 ClientContext context; 46 unsigned int timout_in_sec = 5; 47 48 // predict request 49 PredictRequest request; 50 PredictResponse response; 51 52 // input tensor 53 tensorflow::TensorProto proto; 54 55 // parse arguments 56 options_description desc("Allowed options"); 57 desc.add_options() 58 // First parameter describes option name/short name 59 // The second is parameter to option 60 // The third is description 61 ("help,h", "print usage message")( 62 "server_addr,s", value(&server_addr)->default_value(server_addr), 63 "the destination address host:port")( 64 "model_name,m", value(&model_name)->default_value(model_name), 65 "the mode name for prediction")( 66 "model_version,v", 67 value<int>(&model_version)->default_value(model_version), 68 "the model version for prediction")( 69 "model_version_label,l", 70 value(&model_version_label)->default_value(model_version_label), 71 "the annotation name of model version for prediction"); 72 73 store(parse_command_line(argc, argv, desc), vm); 74 75 if (vm.count("help")) { 76 std::cout << desc << "\n"; 77 return 0; 78 } 79 80 // set grpc timeout 81 std::chrono::system_clock::time_point deadline = 82 std::chrono::system_clock::now() + std::chrono::seconds(timout_in_sec); 83 context.set_deadline(deadline); 84 85 server_addr = vm["server_addr"].as<std::string>(); 86 model_name = vm["model_name"].as<std::string>(); 87 model_version = vm["model_version"].as<int>(); 88 model_version_label = vm["model_version_label"].as<std::string>(); 89 90 // start a 91 std::shared_ptr<Channel> channel = 92 grpc::CreateChannel(server_addr, grpc::InsecureChannelCredentials()); 93 std::unique_ptr<PredictionService::Stub> stub = 94 PredictionService::NewStub(channel); 95 96 request.mutable_model_spec()->set_name(model_name); 97 request.mutable_model_spec()->set_signature_name(model_signature_name); 98 99 if (model_version > -1) { 100 request.mutable_model_spec()->mutable_version()->set_value(model_version); 101 } 102 103 if (model_version_label != "") { 104 request.mutable_model_spec()->set_version_label(model_version_label); 105 } 106 107 OutMap& inputs = *request.mutable_inputs(); 108 109 std::vector<float> data{ 110 1., 2., 1., 3., 1., 4., 111 }; 112 113 proto.set_dtype(tensorflow::DataType::DT_FLOAT); 114 115 for (const float& e : data) { 116 proto.add_float_val(e); 117 } 118 119 proto.mutable_tensor_shape()->add_dim()->set_size(3); 120 proto.mutable_tensor_shape()->add_dim()->set_size(2); 121 122 inputs["input_1"].CopyFrom(proto); 123 // inputs["input_1"] = proto; 124 125 std::cout << "calling prediction service on " << server_addr << std::endl; 126 Status status = stub->Predict(&context, request, &response); 127 128 // Act upon its status. 129 if (status.ok()) { 130 const std::string output_label = "output_1"; 131 132 std::cout << "call predict ok" << std::endl; 133 std::cout << "outputs size is " << response.outputs_size() << std::endl; 134 135 OutMap& map_outputs = *response.mutable_outputs(); 136 137 tensorflow::TensorProto& result_tensor_proto = map_outputs[output_label]; 138 139 std::cout << std::endl << output_label << ":" << std::endl; 140 141 for (int titer = 0; titer != result_tensor_proto.float_val_size(); 142 ++titer) { 143 std::cout << result_tensor_proto.float_val(titer) << "\n"; 144 } 145 std::cout << "Done." << std::endl; 146 } else { 147 std::cout << "gRPC call return code: " << status.error_code() << ": " 148 << status.error_message() << std::endl; 149 std::cout << "RPC failed" << std::endl; 150 } 151 152 return 0; 153 }