github.com/wlattner/mlserver@v0.0.0-20141113171038-895f261d2bfd/fit.py (about)

     1  import os
     2  import json
     3  import datetime
     4  
     5  from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
     6  from sklearn.linear_model import LogisticRegression
     7  from sklearn.feature_extraction import DictVectorizer
     8  from sklearn.pipeline import Pipeline
     9  from sklearn.externals import joblib
    10  from sklearn.cross_validation import cross_val_score
    11  from sklearn.metrics import confusion_matrix
    12  
    13  def fit(X, Y):
    14    models = {
    15      'LogisticRegression': LogisticRegression(),
    16      'GradientBoostingClassifier': GradientBoostingClassifier(n_estimators=150),
    17      'RandomForestClassifier': RandomForestClassifier(n_estimators=150)
    18    }
    19  
    20    best_score = 0
    21    best_model = ''
    22    for model in models:
    23      vec = DictVectorizer(sparse=False)
    24      clf = models[model]
    25      pl = Pipeline([('vec', vec), ('clf', clf)])
    26  
    27      #TODO: grid search for model params
    28      scores = cross_val_score(pl, X, Y, n_jobs=3)
    29      if scores.mean() > best_score:
    30        best_score = scores.mean()
    31        best_model = model
    32  
    33    # retrain best model with all data
    34    vec = DictVectorizer(sparse=False)
    35    clf = models[best_model]
    36    pl = Pipeline([('vec', vec), ('clf', clf)])
    37    pl.fit(X, Y)
    38    pl.score_ = best_score  # report cv score
    39    return pl
    40  
    41  def save(path, model_id, model):
    42  	fname = model_id + '.pkl'
    43  	if not os.path.exists(path):
    44  		os.makedirs(path)
    45  	joblib.dump(model, os.path.join(path, fname))
    46  
    47  def save_metadata(path, model_id, model_name, model, X, Y):
    48  	Y_hat = model.predict(X)
    49  
    50  	labels = [l for l in model.named_steps['clf'].classes_]
    51  
    52  	cm = confusion_matrix(Y, Y_hat, labels=labels)
    53  	# this is an insane dict comprehension, need to encode the val as a float, json will not encode 0
    54  	cm_dict = {str(labels[inx]): {str(labels[c]):float(val) for c, val in enumerate(row)} for inx, row in enumerate(cm)}
    55  
    56  	model_data = {
    57  		"model_id": model_id,
    58  		"metadata": {
    59  			"name": model_name,
    60  			"created_at": datetime.datetime.utcnow().isoformat('T') + 'Z'
    61  		},
    62  		"performance" : {
    63  			"algorithm": model.named_steps['clf'].__class__.__name__,
    64  			"score": model.score_,
    65  			"confusion_matrix": cm_dict
    66  		}
    67  	}
    68  
    69  	json.dump(model_data, open(os.path.join(path, model_id + '.json'), 'w'))
    70  
    71  
    72  if __name__ == "__main__":
    73  	import sys
    74  
    75  	data = json.load(open(sys.argv[2]))
    76  	model_save_path = sys.argv[1]
    77  	model_id = os.path.basename(model_save_path)
    78  
    79  	model = fit(data['data'], data['labels'])
    80  	save(model_save_path, model_id, model)
    81  	save_metadata(model_save_path, model_id, data['name'], model, data['data'], data['labels'])