github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/docs/examples/etl-imagenet-dataset/train_aistore.py (about) 1 import argparse 2 import os 3 import random 4 import shutil 5 import time 6 import warnings 7 8 import aistore 9 from aistore.client import Bck 10 11 import torch 12 import torch.nn as nn 13 import torch.nn.parallel 14 import torch.backends.cudnn as cudnn 15 import torch.optim 16 import torch.utils.data 17 import torch.utils.data.distributed 18 import torchvision.models as models 19 20 model_names = sorted( 21 name 22 for name in models.__dict__ 23 if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) 24 ) 25 26 parser = argparse.ArgumentParser(description="PyTorch ImageNet Training") 27 parser.add_argument( 28 "-a", 29 "--arch", 30 metavar="ARCH", 31 default="resnet18", 32 choices=model_names, 33 help="model architecture: " + " | ".join(model_names) + " (default: resnet18)", 34 ) 35 parser.add_argument( 36 "-j", 37 "--workers", 38 default=4, 39 type=int, 40 metavar="N", 41 help="number of data loading workers (default: 4)", 42 ) 43 parser.add_argument( 44 "--epochs", default=90, type=int, metavar="N", help="number of total epochs to run" 45 ) 46 parser.add_argument( 47 "--start-epoch", 48 default=0, 49 type=int, 50 metavar="N", 51 help="manual epoch number (useful on restarts)", 52 ) 53 parser.add_argument( 54 "-b", 55 "--batch-size", 56 default=256, 57 type=int, 58 metavar="N", 59 help="mini-batch size (default: 256), this is the total " 60 "batch size of all GPUs on the current node when " 61 "using Data Parallel or Distributed Data Parallel", 62 ) 63 parser.add_argument( 64 "--lr", 65 "--learning-rate", 66 default=0.1, 67 type=float, 68 metavar="LR", 69 help="initial learning rate", 70 dest="lr", 71 ) 72 parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") 73 parser.add_argument( 74 "--wd", 75 "--weight-decay", 76 default=1e-4, 77 type=float, 78 metavar="W", 79 help="weight decay (default: 1e-4)", 80 dest="weight_decay", 81 ) 82 parser.add_argument( 83 "-p", 84 "--print-freq", 85 default=10, 86 type=int, 87 metavar="N", 88 help="print frequency (default: 10)", 89 ) 90 parser.add_argument( 91 "--resume", 92 default="", 93 type=str, 94 metavar="PATH", 95 help="path to latest checkpoint (default: none)", 96 ) 97 parser.add_argument( 98 "-e", 99 "--evaluate", 100 dest="evaluate", 101 action="store_true", 102 help="evaluate model on validation set", 103 ) 104 parser.add_argument( 105 "--pretrained", dest="pretrained", action="store_true", help="use pre-trained model" 106 ) 107 parser.add_argument( 108 "--seed", default=None, type=int, help="seed for initializing training" 109 ) 110 parser.add_argument("--gpu", default=None, type=int, help="GPU id to use.") 111 112 best_acc1 = 0 113 114 115 def main(): 116 args = parser.parse_args() 117 118 if args.seed is not None: 119 random.seed(args.seed) 120 torch.manual_seed(args.seed) 121 cudnn.deterministic = True 122 warnings.warn( 123 "You have chosen to seed training. " 124 "This will turn on the CUDNN deterministic setting, " 125 "which can slow down your training considerably! " 126 "You may see unexpected behavior when restarting " 127 "from checkpoints." 128 ) 129 130 if args.gpu is not None: 131 warnings.warn( 132 "You have chosen a specific GPU. This will completely " 133 "disable data parallelism." 134 ) 135 136 ngpus_per_node = torch.cuda.device_count() 137 main_worker(args.gpu, ngpus_per_node, args) 138 139 140 def main_worker(gpu, ngpus_per_node, args): 141 global best_acc1 142 args.gpu = gpu 143 144 if args.gpu is not None: 145 print("Use GPU: {} for training".format(args.gpu)) 146 147 # create model 148 if args.pretrained: 149 print("=> using pre-trained model '{}'".format(args.arch)) 150 model = models.__dict__[args.arch](pretrained=True) 151 else: 152 print("=> creating model '{}'".format(args.arch)) 153 model = models.__dict__[args.arch]() 154 155 if not torch.cuda.is_available(): 156 print("using CPU, this will be slow") 157 elif args.gpu is not None: 158 torch.cuda.set_device(args.gpu) 159 model = model.cuda(args.gpu) 160 else: 161 # DataParallel will divide and allocate batch_size to all available GPUs 162 if args.arch.startswith("alexnet") or args.arch.startswith("vgg"): 163 model.features = torch.nn.DataParallel(model.features) 164 model.cuda() 165 else: 166 model = torch.nn.DataParallel(model).cuda() 167 168 # define loss function (criterion) and optimizer 169 criterion = nn.CrossEntropyLoss().cuda(args.gpu) 170 171 optimizer = torch.optim.SGD( 172 model.parameters(), 173 args.lr, 174 momentum=args.momentum, 175 weight_decay=args.weight_decay, 176 ) 177 178 # optionally resume from a checkpoint 179 if args.resume: 180 if os.path.isfile(args.resume): 181 print("=> loading checkpoint '{}'".format(args.resume)) 182 if args.gpu is None: 183 checkpoint = torch.load(args.resume) 184 else: 185 # Map model to be loaded to specified single gpu. 186 loc = "cuda:{}".format(args.gpu) 187 checkpoint = torch.load(args.resume, map_location=loc) 188 args.start_epoch = checkpoint["epoch"] 189 best_acc1 = checkpoint["best_acc1"] 190 if args.gpu is not None: 191 # best_acc1 may be from a checkpoint from a different GPU 192 best_acc1 = best_acc1.to(args.gpu) 193 model.load_state_dict(checkpoint["state_dict"]) 194 optimizer.load_state_dict(checkpoint["optimizer"]) 195 print( 196 "=> loaded checkpoint '{}' (epoch {})".format( 197 args.resume, checkpoint["epoch"] 198 ) 199 ) 200 else: 201 print("=> no checkpoint found at '{}'".format(args.resume)) 202 203 cudnn.benchmark = True 204 205 # Data loading code 206 train_loader = torch.utils.data.DataLoader( 207 aistore.pytorch.Dataset( 208 "http://aistore-sample-proxy:51080", 209 Bck("imagenet"), # AIS IP address or hostname 210 prefix="train/", 211 transform_id="imagenet-train", 212 transform_filter=lambda object_name: object_name.endswith(".jpg"), 213 ), 214 batch_size=args.batch_size, 215 shuffle=True, 216 num_workers=args.workers, 217 pin_memory=True, 218 ) 219 220 val_loader = torch.utils.data.DataLoader( 221 aistore.pytorch.Dataset( 222 "http://aistore-sample-proxy:51080", 223 Bck("imagenet"), 224 prefix="val/", 225 transform_id="imagenet-train", 226 transform_filter=lambda object_name: object_name.endswith(".jpg"), 227 ), 228 batch_size=args.batch_size, 229 shuffle=False, 230 num_workers=args.workers, 231 pin_memory=True, 232 ) 233 234 if args.evaluate: 235 validate(val_loader, model, criterion, args) 236 return 237 238 for epoch in range(args.start_epoch, args.epochs): 239 adjust_learning_rate(optimizer, epoch, args) 240 241 # train for one epoch 242 train(train_loader, model, criterion, optimizer, epoch, args) 243 244 # evaluate on validation set 245 acc1 = validate(val_loader, model, criterion, args) 246 247 # remember best acc@1 and save checkpoint 248 is_best = acc1 > best_acc1 249 best_acc1 = max(acc1, best_acc1) 250 251 save_checkpoint( 252 { 253 "epoch": epoch + 1, 254 "arch": args.arch, 255 "state_dict": model.state_dict(), 256 "best_acc1": best_acc1, 257 "optimizer": optimizer.state_dict(), 258 }, 259 is_best, 260 ) 261 262 263 def train(train_loader, model, criterion, optimizer, epoch, args): 264 batch_time = AverageMeter("Time", ":6.3f") 265 data_time = AverageMeter("Data", ":6.3f") 266 losses = AverageMeter("Loss", ":.4e") 267 top1 = AverageMeter("Acc@1", ":6.2f") 268 top5 = AverageMeter("Acc@5", ":6.2f") 269 progress = ProgressMeter( 270 len(train_loader), 271 [batch_time, data_time, losses, top1, top5], 272 prefix="Epoch: [{}]".format(epoch), 273 ) 274 275 # switch to train mode 276 model.train() 277 278 end = time.time() 279 for i, (images, target) in enumerate(train_loader): 280 # measure data loading time 281 data_time.update(time.time() - end) 282 283 if args.gpu is not None: 284 images = images.cuda(args.gpu, non_blocking=True) 285 if torch.cuda.is_available(): 286 target = target.cuda(args.gpu, non_blocking=True) 287 288 # compute output 289 output = model(images) 290 loss = criterion(output, target) 291 292 # measure accuracy and record loss 293 acc1, acc5 = accuracy(output, target, topk=(1, 5)) 294 losses.update(loss.item(), images.size(0)) 295 top1.update(acc1[0], images.size(0)) 296 top5.update(acc5[0], images.size(0)) 297 298 # compute gradient and do SGD step 299 optimizer.zero_grad() 300 loss.backward() 301 optimizer.step() 302 303 # measure elapsed time 304 batch_time.update(time.time() - end) 305 end = time.time() 306 307 if i % args.print_freq == 0: 308 progress.display(i) 309 310 311 def validate(val_loader, model, criterion, args): 312 batch_time = AverageMeter("Time", ":6.3f") 313 losses = AverageMeter("Loss", ":.4e") 314 top1 = AverageMeter("Acc@1", ":6.2f") 315 top5 = AverageMeter("Acc@5", ":6.2f") 316 progress = ProgressMeter( 317 len(val_loader), [batch_time, losses, top1, top5], prefix="Test: " 318 ) 319 320 # switch to evaluate mode 321 model.eval() 322 323 with torch.no_grad(): 324 end = time.time() 325 for i, (images, target) in enumerate(val_loader): 326 if args.gpu is not None: 327 images = images.cuda(args.gpu, non_blocking=True) 328 if torch.cuda.is_available(): 329 target = target.cuda(args.gpu, non_blocking=True) 330 331 # compute output 332 output = model(images) 333 loss = criterion(output, target) 334 335 # measure accuracy and record loss 336 acc1, acc5 = accuracy(output, target, topk=(1, 5)) 337 losses.update(loss.item(), images.size(0)) 338 top1.update(acc1[0], images.size(0)) 339 top5.update(acc5[0], images.size(0)) 340 341 # measure elapsed time 342 batch_time.update(time.time() - end) 343 end = time.time() 344 345 if i % args.print_freq == 0: 346 progress.display(i) 347 348 # TODO: this should also be done with the ProgressMeter 349 print( 350 " * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}".format(top1=top1, top5=top5) 351 ) 352 353 return top1.avg 354 355 356 def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"): 357 torch.save(state, filename) 358 if is_best: 359 shutil.copyfile(filename, "model_best.pth.tar") 360 361 362 class AverageMeter(object): 363 """Computes and stores the average and current value""" 364 365 def __init__(self, name, fmt=":f"): 366 self.name = name 367 self.fmt = fmt 368 self.reset() 369 370 def reset(self): 371 self.val = 0 372 self.avg = 0 373 self.sum = 0 374 self.count = 0 375 376 def update(self, val, n=1): 377 self.val = val 378 self.sum += val * n 379 self.count += n 380 self.avg = self.sum / self.count 381 382 def __str__(self): 383 fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 384 return fmtstr.format(**self.__dict__) 385 386 387 class ProgressMeter(object): 388 def __init__(self, num_batches, meters, prefix=""): 389 self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 390 self.meters = meters 391 self.prefix = prefix 392 393 def display(self, batch): 394 entries = [self.prefix + self.batch_fmtstr.format(batch)] 395 entries += [str(meter) for meter in self.meters] 396 print("\t".join(entries)) 397 398 def _get_batch_fmtstr(self, num_batches): 399 num_digits = len(str(num_batches // 1)) 400 fmt = "{:" + str(num_digits) + "d}" 401 return "[" + fmt + "/" + fmt.format(num_batches) + "]" 402 403 404 def adjust_learning_rate(optimizer, epoch, args): 405 """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 406 lr = args.lr * (0.1 ** (epoch // 30)) 407 for param_group in optimizer.param_groups: 408 param_group["lr"] = lr 409 410 411 def accuracy(output, target, topk=(1,)): 412 """Computes the accuracy over the k top predictions for the specified values of k""" 413 with torch.no_grad(): 414 maxk = max(topk) 415 batch_size = target.size(0) 416 417 _, pred = output.topk(maxk, 1, True, True) 418 pred = pred.t() 419 correct = pred.eq(target.view(1, -1).expand_as(pred)) 420 421 res = [] 422 for k in topk: 423 correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 424 res.append(correct_k.mul_(100.0 / batch_size)) 425 return res 426 427 428 if __name__ == "__main__": 429 main()