github.com/alwaysproblem/mlserving-tutorial@v0.0.0-20221124033215-121cfddbfbf4/TFserving/ClientAPI/cpp/cmake/src/predict-log/main.cc (about) 1 #include <grpcpp/grpcpp.h> 2 3 #include <boost/program_options.hpp> 4 #include <iostream> 5 #include <memory> 6 #include <string> 7 8 #include "google/protobuf/map.h" 9 #include "grpcpp/create_channel.h" 10 #include "grpcpp/security/credentials.h" 11 #include "tensorflow/core/framework/tensor.grpc.pb.h" 12 #include "tensorflow/core/framework/tensor_shape.grpc.pb.h" 13 #include "tensorflow/core/framework/types.grpc.pb.h" 14 #include "tensorflow_serving/apis/predict.grpc.pb.h" 15 #include "tensorflow_serving/apis/prediction_log.grpc.pb.h" 16 #include "tensorflow_serving/apis/prediction_service.grpc.pb.h" 17 18 using grpc::Channel; 19 using grpc::ClientContext; 20 using grpc::Status; 21 22 using tensorflow::TensorProto; 23 using tensorflow::TensorShapeProto; 24 using tensorflow::serving::PredictionService; 25 using tensorflow::serving::PredictLog; 26 using tensorflow::serving::PredictRequest; 27 using tensorflow::serving::PredictResponse; 28 29 using namespace boost::program_options; 30 31 typedef google::protobuf::Map<std::string, tensorflow::TensorProto> OutMap; 32 /* 33 Application entry point 34 */ 35 int main(int argc, char** argv) { 36 std::string server_addr = "172.17.0.2:8500"; 37 std::string model_name = "Toy"; 38 int model_version = -1; 39 std::string model_version_label = ""; 40 const std::string model_signature_name = "serving_default"; 41 42 // for parse argument 43 variables_map vm; 44 45 // grpc context 46 ClientContext context; 47 unsigned int timout_in_sec = 5; 48 49 // predict request 50 PredictRequest request; 51 PredictResponse response; 52 53 // predict log 54 PredictLog logs; 55 56 // input tensor 57 tensorflow::TensorProto proto; 58 59 // parse arguments 60 options_description desc("Allowed options"); 61 desc.add_options() 62 // First parameter describes option name/short name 63 // The second is parameter to option 64 // The third is description 65 ("help,h", "print usage message")( 66 "server_addr,s", value(&server_addr)->default_value(server_addr), 67 "the destination address host:port")( 68 "model_name,m", value(&model_name)->default_value(model_name), 69 "the mode name for prediction")( 70 "model_version,v", 71 value<int>(&model_version)->default_value(model_version), 72 "the model version for prediction")( 73 "model_version_label,l", 74 value(&model_version_label)->default_value(model_version_label), 75 "the annotation name of model version for prediction"); 76 77 store(parse_command_line(argc, argv, desc), vm); 78 79 if (vm.count("help")) { 80 std::cout << desc << "\n"; 81 return 0; 82 } 83 84 // set grpc timeout 85 std::chrono::system_clock::time_point deadline = 86 std::chrono::system_clock::now() + std::chrono::seconds(timout_in_sec); 87 context.set_deadline(deadline); 88 89 server_addr = vm["server_addr"].as<std::string>(); 90 model_name = vm["model_name"].as<std::string>(); 91 model_version = vm["model_version"].as<int>(); 92 model_version_label = vm["model_version_label"].as<std::string>(); 93 94 // start a 95 std::shared_ptr<Channel> channel = 96 grpc::CreateChannel(server_addr, grpc::InsecureChannelCredentials()); 97 std::unique_ptr<PredictionService::Stub> stub = 98 PredictionService::NewStub(channel); 99 100 request.mutable_model_spec()->set_name(model_name); 101 request.mutable_model_spec()->set_signature_name(model_signature_name); 102 103 if (model_version > -1) { 104 request.mutable_model_spec()->mutable_version()->set_value(model_version); 105 } 106 107 if (model_version_label != "") { 108 request.mutable_model_spec()->set_version_label(model_version_label); 109 } 110 111 OutMap& inputs = *request.mutable_inputs(); 112 113 std::vector<float> data{ 114 1., 2., 1., 3., 1., 4., 115 }; 116 117 proto.set_dtype(tensorflow::DataType::DT_FLOAT); 118 119 for (const float& e : data) { 120 proto.add_float_val(e); 121 } 122 123 proto.mutable_tensor_shape()->add_dim()->set_size(3); 124 proto.mutable_tensor_shape()->add_dim()->set_size(2); 125 126 inputs["input_1"].CopyFrom(proto); 127 // inputs["input_1"] = proto; 128 129 std::cout << "calling prediction service on " << server_addr << std::endl; 130 Status status = stub->Predict(&context, request, &response); 131 132 // Act upon its status. 133 if (status.ok()) { 134 const std::string output_label = "output_1"; 135 136 std::cout << "call predict ok" << std::endl; 137 std::cout << "outputs size is " << response.outputs_size() << std::endl; 138 139 OutMap& map_outputs = *response.mutable_outputs(); 140 141 tensorflow::TensorProto& result_tensor_proto = map_outputs[output_label]; 142 143 std::cout << std::endl << output_label << ":" << std::endl; 144 145 for (int titer = 0; titer != result_tensor_proto.float_val_size(); 146 ++titer) { 147 std::cout << result_tensor_proto.float_val(titer) << "\n"; 148 } 149 150 logs.mutable_request()->CopyFrom(request); 151 logs.mutable_response()->CopyFrom(response); 152 153 std::cout << "********************Predict Log*********************\n" 154 << logs.DebugString() 155 << "****************************************************" 156 << std::endl; 157 158 std::cout << "Done." << std::endl; 159 160 } else { 161 std::cout << "gRPC call return code: " << status.error_code() << ": " 162 << status.error_message() << std::endl; 163 std::cout << "RPC failed" << std::endl; 164 } 165 166 return 0; 167 }