github.com/abayer/test-infra@v0.0.5/mungegithub/issue_labeler/simple_app.py (about) 1 #!/usr/bin/env python 2 3 # Copyright 2016 The Kubernetes Authors. 4 # 5 # Licensed under the Apache License, Version 2.0 (the "License"); 6 # you may not use this file except in compliance with the License. 7 # You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 17 import os 18 import logging 19 from logging.handlers import RotatingFileHandler 20 21 # pylint: disable=import-error 22 import numpy as np 23 from flask import Flask, request 24 from sklearn.feature_extraction import FeatureHasher 25 from sklearn.externals import joblib 26 from sklearn.linear_model import SGDClassifier 27 from nltk.tokenize import RegexpTokenizer 28 from nltk.stem.porter import PorterStemmer 29 # pylint: enable=import-error 30 31 APP = Flask(__name__) 32 # parameters 33 TEAM_FN = './models/trained_teams_model.pkl' 34 COMPONENT_FN = './models/trained_components_model.pkl' 35 LOG_FILE = '/tmp/issue-labeler.log' 36 LOG_SIZE = 1024*1024*100 37 NUM_FEATURES = 262144 38 MY_LOSS = 'hinge' 39 MY_ALPHA = .1 40 MY_PENALTY = 'l2' 41 MY_HASHER = FeatureHasher(input_type='string', n_features=NUM_FEATURES, non_negative=True) 42 MY_STEMMER = PorterStemmer() 43 TOKENIZER = RegexpTokenizer(r'\w+') 44 45 STOPWORDS = [] 46 try: 47 if not STOPWORDS: 48 STOPWORDS_FILENAME = './stopwords.txt' 49 with open(STOPWORDS_FILENAME, 'r') as fp: 50 STOPWORDS = list([word.strip() for word in fp]) 51 except: # pylint:disable=bare-except 52 # don't remove any stopwords 53 STOPWORDS = [] 54 55 @APP.errorhandler(500) 56 def internal_error(exception): 57 return str(exception), 500 58 59 @APP.route("/", methods=['POST']) 60 def get_labels(): 61 """ 62 The request should contain 2 form-urlencoded parameters 63 1) title : title of the issue 64 2) body: body of the issue 65 It returns a team/<label> and a component/<label> 66 """ 67 title = request.form.get('title', '') 68 body = request.form.get('body', '') 69 tokens = tokenize_stem_stop(" ".join([title, body])) 70 team_mod = joblib.load(TEAM_FN) 71 comp_mod = joblib.load(COMPONENT_FN) 72 vec = MY_HASHER.transform([tokens]) 73 tlabel = team_mod.predict(vec)[0] 74 clabel = comp_mod.predict(vec)[0] 75 return ",".join([tlabel, clabel]) 76 77 78 def tokenize_stem_stop(input_string): 79 input_string = input_string.encode('utf-8') 80 cur_title_body = TOKENIZER.tokenize(input_string.decode('utf-8').lower()) 81 return [MY_STEMMER.stem(x) for x in cur_title_body if x not in STOPWORDS] 82 83 84 @APP.route("/update_models", methods=['PUT']) 85 def update_model(): # pylint: disable=too-many-locals 86 """ 87 data should contain three fields 88 titles: list of titles 89 bodies: list of bodies 90 labels: list of list of labels 91 """ 92 data = request.json 93 titles = data.get('titles') 94 bodies = data.get('bodies') 95 labels = data.get('labels') 96 97 t_tokens = [] 98 c_tokens = [] 99 team_labels = [] 100 component_labels = [] 101 for (title, body, label_list) in zip(titles, bodies, labels): 102 t_label = [x for x in label_list if x.startswith('team')] 103 c_label = [x for x in label_list if x.startswith('component')] 104 tokens = tokenize_stem_stop(" ".join([title, body])) 105 if t_label: 106 team_labels += t_label 107 t_tokens += [tokens] 108 if c_label: 109 component_labels += c_label 110 c_tokens += [tokens] 111 t_vec = MY_HASHER.transform(t_tokens) 112 c_vec = MY_HASHER.transform(c_tokens) 113 114 if team_labels: 115 if os.path.isfile(TEAM_FN): 116 team_model = joblib.load(TEAM_FN) 117 team_model.partial_fit(t_vec, np.array(team_labels)) 118 else: 119 # no team model stored so build a new one 120 team_model = SGDClassifier(loss=MY_LOSS, penalty=MY_PENALTY, alpha=MY_ALPHA) 121 team_model.fit(t_vec, np.array(team_labels)) 122 123 if component_labels: 124 if os.path.isfile(COMPONENT_FN): 125 component_model = joblib.load(COMPONENT_FN) 126 component_model.partial_fit(c_vec, np.array(component_labels)) 127 else: 128 # no comp model stored so build a new one 129 component_model = SGDClassifier(loss=MY_LOSS, penalty=MY_PENALTY, alpha=MY_ALPHA) 130 component_model.fit(c_vec, np.array(component_labels)) 131 132 joblib.dump(team_model, TEAM_FN) 133 joblib.dump(component_model, COMPONENT_FN) 134 return "" 135 136 def configure_logger(): 137 log_format = '%(asctime)-20s %(levelname)-10s %(message)s' 138 file_handler = RotatingFileHandler(LOG_FILE, maxBytes=LOG_SIZE, backupCount=3) 139 formatter = logging.Formatter(log_format) 140 file_handler.setFormatter(formatter) 141 APP.logger.addHandler(file_handler) 142 143 if __name__ == "__main__": 144 configure_logger() 145 APP.run(host="0.0.0.0")