github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/docs/examples/etl-imagenet-dataset/train_pytorch.py (about)

     1  import argparse
     2  import os
     3  import random
     4  import shutil
     5  import time
     6  import warnings
     7  
     8  import torch
     9  import torch.nn as nn
    10  import torch.nn.parallel
    11  import torch.backends.cudnn as cudnn
    12  import torch.optim
    13  import torch.utils.data
    14  import torch.utils.data.distributed
    15  import torchvision.transforms as transforms
    16  import torchvision.models as models
    17  from PIL import Image
    18  
    19  model_names = sorted(
    20      name
    21      for name in models.__dict__
    22      if name.islower() and not name.startswith("__") and callable(models.__dict__[name])
    23  )
    24  
    25  parser = argparse.ArgumentParser(description="PyTorch ImageNet Training")
    26  parser.add_argument("data", metavar="DIR", help="path to dataset")
    27  parser.add_argument(
    28      "-a",
    29      "--arch",
    30      metavar="ARCH",
    31      default="resnet18",
    32      choices=model_names,
    33      help="model architecture: " + " | ".join(model_names) + " (default: resnet18)",
    34  )
    35  parser.add_argument(
    36      "-j",
    37      "--workers",
    38      default=4,
    39      type=int,
    40      metavar="N",
    41      help="number of data loading workers (default: 4)",
    42  )
    43  parser.add_argument(
    44      "--epochs", default=90, type=int, metavar="N", help="number of total epochs to run"
    45  )
    46  parser.add_argument(
    47      "--start-epoch",
    48      default=0,
    49      type=int,
    50      metavar="N",
    51      help="manual epoch number (useful on restarts)",
    52  )
    53  parser.add_argument(
    54      "-b",
    55      "--batch-size",
    56      default=256,
    57      type=int,
    58      metavar="N",
    59      help="mini-batch size (default: 256), this is the total "
    60      "batch size of all GPUs on the current node when "
    61      "using Data Parallel or Distributed Data Parallel",
    62  )
    63  parser.add_argument(
    64      "--lr",
    65      "--learning-rate",
    66      default=0.1,
    67      type=float,
    68      metavar="LR",
    69      help="initial learning rate",
    70      dest="lr",
    71  )
    72  parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    73  parser.add_argument(
    74      "--wd",
    75      "--weight-decay",
    76      default=1e-4,
    77      type=float,
    78      metavar="W",
    79      help="weight decay (default: 1e-4)",
    80      dest="weight_decay",
    81  )
    82  parser.add_argument(
    83      "-p",
    84      "--print-freq",
    85      default=10,
    86      type=int,
    87      metavar="N",
    88      help="print frequency (default: 10)",
    89  )
    90  parser.add_argument(
    91      "--resume",
    92      default="",
    93      type=str,
    94      metavar="PATH",
    95      help="path to latest checkpoint (default: none)",
    96  )
    97  parser.add_argument(
    98      "-e",
    99      "--evaluate",
   100      dest="evaluate",
   101      action="store_true",
   102      help="evaluate model on validation set",
   103  )
   104  parser.add_argument(
   105      "--pretrained", dest="pretrained", action="store_true", help="use pre-trained model"
   106  )
   107  parser.add_argument(
   108      "--seed", default=None, type=int, help="seed for initializing training"
   109  )
   110  parser.add_argument("--gpu", default=None, type=int, help="GPU id to use.")
   111  
   112  best_acc1 = 0
   113  
   114  
   115  def main():
   116      args = parser.parse_args()
   117  
   118      if args.seed is not None:
   119          random.seed(args.seed)
   120          torch.manual_seed(args.seed)
   121          cudnn.deterministic = True
   122          warnings.warn(
   123              "You have chosen to seed training. "
   124              "This will turn on the CUDNN deterministic setting, "
   125              "which can slow down your training considerably! "
   126              "You may see unexpected behavior when restarting "
   127              "from checkpoints."
   128          )
   129  
   130      if args.gpu is not None:
   131          warnings.warn(
   132              "You have chosen a specific GPU. This will completely "
   133              "disable data parallelism."
   134          )
   135  
   136      ngpus_per_node = torch.cuda.device_count()
   137      main_worker(args.gpu, ngpus_per_node, args)
   138  
   139  
   140  def main_worker(gpu, ngpus_per_node, args):
   141      global best_acc1
   142      args.gpu = gpu
   143  
   144      if args.gpu is not None:
   145          print("Use GPU: {} for training".format(args.gpu))
   146  
   147      # create model
   148      if args.pretrained:
   149          print("=> using pre-trained model '{}'".format(args.arch))
   150          model = models.__dict__[args.arch](pretrained=True)
   151      else:
   152          print("=> creating model '{}'".format(args.arch))
   153          model = models.__dict__[args.arch]()
   154  
   155      if not torch.cuda.is_available():
   156          print("using CPU, this will be slow")
   157      elif args.gpu is not None:
   158          torch.cuda.set_device(args.gpu)
   159          model = model.cuda(args.gpu)
   160      else:
   161          # DataParallel will divide and allocate batch_size to all available GPUs
   162          if args.arch.startswith("alexnet") or args.arch.startswith("vgg"):
   163              model.features = torch.nn.DataParallel(model.features)
   164              model.cuda()
   165          else:
   166              model = torch.nn.DataParallel(model).cuda()
   167  
   168      # define loss function (criterion) and optimizer
   169      criterion = nn.CrossEntropyLoss().cuda(args.gpu)
   170  
   171      optimizer = torch.optim.SGD(
   172          model.parameters(),
   173          args.lr,
   174          momentum=args.momentum,
   175          weight_decay=args.weight_decay,
   176      )
   177  
   178      # optionally resume from a checkpoint
   179      if args.resume:
   180          if os.path.isfile(args.resume):
   181              print("=> loading checkpoint '{}'".format(args.resume))
   182              if args.gpu is None:
   183                  checkpoint = torch.load(args.resume)
   184              else:
   185                  # Map model to be loaded to specified single gpu.
   186                  loc = "cuda:{}".format(args.gpu)
   187                  checkpoint = torch.load(args.resume, map_location=loc)
   188              args.start_epoch = checkpoint["epoch"]
   189              best_acc1 = checkpoint["best_acc1"]
   190              if args.gpu is not None:
   191                  # best_acc1 may be from a checkpoint from a different GPU
   192                  best_acc1 = best_acc1.to(args.gpu)
   193              model.load_state_dict(checkpoint["state_dict"])
   194              optimizer.load_state_dict(checkpoint["optimizer"])
   195              print(
   196                  "=> loaded checkpoint '{}' (epoch {})".format(
   197                      args.resume, checkpoint["epoch"]
   198                  )
   199              )
   200          else:
   201              print("=> no checkpoint found at '{}'".format(args.resume))
   202  
   203      cudnn.benchmark = True
   204  
   205      # Data loading code
   206      normalize = transforms.Normalize(
   207          mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
   208      )
   209  
   210      train_loader = torch.utils.data.DataLoader(
   211          LocalDataset(
   212              os.path.join(args.data, "train"),
   213              transforms.Compose(
   214                  [
   215                      transforms.RandomResizedCrop(224),
   216                      transforms.RandomHorizontalFlip(),
   217                      transforms.ToTensor(),
   218                      normalize,
   219                  ]
   220              ),
   221          ),
   222          batch_size=args.batch_size,
   223          shuffle=True,
   224          num_workers=args.workers,
   225          pin_memory=True,
   226      )
   227  
   228      val_loader = torch.utils.data.DataLoader(
   229          LocalDataset(
   230              os.path.join(args.data, "val"),
   231              transforms.Compose(
   232                  [
   233                      transforms.Resize(256),
   234                      transforms.CenterCrop(224),
   235                      transforms.ToTensor(),
   236                      normalize,
   237                  ]
   238              ),
   239          ),
   240          batch_size=args.batch_size,
   241          shuffle=False,
   242          num_workers=args.workers,
   243          pin_memory=True,
   244      )
   245  
   246      if args.evaluate:
   247          validate(val_loader, model, criterion, args)
   248          return
   249  
   250      for epoch in range(args.start_epoch, args.epochs):
   251          adjust_learning_rate(optimizer, epoch, args)
   252  
   253          # train for one epoch
   254          train(train_loader, model, criterion, optimizer, epoch, args)
   255  
   256          # evaluate on validation set
   257          acc1 = validate(val_loader, model, criterion, args)
   258  
   259          # remember best acc@1 and save checkpoint
   260          is_best = acc1 > best_acc1
   261          best_acc1 = max(acc1, best_acc1)
   262  
   263          save_checkpoint(
   264              {
   265                  "epoch": epoch + 1,
   266                  "arch": args.arch,
   267                  "state_dict": model.state_dict(),
   268                  "best_acc1": best_acc1,
   269                  "optimizer": optimizer.state_dict(),
   270              },
   271              is_best,
   272          )
   273  
   274  
   275  def train(train_loader, model, criterion, optimizer, epoch, args):
   276      batch_time = AverageMeter("Time", ":6.3f")
   277      data_time = AverageMeter("Data", ":6.3f")
   278      losses = AverageMeter("Loss", ":.4e")
   279      top1 = AverageMeter("Acc@1", ":6.2f")
   280      top5 = AverageMeter("Acc@5", ":6.2f")
   281      progress = ProgressMeter(
   282          len(train_loader),
   283          [batch_time, data_time, losses, top1, top5],
   284          prefix="Epoch: [{}]".format(epoch),
   285      )
   286  
   287      # switch to train mode
   288      model.train()
   289  
   290      end = time.time()
   291      for i, (images, target) in enumerate(train_loader):
   292          # measure data loading time
   293          data_time.update(time.time() - end)
   294  
   295          if args.gpu is not None:
   296              images = images.cuda(args.gpu, non_blocking=True)
   297          if torch.cuda.is_available():
   298              target = target.cuda(args.gpu, non_blocking=True)
   299  
   300          # compute output
   301          output = model(images)
   302          loss = criterion(output, target)
   303  
   304          # measure accuracy and record loss
   305          acc1, acc5 = accuracy(output, target, topk=(1, 5))
   306          losses.update(loss.item(), images.size(0))
   307          top1.update(acc1[0], images.size(0))
   308          top5.update(acc5[0], images.size(0))
   309  
   310          # compute gradient and do SGD step
   311          optimizer.zero_grad()
   312          loss.backward()
   313          optimizer.step()
   314  
   315          # measure elapsed time
   316          batch_time.update(time.time() - end)
   317          end = time.time()
   318  
   319          if i % args.print_freq == 0:
   320              progress.display(i)
   321  
   322  
   323  def validate(val_loader, model, criterion, args):
   324      batch_time = AverageMeter("Time", ":6.3f")
   325      losses = AverageMeter("Loss", ":.4e")
   326      top1 = AverageMeter("Acc@1", ":6.2f")
   327      top5 = AverageMeter("Acc@5", ":6.2f")
   328      progress = ProgressMeter(
   329          len(val_loader), [batch_time, losses, top1, top5], prefix="Test: "
   330      )
   331  
   332      # switch to evaluate mode
   333      model.eval()
   334  
   335      with torch.no_grad():
   336          end = time.time()
   337          for i, (images, target) in enumerate(val_loader):
   338              if args.gpu is not None:
   339                  images = images.cuda(args.gpu, non_blocking=True)
   340              if torch.cuda.is_available():
   341                  target = target.cuda(args.gpu, non_blocking=True)
   342  
   343              # compute output
   344              output = model(images)
   345              loss = criterion(output, target)
   346  
   347              # measure accuracy and record loss
   348              acc1, acc5 = accuracy(output, target, topk=(1, 5))
   349              losses.update(loss.item(), images.size(0))
   350              top1.update(acc1[0], images.size(0))
   351              top5.update(acc5[0], images.size(0))
   352  
   353              # measure elapsed time
   354              batch_time.update(time.time() - end)
   355              end = time.time()
   356  
   357              if i % args.print_freq == 0:
   358                  progress.display(i)
   359  
   360          # TODO: this should also be done with the ProgressMeter
   361          print(
   362              " * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}".format(top1=top1, top5=top5)
   363          )
   364  
   365      return top1.avg
   366  
   367  
   368  def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"):
   369      torch.save(state, filename)
   370      if is_best:
   371          shutil.copyfile(filename, "model_best.pth.tar")
   372  
   373  
   374  class LocalDataset(torch.utils.data.Dataset):
   375      def __init__(self, root_dir, transform=None):
   376          self.root_dir = root_dir
   377          self.transform = transform
   378  
   379          paths = set()
   380          for root_dir, _, fnames in sorted(os.walk(self.root_dir)):
   381              for fname in sorted(fnames):
   382                  bname = os.path.splitext(fname)[0]
   383                  paths.add(os.path.join(root_dir, bname))
   384  
   385          self.samples = []
   386          for path in paths:
   387              target = None
   388              with open(path + ".cls", "r") as f:
   389                  target = int(f.read())
   390              self.samples.append((path + ".jpg", target))
   391  
   392      def __len__(self) -> int:
   393          return len(self.samples)
   394  
   395      def __getitem__(self, index: int):
   396          path, target = self.samples[index]
   397          sample = self.loader(path)
   398          if self.transform is not None:
   399              sample = self.transform(sample)
   400  
   401          return sample, target
   402  
   403      def loader(self, path: str):
   404          with open(path, "rb") as f:
   405              img = Image.open(f)
   406              return img.convert("RGB")
   407  
   408  
   409  class AverageMeter(object):
   410      """Computes and stores the average and current value"""
   411  
   412      def __init__(self, name, fmt=":f"):
   413          self.name = name
   414          self.fmt = fmt
   415          self.reset()
   416  
   417      def reset(self):
   418          self.val = 0
   419          self.avg = 0
   420          self.sum = 0
   421          self.count = 0
   422  
   423      def update(self, val, n=1):
   424          self.val = val
   425          self.sum += val * n
   426          self.count += n
   427          self.avg = self.sum / self.count
   428  
   429      def __str__(self):
   430          fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
   431          return fmtstr.format(**self.__dict__)
   432  
   433  
   434  class ProgressMeter(object):
   435      def __init__(self, num_batches, meters, prefix=""):
   436          self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
   437          self.meters = meters
   438          self.prefix = prefix
   439  
   440      def display(self, batch):
   441          entries = [self.prefix + self.batch_fmtstr.format(batch)]
   442          entries += [str(meter) for meter in self.meters]
   443          print("\t".join(entries))
   444  
   445      def _get_batch_fmtstr(self, num_batches):
   446          num_digits = len(str(num_batches // 1))
   447          fmt = "{:" + str(num_digits) + "d}"
   448          return "[" + fmt + "/" + fmt.format(num_batches) + "]"
   449  
   450  
   451  def adjust_learning_rate(optimizer, epoch, args):
   452      """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
   453      lr = args.lr * (0.1 ** (epoch // 30))
   454      for param_group in optimizer.param_groups:
   455          param_group["lr"] = lr
   456  
   457  
   458  def accuracy(output, target, topk=(1,)):
   459      """Computes the accuracy over the k top predictions for the specified values of k"""
   460      with torch.no_grad():
   461          maxk = max(topk)
   462          batch_size = target.size(0)
   463  
   464          _, pred = output.topk(maxk, 1, True, True)
   465          pred = pred.t()
   466          correct = pred.eq(target.view(1, -1).expand_as(pred))
   467  
   468          res = []
   469          for k in topk:
   470              correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
   471              res.append(correct_k.mul_(100.0 / batch_size))
   472          return res
   473  
   474  
   475  if __name__ == "__main__":
   476      main()