github.com/wlattner/mlserver@v0.0.0-20141113171038-895f261d2bfd/predict.py (about) 1 import zmq 2 import signal 3 import sys 4 from sklearn.externals import joblib 5 6 def predict(model, X): 7 predictions = [] 8 labels = [str(label) for label in model.steps[-1][-1].classes_] 9 for prediction in model.predict_proba(X): 10 predictions.append({labels[lab]: prob for lab, prob in enumerate(prediction)}) 11 return predictions 12 13 def load(path): 14 return joblib.load(path) 15 16 def run(model_path, socket_path): 17 model = load(model_path) 18 19 context = zmq.Context() 20 socket = context.socket(zmq.REP) 21 socket.bind(socket_path) 22 23 try: 24 while True: 25 message = socket.recv_json() 26 predictions = predict(model, message['data']) 27 socket.send_json(predictions) 28 finally: 29 context.destroy() 30 31 def exit_on_sigint(_sig, _stack_frame): 32 sys.exit(0) 33 34 if __name__ == "__main__": 35 signal.signal(signal.SIGINT, exit_on_sigint) 36 37 socket_path = sys.argv[1] 38 model_path = sys.argv[2] 39 40 run(model_path, socket_path)