github.com/wlattner/mlserver@v0.0.0-20141113171038-895f261d2bfd/predict_py.go (about)

     1  package main
     2  
     3  var predictPy = `
     4  import zmq
     5  import signal
     6  import sys
     7  from sklearn.externals import joblib
     8  
     9  def predict(model, X):
    10  	predictions = []
    11  	labels = [str(label) for label in model.steps[-1][-1].classes_]
    12  	for prediction in model.predict_proba(X):
    13  		predictions.append({labels[lab]: prob for lab, prob in enumerate(prediction)})
    14  	return predictions
    15  
    16  def load(path):
    17  	return joblib.load(path)
    18  
    19  def run(model_path, socket_path):
    20  	model = load(model_path)
    21  
    22  	context = zmq.Context()
    23  	socket = context.socket(zmq.REP)
    24  	socket.bind(socket_path)
    25  
    26  	try:
    27  		while True:
    28  			message = socket.recv_json()
    29  			predictions = predict(model, message['data'])
    30  			socket.send_json(predictions)
    31  	finally:
    32  		context.destroy()
    33  
    34  def exit_on_sigint(_sig, _stack_frame):
    35  	sys.exit(0)
    36  
    37  if __name__ == "__main__":
    38  	signal.signal(signal.SIGINT, exit_on_sigint)
    39  
    40  	socket_path = sys.argv[1]
    41  	model_path = sys.argv[2]
    42  
    43  	run(model_path, socket_path)
    44  
    45  `