github.com/wlattner/mlserver@v0.0.0-20141113171038-895f261d2bfd/fit_py.go (about) 1 package main 2 3 var fitPy = ` 4 import os 5 import json 6 import datetime 7 8 from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier 9 from sklearn.linear_model import LogisticRegression 10 from sklearn.feature_extraction import DictVectorizer 11 from sklearn.pipeline import Pipeline 12 from sklearn.externals import joblib 13 from sklearn.cross_validation import cross_val_score 14 from sklearn.metrics import confusion_matrix 15 16 def fit(X, Y): 17 models = { 18 'LogisticRegression': LogisticRegression(), 19 'GradientBoostingClassifier': GradientBoostingClassifier(n_estimators=150), 20 'RandomForestClassifier': RandomForestClassifier(n_estimators=150) 21 } 22 23 best_score = 0 24 best_model = '' 25 for model in models: 26 vec = DictVectorizer(sparse=False) 27 clf = models[model] 28 pl = Pipeline([('vec', vec), ('clf', clf)]) 29 30 #TODO: grid search for model params 31 scores = cross_val_score(pl, X, Y, n_jobs=3) 32 if scores.mean() > best_score: 33 best_score = scores.mean() 34 best_model = model 35 36 # retrain best model with all data 37 vec = DictVectorizer(sparse=False) 38 clf = models[best_model] 39 pl = Pipeline([('vec', vec), ('clf', clf)]) 40 pl.fit(X, Y) 41 pl.score_ = best_score # report cv score 42 return pl 43 44 def save(path, model_id, model): 45 fname = model_id + '.pkl' 46 if not os.path.exists(path): 47 os.makedirs(path) 48 joblib.dump(model, os.path.join(path, fname)) 49 50 def save_metadata(path, model_id, model_name, model, X, Y): 51 Y_hat = model.predict(X) 52 53 labels = [l for l in model.named_steps['clf'].classes_] 54 55 cm = confusion_matrix(Y, Y_hat, labels=labels) 56 # this is an insane dict comprehension, need to encode the val as a float, json will not encode 0 57 cm_dict = {str(labels[inx]): {str(labels[c]):float(val) for c, val in enumerate(row)} for inx, row in enumerate(cm)} 58 59 model_data = { 60 "model_id": model_id, 61 "metadata": { 62 "name": model_name, 63 "created_at": datetime.datetime.utcnow().isoformat('T') + 'Z' 64 }, 65 "performance" : { 66 "algorithm": model.named_steps['clf'].__class__.__name__, 67 "score": model.score_, 68 "confusion_matrix": cm_dict 69 } 70 } 71 72 json.dump(model_data, open(os.path.join(path, model_id + '.json'), 'w')) 73 74 75 if __name__ == "__main__": 76 import sys 77 78 data = json.load(open(sys.argv[2])) 79 model_save_path = sys.argv[1] 80 model_id = os.path.basename(model_save_path) 81 82 model = fit(data['data'], data['labels']) 83 save(model_save_path, model_id, model) 84 save_metadata(model_save_path, model_id, data['name'], model, data['data'], data['labels']) 85 86 `