github.com/kubeflow/training-operator@v1.7.0/examples/tensorflow/mnist_with_summaries/mnist_with_summaries.py (about)

     1  # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
     2  #
     3  # Licensed under the Apache License, Version 2.0 (the 'License');
     4  # you may not use this file except in compliance with the License.
     5  # You may obtain a copy of the License at
     6  #
     7  #     http://www.apache.org/licenses/LICENSE-2.0
     8  #
     9  # Unless required by applicable law or agreed to in writing, software
    10  # distributed under the License is distributed on an 'AS IS' BASIS,
    11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  # See the License for the specific language governing permissions and
    13  # limitations under the License.
    14  # ==============================================================================
    15  """A simple MNIST classifier which displays summaries in TensorBoard.
    16  This is an unimpressive MNIST model, but it is a good example of using
    17  tf.name_scope to make a graph legible in the TensorBoard graph explorer, and of
    18  naming summary tags so that they are grouped meaningfully in TensorBoard.
    19  It demonstrates the functionality of every TensorBoard dashboard.
    20  """
    21  from __future__ import absolute_import
    22  from __future__ import division
    23  from __future__ import print_function
    24  
    25  import argparse
    26  import os
    27  import sys
    28  
    29  import tensorflow as tf
    30  
    31  from tensorflow.examples.tutorials.mnist import input_data
    32  
    33  FLAGS = None
    34  
    35  
    36  def train():
    37    # Import data
    38    mnist = input_data.read_data_sets(FLAGS.data_dir,
    39                                      fake_data=FLAGS.fake_data)
    40  
    41    sess = tf.InteractiveSession()
    42    # Create a multilayer model.
    43  
    44    # Input placeholders
    45    with tf.name_scope('input'):
    46      x = tf.placeholder(tf.float32, [None, 784], name='x-input')
    47      y_ = tf.placeholder(tf.int64, [None], name='y-input')
    48  
    49    with tf.name_scope('input_reshape'):
    50      image_shaped_input = tf.reshape(x, [-1, 28, 28, 1])
    51      tf.summary.image('input', image_shaped_input, 10)
    52  
    53    # We can't initialize these variables to 0 - the network will get stuck.
    54    def weight_variable(shape):
    55      """Create a weight variable with appropriate initialization."""
    56      initial = tf.truncated_normal(shape, stddev=0.1)
    57      return tf.Variable(initial)
    58  
    59    def bias_variable(shape):
    60      """Create a bias variable with appropriate initialization."""
    61      initial = tf.constant(0.1, shape=shape)
    62      return tf.Variable(initial)
    63  
    64    def variable_summaries(var):
    65      """Attach a lot of summaries to a Tensor (for TensorBoard visualization)."""
    66      with tf.name_scope('summaries'):
    67        mean = tf.reduce_mean(var)
    68        tf.summary.scalar('mean', mean)
    69        with tf.name_scope('stddev'):
    70          stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean)))
    71        tf.summary.scalar('stddev', stddev)
    72        tf.summary.scalar('max', tf.reduce_max(var))
    73        tf.summary.scalar('min', tf.reduce_min(var))
    74        tf.summary.histogram('histogram', var)
    75  
    76    def nn_layer(input_tensor, input_dim, output_dim, layer_name, act=tf.nn.relu):
    77      """Reusable code for making a simple neural net layer.
    78      It does a matrix multiply, bias add, and then uses ReLU to nonlinearize.
    79      It also sets up name scoping so that the resultant graph is easy to read,
    80      and adds a number of summary ops.
    81      """
    82      # Adding a name scope ensures logical grouping of the layers in the graph.
    83      with tf.name_scope(layer_name):
    84        # This Variable will hold the state of the weights for the layer
    85        with tf.name_scope('weights'):
    86          weights = weight_variable([input_dim, output_dim])
    87          variable_summaries(weights)
    88        with tf.name_scope('biases'):
    89          biases = bias_variable([output_dim])
    90          variable_summaries(biases)
    91        with tf.name_scope('Wx_plus_b'):
    92          preactivate = tf.matmul(input_tensor, weights) + biases
    93          tf.summary.histogram('pre_activations', preactivate)
    94        activations = act(preactivate, name='activation')
    95        tf.summary.histogram('activations', activations)
    96        return activations
    97  
    98    hidden1 = nn_layer(x, 784, 500, 'layer1')
    99  
   100    with tf.name_scope('dropout'):
   101      keep_prob = tf.placeholder(tf.float32)
   102      tf.summary.scalar('dropout_keep_probability', keep_prob)
   103      dropped = tf.nn.dropout(hidden1, keep_prob)
   104  
   105    # Do not apply softmax activation yet, see below.
   106    y = nn_layer(dropped, 500, 10, 'layer2', act=tf.identity)
   107  
   108    with tf.name_scope('cross_entropy'):
   109      # The raw formulation of cross-entropy,
   110      #
   111      # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.softmax(y)),
   112      #                               reduction_indices=[1]))
   113      #
   114      # can be numerically unstable.
   115      #
   116      # So here we use tf.losses.sparse_softmax_cross_entropy on the
   117      # raw logit outputs of the nn_layer above, and then average across
   118      # the batch.
   119      with tf.name_scope('total'):
   120        cross_entropy = tf.losses.sparse_softmax_cross_entropy(
   121            labels=y_, logits=y)
   122    tf.summary.scalar('cross_entropy', cross_entropy)
   123  
   124    with tf.name_scope('train'):
   125      train_step = tf.train.AdamOptimizer(FLAGS.learning_rate).minimize(
   126          cross_entropy)
   127  
   128    with tf.name_scope('accuracy'):
   129      with tf.name_scope('correct_prediction'):
   130        correct_prediction = tf.equal(tf.argmax(y, 1), y_)
   131      with tf.name_scope('accuracy'):
   132        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
   133    tf.summary.scalar('accuracy', accuracy)
   134  
   135    # Merge all the summaries and write them out to
   136    # /tmp/tensorflow/mnist/logs/mnist_with_summaries (by default)
   137    merged = tf.summary.merge_all()
   138    train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
   139    test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/test')
   140    tf.global_variables_initializer().run()
   141  
   142    # Train the model, and also write summaries.
   143    # Every 10th step, measure test-set accuracy, and write test summaries
   144    # All other steps, run train_step on training data, & add training summaries
   145  
   146    def feed_dict(train):     # pylint: disable=redefined-outer-name
   147      """Make a TensorFlow feed_dict: maps data onto Tensor placeholders."""
   148      if train or FLAGS.fake_data:
   149        xs, ys = mnist.train.next_batch(FLAGS.batch_size, fake_data=FLAGS.fake_data)
   150        k = FLAGS.dropout
   151      else:
   152        xs, ys = mnist.test.images, mnist.test.labels
   153        k = 1.0
   154      return {x: xs, y_: ys, keep_prob: k}
   155  
   156    for i in range(FLAGS.max_steps):
   157      if i % 10 == 0:  # Record summaries and test-set accuracy
   158        summary, acc = sess.run([merged, accuracy], feed_dict=feed_dict(False))
   159        test_writer.add_summary(summary, i)
   160        print('Accuracy at step %s: %s' % (i, acc))
   161      else:  # Record train set summaries, and train
   162        if i % 100 == 99:  # Record execution stats
   163          run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
   164          run_metadata = tf.RunMetadata()
   165          summary, _ = sess.run([merged, train_step],
   166                                feed_dict=feed_dict(True),
   167                                options=run_options,
   168                                run_metadata=run_metadata)
   169          train_writer.add_run_metadata(run_metadata, 'step%03d' % i)
   170          train_writer.add_summary(summary, i)
   171          print('Adding run metadata for', i)
   172        else:  # Record a summary
   173          summary, _ = sess.run([merged, train_step], feed_dict=feed_dict(True))
   174          train_writer.add_summary(summary, i)
   175    train_writer.close()
   176    test_writer.close()
   177  
   178  
   179  def main(_):
   180    if tf.gfile.Exists(FLAGS.log_dir):
   181      tf.gfile.DeleteRecursively(FLAGS.log_dir)
   182    tf.gfile.MakeDirs(FLAGS.log_dir)
   183    train()
   184  
   185  
   186  if __name__ == '__main__':
   187    parser = argparse.ArgumentParser()
   188    parser.add_argument('--fake_data', nargs='?', const=True, type=bool,
   189                        default=False,
   190                        help='If true, uses fake data for unit testing.')
   191    parser.add_argument('--max_steps', type=int, default=1000,
   192                        help='Number of steps to run trainer.')
   193    parser.add_argument('--learning_rate', type=float, default=0.001,
   194                        help='Initial learning rate')
   195    parser.add_argument('--batch_size', type=int, default=100,
   196                        help='Training batch size')
   197    parser.add_argument('--dropout', type=float, default=0.9,
   198                        help='Keep probability for training dropout.')
   199    parser.add_argument(
   200        '--data_dir',
   201        type=str,
   202        default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'),
   203                             'tensorflow/mnist/input_data'),
   204        help='Directory for storing input data')
   205    parser.add_argument(
   206        '--log_dir',
   207        type=str,
   208        default=os.path.join(os.getenv('TEST_TMPDIR', '/tmp'),
   209                             'tensorflow/mnist/logs/mnist_with_summaries'),
   210        help='Summaries log directory')
   211    FLAGS, unparsed = parser.parse_known_args()
   212    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)