github.com/kubeflow/training-operator@v1.7.0/examples/mxnet/tune/auto-tuning.py (about)

     1  import os
     2  import sys
     3  
     4  import numpy as np
     5  import argparse
     6  
     7  import nnvm.testing
     8  import nnvm.compiler
     9  import tvm
    10  from tvm import autotvm
    11  from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
    12  from tvm.contrib.util import tempdir
    13  import tvm.contrib.graph_runtime as runtime
    14  #from mxnet.gluon.model_zoo.vision import get_model
    15  
    16  def get_network(name, batch_size):
    17      """Get the symbol definition and random weight of a network"""
    18      input_shape = (batch_size, 3, 224, 224)
    19      output_shape = (batch_size, 1000)
    20  
    21      if "resnet" in name:
    22          n_layer = int(name.split('-')[1])
    23          net, params = nnvm.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size)
    24      elif "vgg" in name:
    25          n_layer = int(name.split('-')[1])
    26          net, params = nnvm.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size)
    27      elif name == 'mobilenet':
    28          net, params = nnvm.testing.mobilenet.get_workload(batch_size=batch_size)
    29      elif name == 'squeezenet_v1.1':
    30          net, params = nnvm.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1')
    31      elif name == 'inception_v3':
    32          input_shape = (1, 3, 299, 299)
    33          net, params = nnvm.testing.inception_v3.get_workload(batch_size=batch_size)
    34      elif name == 'custom':
    35          # an example for custom network
    36          from nnvm.testing import utils
    37          net = nnvm.sym.Variable('data')
    38          net = nnvm.sym.conv2d(net, channels=4, kernel_size=(3,3), padding=(1,1))
    39          net = nnvm.sym.flatten(net)
    40          net = nnvm.sym.dense(net, units=1000)
    41          net, params = utils.create_workload(net, batch_size, (3, 224, 224))
    42      elif name == 'mxnet':
    43          # an example for mxnet model
    44          from mxnet.gluon.model_zoo.vision import get_model
    45          block = get_model('resnet18_v1', pretrained=True)
    46          net, params = nnvm.frontend.from_mxnet(block)
    47          net = nnvm.sym.softmax(net)
    48      else:
    49          raise ValueError("Unsupported network: " + name)
    50  
    51      return net, params, input_shape, output_shape
    52  
    53  '''def get_network_from_mxnet(batch_size):
    54      input_shape = (batch_size, 3, 224, 224)
    55      output_shape = (batch_size, 1000)
    56  
    57      block = get_model('resnet18_v1', pretrained=True)
    58      sym, params = nnvm.frontend.from_mxnet(block)
    59  
    60      return sym, params, input_shape, output_shape'''
    61  
    62  #### DEVICE CONFIG ####
    63  target = tvm.target.cuda()
    64  
    65  #### TUNING OPTION ####
    66  network = 'resnet-18'
    67  log_file = "%s.log" % network
    68  dtype = 'float32'
    69  
    70  # You can skip the implementation of this function for this tutorial.
    71  def tune_tasks(tasks,
    72                 measure_option,
    73                 tuner='xgb',
    74                 n_trial=1000,
    75                 early_stopping=None,
    76                 log_filename='tuning.log',
    77                 use_transfer_learning=True,
    78                 try_winograd=True):
    79      if try_winograd:
    80          for i in range(len(tasks)):
    81              try:  # try winograd template
    82                  tsk = autotvm.task.create(tasks[i].name, tasks[i].args,
    83                                            tasks[i].target, tasks[i].target_host, 'winograd')
    84                  input_channel = tsk.workload[1][1]
    85                  if input_channel >= 64:
    86                      tasks[i] = tsk
    87              except Exception:
    88                  pass
    89  
    90      # create tmp log file
    91      tmp_log_file = log_filename + ".tmp"
    92      if os.path.exists(tmp_log_file):
    93          os.remove(tmp_log_file)
    94  
    95      for i, tsk in enumerate(reversed(tasks)):
    96          prefix = "[Task %2d/%2d] " %(i+1, len(tasks))
    97  
    98          # create tuner
    99          if tuner == 'xgb' or tuner == 'xgb-rank':
   100              tuner_obj = XGBTuner(tsk, loss_type='rank')
   101          elif tuner == 'ga':
   102              tuner_obj = GATuner(tsk, pop_size=100)
   103          elif tuner == 'random':
   104              tuner_obj = RandomTuner(tsk)
   105          elif tuner == 'gridsearch':
   106              tuner_obj = GridSearchTuner(tsk)
   107          else:
   108              raise ValueError("Invalid tuner: " + tuner)
   109  
   110          if use_transfer_learning:
   111              if os.path.isfile(tmp_log_file):
   112                  tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))
   113  
   114          # do tuning
   115          tuner_obj.tune(n_trial=min(n_trial, len(tsk.config_space)),
   116                         early_stopping=early_stopping,
   117                         measure_option=measure_option,
   118                         callbacks=[
   119                             autotvm.callback.progress_bar(n_trial, prefix=prefix),
   120                             autotvm.callback.log_to_file(tmp_log_file)])
   121  
   122      # pick best records to a cache file
   123      autotvm.record.pick_best(tmp_log_file, log_filename)
   124      os.remove(tmp_log_file)
   125  
   126  def tune_and_evaluate(tuning_opt):
   127      # extract workloads from nnvm graph
   128      print("Extract tasks...")
   129      #net, params, input_shape, out_shape = get_network_from_mxnet(batch_size=1)
   130      net, params, input_shape, out_shape = get_network(network, batch_size=1)
   131      tasks = autotvm.task.extract_from_graph(net, target=target,
   132                                              shape={'data': input_shape}, dtype=dtype,
   133                                              symbols=(nnvm.sym.conv2d,))
   134  
   135      # run tuning tasks
   136      print("Tuning...")
   137      tune_tasks(tasks, **tuning_opt)
   138  
   139      # compile kernels with history best records
   140      with autotvm.apply_history_best(log_file):
   141          print("Compile...")
   142          with nnvm.compiler.build_config(opt_level=3):
   143              graph, lib, params = nnvm.compiler.build(
   144                  net, target=target, shape={'data': input_shape}, params=params, dtype=dtype)
   145  
   146          # export library
   147          tmp = tempdir()
   148          filename = "net.tar"
   149          lib.export_library(tmp.relpath(filename))
   150  
   151          # load parameters
   152          ctx = tvm.context(str(target), 0)
   153          module = runtime.create(graph, lib, ctx)
   154          data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
   155          module.set_input('data', data_tvm)
   156          module.set_input(**params)
   157  
   158          # evaluate
   159          print("Evaluate inference time cost...")
   160          ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=600)
   161          prof_res = np.array(ftimer().results) * 1000  # convert to millisecond
   162          print("Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res)))
   163  
   164  # We do not run the tuning in our webpage server since it takes too long.
   165  # Uncomment the following line to run it by yourself.
   166  
   167  if __name__ == '__main__':
   168  
   169      parser = argparse.ArgumentParser(description="auto tuning",
   170                                       formatter_class=argparse.ArgumentDefaultsHelpFormatter)
   171      parser.add_argument('--tracker', default='auto-tuning-job-tunertracker-0',
   172                          help='the url of tune tracker')
   173      parser.add_argument('--tracker_port', type=int, default=9190,
   174                          help='the port of tune tracker')
   175      parser.add_argument('--server_key', default='gpu',
   176                          help='the key to identify tune server')
   177      args = parser.parse_args()
   178  
   179      tuning_option = {
   180          'log_filename': log_file,
   181  
   182          'tuner': 'xgb',
   183          'n_trial': 100,
   184          'early_stopping': 600,
   185  
   186          'measure_option': autotvm.measure_option(
   187              builder=autotvm.LocalBuilder(timeout=10),
   188              runner=autotvm.RPCRunner(
   189                  args.server_key,  # change the device key to your key
   190                  args.tracker, args.tracker_port,
   191                  number=20, repeat=3, timeout=4),
   192          ),
   193      }
   194  
   195      tune_and_evaluate(tuning_option)