github.com/pachyderm/pachyderm@v1.13.4/examples/ml/neon/inference/auto_inference.py (about)

     1  #!/usr/bin/env python
     2  """
     3  Example that does inference on an LSTM networks for amazon review analysis
     4  
     5  $ python examples/imdb/auto_inference.py --model_weights imdb.p --vocab_file imdb.vocab 
     6      --review_files /pfs/reviews --output_dir /pfs/out
     7  
     8  """
     9  
    10  from __future__ import print_function
    11  from future import standard_library
    12  standard_library.install_aliases()  # triggers E402, hence noqa below
    13  from builtins import input  # noqa
    14  import numpy as np  # noqa
    15  from neon.backends import gen_backend  # noqa
    16  from neon.initializers import Uniform, GlorotUniform  # noqa
    17  from neon.layers import LSTM, Affine, Dropout, LookupTable, RecurrentSum  # noqa
    18  from neon.models import Model  # noqa
    19  from neon.transforms import Logistic, Tanh, Softmax  # noqa
    20  from neon.util.argparser import NeonArgparser, extract_valid_args  # noqa
    21  from neon.util.compat import pickle  # noqa
    22  from neon.data.text_preprocessing import clean_string  # noqa
    23  import os
    24  
    25  # parse the command line arguments
    26  parser = NeonArgparser(__doc__)
    27  parser.add_argument('--model_weights', required=True,
    28                      help='pickle file of trained weights')
    29  parser.add_argument('--vocab_file', required=True,
    30                      help='vocabulary file')
    31  parser.add_argument('--review_files', required=True,
    32                      help='directory containing reviews in text files')
    33  parser.add_argument('--output_dir', required=True,
    34                      help='directory where results will be saved')
    35  args = parser.parse_args()
    36  
    37  
    38  # hyperparameters from the reference
    39  batch_size = 1
    40  clip_gradients = True
    41  gradient_limit = 5
    42  vocab_size = 20000
    43  sentence_length = 128
    44  embedding_dim = 128
    45  hidden_size = 128
    46  reset_cells = True
    47  num_epochs = args.epochs
    48  
    49  # setup backend
    50  be = gen_backend(**extract_valid_args(args, gen_backend))
    51  be.bsz = 1
    52  
    53  
    54  # define same model as in train
    55  init_glorot = GlorotUniform()
    56  init_emb = Uniform(low=-0.1 / embedding_dim, high=0.1 / embedding_dim)
    57  nclass = 2
    58  layers = [
    59      LookupTable(vocab_size=vocab_size, embedding_dim=embedding_dim, init=init_emb,
    60                  pad_idx=0, update=True),
    61      LSTM(hidden_size, init_glorot, activation=Tanh(),
    62           gate_activation=Logistic(), reset_cells=True),
    63      RecurrentSum(),
    64      Dropout(keep=0.5),
    65      Affine(nclass, init_glorot, bias=init_glorot, activation=Softmax())
    66  ]
    67  
    68  
    69  # load the weights
    70  print("Initialized the models - ")
    71  model_new = Model(layers=layers)
    72  print("Loading the weights from {0}".format(args.model_weights))
    73  
    74  model_new.load_params(args.model_weights)
    75  model_new.initialize(dataset=(sentence_length, batch_size))
    76  
    77  # setup buffers before accepting reviews
    78  xdev = be.zeros((sentence_length, 1), dtype=np.int32)  # bsz is 1, feature size
    79  xbuf = np.zeros((1, sentence_length), dtype=np.int32)
    80  oov = 2
    81  start = 1
    82  index_from = 3
    83  pad_char = 0
    84  vocab, rev_vocab = pickle.load(open(args.vocab_file, 'rb'))
    85  
    86  # walk over the reviews in the text files, making inferences
    87  for dirpath, dirs, files in os.walk(args.review_files):
    88      for file in files:
    89          with open(os.path.join(dirpath, file), 'r') as myfile:
    90                  data=myfile.read()
    91  
    92                  # clean the input
    93                  tokens = clean_string(data).strip().split()
    94  
    95                  # check for oov and add start
    96                  sent = [len(vocab) + 1 if t not in vocab else vocab[t] for t in tokens]
    97                  sent = [start] + [w + index_from for w in sent]
    98                  sent = [oov if w >= vocab_size else w for w in sent]
    99  
   100                  # pad sentences
   101                  xbuf[:] = 0
   102                  trunc = sent[-sentence_length:]
   103                  xbuf[0, -len(trunc):] = trunc
   104                  xdev[:] = xbuf.T.copy()
   105                  y_pred = model_new.fprop(xdev, inference=True)  # inference flag dropout
   106  
   107                  with open(os.path.join(args.output_dir, file), "w") as output_file:
   108                          output_file.write("Pred - {0}\n".format(y_pred.get().T))