github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/docs/examples/etl-imagenet-dataset/train_aistore.py (about)

     1  import argparse
     2  import os
     3  import random
     4  import shutil
     5  import time
     6  import warnings
     7  
     8  import aistore
     9  from aistore.client import Bck
    10  
    11  import torch
    12  import torch.nn as nn
    13  import torch.nn.parallel
    14  import torch.backends.cudnn as cudnn
    15  import torch.optim
    16  import torch.utils.data
    17  import torch.utils.data.distributed
    18  import torchvision.models as models
    19  
    20  model_names = sorted(
    21      name
    22      for name in models.__dict__
    23      if name.islower() and not name.startswith("__") and callable(models.__dict__[name])
    24  )
    25  
    26  parser = argparse.ArgumentParser(description="PyTorch ImageNet Training")
    27  parser.add_argument(
    28      "-a",
    29      "--arch",
    30      metavar="ARCH",
    31      default="resnet18",
    32      choices=model_names,
    33      help="model architecture: " + " | ".join(model_names) + " (default: resnet18)",
    34  )
    35  parser.add_argument(
    36      "-j",
    37      "--workers",
    38      default=4,
    39      type=int,
    40      metavar="N",
    41      help="number of data loading workers (default: 4)",
    42  )
    43  parser.add_argument(
    44      "--epochs", default=90, type=int, metavar="N", help="number of total epochs to run"
    45  )
    46  parser.add_argument(
    47      "--start-epoch",
    48      default=0,
    49      type=int,
    50      metavar="N",
    51      help="manual epoch number (useful on restarts)",
    52  )
    53  parser.add_argument(
    54      "-b",
    55      "--batch-size",
    56      default=256,
    57      type=int,
    58      metavar="N",
    59      help="mini-batch size (default: 256), this is the total "
    60      "batch size of all GPUs on the current node when "
    61      "using Data Parallel or Distributed Data Parallel",
    62  )
    63  parser.add_argument(
    64      "--lr",
    65      "--learning-rate",
    66      default=0.1,
    67      type=float,
    68      metavar="LR",
    69      help="initial learning rate",
    70      dest="lr",
    71  )
    72  parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    73  parser.add_argument(
    74      "--wd",
    75      "--weight-decay",
    76      default=1e-4,
    77      type=float,
    78      metavar="W",
    79      help="weight decay (default: 1e-4)",
    80      dest="weight_decay",
    81  )
    82  parser.add_argument(
    83      "-p",
    84      "--print-freq",
    85      default=10,
    86      type=int,
    87      metavar="N",
    88      help="print frequency (default: 10)",
    89  )
    90  parser.add_argument(
    91      "--resume",
    92      default="",
    93      type=str,
    94      metavar="PATH",
    95      help="path to latest checkpoint (default: none)",
    96  )
    97  parser.add_argument(
    98      "-e",
    99      "--evaluate",
   100      dest="evaluate",
   101      action="store_true",
   102      help="evaluate model on validation set",
   103  )
   104  parser.add_argument(
   105      "--pretrained", dest="pretrained", action="store_true", help="use pre-trained model"
   106  )
   107  parser.add_argument(
   108      "--seed", default=None, type=int, help="seed for initializing training"
   109  )
   110  parser.add_argument("--gpu", default=None, type=int, help="GPU id to use.")
   111  
   112  best_acc1 = 0
   113  
   114  
   115  def main():
   116      args = parser.parse_args()
   117  
   118      if args.seed is not None:
   119          random.seed(args.seed)
   120          torch.manual_seed(args.seed)
   121          cudnn.deterministic = True
   122          warnings.warn(
   123              "You have chosen to seed training. "
   124              "This will turn on the CUDNN deterministic setting, "
   125              "which can slow down your training considerably! "
   126              "You may see unexpected behavior when restarting "
   127              "from checkpoints."
   128          )
   129  
   130      if args.gpu is not None:
   131          warnings.warn(
   132              "You have chosen a specific GPU. This will completely "
   133              "disable data parallelism."
   134          )
   135  
   136      ngpus_per_node = torch.cuda.device_count()
   137      main_worker(args.gpu, ngpus_per_node, args)
   138  
   139  
   140  def main_worker(gpu, ngpus_per_node, args):
   141      global best_acc1
   142      args.gpu = gpu
   143  
   144      if args.gpu is not None:
   145          print("Use GPU: {} for training".format(args.gpu))
   146  
   147      # create model
   148      if args.pretrained:
   149          print("=> using pre-trained model '{}'".format(args.arch))
   150          model = models.__dict__[args.arch](pretrained=True)
   151      else:
   152          print("=> creating model '{}'".format(args.arch))
   153          model = models.__dict__[args.arch]()
   154  
   155      if not torch.cuda.is_available():
   156          print("using CPU, this will be slow")
   157      elif args.gpu is not None:
   158          torch.cuda.set_device(args.gpu)
   159          model = model.cuda(args.gpu)
   160      else:
   161          # DataParallel will divide and allocate batch_size to all available GPUs
   162          if args.arch.startswith("alexnet") or args.arch.startswith("vgg"):
   163              model.features = torch.nn.DataParallel(model.features)
   164              model.cuda()
   165          else:
   166              model = torch.nn.DataParallel(model).cuda()
   167  
   168      # define loss function (criterion) and optimizer
   169      criterion = nn.CrossEntropyLoss().cuda(args.gpu)
   170  
   171      optimizer = torch.optim.SGD(
   172          model.parameters(),
   173          args.lr,
   174          momentum=args.momentum,
   175          weight_decay=args.weight_decay,
   176      )
   177  
   178      # optionally resume from a checkpoint
   179      if args.resume:
   180          if os.path.isfile(args.resume):
   181              print("=> loading checkpoint '{}'".format(args.resume))
   182              if args.gpu is None:
   183                  checkpoint = torch.load(args.resume)
   184              else:
   185                  # Map model to be loaded to specified single gpu.
   186                  loc = "cuda:{}".format(args.gpu)
   187                  checkpoint = torch.load(args.resume, map_location=loc)
   188              args.start_epoch = checkpoint["epoch"]
   189              best_acc1 = checkpoint["best_acc1"]
   190              if args.gpu is not None:
   191                  # best_acc1 may be from a checkpoint from a different GPU
   192                  best_acc1 = best_acc1.to(args.gpu)
   193              model.load_state_dict(checkpoint["state_dict"])
   194              optimizer.load_state_dict(checkpoint["optimizer"])
   195              print(
   196                  "=> loaded checkpoint '{}' (epoch {})".format(
   197                      args.resume, checkpoint["epoch"]
   198                  )
   199              )
   200          else:
   201              print("=> no checkpoint found at '{}'".format(args.resume))
   202  
   203      cudnn.benchmark = True
   204  
   205      # Data loading code
   206      train_loader = torch.utils.data.DataLoader(
   207          aistore.pytorch.Dataset(
   208              "http://aistore-sample-proxy:51080",
   209              Bck("imagenet"),  # AIS IP address or hostname
   210              prefix="train/",
   211              transform_id="imagenet-train",
   212              transform_filter=lambda object_name: object_name.endswith(".jpg"),
   213          ),
   214          batch_size=args.batch_size,
   215          shuffle=True,
   216          num_workers=args.workers,
   217          pin_memory=True,
   218      )
   219  
   220      val_loader = torch.utils.data.DataLoader(
   221          aistore.pytorch.Dataset(
   222              "http://aistore-sample-proxy:51080",
   223              Bck("imagenet"),
   224              prefix="val/",
   225              transform_id="imagenet-train",
   226              transform_filter=lambda object_name: object_name.endswith(".jpg"),
   227          ),
   228          batch_size=args.batch_size,
   229          shuffle=False,
   230          num_workers=args.workers,
   231          pin_memory=True,
   232      )
   233  
   234      if args.evaluate:
   235          validate(val_loader, model, criterion, args)
   236          return
   237  
   238      for epoch in range(args.start_epoch, args.epochs):
   239          adjust_learning_rate(optimizer, epoch, args)
   240  
   241          # train for one epoch
   242          train(train_loader, model, criterion, optimizer, epoch, args)
   243  
   244          # evaluate on validation set
   245          acc1 = validate(val_loader, model, criterion, args)
   246  
   247          # remember best acc@1 and save checkpoint
   248          is_best = acc1 > best_acc1
   249          best_acc1 = max(acc1, best_acc1)
   250  
   251          save_checkpoint(
   252              {
   253                  "epoch": epoch + 1,
   254                  "arch": args.arch,
   255                  "state_dict": model.state_dict(),
   256                  "best_acc1": best_acc1,
   257                  "optimizer": optimizer.state_dict(),
   258              },
   259              is_best,
   260          )
   261  
   262  
   263  def train(train_loader, model, criterion, optimizer, epoch, args):
   264      batch_time = AverageMeter("Time", ":6.3f")
   265      data_time = AverageMeter("Data", ":6.3f")
   266      losses = AverageMeter("Loss", ":.4e")
   267      top1 = AverageMeter("Acc@1", ":6.2f")
   268      top5 = AverageMeter("Acc@5", ":6.2f")
   269      progress = ProgressMeter(
   270          len(train_loader),
   271          [batch_time, data_time, losses, top1, top5],
   272          prefix="Epoch: [{}]".format(epoch),
   273      )
   274  
   275      # switch to train mode
   276      model.train()
   277  
   278      end = time.time()
   279      for i, (images, target) in enumerate(train_loader):
   280          # measure data loading time
   281          data_time.update(time.time() - end)
   282  
   283          if args.gpu is not None:
   284              images = images.cuda(args.gpu, non_blocking=True)
   285          if torch.cuda.is_available():
   286              target = target.cuda(args.gpu, non_blocking=True)
   287  
   288          # compute output
   289          output = model(images)
   290          loss = criterion(output, target)
   291  
   292          # measure accuracy and record loss
   293          acc1, acc5 = accuracy(output, target, topk=(1, 5))
   294          losses.update(loss.item(), images.size(0))
   295          top1.update(acc1[0], images.size(0))
   296          top5.update(acc5[0], images.size(0))
   297  
   298          # compute gradient and do SGD step
   299          optimizer.zero_grad()
   300          loss.backward()
   301          optimizer.step()
   302  
   303          # measure elapsed time
   304          batch_time.update(time.time() - end)
   305          end = time.time()
   306  
   307          if i % args.print_freq == 0:
   308              progress.display(i)
   309  
   310  
   311  def validate(val_loader, model, criterion, args):
   312      batch_time = AverageMeter("Time", ":6.3f")
   313      losses = AverageMeter("Loss", ":.4e")
   314      top1 = AverageMeter("Acc@1", ":6.2f")
   315      top5 = AverageMeter("Acc@5", ":6.2f")
   316      progress = ProgressMeter(
   317          len(val_loader), [batch_time, losses, top1, top5], prefix="Test: "
   318      )
   319  
   320      # switch to evaluate mode
   321      model.eval()
   322  
   323      with torch.no_grad():
   324          end = time.time()
   325          for i, (images, target) in enumerate(val_loader):
   326              if args.gpu is not None:
   327                  images = images.cuda(args.gpu, non_blocking=True)
   328              if torch.cuda.is_available():
   329                  target = target.cuda(args.gpu, non_blocking=True)
   330  
   331              # compute output
   332              output = model(images)
   333              loss = criterion(output, target)
   334  
   335              # measure accuracy and record loss
   336              acc1, acc5 = accuracy(output, target, topk=(1, 5))
   337              losses.update(loss.item(), images.size(0))
   338              top1.update(acc1[0], images.size(0))
   339              top5.update(acc5[0], images.size(0))
   340  
   341              # measure elapsed time
   342              batch_time.update(time.time() - end)
   343              end = time.time()
   344  
   345              if i % args.print_freq == 0:
   346                  progress.display(i)
   347  
   348          # TODO: this should also be done with the ProgressMeter
   349          print(
   350              " * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}".format(top1=top1, top5=top5)
   351          )
   352  
   353      return top1.avg
   354  
   355  
   356  def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"):
   357      torch.save(state, filename)
   358      if is_best:
   359          shutil.copyfile(filename, "model_best.pth.tar")
   360  
   361  
   362  class AverageMeter(object):
   363      """Computes and stores the average and current value"""
   364  
   365      def __init__(self, name, fmt=":f"):
   366          self.name = name
   367          self.fmt = fmt
   368          self.reset()
   369  
   370      def reset(self):
   371          self.val = 0
   372          self.avg = 0
   373          self.sum = 0
   374          self.count = 0
   375  
   376      def update(self, val, n=1):
   377          self.val = val
   378          self.sum += val * n
   379          self.count += n
   380          self.avg = self.sum / self.count
   381  
   382      def __str__(self):
   383          fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
   384          return fmtstr.format(**self.__dict__)
   385  
   386  
   387  class ProgressMeter(object):
   388      def __init__(self, num_batches, meters, prefix=""):
   389          self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
   390          self.meters = meters
   391          self.prefix = prefix
   392  
   393      def display(self, batch):
   394          entries = [self.prefix + self.batch_fmtstr.format(batch)]
   395          entries += [str(meter) for meter in self.meters]
   396          print("\t".join(entries))
   397  
   398      def _get_batch_fmtstr(self, num_batches):
   399          num_digits = len(str(num_batches // 1))
   400          fmt = "{:" + str(num_digits) + "d}"
   401          return "[" + fmt + "/" + fmt.format(num_batches) + "]"
   402  
   403  
   404  def adjust_learning_rate(optimizer, epoch, args):
   405      """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
   406      lr = args.lr * (0.1 ** (epoch // 30))
   407      for param_group in optimizer.param_groups:
   408          param_group["lr"] = lr
   409  
   410  
   411  def accuracy(output, target, topk=(1,)):
   412      """Computes the accuracy over the k top predictions for the specified values of k"""
   413      with torch.no_grad():
   414          maxk = max(topk)
   415          batch_size = target.size(0)
   416  
   417          _, pred = output.topk(maxk, 1, True, True)
   418          pred = pred.t()
   419          correct = pred.eq(target.view(1, -1).expand_as(pred))
   420  
   421          res = []
   422          for k in topk:
   423              correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
   424              res.append(correct_k.mul_(100.0 / batch_size))
   425          return res
   426  
   427  
   428  if __name__ == "__main__":
   429      main()