github.com/abayer/test-infra@v0.0.5/mungegithub/issue_labeler/simple_app.py (about)

     1  #!/usr/bin/env python
     2  
     3  # Copyright 2016 The Kubernetes Authors.
     4  #
     5  # Licensed under the Apache License, Version 2.0 (the "License");
     6  # you may not use this file except in compliance with the License.
     7  # You may obtain a copy of the License at
     8  #
     9  #     http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  
    17  import os
    18  import logging
    19  from logging.handlers import RotatingFileHandler
    20  
    21  # pylint: disable=import-error
    22  import numpy as np
    23  from flask import Flask, request
    24  from sklearn.feature_extraction import FeatureHasher
    25  from sklearn.externals import joblib
    26  from sklearn.linear_model import SGDClassifier
    27  from nltk.tokenize import RegexpTokenizer
    28  from nltk.stem.porter import PorterStemmer
    29  # pylint: enable=import-error
    30  
    31  APP = Flask(__name__)
    32  # parameters
    33  TEAM_FN = './models/trained_teams_model.pkl'
    34  COMPONENT_FN = './models/trained_components_model.pkl'
    35  LOG_FILE = '/tmp/issue-labeler.log'
    36  LOG_SIZE = 1024*1024*100
    37  NUM_FEATURES = 262144
    38  MY_LOSS = 'hinge'
    39  MY_ALPHA = .1
    40  MY_PENALTY = 'l2'
    41  MY_HASHER = FeatureHasher(input_type='string', n_features=NUM_FEATURES, non_negative=True)
    42  MY_STEMMER = PorterStemmer()
    43  TOKENIZER = RegexpTokenizer(r'\w+')
    44  
    45  STOPWORDS = []
    46  try:
    47      if not STOPWORDS:
    48          STOPWORDS_FILENAME = './stopwords.txt'
    49      with open(STOPWORDS_FILENAME, 'r') as fp:
    50          STOPWORDS = list([word.strip() for word in fp])
    51  except: # pylint:disable=bare-except
    52      # don't remove any stopwords
    53      STOPWORDS = []
    54  
    55  @APP.errorhandler(500)
    56  def internal_error(exception):
    57      return str(exception), 500
    58  
    59  @APP.route("/", methods=['POST'])
    60  def get_labels():
    61      """
    62      The request should contain 2 form-urlencoded parameters
    63        1) title : title of the issue
    64        2) body: body of the issue
    65      It returns a team/<label> and a component/<label>
    66      """
    67      title = request.form.get('title', '')
    68      body = request.form.get('body', '')
    69      tokens = tokenize_stem_stop(" ".join([title, body]))
    70      team_mod = joblib.load(TEAM_FN)
    71      comp_mod = joblib.load(COMPONENT_FN)
    72      vec = MY_HASHER.transform([tokens])
    73      tlabel = team_mod.predict(vec)[0]
    74      clabel = comp_mod.predict(vec)[0]
    75      return ",".join([tlabel, clabel])
    76  
    77  
    78  def tokenize_stem_stop(input_string):
    79      input_string = input_string.encode('utf-8')
    80      cur_title_body = TOKENIZER.tokenize(input_string.decode('utf-8').lower())
    81      return [MY_STEMMER.stem(x) for x in cur_title_body if x not in STOPWORDS]
    82  
    83  
    84  @APP.route("/update_models", methods=['PUT'])
    85  def update_model(): # pylint: disable=too-many-locals
    86      """
    87      data should contain three fields
    88        titles: list of titles
    89        bodies: list of bodies
    90        labels: list of list of labels
    91      """
    92      data = request.json
    93      titles = data.get('titles')
    94      bodies = data.get('bodies')
    95      labels = data.get('labels')
    96  
    97      t_tokens = []
    98      c_tokens = []
    99      team_labels = []
   100      component_labels = []
   101      for (title, body, label_list) in zip(titles, bodies, labels):
   102          t_label = [x for x in label_list if x.startswith('team')]
   103          c_label = [x for x in label_list if x.startswith('component')]
   104          tokens = tokenize_stem_stop(" ".join([title, body]))
   105          if t_label:
   106              team_labels += t_label
   107              t_tokens += [tokens]
   108          if c_label:
   109              component_labels += c_label
   110              c_tokens += [tokens]
   111      t_vec = MY_HASHER.transform(t_tokens)
   112      c_vec = MY_HASHER.transform(c_tokens)
   113  
   114      if team_labels:
   115          if os.path.isfile(TEAM_FN):
   116              team_model = joblib.load(TEAM_FN)
   117              team_model.partial_fit(t_vec, np.array(team_labels))
   118          else:
   119              # no team model stored so build a new one
   120              team_model = SGDClassifier(loss=MY_LOSS, penalty=MY_PENALTY, alpha=MY_ALPHA)
   121              team_model.fit(t_vec, np.array(team_labels))
   122  
   123      if component_labels:
   124          if os.path.isfile(COMPONENT_FN):
   125              component_model = joblib.load(COMPONENT_FN)
   126              component_model.partial_fit(c_vec, np.array(component_labels))
   127          else:
   128              # no comp model stored so build a new one
   129              component_model = SGDClassifier(loss=MY_LOSS, penalty=MY_PENALTY, alpha=MY_ALPHA)
   130              component_model.fit(c_vec, np.array(component_labels))
   131  
   132      joblib.dump(team_model, TEAM_FN)
   133      joblib.dump(component_model, COMPONENT_FN)
   134      return ""
   135  
   136  def configure_logger():
   137      log_format = '%(asctime)-20s %(levelname)-10s %(message)s'
   138      file_handler = RotatingFileHandler(LOG_FILE, maxBytes=LOG_SIZE, backupCount=3)
   139      formatter = logging.Formatter(log_format)
   140      file_handler.setFormatter(formatter)
   141      APP.logger.addHandler(file_handler)
   142  
   143  if __name__ == "__main__":
   144      configure_logger()
   145      APP.run(host="0.0.0.0")