github.com/wlattner/mlserver@v0.0.0-20141113171038-895f261d2bfd/fit_py.go (about)

     1  package main
     2  
     3  var fitPy = `
     4  import os
     5  import json
     6  import datetime
     7  
     8  from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
     9  from sklearn.linear_model import LogisticRegression
    10  from sklearn.feature_extraction import DictVectorizer
    11  from sklearn.pipeline import Pipeline
    12  from sklearn.externals import joblib
    13  from sklearn.cross_validation import cross_val_score
    14  from sklearn.metrics import confusion_matrix
    15  
    16  def fit(X, Y):
    17    models = {
    18      'LogisticRegression': LogisticRegression(),
    19      'GradientBoostingClassifier': GradientBoostingClassifier(n_estimators=150),
    20      'RandomForestClassifier': RandomForestClassifier(n_estimators=150)
    21    }
    22  
    23    best_score = 0
    24    best_model = ''
    25    for model in models:
    26      vec = DictVectorizer(sparse=False)
    27      clf = models[model]
    28      pl = Pipeline([('vec', vec), ('clf', clf)])
    29  
    30      #TODO: grid search for model params
    31      scores = cross_val_score(pl, X, Y, n_jobs=3)
    32      if scores.mean() > best_score:
    33        best_score = scores.mean()
    34        best_model = model
    35  
    36    # retrain best model with all data
    37    vec = DictVectorizer(sparse=False)
    38    clf = models[best_model]
    39    pl = Pipeline([('vec', vec), ('clf', clf)])
    40    pl.fit(X, Y)
    41    pl.score_ = best_score  # report cv score
    42    return pl
    43  
    44  def save(path, model_id, model):
    45  	fname = model_id + '.pkl'
    46  	if not os.path.exists(path):
    47  		os.makedirs(path)
    48  	joblib.dump(model, os.path.join(path, fname))
    49  
    50  def save_metadata(path, model_id, model_name, model, X, Y):
    51  	Y_hat = model.predict(X)
    52  
    53  	labels = [l for l in model.named_steps['clf'].classes_]
    54  
    55  	cm = confusion_matrix(Y, Y_hat, labels=labels)
    56  	# this is an insane dict comprehension, need to encode the val as a float, json will not encode 0
    57  	cm_dict = {str(labels[inx]): {str(labels[c]):float(val) for c, val in enumerate(row)} for inx, row in enumerate(cm)}
    58  
    59  	model_data = {
    60  		"model_id": model_id,
    61  		"metadata": {
    62  			"name": model_name,
    63  			"created_at": datetime.datetime.utcnow().isoformat('T') + 'Z'
    64  		},
    65  		"performance" : {
    66  			"algorithm": model.named_steps['clf'].__class__.__name__,
    67  			"score": model.score_,
    68  			"confusion_matrix": cm_dict
    69  		}
    70  	}
    71  
    72  	json.dump(model_data, open(os.path.join(path, model_id + '.json'), 'w'))
    73  
    74  
    75  if __name__ == "__main__":
    76  	import sys
    77  
    78  	data = json.load(open(sys.argv[2]))
    79  	model_save_path = sys.argv[1]
    80  	model_id = os.path.basename(model_save_path)
    81  
    82  	model = fit(data['data'], data['labels'])
    83  	save(model_save_path, model_id, model)
    84  	save_metadata(model_save_path, model_id, data['name'], model, data['data'], data['labels'])
    85  
    86  `