github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/docs/examples/etl-imagenet-dataset/train_pytorch.py (about) 1 import argparse 2 import os 3 import random 4 import shutil 5 import time 6 import warnings 7 8 import torch 9 import torch.nn as nn 10 import torch.nn.parallel 11 import torch.backends.cudnn as cudnn 12 import torch.optim 13 import torch.utils.data 14 import torch.utils.data.distributed 15 import torchvision.transforms as transforms 16 import torchvision.models as models 17 from PIL import Image 18 19 model_names = sorted( 20 name 21 for name in models.__dict__ 22 if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) 23 ) 24 25 parser = argparse.ArgumentParser(description="PyTorch ImageNet Training") 26 parser.add_argument("data", metavar="DIR", help="path to dataset") 27 parser.add_argument( 28 "-a", 29 "--arch", 30 metavar="ARCH", 31 default="resnet18", 32 choices=model_names, 33 help="model architecture: " + " | ".join(model_names) + " (default: resnet18)", 34 ) 35 parser.add_argument( 36 "-j", 37 "--workers", 38 default=4, 39 type=int, 40 metavar="N", 41 help="number of data loading workers (default: 4)", 42 ) 43 parser.add_argument( 44 "--epochs", default=90, type=int, metavar="N", help="number of total epochs to run" 45 ) 46 parser.add_argument( 47 "--start-epoch", 48 default=0, 49 type=int, 50 metavar="N", 51 help="manual epoch number (useful on restarts)", 52 ) 53 parser.add_argument( 54 "-b", 55 "--batch-size", 56 default=256, 57 type=int, 58 metavar="N", 59 help="mini-batch size (default: 256), this is the total " 60 "batch size of all GPUs on the current node when " 61 "using Data Parallel or Distributed Data Parallel", 62 ) 63 parser.add_argument( 64 "--lr", 65 "--learning-rate", 66 default=0.1, 67 type=float, 68 metavar="LR", 69 help="initial learning rate", 70 dest="lr", 71 ) 72 parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") 73 parser.add_argument( 74 "--wd", 75 "--weight-decay", 76 default=1e-4, 77 type=float, 78 metavar="W", 79 help="weight decay (default: 1e-4)", 80 dest="weight_decay", 81 ) 82 parser.add_argument( 83 "-p", 84 "--print-freq", 85 default=10, 86 type=int, 87 metavar="N", 88 help="print frequency (default: 10)", 89 ) 90 parser.add_argument( 91 "--resume", 92 default="", 93 type=str, 94 metavar="PATH", 95 help="path to latest checkpoint (default: none)", 96 ) 97 parser.add_argument( 98 "-e", 99 "--evaluate", 100 dest="evaluate", 101 action="store_true", 102 help="evaluate model on validation set", 103 ) 104 parser.add_argument( 105 "--pretrained", dest="pretrained", action="store_true", help="use pre-trained model" 106 ) 107 parser.add_argument( 108 "--seed", default=None, type=int, help="seed for initializing training" 109 ) 110 parser.add_argument("--gpu", default=None, type=int, help="GPU id to use.") 111 112 best_acc1 = 0 113 114 115 def main(): 116 args = parser.parse_args() 117 118 if args.seed is not None: 119 random.seed(args.seed) 120 torch.manual_seed(args.seed) 121 cudnn.deterministic = True 122 warnings.warn( 123 "You have chosen to seed training. " 124 "This will turn on the CUDNN deterministic setting, " 125 "which can slow down your training considerably! " 126 "You may see unexpected behavior when restarting " 127 "from checkpoints." 128 ) 129 130 if args.gpu is not None: 131 warnings.warn( 132 "You have chosen a specific GPU. This will completely " 133 "disable data parallelism." 134 ) 135 136 ngpus_per_node = torch.cuda.device_count() 137 main_worker(args.gpu, ngpus_per_node, args) 138 139 140 def main_worker(gpu, ngpus_per_node, args): 141 global best_acc1 142 args.gpu = gpu 143 144 if args.gpu is not None: 145 print("Use GPU: {} for training".format(args.gpu)) 146 147 # create model 148 if args.pretrained: 149 print("=> using pre-trained model '{}'".format(args.arch)) 150 model = models.__dict__[args.arch](pretrained=True) 151 else: 152 print("=> creating model '{}'".format(args.arch)) 153 model = models.__dict__[args.arch]() 154 155 if not torch.cuda.is_available(): 156 print("using CPU, this will be slow") 157 elif args.gpu is not None: 158 torch.cuda.set_device(args.gpu) 159 model = model.cuda(args.gpu) 160 else: 161 # DataParallel will divide and allocate batch_size to all available GPUs 162 if args.arch.startswith("alexnet") or args.arch.startswith("vgg"): 163 model.features = torch.nn.DataParallel(model.features) 164 model.cuda() 165 else: 166 model = torch.nn.DataParallel(model).cuda() 167 168 # define loss function (criterion) and optimizer 169 criterion = nn.CrossEntropyLoss().cuda(args.gpu) 170 171 optimizer = torch.optim.SGD( 172 model.parameters(), 173 args.lr, 174 momentum=args.momentum, 175 weight_decay=args.weight_decay, 176 ) 177 178 # optionally resume from a checkpoint 179 if args.resume: 180 if os.path.isfile(args.resume): 181 print("=> loading checkpoint '{}'".format(args.resume)) 182 if args.gpu is None: 183 checkpoint = torch.load(args.resume) 184 else: 185 # Map model to be loaded to specified single gpu. 186 loc = "cuda:{}".format(args.gpu) 187 checkpoint = torch.load(args.resume, map_location=loc) 188 args.start_epoch = checkpoint["epoch"] 189 best_acc1 = checkpoint["best_acc1"] 190 if args.gpu is not None: 191 # best_acc1 may be from a checkpoint from a different GPU 192 best_acc1 = best_acc1.to(args.gpu) 193 model.load_state_dict(checkpoint["state_dict"]) 194 optimizer.load_state_dict(checkpoint["optimizer"]) 195 print( 196 "=> loaded checkpoint '{}' (epoch {})".format( 197 args.resume, checkpoint["epoch"] 198 ) 199 ) 200 else: 201 print("=> no checkpoint found at '{}'".format(args.resume)) 202 203 cudnn.benchmark = True 204 205 # Data loading code 206 normalize = transforms.Normalize( 207 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 208 ) 209 210 train_loader = torch.utils.data.DataLoader( 211 LocalDataset( 212 os.path.join(args.data, "train"), 213 transforms.Compose( 214 [ 215 transforms.RandomResizedCrop(224), 216 transforms.RandomHorizontalFlip(), 217 transforms.ToTensor(), 218 normalize, 219 ] 220 ), 221 ), 222 batch_size=args.batch_size, 223 shuffle=True, 224 num_workers=args.workers, 225 pin_memory=True, 226 ) 227 228 val_loader = torch.utils.data.DataLoader( 229 LocalDataset( 230 os.path.join(args.data, "val"), 231 transforms.Compose( 232 [ 233 transforms.Resize(256), 234 transforms.CenterCrop(224), 235 transforms.ToTensor(), 236 normalize, 237 ] 238 ), 239 ), 240 batch_size=args.batch_size, 241 shuffle=False, 242 num_workers=args.workers, 243 pin_memory=True, 244 ) 245 246 if args.evaluate: 247 validate(val_loader, model, criterion, args) 248 return 249 250 for epoch in range(args.start_epoch, args.epochs): 251 adjust_learning_rate(optimizer, epoch, args) 252 253 # train for one epoch 254 train(train_loader, model, criterion, optimizer, epoch, args) 255 256 # evaluate on validation set 257 acc1 = validate(val_loader, model, criterion, args) 258 259 # remember best acc@1 and save checkpoint 260 is_best = acc1 > best_acc1 261 best_acc1 = max(acc1, best_acc1) 262 263 save_checkpoint( 264 { 265 "epoch": epoch + 1, 266 "arch": args.arch, 267 "state_dict": model.state_dict(), 268 "best_acc1": best_acc1, 269 "optimizer": optimizer.state_dict(), 270 }, 271 is_best, 272 ) 273 274 275 def train(train_loader, model, criterion, optimizer, epoch, args): 276 batch_time = AverageMeter("Time", ":6.3f") 277 data_time = AverageMeter("Data", ":6.3f") 278 losses = AverageMeter("Loss", ":.4e") 279 top1 = AverageMeter("Acc@1", ":6.2f") 280 top5 = AverageMeter("Acc@5", ":6.2f") 281 progress = ProgressMeter( 282 len(train_loader), 283 [batch_time, data_time, losses, top1, top5], 284 prefix="Epoch: [{}]".format(epoch), 285 ) 286 287 # switch to train mode 288 model.train() 289 290 end = time.time() 291 for i, (images, target) in enumerate(train_loader): 292 # measure data loading time 293 data_time.update(time.time() - end) 294 295 if args.gpu is not None: 296 images = images.cuda(args.gpu, non_blocking=True) 297 if torch.cuda.is_available(): 298 target = target.cuda(args.gpu, non_blocking=True) 299 300 # compute output 301 output = model(images) 302 loss = criterion(output, target) 303 304 # measure accuracy and record loss 305 acc1, acc5 = accuracy(output, target, topk=(1, 5)) 306 losses.update(loss.item(), images.size(0)) 307 top1.update(acc1[0], images.size(0)) 308 top5.update(acc5[0], images.size(0)) 309 310 # compute gradient and do SGD step 311 optimizer.zero_grad() 312 loss.backward() 313 optimizer.step() 314 315 # measure elapsed time 316 batch_time.update(time.time() - end) 317 end = time.time() 318 319 if i % args.print_freq == 0: 320 progress.display(i) 321 322 323 def validate(val_loader, model, criterion, args): 324 batch_time = AverageMeter("Time", ":6.3f") 325 losses = AverageMeter("Loss", ":.4e") 326 top1 = AverageMeter("Acc@1", ":6.2f") 327 top5 = AverageMeter("Acc@5", ":6.2f") 328 progress = ProgressMeter( 329 len(val_loader), [batch_time, losses, top1, top5], prefix="Test: " 330 ) 331 332 # switch to evaluate mode 333 model.eval() 334 335 with torch.no_grad(): 336 end = time.time() 337 for i, (images, target) in enumerate(val_loader): 338 if args.gpu is not None: 339 images = images.cuda(args.gpu, non_blocking=True) 340 if torch.cuda.is_available(): 341 target = target.cuda(args.gpu, non_blocking=True) 342 343 # compute output 344 output = model(images) 345 loss = criterion(output, target) 346 347 # measure accuracy and record loss 348 acc1, acc5 = accuracy(output, target, topk=(1, 5)) 349 losses.update(loss.item(), images.size(0)) 350 top1.update(acc1[0], images.size(0)) 351 top5.update(acc5[0], images.size(0)) 352 353 # measure elapsed time 354 batch_time.update(time.time() - end) 355 end = time.time() 356 357 if i % args.print_freq == 0: 358 progress.display(i) 359 360 # TODO: this should also be done with the ProgressMeter 361 print( 362 " * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}".format(top1=top1, top5=top5) 363 ) 364 365 return top1.avg 366 367 368 def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"): 369 torch.save(state, filename) 370 if is_best: 371 shutil.copyfile(filename, "model_best.pth.tar") 372 373 374 class LocalDataset(torch.utils.data.Dataset): 375 def __init__(self, root_dir, transform=None): 376 self.root_dir = root_dir 377 self.transform = transform 378 379 paths = set() 380 for root_dir, _, fnames in sorted(os.walk(self.root_dir)): 381 for fname in sorted(fnames): 382 bname = os.path.splitext(fname)[0] 383 paths.add(os.path.join(root_dir, bname)) 384 385 self.samples = [] 386 for path in paths: 387 target = None 388 with open(path + ".cls", "r") as f: 389 target = int(f.read()) 390 self.samples.append((path + ".jpg", target)) 391 392 def __len__(self) -> int: 393 return len(self.samples) 394 395 def __getitem__(self, index: int): 396 path, target = self.samples[index] 397 sample = self.loader(path) 398 if self.transform is not None: 399 sample = self.transform(sample) 400 401 return sample, target 402 403 def loader(self, path: str): 404 with open(path, "rb") as f: 405 img = Image.open(f) 406 return img.convert("RGB") 407 408 409 class AverageMeter(object): 410 """Computes and stores the average and current value""" 411 412 def __init__(self, name, fmt=":f"): 413 self.name = name 414 self.fmt = fmt 415 self.reset() 416 417 def reset(self): 418 self.val = 0 419 self.avg = 0 420 self.sum = 0 421 self.count = 0 422 423 def update(self, val, n=1): 424 self.val = val 425 self.sum += val * n 426 self.count += n 427 self.avg = self.sum / self.count 428 429 def __str__(self): 430 fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 431 return fmtstr.format(**self.__dict__) 432 433 434 class ProgressMeter(object): 435 def __init__(self, num_batches, meters, prefix=""): 436 self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 437 self.meters = meters 438 self.prefix = prefix 439 440 def display(self, batch): 441 entries = [self.prefix + self.batch_fmtstr.format(batch)] 442 entries += [str(meter) for meter in self.meters] 443 print("\t".join(entries)) 444 445 def _get_batch_fmtstr(self, num_batches): 446 num_digits = len(str(num_batches // 1)) 447 fmt = "{:" + str(num_digits) + "d}" 448 return "[" + fmt + "/" + fmt.format(num_batches) + "]" 449 450 451 def adjust_learning_rate(optimizer, epoch, args): 452 """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 453 lr = args.lr * (0.1 ** (epoch // 30)) 454 for param_group in optimizer.param_groups: 455 param_group["lr"] = lr 456 457 458 def accuracy(output, target, topk=(1,)): 459 """Computes the accuracy over the k top predictions for the specified values of k""" 460 with torch.no_grad(): 461 maxk = max(topk) 462 batch_size = target.size(0) 463 464 _, pred = output.topk(maxk, 1, True, True) 465 pred = pred.t() 466 correct = pred.eq(target.view(1, -1).expand_as(pred)) 467 468 res = [] 469 for k in topk: 470 correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 471 res.append(correct_k.mul_(100.0 / batch_size)) 472 return res 473 474 475 if __name__ == "__main__": 476 main()