github.com/kubeflow/training-operator@v1.7.0/examples/mxnet/tune/auto-tuning.py (about) 1 import os 2 import sys 3 4 import numpy as np 5 import argparse 6 7 import nnvm.testing 8 import nnvm.compiler 9 import tvm 10 from tvm import autotvm 11 from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner 12 from tvm.contrib.util import tempdir 13 import tvm.contrib.graph_runtime as runtime 14 #from mxnet.gluon.model_zoo.vision import get_model 15 16 def get_network(name, batch_size): 17 """Get the symbol definition and random weight of a network""" 18 input_shape = (batch_size, 3, 224, 224) 19 output_shape = (batch_size, 1000) 20 21 if "resnet" in name: 22 n_layer = int(name.split('-')[1]) 23 net, params = nnvm.testing.resnet.get_workload(num_layers=n_layer, batch_size=batch_size) 24 elif "vgg" in name: 25 n_layer = int(name.split('-')[1]) 26 net, params = nnvm.testing.vgg.get_workload(num_layers=n_layer, batch_size=batch_size) 27 elif name == 'mobilenet': 28 net, params = nnvm.testing.mobilenet.get_workload(batch_size=batch_size) 29 elif name == 'squeezenet_v1.1': 30 net, params = nnvm.testing.squeezenet.get_workload(batch_size=batch_size, version='1.1') 31 elif name == 'inception_v3': 32 input_shape = (1, 3, 299, 299) 33 net, params = nnvm.testing.inception_v3.get_workload(batch_size=batch_size) 34 elif name == 'custom': 35 # an example for custom network 36 from nnvm.testing import utils 37 net = nnvm.sym.Variable('data') 38 net = nnvm.sym.conv2d(net, channels=4, kernel_size=(3,3), padding=(1,1)) 39 net = nnvm.sym.flatten(net) 40 net = nnvm.sym.dense(net, units=1000) 41 net, params = utils.create_workload(net, batch_size, (3, 224, 224)) 42 elif name == 'mxnet': 43 # an example for mxnet model 44 from mxnet.gluon.model_zoo.vision import get_model 45 block = get_model('resnet18_v1', pretrained=True) 46 net, params = nnvm.frontend.from_mxnet(block) 47 net = nnvm.sym.softmax(net) 48 else: 49 raise ValueError("Unsupported network: " + name) 50 51 return net, params, input_shape, output_shape 52 53 '''def get_network_from_mxnet(batch_size): 54 input_shape = (batch_size, 3, 224, 224) 55 output_shape = (batch_size, 1000) 56 57 block = get_model('resnet18_v1', pretrained=True) 58 sym, params = nnvm.frontend.from_mxnet(block) 59 60 return sym, params, input_shape, output_shape''' 61 62 #### DEVICE CONFIG #### 63 target = tvm.target.cuda() 64 65 #### TUNING OPTION #### 66 network = 'resnet-18' 67 log_file = "%s.log" % network 68 dtype = 'float32' 69 70 # You can skip the implementation of this function for this tutorial. 71 def tune_tasks(tasks, 72 measure_option, 73 tuner='xgb', 74 n_trial=1000, 75 early_stopping=None, 76 log_filename='tuning.log', 77 use_transfer_learning=True, 78 try_winograd=True): 79 if try_winograd: 80 for i in range(len(tasks)): 81 try: # try winograd template 82 tsk = autotvm.task.create(tasks[i].name, tasks[i].args, 83 tasks[i].target, tasks[i].target_host, 'winograd') 84 input_channel = tsk.workload[1][1] 85 if input_channel >= 64: 86 tasks[i] = tsk 87 except Exception: 88 pass 89 90 # create tmp log file 91 tmp_log_file = log_filename + ".tmp" 92 if os.path.exists(tmp_log_file): 93 os.remove(tmp_log_file) 94 95 for i, tsk in enumerate(reversed(tasks)): 96 prefix = "[Task %2d/%2d] " %(i+1, len(tasks)) 97 98 # create tuner 99 if tuner == 'xgb' or tuner == 'xgb-rank': 100 tuner_obj = XGBTuner(tsk, loss_type='rank') 101 elif tuner == 'ga': 102 tuner_obj = GATuner(tsk, pop_size=100) 103 elif tuner == 'random': 104 tuner_obj = RandomTuner(tsk) 105 elif tuner == 'gridsearch': 106 tuner_obj = GridSearchTuner(tsk) 107 else: 108 raise ValueError("Invalid tuner: " + tuner) 109 110 if use_transfer_learning: 111 if os.path.isfile(tmp_log_file): 112 tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file)) 113 114 # do tuning 115 tuner_obj.tune(n_trial=min(n_trial, len(tsk.config_space)), 116 early_stopping=early_stopping, 117 measure_option=measure_option, 118 callbacks=[ 119 autotvm.callback.progress_bar(n_trial, prefix=prefix), 120 autotvm.callback.log_to_file(tmp_log_file)]) 121 122 # pick best records to a cache file 123 autotvm.record.pick_best(tmp_log_file, log_filename) 124 os.remove(tmp_log_file) 125 126 def tune_and_evaluate(tuning_opt): 127 # extract workloads from nnvm graph 128 print("Extract tasks...") 129 #net, params, input_shape, out_shape = get_network_from_mxnet(batch_size=1) 130 net, params, input_shape, out_shape = get_network(network, batch_size=1) 131 tasks = autotvm.task.extract_from_graph(net, target=target, 132 shape={'data': input_shape}, dtype=dtype, 133 symbols=(nnvm.sym.conv2d,)) 134 135 # run tuning tasks 136 print("Tuning...") 137 tune_tasks(tasks, **tuning_opt) 138 139 # compile kernels with history best records 140 with autotvm.apply_history_best(log_file): 141 print("Compile...") 142 with nnvm.compiler.build_config(opt_level=3): 143 graph, lib, params = nnvm.compiler.build( 144 net, target=target, shape={'data': input_shape}, params=params, dtype=dtype) 145 146 # export library 147 tmp = tempdir() 148 filename = "net.tar" 149 lib.export_library(tmp.relpath(filename)) 150 151 # load parameters 152 ctx = tvm.context(str(target), 0) 153 module = runtime.create(graph, lib, ctx) 154 data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype)) 155 module.set_input('data', data_tvm) 156 module.set_input(**params) 157 158 # evaluate 159 print("Evaluate inference time cost...") 160 ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=600) 161 prof_res = np.array(ftimer().results) * 1000 # convert to millisecond 162 print("Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res))) 163 164 # We do not run the tuning in our webpage server since it takes too long. 165 # Uncomment the following line to run it by yourself. 166 167 if __name__ == '__main__': 168 169 parser = argparse.ArgumentParser(description="auto tuning", 170 formatter_class=argparse.ArgumentDefaultsHelpFormatter) 171 parser.add_argument('--tracker', default='auto-tuning-job-tunertracker-0', 172 help='the url of tune tracker') 173 parser.add_argument('--tracker_port', type=int, default=9190, 174 help='the port of tune tracker') 175 parser.add_argument('--server_key', default='gpu', 176 help='the key to identify tune server') 177 args = parser.parse_args() 178 179 tuning_option = { 180 'log_filename': log_file, 181 182 'tuner': 'xgb', 183 'n_trial': 100, 184 'early_stopping': 600, 185 186 'measure_option': autotvm.measure_option( 187 builder=autotvm.LocalBuilder(timeout=10), 188 runner=autotvm.RPCRunner( 189 args.server_key, # change the device key to your key 190 args.tracker, args.tracker_port, 191 number=20, repeat=3, timeout=4), 192 ), 193 } 194 195 tune_and_evaluate(tuning_option)