github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/examples/eager/spinn/spinn.py (about) 1 r"""Implementation of SPINN in TensorFlow eager execution. 2 3 SPINN: Stack-Augmented Parser-Interpreter Neural Network. 4 5 Ths file contains model definition and code for training the model. 6 7 The model definition is based on PyTorch implementation at: 8 https://github.com/jekbradbury/examples/tree/spinn/snli 9 10 which was released under a BSD 3-Clause License at: 11 https://github.com/jekbradbury/examples/blob/spinn/LICENSE: 12 13 Copyright (c) 2017, 14 All rights reserved. 15 16 See ./LICENSE for more details. 17 18 Instructions for use: 19 * See `README.md` for details on how to prepare the SNLI and GloVe data. 20 * Suppose you have prepared the data at "/tmp/spinn-data", use the folloing 21 command to train the model: 22 23 ```bash 24 python spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs 25 ``` 26 27 Checkpoints and TensorBoard summaries will be written to "/tmp/spinn-logs". 28 29 References: 30 * Bowman, S.R., Gauthier, J., Rastogi A., Gupta, R., Manning, C.D., & Potts, C. 31 (2016). A Fast Unified Model for Parsing and Sentence Understanding. 32 https://arxiv.org/abs/1603.06021 33 * Bradbury, J. (2017). Recursive Neural Networks with PyTorch. 34 https://devblogs.nvidia.com/parallelforall/recursive-neural-networks-pytorch/ 35 """ 36 37 from __future__ import absolute_import 38 from __future__ import division 39 from __future__ import print_function 40 41 import argparse 42 import itertools 43 import os 44 import sys 45 import time 46 47 from six.moves import xrange # pylint: disable=redefined-builtin 48 import tensorflow as tf 49 50 import tensorflow.contrib.eager as tfe 51 from tensorflow.contrib.eager.python.examples.spinn import data 52 53 54 layers = tf.keras.layers 55 56 57 def _bundle(lstm_iter): 58 """Concatenate a list of Tensors along 1st axis and split result into two. 59 60 Args: 61 lstm_iter: A `list` of `N` dense `Tensor`s, each of which has the shape 62 (R, 2 * M). 63 64 Returns: 65 A `list` of two dense `Tensor`s, each of which has the shape (N * R, M). 66 """ 67 return tf.split(tf.concat(lstm_iter, 0), 2, axis=1) 68 69 70 def _unbundle(state): 71 """Concatenate a list of Tensors along 2nd axis and split result. 72 73 This is the inverse of `_bundle`. 74 75 Args: 76 state: A `list` of two dense `Tensor`s, each of which has the shape (R, M). 77 78 Returns: 79 A `list` of `R` dense `Tensors`, each of which has the shape (1, 2 * M). 80 """ 81 return tf.split(tf.concat(state, 1), state[0].shape[0], axis=0) 82 83 84 # pylint: disable=not-callable 85 class Reducer(tf.keras.Model): 86 """A module that applies reduce operation on left and right vectors.""" 87 88 def __init__(self, size, tracker_size=None): 89 super(Reducer, self).__init__() 90 self.left = layers.Dense(5 * size, activation=None) 91 self.right = layers.Dense(5 * size, activation=None, use_bias=False) 92 if tracker_size is not None: 93 self.track = layers.Dense(5 * size, activation=None, use_bias=False) 94 else: 95 self.track = None 96 97 def call(self, left_in, right_in, tracking=None): 98 """Invoke forward pass of the Reduce module. 99 100 This method feeds a linear combination of `left_in`, `right_in` and 101 `tracking` into a Tree LSTM and returns the output of the Tree LSTM. 102 103 Args: 104 left_in: A list of length L. Each item is a dense `Tensor` with 105 the shape (1, n_dims). n_dims is the size of the embedding vector. 106 right_in: A list of the same length as `left_in`. Each item should have 107 the same shape as the items of `left_in`. 108 tracking: Optional list of the same length as `left_in`. Each item is a 109 dense `Tensor` with shape (1, tracker_size * 2). tracker_size is the 110 size of the Tracker's state vector. 111 112 Returns: 113 Output: A list of length batch_size. Each item has the shape (1, n_dims). 114 """ 115 left, right = _bundle(left_in), _bundle(right_in) 116 lstm_in = self.left(left[0]) + self.right(right[0]) 117 if self.track and tracking: 118 lstm_in += self.track(_bundle(tracking)[0]) 119 return _unbundle(self._tree_lstm(left[1], right[1], lstm_in)) 120 121 def _tree_lstm(self, c1, c2, lstm_in): 122 a, i, f1, f2, o = tf.split(lstm_in, 5, axis=1) 123 c = tf.tanh(a) * tf.sigmoid(i) + tf.sigmoid(f1) * c1 + tf.sigmoid(f2) * c2 124 h = tf.sigmoid(o) * tf.tanh(c) 125 return h, c 126 127 128 class Tracker(tf.keras.Model): 129 """A module that tracks the history of the sentence with an LSTM.""" 130 131 def __init__(self, tracker_size, predict): 132 """Constructor of Tracker. 133 134 Args: 135 tracker_size: Number of dimensions of the underlying `LSTMCell`. 136 predict: (`bool`) Whether prediction mode is enabled. 137 """ 138 super(Tracker, self).__init__() 139 self._rnn = tf.nn.rnn_cell.LSTMCell(tracker_size) 140 self._state_size = tracker_size 141 if predict: 142 self._transition = layers.Dense(4) 143 else: 144 self._transition = None 145 146 def reset_state(self): 147 self.state = None 148 149 def call(self, bufs, stacks): 150 """Invoke the forward pass of the Tracker module. 151 152 This method feeds the concatenation of the top two elements of the stacks 153 into an LSTM cell and returns the resultant state of the LSTM cell. 154 155 Args: 156 bufs: A `list` of length batch_size. Each item is a `list` of 157 max_sequence_len (maximum sequence length of the batch). Each item 158 of the nested list is a dense `Tensor` of shape (1, d_proj), where 159 d_proj is the size of the word embedding vector or the size of the 160 vector space that the word embedding vector is projected to. 161 stacks: A `list` of size batch_size. Each item is a `list` of 162 variable length corresponding to the current height of the stack. 163 Each item of the nested list is a dense `Tensor` of shape (1, d_proj). 164 165 Returns: 166 1. A list of length batch_size. Each item is a dense `Tensor` of shape 167 (1, d_tracker * 2). 168 2. If under predict mode, result of applying a Dense layer on the 169 first state vector of the RNN. Else, `None`. 170 """ 171 buf = _bundle([buf[-1] for buf in bufs])[0] 172 stack1 = _bundle([stack[-1] for stack in stacks])[0] 173 stack2 = _bundle([stack[-2] for stack in stacks])[0] 174 x = tf.concat([buf, stack1, stack2], 1) 175 if self.state is None: 176 batch_size = int(x.shape[0]) 177 zeros = tf.zeros((batch_size, self._state_size), dtype=tf.float32) 178 self.state = [zeros, zeros] 179 _, self.state = self._rnn(x, self.state) 180 unbundled = _unbundle(self.state) 181 if self._transition: 182 return unbundled, self._transition(self.state[0]) 183 else: 184 return unbundled, None 185 186 187 class SPINN(tf.keras.Model): 188 """Stack-augmented Parser-Interpreter Neural Network. 189 190 See https://arxiv.org/abs/1603.06021 for more details. 191 """ 192 193 def __init__(self, config): 194 """Constructor of SPINN. 195 196 Args: 197 config: A `namedtupled` with the following attributes. 198 d_proj - (`int`) number of dimensions of the vector space to project the 199 word embeddings to. 200 d_tracker - (`int`) number of dimensions of the Tracker's state vector. 201 d_hidden - (`int`) number of the dimensions of the hidden state, for the 202 Reducer module. 203 n_mlp_layers - (`int`) number of multi-layer perceptron layers to use to 204 convert the output of the `Feature` module to logits. 205 predict - (`bool`) Whether the Tracker will enabled predictions. 206 """ 207 super(SPINN, self).__init__() 208 self.config = config 209 self.reducer = Reducer(config.d_hidden, config.d_tracker) 210 if config.d_tracker is not None: 211 self.tracker = Tracker(config.d_tracker, config.predict) 212 else: 213 self.tracker = None 214 215 def call(self, buffers, transitions, training=False): 216 """Invoke the forward pass of the SPINN model. 217 218 Args: 219 buffers: Dense `Tensor` of shape 220 (max_sequence_len, batch_size, config.d_proj). 221 transitions: Dense `Tensor` with integer values that represent the parse 222 trees of the sentences. A value of 2 indicates "reduce"; a value of 3 223 indicates "shift". Shape: (max_sequence_len * 2 - 3, batch_size). 224 training: Whether the invocation is under training mode. 225 226 Returns: 227 Output `Tensor` of shape (batch_size, config.d_embed). 228 """ 229 max_sequence_len, batch_size, d_proj = (int(x) for x in buffers.shape) 230 231 # Split the buffers into left and right word items and put the initial 232 # items in a stack. 233 splitted = tf.split( 234 tf.reshape(tf.transpose(buffers, [1, 0, 2]), [-1, d_proj]), 235 max_sequence_len * batch_size, axis=0) 236 buffers = [splitted[k:k + max_sequence_len] 237 for k in xrange(0, len(splitted), max_sequence_len)] 238 stacks = [[buf[0], buf[0]] for buf in buffers] 239 240 if self.tracker: 241 # Reset tracker state for new batch. 242 self.tracker.reset_state() 243 244 num_transitions = transitions.shape[0] 245 246 # Iterate through transitions and perform the appropriate stack-pop, reduce 247 # and stack-push operations. 248 transitions = transitions.numpy() 249 for i in xrange(num_transitions): 250 trans = transitions[i] 251 if self.tracker: 252 # Invoke tracker to obtain the current tracker states for the sentences. 253 tracker_states, trans_hypothesis = self.tracker(buffers, stacks=stacks) 254 if trans_hypothesis: 255 trans = tf.argmax(trans_hypothesis, axis=-1) 256 else: 257 tracker_states = itertools.repeat(None) 258 lefts, rights, trackings = [], [], [] 259 for transition, buf, stack, tracking in zip( 260 trans, buffers, stacks, tracker_states): 261 if int(transition) == 3: # Shift. 262 stack.append(buf.pop()) 263 elif int(transition) == 2: # Reduce. 264 rights.append(stack.pop()) 265 lefts.append(stack.pop()) 266 trackings.append(tracking) 267 268 if rights: 269 reducer_output = self.reducer(lefts, rights, trackings) 270 reduced = iter(reducer_output) 271 272 for transition, stack in zip(trans, stacks): 273 if int(transition) == 2: # Reduce. 274 stack.append(next(reduced)) 275 return _bundle([stack.pop() for stack in stacks])[0] 276 277 278 class Perceptron(tf.keras.Model): 279 """One layer of the SNLIClassifier multi-layer perceptron.""" 280 281 def __init__(self, dimension, dropout_rate, previous_layer): 282 """Configure the Perceptron.""" 283 super(Perceptron, self).__init__() 284 self.dense = tf.keras.layers.Dense(dimension, activation=tf.nn.elu) 285 self.batchnorm = layers.BatchNormalization() 286 self.dropout = layers.Dropout(rate=dropout_rate) 287 self.previous_layer = previous_layer 288 289 def call(self, x, training): 290 """Run previous Perceptron layers, then this one.""" 291 x = self.previous_layer(x, training=training) 292 x = self.dense(x) 293 x = self.batchnorm(x, training=training) 294 x = self.dropout(x, training=training) 295 return x 296 297 298 class SNLIClassifier(tf.keras.Model): 299 """SNLI Classifier Model. 300 301 A model aimed at solving the SNLI (Standford Natural Language Inference) 302 task, using the SPINN model from above. For details of the task, see: 303 https://nlp.stanford.edu/projects/snli/ 304 """ 305 306 def __init__(self, config, embed): 307 """Constructor of SNLICLassifier. 308 309 Args: 310 config: A namedtuple containing required configurations for the model. It 311 needs to have the following attributes. 312 projection - (`bool`) whether the word vectors are to be projected onto 313 another vector space (of `d_proj` dimensions). 314 d_proj - (`int`) number of dimensions of the vector space to project the 315 word embeddings to. 316 embed_dropout - (`float`) dropout rate for the word embedding vectors. 317 n_mlp_layers - (`int`) number of multi-layer perceptron (MLP) layers to 318 use to convert the output of the `Feature` module to logits. 319 mlp_dropout - (`float`) dropout rate of the MLP layers. 320 d_out - (`int`) number of dimensions of the final output of the MLP 321 layers. 322 lr - (`float`) learning rate. 323 embed: A embedding matrix of shape (vocab_size, d_embed). 324 """ 325 super(SNLIClassifier, self).__init__() 326 self.config = config 327 self.embed = tf.constant(embed) 328 329 self.projection = layers.Dense(config.d_proj) 330 self.embed_bn = layers.BatchNormalization() 331 self.embed_dropout = layers.Dropout(rate=config.embed_dropout) 332 self.encoder = SPINN(config) 333 334 self.feature_bn = layers.BatchNormalization() 335 self.feature_dropout = layers.Dropout(rate=config.mlp_dropout) 336 337 current_mlp = lambda result, training: result 338 for _ in range(config.n_mlp_layers): 339 current_mlp = Perceptron(dimension=config.d_mlp, 340 dropout_rate=config.mlp_dropout, 341 previous_layer=current_mlp) 342 self.mlp = current_mlp 343 self.mlp_output = layers.Dense( 344 config.d_out, 345 kernel_initializer=tf.random_uniform_initializer(minval=-5e-3, 346 maxval=5e-3)) 347 348 def call(self, 349 premise, 350 premise_transition, 351 hypothesis, 352 hypothesis_transition, 353 training=False): 354 """Invoke the forward pass the SNLIClassifier model. 355 356 Args: 357 premise: The word indices of the premise sentences, with shape 358 (max_prem_seq_len, batch_size). 359 premise_transition: The transitions for the premise sentences, with shape 360 (max_prem_seq_len * 2 - 3, batch_size). 361 hypothesis: The word indices of the hypothesis sentences, with shape 362 (max_hypo_seq_len, batch_size). 363 hypothesis_transition: The transitions for the hypothesis sentences, with 364 shape (max_hypo_seq_len * 2 - 3, batch_size). 365 training: Whether the invocation is under training mode. 366 367 Returns: 368 The logits, as a dense `Tensor` of shape (batch_size, d_out), where d_out 369 is the size of the output vector. 370 """ 371 # Perform embedding lookup on the premise and hypothesis inputs, which have 372 # the word-index format. 373 premise_embed = tf.nn.embedding_lookup(self.embed, premise) 374 hypothesis_embed = tf.nn.embedding_lookup(self.embed, hypothesis) 375 376 if self.config.projection: 377 # Project the embedding vectors to another vector space. 378 premise_embed = self.projection(premise_embed) 379 hypothesis_embed = self.projection(hypothesis_embed) 380 381 # Perform batch normalization and dropout on the possibly projected word 382 # vectors. 383 premise_embed = self.embed_bn(premise_embed, training=training) 384 hypothesis_embed = self.embed_bn(hypothesis_embed, training=training) 385 premise_embed = self.embed_dropout(premise_embed, training=training) 386 hypothesis_embed = self.embed_dropout(hypothesis_embed, training=training) 387 388 # Run the batch-normalized and dropout-processed word vectors through the 389 # SPINN encoder. 390 premise = self.encoder(premise_embed, premise_transition, 391 training=training) 392 hypothesis = self.encoder(hypothesis_embed, hypothesis_transition, 393 training=training) 394 395 # Combine encoder outputs for premises and hypotheses into logits. 396 # Then apply batch normalization and dropuout on the logits. 397 logits = tf.concat( 398 [premise, hypothesis, premise - hypothesis, premise * hypothesis], 1) 399 logits = self.feature_dropout( 400 self.feature_bn(logits, training=training), training=training) 401 402 # Apply the multi-layer perceptron on the logits. 403 logits = self.mlp(logits, training=training) 404 logits = self.mlp_output(logits) 405 return logits 406 407 408 class SNLIClassifierTrainer(tfe.Checkpointable): 409 """A class that coordinates the training of an SNLIClassifier.""" 410 411 def __init__(self, snli_classifier, lr): 412 """Constructor of SNLIClassifierTrainer. 413 414 Args: 415 snli_classifier: An instance of `SNLIClassifier`. 416 lr: Learning rate. 417 """ 418 self._model = snli_classifier 419 # Create a custom learning rate Variable for the RMSProp optimizer, because 420 # the learning rate needs to be manually decayed later (see 421 # decay_learning_rate()). 422 self._learning_rate = tf.Variable(lr, name="learning_rate") 423 self._optimizer = tf.train.RMSPropOptimizer(self._learning_rate, 424 epsilon=1e-6) 425 426 def loss(self, labels, logits): 427 """Calculate the loss given a batch of data. 428 429 Args: 430 labels: The truth labels, with shape (batch_size,). 431 logits: The logits output from the forward pass of the SNLIClassifier 432 model, with shape (batch_size, d_out), where d_out is the output 433 dimension size of the SNLIClassifier. 434 435 Returns: 436 The loss value, as a scalar `Tensor`. 437 """ 438 return tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits( 439 labels=labels, logits=logits)) 440 441 def train_batch(self, 442 labels, 443 premise, 444 premise_transition, 445 hypothesis, 446 hypothesis_transition): 447 """Train model on batch of data. 448 449 Args: 450 labels: The truth labels, with shape (batch_size,). 451 premise: The word indices of the premise sentences, with shape 452 (max_prem_seq_len, batch_size). 453 premise_transition: The transitions for the premise sentences, with shape 454 (max_prem_seq_len * 2 - 3, batch_size). 455 hypothesis: The word indices of the hypothesis sentences, with shape 456 (max_hypo_seq_len, batch_size). 457 hypothesis_transition: The transitions for the hypothesis sentences, with 458 shape (max_hypo_seq_len * 2 - 3, batch_size). 459 460 Returns: 461 1. loss value as a scalar `Tensor`. 462 2. logits as a dense `Tensor` of shape (batch_size, d_out), where d_out is 463 the output dimension size of the SNLIClassifier. 464 """ 465 with tf.GradientTape() as tape: 466 tape.watch(self._model.variables) 467 logits = self._model(premise, 468 premise_transition, 469 hypothesis, 470 hypothesis_transition, 471 training=True) 472 loss = self.loss(labels, logits) 473 gradients = tape.gradient(loss, self._model.variables) 474 self._optimizer.apply_gradients(zip(gradients, self._model.variables), 475 global_step=tf.train.get_global_step()) 476 return loss, logits 477 478 def decay_learning_rate(self, decay_by): 479 """Decay learning rate of the optimizer by factor decay_by.""" 480 self._learning_rate.assign(self._learning_rate * decay_by) 481 print("Decayed learning rate of optimizer to: %s" % 482 self._learning_rate.numpy()) 483 484 @property 485 def learning_rate(self): 486 return self._learning_rate 487 488 @property 489 def model(self): 490 return self._model 491 492 @property 493 def variables(self): 494 return (self._model.variables + [self.learning_rate] + 495 self._optimizer.variables()) 496 497 498 def _batch_n_correct(logits, label): 499 """Calculate number of correct predictions in a batch. 500 501 Args: 502 logits: A logits Tensor of shape `(batch_size, num_categories)` and dtype 503 `float32`. 504 label: A labels Tensor of shape `(batch_size,)` and dtype `int64` 505 506 Returns: 507 Number of correct predictions. 508 """ 509 return tf.reduce_sum( 510 tf.cast((tf.equal( 511 tf.argmax(logits, axis=1), label)), tf.float32)).numpy() 512 513 514 def _evaluate_on_dataset(snli_data, batch_size, trainer, use_gpu): 515 """Run evaluation on a dataset. 516 517 Args: 518 snli_data: The `data.SnliData` to use in this evaluation. 519 batch_size: The batch size to use during this evaluation. 520 trainer: An instance of `SNLIClassifierTrainer to use for this 521 evaluation. 522 use_gpu: Whether GPU is being used. 523 524 Returns: 525 1. Average loss across all examples of the dataset. 526 2. Average accuracy rate across all examples of the dataset. 527 """ 528 mean_loss = tfe.metrics.Mean() 529 accuracy = tfe.metrics.Accuracy() 530 for label, prem, prem_trans, hypo, hypo_trans in _get_dataset_iterator( 531 snli_data, batch_size): 532 if use_gpu: 533 label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu() 534 logits = trainer.model(prem, prem_trans, hypo, hypo_trans, training=False) 535 loss_val = trainer.loss(label, logits) 536 batch_size = tf.shape(label)[0] 537 mean_loss(loss_val, weights=batch_size.gpu() if use_gpu else batch_size) 538 accuracy(tf.argmax(logits, axis=1), label) 539 return mean_loss.result().numpy(), accuracy.result().numpy() 540 541 542 def _get_dataset_iterator(snli_data, batch_size): 543 """Get a data iterator for a split of SNLI data. 544 545 Args: 546 snli_data: A `data.SnliData` object. 547 batch_size: The desired batch size. 548 549 Returns: 550 A dataset iterator. 551 """ 552 with tf.device("/device:CPU:0"): 553 # Some tf.data ops, such as ShuffleDataset, are available only on CPU. 554 dataset = tf.data.Dataset.from_generator( 555 snli_data.get_generator(batch_size), 556 (tf.int64, tf.int64, tf.int64, tf.int64, tf.int64)) 557 dataset = dataset.shuffle(snli_data.num_batches(batch_size)) 558 return tfe.Iterator(dataset) 559 560 561 def train_or_infer_spinn(embed, 562 word2index, 563 train_data, 564 dev_data, 565 test_data, 566 config): 567 """Perform Training or Inference on a SPINN model. 568 569 Args: 570 embed: The embedding matrix as a float32 numpy array with shape 571 [vocabulary_size, word_vector_len]. word_vector_len is the length of a 572 word embedding vector. 573 word2index: A `dict` mapping word to word index. 574 train_data: An instance of `data.SnliData`, for the train split. 575 dev_data: Same as above, for the dev split. 576 test_data: Same as above, for the test split. 577 config: A configuration object. See the argument to this Python binary for 578 details. 579 580 Returns: 581 If `config.inference_premise ` and `config.inference_hypothesis` are not 582 `None`, i.e., inference mode: the logits for the possible labels of the 583 SNLI data set, as a `Tensor` of three floats. 584 else: 585 The trainer object. 586 Raises: 587 ValueError: if only one of config.inference_premise and 588 config.inference_hypothesis is specified. 589 """ 590 # TODO(cais): Refactor this function into separate one for training and 591 # inference. 592 use_gpu = tfe.num_gpus() > 0 and not config.force_cpu 593 device = "gpu:0" if use_gpu else "cpu:0" 594 print("Using device: %s" % device) 595 596 if ((config.inference_premise and not config.inference_hypothesis) or 597 (not config.inference_premise and config.inference_hypothesis)): 598 raise ValueError( 599 "--inference_premise and --inference_hypothesis must be both " 600 "specified or both unspecified, but only one is specified.") 601 602 if config.inference_premise: 603 # Inference mode. 604 inference_sentence_pair = [ 605 data.encode_sentence(config.inference_premise, word2index), 606 data.encode_sentence(config.inference_hypothesis, word2index)] 607 else: 608 inference_sentence_pair = None 609 610 log_header = ( 611 " Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss" 612 " Accuracy Dev/Accuracy") 613 log_template = ( 614 "{:>6.0f} {:>5.0f} {:>9.0f} {:>5.0f}/{:<5.0f} {:>7.0f}% {:>8.6f} {} " 615 "{:12.4f} {}") 616 dev_log_template = ( 617 "{:>6.0f} {:>5.0f} {:>9.0f} {:>5.0f}/{:<5.0f} {:>7.0f}% {:>8.6f} " 618 "{:8.6f} {:12.4f} {:12.4f}") 619 620 summary_writer = tf.contrib.summary.create_file_writer( 621 config.logdir, flush_millis=10000) 622 623 with tf.device(device), \ 624 summary_writer.as_default(), \ 625 tf.contrib.summary.always_record_summaries(): 626 model = SNLIClassifier(config, embed) 627 global_step = tf.train.get_or_create_global_step() 628 trainer = SNLIClassifierTrainer(model, config.lr) 629 checkpoint = tf.train.Checkpoint(trainer=trainer, global_step=global_step) 630 checkpoint.restore(tf.train.latest_checkpoint(config.logdir)) 631 632 if inference_sentence_pair: 633 # Inference mode. 634 prem, prem_trans = inference_sentence_pair[0] 635 hypo, hypo_trans = inference_sentence_pair[1] 636 hypo_trans = inference_sentence_pair[1][1] 637 inference_logits = model( 638 tf.constant(prem), tf.constant(prem_trans), 639 tf.constant(hypo), tf.constant(hypo_trans), training=False) 640 inference_logits = inference_logits[0][1:] 641 max_index = tf.argmax(inference_logits) 642 print("\nInference logits:") 643 for i, (label, logit) in enumerate( 644 zip(data.POSSIBLE_LABELS, inference_logits)): 645 winner_tag = " (winner)" if max_index == i else "" 646 print(" {0:<16}{1:.6f}{2}".format(label + ":", logit, winner_tag)) 647 return inference_logits 648 649 train_len = train_data.num_batches(config.batch_size) 650 start = time.time() 651 iterations = 0 652 mean_loss = tfe.metrics.Mean() 653 accuracy = tfe.metrics.Accuracy() 654 print(log_header) 655 for epoch in xrange(config.epochs): 656 batch_idx = 0 657 for label, prem, prem_trans, hypo, hypo_trans in _get_dataset_iterator( 658 train_data, config.batch_size): 659 if use_gpu: 660 label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu() 661 # prem_trans and hypo_trans are used for dynamic control flow and can 662 # remain on CPU. Same in _evaluate_on_dataset(). 663 664 iterations += 1 665 batch_train_loss, batch_train_logits = trainer.train_batch( 666 label, prem, prem_trans, hypo, hypo_trans) 667 batch_size = tf.shape(label)[0] 668 mean_loss(batch_train_loss.numpy(), 669 weights=batch_size.gpu() if use_gpu else batch_size) 670 accuracy(tf.argmax(batch_train_logits, axis=1), label) 671 672 if iterations % config.save_every == 0: 673 checkpoint.save(os.path.join(config.logdir, "ckpt")) 674 675 if iterations % config.dev_every == 0: 676 dev_loss, dev_frac_correct = _evaluate_on_dataset( 677 dev_data, config.batch_size, trainer, use_gpu) 678 print(dev_log_template.format( 679 time.time() - start, 680 epoch, iterations, 1 + batch_idx, train_len, 681 100.0 * (1 + batch_idx) / train_len, 682 mean_loss.result(), dev_loss, 683 accuracy.result() * 100.0, dev_frac_correct * 100.0)) 684 tf.contrib.summary.scalar("dev/loss", dev_loss) 685 tf.contrib.summary.scalar("dev/accuracy", dev_frac_correct) 686 elif iterations % config.log_every == 0: 687 mean_loss_val = mean_loss.result() 688 accuracy_val = accuracy.result() 689 print(log_template.format( 690 time.time() - start, 691 epoch, iterations, 1 + batch_idx, train_len, 692 100.0 * (1 + batch_idx) / train_len, 693 mean_loss_val, " " * 8, accuracy_val * 100.0, " " * 12)) 694 tf.contrib.summary.scalar("train/loss", mean_loss_val) 695 tf.contrib.summary.scalar("train/accuracy", accuracy_val) 696 # Reset metrics. 697 mean_loss = tfe.metrics.Mean() 698 accuracy = tfe.metrics.Accuracy() 699 700 batch_idx += 1 701 if (epoch + 1) % config.lr_decay_every == 0: 702 trainer.decay_learning_rate(config.lr_decay_by) 703 704 test_loss, test_frac_correct = _evaluate_on_dataset( 705 test_data, config.batch_size, trainer, use_gpu) 706 print("Final test loss: %g; accuracy: %g%%" % 707 (test_loss, test_frac_correct * 100.0)) 708 709 return trainer 710 711 712 def main(_): 713 config = FLAGS 714 715 # Load embedding vectors. 716 vocab = data.load_vocabulary(FLAGS.data_root) 717 word2index, embed = data.load_word_vectors(FLAGS.data_root, vocab) 718 719 if not (config.inference_premise or config.inference_hypothesis): 720 print("Loading train, dev and test data...") 721 train_data = data.SnliData( 722 os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_train.txt"), 723 word2index, sentence_len_limit=FLAGS.sentence_len_limit) 724 dev_data = data.SnliData( 725 os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_dev.txt"), 726 word2index, sentence_len_limit=FLAGS.sentence_len_limit) 727 test_data = data.SnliData( 728 os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_test.txt"), 729 word2index, sentence_len_limit=FLAGS.sentence_len_limit) 730 else: 731 train_data = None 732 dev_data = None 733 test_data = None 734 735 train_or_infer_spinn( 736 embed, word2index, train_data, dev_data, test_data, config) 737 738 739 if __name__ == "__main__": 740 parser = argparse.ArgumentParser( 741 description= 742 "TensorFlow eager implementation of the SPINN SNLI classifier.") 743 parser.add_argument("--data_root", type=str, default="/tmp/spinn-data", 744 help="Root directory in which the training data and " 745 "embedding matrix are found. See README.md for how to " 746 "generate such a directory.") 747 parser.add_argument("--sentence_len_limit", type=int, default=-1, 748 help="Maximum allowed sentence length (# of words). " 749 "The default of -1 means unlimited.") 750 parser.add_argument("--logdir", type=str, default="/tmp/spinn-logs", 751 help="Directory in which summaries will be written for " 752 "TensorBoard.") 753 parser.add_argument("--inference_premise", type=str, default=None, 754 help="Premise sentence for inference. Must be " 755 "accompanied by --inference_hypothesis. If specified, " 756 "will override all training parameters and perform " 757 "inference.") 758 parser.add_argument("--inference_hypothesis", type=str, default=None, 759 help="Hypothesis sentence for inference. Must be " 760 "accompanied by --inference_premise. If specified, will " 761 "override all training parameters and perform inference.") 762 parser.add_argument("--epochs", type=int, default=50, 763 help="Number of epochs to train.") 764 parser.add_argument("--batch_size", type=int, default=128, 765 help="Batch size to use during training.") 766 parser.add_argument("--d_proj", type=int, default=600, 767 help="Dimensions to project the word embedding vectors " 768 "to.") 769 parser.add_argument("--d_hidden", type=int, default=300, 770 help="Size of the hidden layer of the Tracker.") 771 parser.add_argument("--d_out", type=int, default=4, 772 help="Output dimensions of the SNLIClassifier.") 773 parser.add_argument("--d_mlp", type=int, default=1024, 774 help="Size of each layer of the multi-layer perceptron " 775 "of the SNLICLassifier.") 776 parser.add_argument("--n_mlp_layers", type=int, default=2, 777 help="Number of layers in the multi-layer perceptron " 778 "of the SNLICLassifier.") 779 parser.add_argument("--d_tracker", type=int, default=64, 780 help="Size of the tracker LSTM.") 781 parser.add_argument("--log_every", type=int, default=50, 782 help="Print log and write TensorBoard summary every _ " 783 "training batches.") 784 parser.add_argument("--lr", type=float, default=2e-3, 785 help="Initial learning rate.") 786 parser.add_argument("--lr_decay_by", type=float, default=0.75, 787 help="The ratio to multiply the learning rate by every " 788 "time the learning rate is decayed.") 789 parser.add_argument("--lr_decay_every", type=float, default=1, 790 help="Decay the learning rate every _ epoch(s).") 791 parser.add_argument("--dev_every", type=int, default=1000, 792 help="Run evaluation on the dev split every _ training " 793 "batches.") 794 parser.add_argument("--save_every", type=int, default=1000, 795 help="Save checkpoint every _ training batches.") 796 parser.add_argument("--embed_dropout", type=float, default=0.08, 797 help="Word embedding dropout rate.") 798 parser.add_argument("--mlp_dropout", type=float, default=0.07, 799 help="SNLIClassifier multi-layer perceptron dropout " 800 "rate.") 801 parser.add_argument("--no-projection", action="store_false", 802 dest="projection", 803 help="Whether word embedding vectors are projected to " 804 "another set of vectors (see d_proj).") 805 parser.add_argument("--predict_transitions", action="store_true", 806 dest="predict", 807 help="Whether the Tracker will perform prediction.") 808 parser.add_argument("--force_cpu", action="store_true", dest="force_cpu", 809 help="Force use CPU-only regardless of whether a GPU is " 810 "available.") 811 FLAGS, unparsed = parser.parse_known_args() 812 813 tfe.run(main=main, argv=[sys.argv[0]] + unparsed)