github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/examples/eager/spinn/spinn.py (about)

     1  r"""Implementation of SPINN in TensorFlow eager execution.
     2  
     3  SPINN: Stack-Augmented Parser-Interpreter Neural Network.
     4  
     5  Ths file contains model definition and code for training the model.
     6  
     7  The model definition is based on PyTorch implementation at:
     8    https://github.com/jekbradbury/examples/tree/spinn/snli
     9  
    10  which was released under a BSD 3-Clause License at:
    11  https://github.com/jekbradbury/examples/blob/spinn/LICENSE:
    12  
    13  Copyright (c) 2017,
    14  All rights reserved.
    15  
    16  See ./LICENSE for more details.
    17  
    18  Instructions for use:
    19  * See `README.md` for details on how to prepare the SNLI and GloVe data.
    20  * Suppose you have prepared the data at "/tmp/spinn-data", use the folloing
    21    command to train the model:
    22  
    23    ```bash
    24    python spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs
    25    ```
    26  
    27    Checkpoints and TensorBoard summaries will be written to "/tmp/spinn-logs".
    28  
    29  References:
    30  * Bowman, S.R., Gauthier, J., Rastogi A., Gupta, R., Manning, C.D., & Potts, C.
    31    (2016). A Fast Unified Model for Parsing and Sentence Understanding.
    32    https://arxiv.org/abs/1603.06021
    33  * Bradbury, J. (2017). Recursive Neural Networks with PyTorch.
    34    https://devblogs.nvidia.com/parallelforall/recursive-neural-networks-pytorch/
    35  """
    36  
    37  from __future__ import absolute_import
    38  from __future__ import division
    39  from __future__ import print_function
    40  
    41  import argparse
    42  import itertools
    43  import os
    44  import sys
    45  import time
    46  
    47  from six.moves import xrange  # pylint: disable=redefined-builtin
    48  import tensorflow as tf
    49  
    50  import tensorflow.contrib.eager as tfe
    51  from tensorflow.contrib.eager.python.examples.spinn import data
    52  
    53  
    54  layers = tf.keras.layers
    55  
    56  
    57  def _bundle(lstm_iter):
    58    """Concatenate a list of Tensors along 1st axis and split result into two.
    59  
    60    Args:
    61      lstm_iter: A `list` of `N` dense `Tensor`s, each of which has the shape
    62        (R, 2 * M).
    63  
    64    Returns:
    65      A `list` of two dense `Tensor`s, each of which has the shape (N * R, M).
    66    """
    67    return tf.split(tf.concat(lstm_iter, 0), 2, axis=1)
    68  
    69  
    70  def _unbundle(state):
    71    """Concatenate a list of Tensors along 2nd axis and split result.
    72  
    73    This is the inverse of `_bundle`.
    74  
    75    Args:
    76      state: A `list` of two dense `Tensor`s, each of which has the shape (R, M).
    77  
    78    Returns:
    79      A `list` of `R` dense `Tensors`, each of which has the shape (1, 2 * M).
    80    """
    81    return tf.split(tf.concat(state, 1), state[0].shape[0], axis=0)
    82  
    83  
    84  # pylint: disable=not-callable
    85  class Reducer(tf.keras.Model):
    86    """A module that applies reduce operation on left and right vectors."""
    87  
    88    def __init__(self, size, tracker_size=None):
    89      super(Reducer, self).__init__()
    90      self.left = layers.Dense(5 * size, activation=None)
    91      self.right = layers.Dense(5 * size, activation=None, use_bias=False)
    92      if tracker_size is not None:
    93        self.track = layers.Dense(5 * size, activation=None, use_bias=False)
    94      else:
    95        self.track = None
    96  
    97    def call(self, left_in, right_in, tracking=None):
    98      """Invoke forward pass of the Reduce module.
    99  
   100      This method feeds a linear combination of `left_in`, `right_in` and
   101      `tracking` into a Tree LSTM and returns the output of the Tree LSTM.
   102  
   103      Args:
   104        left_in: A list of length L. Each item is a dense `Tensor` with
   105          the shape (1, n_dims). n_dims is the size of the embedding vector.
   106        right_in: A list of the same length as `left_in`. Each item should have
   107          the same shape as the items of `left_in`.
   108        tracking: Optional list of the same length as `left_in`. Each item is a
   109          dense `Tensor` with shape (1, tracker_size * 2). tracker_size is the
   110          size of the Tracker's state vector.
   111  
   112      Returns:
   113        Output: A list of length batch_size. Each item has the shape (1, n_dims).
   114      """
   115      left, right = _bundle(left_in), _bundle(right_in)
   116      lstm_in = self.left(left[0]) + self.right(right[0])
   117      if self.track and tracking:
   118        lstm_in += self.track(_bundle(tracking)[0])
   119      return _unbundle(self._tree_lstm(left[1], right[1], lstm_in))
   120  
   121    def _tree_lstm(self, c1, c2, lstm_in):
   122      a, i, f1, f2, o = tf.split(lstm_in, 5, axis=1)
   123      c = tf.tanh(a) * tf.sigmoid(i) + tf.sigmoid(f1) * c1 + tf.sigmoid(f2) * c2
   124      h = tf.sigmoid(o) * tf.tanh(c)
   125      return h, c
   126  
   127  
   128  class Tracker(tf.keras.Model):
   129    """A module that tracks the history of the sentence with an LSTM."""
   130  
   131    def __init__(self, tracker_size, predict):
   132      """Constructor of Tracker.
   133  
   134      Args:
   135        tracker_size: Number of dimensions of the underlying `LSTMCell`.
   136        predict: (`bool`) Whether prediction mode is enabled.
   137      """
   138      super(Tracker, self).__init__()
   139      self._rnn = tf.nn.rnn_cell.LSTMCell(tracker_size)
   140      self._state_size = tracker_size
   141      if predict:
   142        self._transition = layers.Dense(4)
   143      else:
   144        self._transition = None
   145  
   146    def reset_state(self):
   147      self.state = None
   148  
   149    def call(self, bufs, stacks):
   150      """Invoke the forward pass of the Tracker module.
   151  
   152      This method feeds the concatenation of the top two elements of the stacks
   153      into an LSTM cell and returns the resultant state of the LSTM cell.
   154  
   155      Args:
   156        bufs: A `list` of length batch_size. Each item is a `list` of
   157          max_sequence_len (maximum sequence length of the batch). Each item
   158          of the nested list is a dense `Tensor` of shape (1, d_proj), where
   159          d_proj is the size of the word embedding vector or the size of the
   160          vector space that the word embedding vector is projected to.
   161        stacks: A `list` of size batch_size. Each item is a `list` of
   162          variable length corresponding to the current height of the stack.
   163          Each item of the nested list is a dense `Tensor` of shape (1, d_proj).
   164  
   165      Returns:
   166        1. A list of length batch_size. Each item is a dense `Tensor` of shape
   167          (1, d_tracker * 2).
   168        2.  If under predict mode, result of applying a Dense layer on the
   169          first state vector of the RNN. Else, `None`.
   170      """
   171      buf = _bundle([buf[-1] for buf in bufs])[0]
   172      stack1 = _bundle([stack[-1] for stack in stacks])[0]
   173      stack2 = _bundle([stack[-2] for stack in stacks])[0]
   174      x = tf.concat([buf, stack1, stack2], 1)
   175      if self.state is None:
   176        batch_size = int(x.shape[0])
   177        zeros = tf.zeros((batch_size, self._state_size), dtype=tf.float32)
   178        self.state = [zeros, zeros]
   179      _, self.state = self._rnn(x, self.state)
   180      unbundled = _unbundle(self.state)
   181      if self._transition:
   182        return unbundled, self._transition(self.state[0])
   183      else:
   184        return unbundled, None
   185  
   186  
   187  class SPINN(tf.keras.Model):
   188    """Stack-augmented Parser-Interpreter Neural Network.
   189  
   190    See https://arxiv.org/abs/1603.06021 for more details.
   191    """
   192  
   193    def __init__(self, config):
   194      """Constructor of SPINN.
   195  
   196      Args:
   197        config: A `namedtupled` with the following attributes.
   198          d_proj - (`int`) number of dimensions of the vector space to project the
   199            word embeddings to.
   200          d_tracker - (`int`) number of dimensions of the Tracker's state vector.
   201          d_hidden - (`int`) number of the dimensions of the hidden state, for the
   202            Reducer module.
   203          n_mlp_layers - (`int`) number of multi-layer perceptron layers to use to
   204            convert the output of the `Feature` module to logits.
   205          predict - (`bool`) Whether the Tracker will enabled predictions.
   206      """
   207      super(SPINN, self).__init__()
   208      self.config = config
   209      self.reducer = Reducer(config.d_hidden, config.d_tracker)
   210      if config.d_tracker is not None:
   211        self.tracker = Tracker(config.d_tracker, config.predict)
   212      else:
   213        self.tracker = None
   214  
   215    def call(self, buffers, transitions, training=False):
   216      """Invoke the forward pass of the SPINN model.
   217  
   218      Args:
   219        buffers: Dense `Tensor` of shape
   220          (max_sequence_len, batch_size, config.d_proj).
   221        transitions: Dense `Tensor` with integer values that represent the parse
   222          trees of the sentences. A value of 2 indicates "reduce"; a value of 3
   223          indicates "shift". Shape: (max_sequence_len * 2 - 3, batch_size).
   224        training: Whether the invocation is under training mode.
   225  
   226      Returns:
   227        Output `Tensor` of shape (batch_size, config.d_embed).
   228      """
   229      max_sequence_len, batch_size, d_proj = (int(x) for x in buffers.shape)
   230  
   231      # Split the buffers into left and right word items and put the initial
   232      # items in a stack.
   233      splitted = tf.split(
   234          tf.reshape(tf.transpose(buffers, [1, 0, 2]), [-1, d_proj]),
   235          max_sequence_len * batch_size, axis=0)
   236      buffers = [splitted[k:k + max_sequence_len]
   237                 for k in xrange(0, len(splitted), max_sequence_len)]
   238      stacks = [[buf[0], buf[0]] for buf in buffers]
   239  
   240      if self.tracker:
   241        # Reset tracker state for new batch.
   242        self.tracker.reset_state()
   243  
   244      num_transitions = transitions.shape[0]
   245  
   246      # Iterate through transitions and perform the appropriate stack-pop, reduce
   247      # and stack-push operations.
   248      transitions = transitions.numpy()
   249      for i in xrange(num_transitions):
   250        trans = transitions[i]
   251        if self.tracker:
   252          # Invoke tracker to obtain the current tracker states for the sentences.
   253          tracker_states, trans_hypothesis = self.tracker(buffers, stacks=stacks)
   254          if trans_hypothesis:
   255            trans = tf.argmax(trans_hypothesis, axis=-1)
   256        else:
   257          tracker_states = itertools.repeat(None)
   258        lefts, rights, trackings = [], [], []
   259        for transition, buf, stack, tracking in zip(
   260            trans, buffers, stacks, tracker_states):
   261          if int(transition) == 3:  # Shift.
   262            stack.append(buf.pop())
   263          elif int(transition) == 2:  # Reduce.
   264            rights.append(stack.pop())
   265            lefts.append(stack.pop())
   266            trackings.append(tracking)
   267  
   268        if rights:
   269          reducer_output = self.reducer(lefts, rights, trackings)
   270          reduced = iter(reducer_output)
   271  
   272          for transition, stack in zip(trans, stacks):
   273            if int(transition) == 2:  # Reduce.
   274              stack.append(next(reduced))
   275      return _bundle([stack.pop() for stack in stacks])[0]
   276  
   277  
   278  class Perceptron(tf.keras.Model):
   279    """One layer of the SNLIClassifier multi-layer perceptron."""
   280  
   281    def __init__(self, dimension, dropout_rate, previous_layer):
   282      """Configure the Perceptron."""
   283      super(Perceptron, self).__init__()
   284      self.dense = tf.keras.layers.Dense(dimension, activation=tf.nn.elu)
   285      self.batchnorm = layers.BatchNormalization()
   286      self.dropout = layers.Dropout(rate=dropout_rate)
   287      self.previous_layer = previous_layer
   288  
   289    def call(self, x, training):
   290      """Run previous Perceptron layers, then this one."""
   291      x = self.previous_layer(x, training=training)
   292      x = self.dense(x)
   293      x = self.batchnorm(x, training=training)
   294      x = self.dropout(x, training=training)
   295      return x
   296  
   297  
   298  class SNLIClassifier(tf.keras.Model):
   299    """SNLI Classifier Model.
   300  
   301    A model aimed at solving the SNLI (Standford Natural Language Inference)
   302    task, using the SPINN model from above. For details of the task, see:
   303      https://nlp.stanford.edu/projects/snli/
   304    """
   305  
   306    def __init__(self, config, embed):
   307      """Constructor of SNLICLassifier.
   308  
   309      Args:
   310        config: A namedtuple containing required configurations for the model. It
   311          needs to have the following attributes.
   312          projection - (`bool`) whether the word vectors are to be projected onto
   313            another vector space (of `d_proj` dimensions).
   314          d_proj - (`int`) number of dimensions of the vector space to project the
   315            word embeddings to.
   316          embed_dropout - (`float`) dropout rate for the word embedding vectors.
   317          n_mlp_layers - (`int`) number of multi-layer perceptron (MLP) layers to
   318            use to convert the output of the `Feature` module to logits.
   319          mlp_dropout - (`float`) dropout rate of the MLP layers.
   320          d_out - (`int`) number of dimensions of the final output of the MLP
   321            layers.
   322          lr - (`float`) learning rate.
   323        embed: A embedding matrix of shape (vocab_size, d_embed).
   324      """
   325      super(SNLIClassifier, self).__init__()
   326      self.config = config
   327      self.embed = tf.constant(embed)
   328  
   329      self.projection = layers.Dense(config.d_proj)
   330      self.embed_bn = layers.BatchNormalization()
   331      self.embed_dropout = layers.Dropout(rate=config.embed_dropout)
   332      self.encoder = SPINN(config)
   333  
   334      self.feature_bn = layers.BatchNormalization()
   335      self.feature_dropout = layers.Dropout(rate=config.mlp_dropout)
   336  
   337      current_mlp = lambda result, training: result
   338      for _ in range(config.n_mlp_layers):
   339        current_mlp = Perceptron(dimension=config.d_mlp,
   340                                 dropout_rate=config.mlp_dropout,
   341                                 previous_layer=current_mlp)
   342      self.mlp = current_mlp
   343      self.mlp_output = layers.Dense(
   344          config.d_out,
   345          kernel_initializer=tf.random_uniform_initializer(minval=-5e-3,
   346                                                           maxval=5e-3))
   347  
   348    def call(self,
   349             premise,
   350             premise_transition,
   351             hypothesis,
   352             hypothesis_transition,
   353             training=False):
   354      """Invoke the forward pass the SNLIClassifier model.
   355  
   356      Args:
   357        premise: The word indices of the premise sentences, with shape
   358          (max_prem_seq_len, batch_size).
   359        premise_transition: The transitions for the premise sentences, with shape
   360          (max_prem_seq_len * 2 - 3, batch_size).
   361        hypothesis: The word indices of the hypothesis sentences, with shape
   362          (max_hypo_seq_len, batch_size).
   363        hypothesis_transition: The transitions for the hypothesis sentences, with
   364          shape (max_hypo_seq_len * 2 - 3, batch_size).
   365        training: Whether the invocation is under training mode.
   366  
   367      Returns:
   368        The logits, as a dense `Tensor` of shape (batch_size, d_out), where d_out
   369        is the size of the output vector.
   370      """
   371      # Perform embedding lookup on the premise and hypothesis inputs, which have
   372      # the word-index format.
   373      premise_embed = tf.nn.embedding_lookup(self.embed, premise)
   374      hypothesis_embed = tf.nn.embedding_lookup(self.embed, hypothesis)
   375  
   376      if self.config.projection:
   377        # Project the embedding vectors to another vector space.
   378        premise_embed = self.projection(premise_embed)
   379        hypothesis_embed = self.projection(hypothesis_embed)
   380  
   381      # Perform batch normalization and dropout on the possibly projected word
   382      # vectors.
   383      premise_embed = self.embed_bn(premise_embed, training=training)
   384      hypothesis_embed = self.embed_bn(hypothesis_embed, training=training)
   385      premise_embed = self.embed_dropout(premise_embed, training=training)
   386      hypothesis_embed = self.embed_dropout(hypothesis_embed, training=training)
   387  
   388      # Run the batch-normalized and dropout-processed word vectors through the
   389      # SPINN encoder.
   390      premise = self.encoder(premise_embed, premise_transition,
   391                             training=training)
   392      hypothesis = self.encoder(hypothesis_embed, hypothesis_transition,
   393                                training=training)
   394  
   395      # Combine encoder outputs for premises and hypotheses into logits.
   396      # Then apply batch normalization and dropuout on the logits.
   397      logits = tf.concat(
   398          [premise, hypothesis, premise - hypothesis, premise * hypothesis], 1)
   399      logits = self.feature_dropout(
   400          self.feature_bn(logits, training=training), training=training)
   401  
   402      # Apply the multi-layer perceptron on the logits.
   403      logits = self.mlp(logits, training=training)
   404      logits = self.mlp_output(logits)
   405      return logits
   406  
   407  
   408  class SNLIClassifierTrainer(tfe.Checkpointable):
   409    """A class that coordinates the training of an SNLIClassifier."""
   410  
   411    def __init__(self, snli_classifier, lr):
   412      """Constructor of SNLIClassifierTrainer.
   413  
   414      Args:
   415        snli_classifier: An instance of `SNLIClassifier`.
   416        lr: Learning rate.
   417      """
   418      self._model = snli_classifier
   419      # Create a custom learning rate Variable for the RMSProp optimizer, because
   420      # the learning rate needs to be manually decayed later (see
   421      # decay_learning_rate()).
   422      self._learning_rate = tf.Variable(lr, name="learning_rate")
   423      self._optimizer = tf.train.RMSPropOptimizer(self._learning_rate,
   424                                                  epsilon=1e-6)
   425  
   426    def loss(self, labels, logits):
   427      """Calculate the loss given a batch of data.
   428  
   429      Args:
   430        labels: The truth labels, with shape (batch_size,).
   431        logits: The logits output from the forward pass of the SNLIClassifier
   432          model, with shape (batch_size, d_out), where d_out is the output
   433          dimension size of the SNLIClassifier.
   434  
   435      Returns:
   436        The loss value, as a scalar `Tensor`.
   437      """
   438      return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
   439          labels=labels, logits=logits))
   440  
   441    def train_batch(self,
   442                    labels,
   443                    premise,
   444                    premise_transition,
   445                    hypothesis,
   446                    hypothesis_transition):
   447      """Train model on batch of data.
   448  
   449      Args:
   450        labels: The truth labels, with shape (batch_size,).
   451        premise: The word indices of the premise sentences, with shape
   452          (max_prem_seq_len, batch_size).
   453        premise_transition: The transitions for the premise sentences, with shape
   454          (max_prem_seq_len * 2 - 3, batch_size).
   455        hypothesis: The word indices of the hypothesis sentences, with shape
   456          (max_hypo_seq_len, batch_size).
   457        hypothesis_transition: The transitions for the hypothesis sentences, with
   458          shape (max_hypo_seq_len * 2 - 3, batch_size).
   459  
   460      Returns:
   461        1. loss value as a scalar `Tensor`.
   462        2. logits as a dense `Tensor` of shape (batch_size, d_out), where d_out is
   463          the output dimension size of the SNLIClassifier.
   464      """
   465      with tf.GradientTape() as tape:
   466        tape.watch(self._model.variables)
   467        logits = self._model(premise,
   468                             premise_transition,
   469                             hypothesis,
   470                             hypothesis_transition,
   471                             training=True)
   472        loss = self.loss(labels, logits)
   473      gradients = tape.gradient(loss, self._model.variables)
   474      self._optimizer.apply_gradients(zip(gradients, self._model.variables),
   475                                      global_step=tf.train.get_global_step())
   476      return loss, logits
   477  
   478    def decay_learning_rate(self, decay_by):
   479      """Decay learning rate of the optimizer by factor decay_by."""
   480      self._learning_rate.assign(self._learning_rate * decay_by)
   481      print("Decayed learning rate of optimizer to: %s" %
   482            self._learning_rate.numpy())
   483  
   484    @property
   485    def learning_rate(self):
   486      return self._learning_rate
   487  
   488    @property
   489    def model(self):
   490      return self._model
   491  
   492    @property
   493    def variables(self):
   494      return (self._model.variables + [self.learning_rate] +
   495              self._optimizer.variables())
   496  
   497  
   498  def _batch_n_correct(logits, label):
   499    """Calculate number of correct predictions in a batch.
   500  
   501    Args:
   502      logits: A logits Tensor of shape `(batch_size, num_categories)` and dtype
   503        `float32`.
   504      label: A labels Tensor of shape `(batch_size,)` and dtype `int64`
   505  
   506    Returns:
   507      Number of correct predictions.
   508    """
   509    return tf.reduce_sum(
   510        tf.cast((tf.equal(
   511            tf.argmax(logits, axis=1), label)), tf.float32)).numpy()
   512  
   513  
   514  def _evaluate_on_dataset(snli_data, batch_size, trainer, use_gpu):
   515    """Run evaluation on a dataset.
   516  
   517    Args:
   518      snli_data: The `data.SnliData` to use in this evaluation.
   519      batch_size: The batch size to use during this evaluation.
   520      trainer: An instance of `SNLIClassifierTrainer to use for this
   521        evaluation.
   522      use_gpu: Whether GPU is being used.
   523  
   524    Returns:
   525      1. Average loss across all examples of the dataset.
   526      2. Average accuracy rate across all examples of the dataset.
   527    """
   528    mean_loss = tfe.metrics.Mean()
   529    accuracy = tfe.metrics.Accuracy()
   530    for label, prem, prem_trans, hypo, hypo_trans in _get_dataset_iterator(
   531        snli_data, batch_size):
   532      if use_gpu:
   533        label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu()
   534      logits = trainer.model(prem, prem_trans, hypo, hypo_trans, training=False)
   535      loss_val = trainer.loss(label, logits)
   536      batch_size = tf.shape(label)[0]
   537      mean_loss(loss_val, weights=batch_size.gpu() if use_gpu else batch_size)
   538      accuracy(tf.argmax(logits, axis=1), label)
   539    return mean_loss.result().numpy(), accuracy.result().numpy()
   540  
   541  
   542  def _get_dataset_iterator(snli_data, batch_size):
   543    """Get a data iterator for a split of SNLI data.
   544  
   545    Args:
   546      snli_data: A `data.SnliData` object.
   547      batch_size: The desired batch size.
   548  
   549    Returns:
   550      A dataset iterator.
   551    """
   552    with tf.device("/device:CPU:0"):
   553      # Some tf.data ops, such as ShuffleDataset, are available only on CPU.
   554      dataset = tf.data.Dataset.from_generator(
   555          snli_data.get_generator(batch_size),
   556          (tf.int64, tf.int64, tf.int64, tf.int64, tf.int64))
   557      dataset = dataset.shuffle(snli_data.num_batches(batch_size))
   558      return tfe.Iterator(dataset)
   559  
   560  
   561  def train_or_infer_spinn(embed,
   562                           word2index,
   563                           train_data,
   564                           dev_data,
   565                           test_data,
   566                           config):
   567    """Perform Training or Inference on a SPINN model.
   568  
   569    Args:
   570      embed: The embedding matrix as a float32 numpy array with shape
   571        [vocabulary_size, word_vector_len]. word_vector_len is the length of a
   572        word embedding vector.
   573      word2index: A `dict` mapping word to word index.
   574      train_data: An instance of `data.SnliData`, for the train split.
   575      dev_data: Same as above, for the dev split.
   576      test_data: Same as above, for the test split.
   577      config: A configuration object. See the argument to this Python binary for
   578        details.
   579  
   580    Returns:
   581      If `config.inference_premise ` and `config.inference_hypothesis` are not
   582        `None`, i.e., inference mode: the logits for the possible labels of the
   583        SNLI data set, as a `Tensor` of three floats.
   584      else:
   585        The trainer object.
   586    Raises:
   587      ValueError: if only one of config.inference_premise and
   588        config.inference_hypothesis is specified.
   589    """
   590    # TODO(cais): Refactor this function into separate one for training and
   591    #   inference.
   592    use_gpu = tfe.num_gpus() > 0 and not config.force_cpu
   593    device = "gpu:0" if use_gpu else "cpu:0"
   594    print("Using device: %s" % device)
   595  
   596    if ((config.inference_premise and not config.inference_hypothesis) or
   597        (not config.inference_premise and config.inference_hypothesis)):
   598      raise ValueError(
   599          "--inference_premise and --inference_hypothesis must be both "
   600          "specified or both unspecified, but only one is specified.")
   601  
   602    if config.inference_premise:
   603      # Inference mode.
   604      inference_sentence_pair = [
   605          data.encode_sentence(config.inference_premise, word2index),
   606          data.encode_sentence(config.inference_hypothesis, word2index)]
   607    else:
   608      inference_sentence_pair = None
   609  
   610    log_header = (
   611        "  Time Epoch Iteration Progress    (%Epoch)   Loss   Dev/Loss"
   612        "     Accuracy  Dev/Accuracy")
   613    log_template = (
   614        "{:>6.0f} {:>5.0f} {:>9.0f} {:>5.0f}/{:<5.0f} {:>7.0f}% {:>8.6f} {} "
   615        "{:12.4f} {}")
   616    dev_log_template = (
   617        "{:>6.0f} {:>5.0f} {:>9.0f} {:>5.0f}/{:<5.0f} {:>7.0f}% {:>8.6f} "
   618        "{:8.6f} {:12.4f} {:12.4f}")
   619  
   620    summary_writer = tf.contrib.summary.create_file_writer(
   621        config.logdir, flush_millis=10000)
   622  
   623    with tf.device(device), \
   624         summary_writer.as_default(), \
   625         tf.contrib.summary.always_record_summaries():
   626      model = SNLIClassifier(config, embed)
   627      global_step = tf.train.get_or_create_global_step()
   628      trainer = SNLIClassifierTrainer(model, config.lr)
   629      checkpoint = tf.train.Checkpoint(trainer=trainer, global_step=global_step)
   630      checkpoint.restore(tf.train.latest_checkpoint(config.logdir))
   631  
   632      if inference_sentence_pair:
   633        # Inference mode.
   634        prem, prem_trans = inference_sentence_pair[0]
   635        hypo, hypo_trans = inference_sentence_pair[1]
   636        hypo_trans = inference_sentence_pair[1][1]
   637        inference_logits = model(
   638            tf.constant(prem), tf.constant(prem_trans),
   639            tf.constant(hypo), tf.constant(hypo_trans), training=False)
   640        inference_logits = inference_logits[0][1:]
   641        max_index = tf.argmax(inference_logits)
   642        print("\nInference logits:")
   643        for i, (label, logit) in enumerate(
   644            zip(data.POSSIBLE_LABELS, inference_logits)):
   645          winner_tag = " (winner)" if max_index == i else ""
   646          print("  {0:<16}{1:.6f}{2}".format(label + ":", logit, winner_tag))
   647        return inference_logits
   648  
   649      train_len = train_data.num_batches(config.batch_size)
   650      start = time.time()
   651      iterations = 0
   652      mean_loss = tfe.metrics.Mean()
   653      accuracy = tfe.metrics.Accuracy()
   654      print(log_header)
   655      for epoch in xrange(config.epochs):
   656        batch_idx = 0
   657        for label, prem, prem_trans, hypo, hypo_trans in _get_dataset_iterator(
   658            train_data, config.batch_size):
   659          if use_gpu:
   660            label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu()
   661            # prem_trans and hypo_trans are used for dynamic control flow and can
   662            # remain on CPU. Same in _evaluate_on_dataset().
   663  
   664          iterations += 1
   665          batch_train_loss, batch_train_logits = trainer.train_batch(
   666              label, prem, prem_trans, hypo, hypo_trans)
   667          batch_size = tf.shape(label)[0]
   668          mean_loss(batch_train_loss.numpy(),
   669                    weights=batch_size.gpu() if use_gpu else batch_size)
   670          accuracy(tf.argmax(batch_train_logits, axis=1), label)
   671  
   672          if iterations % config.save_every == 0:
   673            checkpoint.save(os.path.join(config.logdir, "ckpt"))
   674  
   675          if iterations % config.dev_every == 0:
   676            dev_loss, dev_frac_correct = _evaluate_on_dataset(
   677                dev_data, config.batch_size, trainer, use_gpu)
   678            print(dev_log_template.format(
   679                time.time() - start,
   680                epoch, iterations, 1 + batch_idx, train_len,
   681                100.0 * (1 + batch_idx) / train_len,
   682                mean_loss.result(), dev_loss,
   683                accuracy.result() * 100.0, dev_frac_correct * 100.0))
   684            tf.contrib.summary.scalar("dev/loss", dev_loss)
   685            tf.contrib.summary.scalar("dev/accuracy", dev_frac_correct)
   686          elif iterations % config.log_every == 0:
   687            mean_loss_val = mean_loss.result()
   688            accuracy_val = accuracy.result()
   689            print(log_template.format(
   690                time.time() - start,
   691                epoch, iterations, 1 + batch_idx, train_len,
   692                100.0 * (1 + batch_idx) / train_len,
   693                mean_loss_val, " " * 8, accuracy_val * 100.0, " " * 12))
   694            tf.contrib.summary.scalar("train/loss", mean_loss_val)
   695            tf.contrib.summary.scalar("train/accuracy", accuracy_val)
   696            # Reset metrics.
   697            mean_loss = tfe.metrics.Mean()
   698            accuracy = tfe.metrics.Accuracy()
   699  
   700          batch_idx += 1
   701        if (epoch + 1) % config.lr_decay_every == 0:
   702          trainer.decay_learning_rate(config.lr_decay_by)
   703  
   704      test_loss, test_frac_correct = _evaluate_on_dataset(
   705          test_data, config.batch_size, trainer, use_gpu)
   706      print("Final test loss: %g; accuracy: %g%%" %
   707            (test_loss, test_frac_correct * 100.0))
   708  
   709    return trainer
   710  
   711  
   712  def main(_):
   713    config = FLAGS
   714  
   715    # Load embedding vectors.
   716    vocab = data.load_vocabulary(FLAGS.data_root)
   717    word2index, embed = data.load_word_vectors(FLAGS.data_root, vocab)
   718  
   719    if not (config.inference_premise or config.inference_hypothesis):
   720      print("Loading train, dev and test data...")
   721      train_data = data.SnliData(
   722          os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_train.txt"),
   723          word2index, sentence_len_limit=FLAGS.sentence_len_limit)
   724      dev_data = data.SnliData(
   725          os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_dev.txt"),
   726          word2index, sentence_len_limit=FLAGS.sentence_len_limit)
   727      test_data = data.SnliData(
   728          os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_test.txt"),
   729          word2index, sentence_len_limit=FLAGS.sentence_len_limit)
   730    else:
   731      train_data = None
   732      dev_data = None
   733      test_data = None
   734  
   735    train_or_infer_spinn(
   736        embed, word2index, train_data, dev_data, test_data, config)
   737  
   738  
   739  if __name__ == "__main__":
   740    parser = argparse.ArgumentParser(
   741        description=
   742        "TensorFlow eager implementation of the SPINN SNLI classifier.")
   743    parser.add_argument("--data_root", type=str, default="/tmp/spinn-data",
   744                        help="Root directory in which the training data and "
   745                        "embedding matrix are found. See README.md for how to "
   746                        "generate such a directory.")
   747    parser.add_argument("--sentence_len_limit", type=int, default=-1,
   748                        help="Maximum allowed sentence length (# of words). "
   749                        "The default of -1 means unlimited.")
   750    parser.add_argument("--logdir", type=str, default="/tmp/spinn-logs",
   751                        help="Directory in which summaries will be written for "
   752                        "TensorBoard.")
   753    parser.add_argument("--inference_premise", type=str, default=None,
   754                        help="Premise sentence for inference. Must be "
   755                        "accompanied by --inference_hypothesis. If specified, "
   756                        "will override all training parameters and perform "
   757                        "inference.")
   758    parser.add_argument("--inference_hypothesis", type=str, default=None,
   759                        help="Hypothesis sentence for inference. Must be "
   760                        "accompanied by --inference_premise. If specified, will "
   761                        "override all training parameters and perform inference.")
   762    parser.add_argument("--epochs", type=int, default=50,
   763                        help="Number of epochs to train.")
   764    parser.add_argument("--batch_size", type=int, default=128,
   765                        help="Batch size to use during training.")
   766    parser.add_argument("--d_proj", type=int, default=600,
   767                        help="Dimensions to project the word embedding vectors "
   768                        "to.")
   769    parser.add_argument("--d_hidden", type=int, default=300,
   770                        help="Size of the hidden layer of the Tracker.")
   771    parser.add_argument("--d_out", type=int, default=4,
   772                        help="Output dimensions of the SNLIClassifier.")
   773    parser.add_argument("--d_mlp", type=int, default=1024,
   774                        help="Size of each layer of the multi-layer perceptron "
   775                        "of the SNLICLassifier.")
   776    parser.add_argument("--n_mlp_layers", type=int, default=2,
   777                        help="Number of layers in the multi-layer perceptron "
   778                        "of the SNLICLassifier.")
   779    parser.add_argument("--d_tracker", type=int, default=64,
   780                        help="Size of the tracker LSTM.")
   781    parser.add_argument("--log_every", type=int, default=50,
   782                        help="Print log and write TensorBoard summary every _ "
   783                        "training batches.")
   784    parser.add_argument("--lr", type=float, default=2e-3,
   785                        help="Initial learning rate.")
   786    parser.add_argument("--lr_decay_by", type=float, default=0.75,
   787                        help="The ratio to multiply the learning rate by every "
   788                        "time the learning rate is decayed.")
   789    parser.add_argument("--lr_decay_every", type=float, default=1,
   790                        help="Decay the learning rate every _ epoch(s).")
   791    parser.add_argument("--dev_every", type=int, default=1000,
   792                        help="Run evaluation on the dev split every _ training "
   793                        "batches.")
   794    parser.add_argument("--save_every", type=int, default=1000,
   795                        help="Save checkpoint every _ training batches.")
   796    parser.add_argument("--embed_dropout", type=float, default=0.08,
   797                        help="Word embedding dropout rate.")
   798    parser.add_argument("--mlp_dropout", type=float, default=0.07,
   799                        help="SNLIClassifier multi-layer perceptron dropout "
   800                        "rate.")
   801    parser.add_argument("--no-projection", action="store_false",
   802                        dest="projection",
   803                        help="Whether word embedding vectors are projected to "
   804                        "another set of vectors (see d_proj).")
   805    parser.add_argument("--predict_transitions", action="store_true",
   806                        dest="predict",
   807                        help="Whether the Tracker will perform prediction.")
   808    parser.add_argument("--force_cpu", action="store_true", dest="force_cpu",
   809                        help="Force use CPU-only regardless of whether a GPU is "
   810                        "available.")
   811    FLAGS, unparsed = parser.parse_known_args()
   812  
   813    tfe.run(main=main, argv=[sys.argv[0]] + unparsed)