github.com/wlattner/mlserver@v0.0.0-20141113171038-895f261d2bfd/predict.py (about)

     1  import zmq
     2  import signal
     3  import sys
     4  from sklearn.externals import joblib
     5  
     6  def predict(model, X):
     7  	predictions = []
     8  	labels = [str(label) for label in model.steps[-1][-1].classes_]
     9  	for prediction in model.predict_proba(X):
    10  		predictions.append({labels[lab]: prob for lab, prob in enumerate(prediction)})
    11  	return predictions
    12  
    13  def load(path):
    14  	return joblib.load(path)
    15  
    16  def run(model_path, socket_path):
    17  	model = load(model_path)
    18  
    19  	context = zmq.Context()
    20  	socket = context.socket(zmq.REP)
    21  	socket.bind(socket_path)
    22  
    23  	try:
    24  		while True:
    25  			message = socket.recv_json()
    26  			predictions = predict(model, message['data'])
    27  			socket.send_json(predictions)
    28  	finally:
    29  		context.destroy()
    30  
    31  def exit_on_sigint(_sig, _stack_frame):
    32  	sys.exit(0)
    33  
    34  if __name__ == "__main__":
    35  	signal.signal(signal.SIGINT, exit_on_sigint)
    36  
    37  	socket_path = sys.argv[1]
    38  	model_path = sys.argv[2]
    39  
    40  	run(model_path, socket_path)