github.com/alwaysproblem/mlserving-tutorial@v0.0.0-20221124033215-121cfddbfbf4/TFserving/ClientAPI/cpp/make/src/model-metadata/main.cc (about) 1 #include <grpcpp/grpcpp.h> 2 3 #include <boost/program_options.hpp> 4 #include <iostream> 5 #include <memory> 6 #include <string> 7 8 #include "google/protobuf/map.h" 9 #include "grpcpp/create_channel.h" 10 #include "grpcpp/security/credentials.h" 11 #include "tensorflow/core/framework/tensor.grpc.pb.h" 12 #include "tensorflow/core/framework/tensor_shape.grpc.pb.h" 13 #include "tensorflow/core/framework/types.grpc.pb.h" 14 #include "tensorflow_serving/apis/get_model_metadata.grpc.pb.h" 15 #include "tensorflow_serving/apis/predict.grpc.pb.h" 16 #include "tensorflow_serving/apis/prediction_service.grpc.pb.h" 17 18 using grpc::Channel; 19 using grpc::ClientContext; 20 using grpc::Status; 21 22 using tensorflow::TensorProto; 23 using tensorflow::TensorShapeProto; 24 using tensorflow::serving::GetModelMetadataRequest; 25 using tensorflow::serving::GetModelMetadataResponse; 26 using tensorflow::serving::PredictionService; 27 28 using namespace boost::program_options; 29 typedef google::protobuf::Map<std::string, google::protobuf::Any> MetadataMap; 30 31 /* 32 Application entry point 33 */ 34 int main(int argc, char** argv) { 35 std::string server_addr = "172.17.0.3:8500"; 36 std::string model_name = "Toy"; 37 int model_version = -1; 38 std::string model_version_label = ""; 39 const std::string model_signature_name = "serving_default"; 40 41 // for parse argument 42 variables_map vm; 43 44 // grpc context 45 ClientContext context; 46 unsigned int timout_in_sec = 5; 47 48 // get model metadata request & response 49 GetModelMetadataRequest request; 50 GetModelMetadataResponse response; 51 52 // parse arguments 53 options_description desc("Allowed options"); 54 desc.add_options() 55 // First parameter describes option name/short name 56 // The second is parameter to option 57 // The third is description 58 ("help,h", "print usage message")( 59 "server_addr,s", value(&server_addr)->default_value(server_addr), 60 "the destination address host:port")( 61 "model_name,m", value(&model_name)->default_value(model_name), 62 "the mode name for prediction")( 63 "model_version,v", 64 value<int>(&model_version)->default_value(model_version), 65 "the model version for prediction")( 66 "model_version_label,l", 67 value(&model_version_label)->default_value(model_version_label), 68 "the annotation name of model version for prediction"); 69 70 store(parse_command_line(argc, argv, desc), vm); 71 72 if (vm.count("help")) { 73 std::cout << desc << "\n"; 74 return 0; 75 } 76 77 // set grpc timeout 78 std::chrono::system_clock::time_point deadline = 79 std::chrono::system_clock::now() + std::chrono::seconds(timout_in_sec); 80 context.set_deadline(deadline); 81 82 server_addr = vm["server_addr"].as<std::string>(); 83 model_name = vm["model_name"].as<std::string>(); 84 model_version = vm["model_version"].as<int>(); 85 model_version_label = vm["model_version_label"].as<std::string>(); 86 87 // crate a channel 88 std::shared_ptr<Channel> channel = 89 grpc::CreateChannel(server_addr, grpc::InsecureChannelCredentials()); 90 std::unique_ptr<PredictionService::Stub> stub = 91 PredictionService::NewStub(channel); 92 93 request.mutable_model_spec()->set_name(model_name); 94 request.mutable_metadata_field()->Add("signature_def"); 95 request.mutable_model_spec()->set_signature_name(model_signature_name); 96 97 if (model_version > -1) { 98 request.mutable_model_spec()->mutable_version()->set_value(model_version); 99 } 100 101 if (model_version_label != "") { 102 request.mutable_model_spec()->set_version_label(model_version_label); 103 } 104 105 std::cout << "calling prediction service on " << server_addr << std::endl; 106 // Status status = stub->Predict(&context, request, &response); 107 Status status = stub->GetModelMetadata(&context, request, &response); 108 109 // Act upon its status. 110 if (status.ok()) { 111 const std::string output_label = "signature_def"; 112 113 std::cout << "call predict ok" << std::endl; 114 std::cout << "metadata size is " << response.metadata_size() << std::endl; 115 std::cout << "metadata DebugString is \n" 116 << response.DebugString() << std::endl; 117 118 } else { 119 std::cout << "gRPC call return code: " << status.error_code() << ": " 120 << status.error_message() << std::endl; 121 std::cout << "RPC failed" << std::endl; 122 } 123 124 return 0; 125 }