github.com/wlattner/mlserver@v0.0.0-20141113171038-895f261d2bfd/predict_py.go (about) 1 package main 2 3 var predictPy = ` 4 import zmq 5 import signal 6 import sys 7 from sklearn.externals import joblib 8 9 def predict(model, X): 10 predictions = [] 11 labels = [str(label) for label in model.steps[-1][-1].classes_] 12 for prediction in model.predict_proba(X): 13 predictions.append({labels[lab]: prob for lab, prob in enumerate(prediction)}) 14 return predictions 15 16 def load(path): 17 return joblib.load(path) 18 19 def run(model_path, socket_path): 20 model = load(model_path) 21 22 context = zmq.Context() 23 socket = context.socket(zmq.REP) 24 socket.bind(socket_path) 25 26 try: 27 while True: 28 message = socket.recv_json() 29 predictions = predict(model, message['data']) 30 socket.send_json(predictions) 31 finally: 32 context.destroy() 33 34 def exit_on_sigint(_sig, _stack_frame): 35 sys.exit(0) 36 37 if __name__ == "__main__": 38 signal.signal(signal.SIGINT, exit_on_sigint) 39 40 socket_path = sys.argv[1] 41 model_path = sys.argv[2] 42 43 run(model_path, socket_path) 44 45 `