github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/docs/examples/etl-imagenet-wd/pytorch_wd.py (about)

     1  import argparse
     2  import os
     3  import random
     4  import shutil
     5  import time
     6  import warnings
     7  
     8  import aistore
     9  from aistore.client import Bck, Client
    10  from aistore.client.transform import WDTransform
    11  
    12  import webdataset as wds
    13  
    14  import torch
    15  import torch.nn as nn
    16  import torch.nn.parallel
    17  import torch.backends.cudnn as cudnn
    18  import torch.optim
    19  import torch.utils.data
    20  import torch.utils.data.distributed
    21  import torchvision.models as models
    22  import torchvision.transforms as transforms
    23  
    24  model_names = sorted(
    25      name
    26      for name in models.__dict__
    27      if name.islower() and not name.startswith("__") and callable(models.__dict__[name])
    28  )
    29  
    30  parser = argparse.ArgumentParser(description="PyTorch ImageNet Training")
    31  parser.add_argument(
    32      "-a",
    33      "--arch",
    34      metavar="ARCH",
    35      default="resnet18",
    36      choices=model_names,
    37      help="model architecture: " + " | ".join(model_names) + " (default: resnet18)",
    38  )
    39  parser.add_argument(
    40      "-j",
    41      "--workers",
    42      default=4,
    43      type=int,
    44      metavar="N",
    45      help="number of data loading workers (default: 4)",
    46  )
    47  parser.add_argument(
    48      "--epochs",
    49      default=90,
    50      type=int,
    51      metavar="N",
    52      help="number of total epochs to run",
    53  )
    54  parser.add_argument(
    55      "--start-epoch",
    56      default=0,
    57      type=int,
    58      metavar="N",
    59      help="manual epoch number (useful on restarts)",
    60  )
    61  parser.add_argument(
    62      "-b",
    63      "--batch-size",
    64      default=256,
    65      type=int,
    66      metavar="N",
    67      help="mini-batch size (default: 256), this is the total "
    68      "batch size of all GPUs on the current node when "
    69      "using Data Parallel or Distributed Data Parallel",
    70  )
    71  parser.add_argument(
    72      "--lr",
    73      "--learning-rate",
    74      default=0.1,
    75      type=float,
    76      metavar="LR",
    77      help="initial learning rate",
    78      dest="lr",
    79  )
    80  parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    81  parser.add_argument(
    82      "--wd",
    83      "--weight-decay",
    84      default=1e-4,
    85      type=float,
    86      metavar="W",
    87      help="weight decay (default: 1e-4)",
    88      dest="weight_decay",
    89  )
    90  parser.add_argument(
    91      "-p",
    92      "--print-freq",
    93      default=10,
    94      type=int,
    95      metavar="N",
    96      help="print frequency (default: 10)",
    97  )
    98  parser.add_argument(
    99      "--resume",
   100      default="",
   101      type=str,
   102      metavar="PATH",
   103      help="path to latest checkpoint (default: none)",
   104  )
   105  parser.add_argument(
   106      "-e",
   107      "--evaluate",
   108      dest="evaluate",
   109      action="store_true",
   110      help="evaluate model on validation set",
   111  )
   112  parser.add_argument(
   113      "--pretrained",
   114      dest="pretrained",
   115      action="store_true",
   116      help="use pre-trained model",
   117  )
   118  parser.add_argument(
   119      "--seed", default=None, type=int, help="seed for initializing training"
   120  )
   121  parser.add_argument("--gpu", default=None, type=int, help="GPU id to use.")
   122  parser.add_argument(
   123      "--train-shards",
   124      default="imagenet-train-{000000..000005}.tar",
   125      type=str,
   126      help="template for shards to use for training",
   127  )
   128  parser.add_argument(
   129      "--val-shards",
   130      default="imagenet-train-{000000..000005}.tar",
   131      type=str,
   132      help="template for shards to use for validation",
   133  )
   134  parser.add_argument(
   135      "--ais-endpoint",
   136      default="http://aistore-sample-proxy:51080",
   137      type=str,
   138      help="AIStore proxy endpoint",
   139  )
   140  parser.add_argument(
   141      "--bucket-name", default="imagenet", type=str, help="dataset bucket name"
   142  )
   143  parser.add_argument(
   144      "--bucket-provider",
   145      default="ais",
   146      type=str,
   147      help='bucket provider (eg. "gcp", "aws", "ais")',
   148  )
   149  
   150  best_acc1 = 0
   151  
   152  
   153  def main():
   154      args = parser.parse_args()
   155  
   156      if args.seed is not None:
   157          random.seed(args.seed)
   158          torch.manual_seed(args.seed)
   159          cudnn.deterministic = True
   160          warnings.warn(
   161              "You have chosen to seed training. "
   162              "This will turn on the CUDNN deterministic setting, "
   163              "which can slow down your training considerably! "
   164              "You may see unexpected behavior when restarting "
   165              "from checkpoints."
   166          )
   167  
   168      if args.gpu is not None:
   169          warnings.warn(
   170              "You have chosen a specific GPU. This will completely "
   171              "disable data parallelism."
   172          )
   173  
   174      ngpus_per_node = torch.cuda.device_count()
   175      main_worker(args.gpu, ngpus_per_node, args)
   176  
   177  
   178  def wd_transform(client, pytorch_transform, name):
   179      def transform(sample):
   180          sample["npy"] = (
   181              pytorch_transform(sample.pop("jpg"))
   182              .permute(1, 2, 0)
   183              .numpy()
   184              .astype("float32")
   185          )
   186          return sample
   187  
   188      return WDTransform(client, transform, transform_name=name, verbose=True)
   189  
   190  
   191  def loader(urls, batch_size, workers):
   192      to_tensor = transforms.Compose([transforms.ToTensor()])
   193      etl_dataset = (
   194          wds.WebDataset(urls, handler=wds.handlers.warn_and_continue)
   195          .decode("rgb")
   196          .to_tuple("npy cls", handler=wds.handlers.warn_and_continue)
   197          .map_tuple(to_tensor, lambda x: x)
   198      )
   199      ds_size = (500 * len(urls)) // batch_size
   200      etl_dataset = etl_dataset.with_length(ds_size)
   201      loader = wds.WebLoader(
   202          etl_dataset,
   203          batch_size=batch_size,
   204          num_workers=workers,
   205      )
   206      return loader.with_length(ds_size)
   207  
   208  
   209  def main_worker(gpu, ngpus_per_node, args):
   210      global best_acc1
   211      args.gpu = gpu
   212  
   213      if args.gpu is not None:
   214          print("Use GPU: {} for training".format(args.gpu))
   215  
   216      # create model
   217      if args.pretrained:
   218          print("=> using pre-trained model '{}'".format(args.arch))
   219          model = models.__dict__[args.arch](pretrained=True)
   220      else:
   221          print("=> creating model '{}'".format(args.arch))
   222          model = models.__dict__[args.arch]()
   223  
   224      if not torch.cuda.is_available():
   225          print("using CPU, this will be slow")
   226      elif args.gpu is not None:
   227          torch.cuda.set_device(args.gpu)
   228          model = model.cuda(args.gpu)
   229      else:
   230          # DataParallel will divide and allocate batch_size to all available GPUs
   231          if args.arch.startswith("alexnet") or args.arch.startswith("vgg"):
   232              model.features = torch.nn.DataParallel(model.features)
   233              model.cuda()
   234          else:
   235              model = torch.nn.DataParallel(model).cuda()
   236  
   237      # define loss function (criterion) and optimizer
   238      criterion = nn.CrossEntropyLoss().cuda(args.gpu)
   239  
   240      optimizer = torch.optim.SGD(
   241          model.parameters(),
   242          args.lr,
   243          momentum=args.momentum,
   244          weight_decay=args.weight_decay,
   245      )
   246  
   247      # optionally resume from a checkpoint
   248      if args.resume:
   249          if os.path.isfile(args.resume):
   250              print("=> loading checkpoint '{}'".format(args.resume))
   251              if args.gpu is None:
   252                  checkpoint = torch.load(args.resume)
   253              else:
   254                  # Map model to be loaded to specified single gpu.
   255                  loc = "cuda:{}".format(args.gpu)
   256                  checkpoint = torch.load(args.resume, map_location=loc)
   257              args.start_epoch = checkpoint["epoch"]
   258              best_acc1 = checkpoint["best_acc1"]
   259              if args.gpu is not None:
   260                  # best_acc1 may be from a checkpoint from a different GPU
   261                  best_acc1 = best_acc1.to(args.gpu)
   262              model.load_state_dict(checkpoint["state_dict"])
   263              optimizer.load_state_dict(checkpoint["optimizer"])
   264              print(
   265                  "=> loaded checkpoint '{}' (epoch {})".format(
   266                      args.resume, checkpoint["epoch"]
   267                  )
   268              )
   269          else:
   270              print("=> no checkpoint found at '{}'".format(args.resume))
   271  
   272      cudnn.benchmark = True
   273  
   274      normalize = transforms.Normalize(
   275          mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
   276      )
   277  
   278      client = Client(args.ais_endpoint)
   279      bck = Bck(args.bucket_name, provider=args.bucket_provider)
   280  
   281      train_transform = transforms.Compose(
   282          [
   283              transforms.RandomResizedCrop(224),
   284              transforms.RandomHorizontalFlip(),
   285              transforms.ToTensor(),
   286              normalize,
   287          ]
   288      )
   289  
   290      train_etl = wd_transform(client, train_transform, "img-train")
   291      train_urls = client.expand_object_urls(
   292          bck, transform_id=train_etl.uuid, template=args.train_shards
   293      )
   294      train_loader = loader(train_urls, args.batch_size, args.workers)
   295  
   296      val_transform = transforms.Compose(
   297          [
   298              transforms.Resize(256),
   299              transforms.CenterCrop(224),
   300              transforms.ToTensor(),
   301              normalize,
   302          ]
   303      )
   304      val_etl = wd_transform(client, val_transform, "img-val")
   305      val_urls = client.expand_object_urls(
   306          bck, transform_id=val_etl.uuid, template=args.val_shards
   307      )
   308      val_loader = loader(val_urls, args.batch_size, args.workers)
   309  
   310      if args.evaluate:
   311          validate(val_loader, model, criterion, args)
   312          return
   313  
   314      for epoch in range(args.start_epoch, args.epochs):
   315          adjust_learning_rate(optimizer, epoch, args)
   316  
   317          # train for one epoch
   318          train(train_loader, model, criterion, optimizer, epoch, args)
   319  
   320          # evaluate on validation set
   321          acc1 = validate(val_loader, model, criterion, args)
   322  
   323          # remember best acc@1 and save checkpoint
   324          is_best = acc1 > best_acc1
   325          best_acc1 = max(acc1, best_acc1)
   326  
   327          save_checkpoint(
   328              {
   329                  "epoch": epoch + 1,
   330                  "arch": args.arch,
   331                  "state_dict": model.state_dict(),
   332                  "best_acc1": best_acc1,
   333                  "optimizer": optimizer.state_dict(),
   334              },
   335              is_best,
   336          )
   337  
   338  
   339  def train(train_loader, model, criterion, optimizer, epoch, args):
   340      batch_time = AverageMeter("Time", ":6.3f")
   341      data_time = AverageMeter("Data", ":6.3f")
   342      losses = AverageMeter("Loss", ":.4e")
   343      top1 = AverageMeter("Acc@1", ":6.2f")
   344      top5 = AverageMeter("Acc@5", ":6.2f")
   345      progress = ProgressMeter(
   346          len(train_loader),
   347          [batch_time, data_time, losses, top1, top5],
   348          prefix="Epoch: [{}]".format(epoch),
   349      )
   350  
   351      # switch to train mode
   352      model.train()
   353  
   354      end = time.time()
   355      for i, (images, target) in enumerate(train_loader):
   356          # measure data loading time
   357          data_time.update(time.time() - end)
   358  
   359          if args.gpu is not None:
   360              images = images.cuda(args.gpu, non_blocking=True)
   361          if torch.cuda.is_available():
   362              target = target.cuda(args.gpu, non_blocking=True)
   363  
   364          # compute output
   365          output = model(images)
   366          loss = criterion(output, target)
   367  
   368          # measure accuracy and record loss
   369          acc1, acc5 = accuracy(output, target, topk=(1, 5))
   370          losses.update(loss.item(), images.size(0))
   371          top1.update(acc1[0], images.size(0))
   372          top5.update(acc5[0], images.size(0))
   373  
   374          # compute gradient and do SGD step
   375          optimizer.zero_grad()
   376          loss.backward()
   377          optimizer.step()
   378  
   379          # measure elapsed time
   380          batch_time.update(time.time() - end)
   381          end = time.time()
   382  
   383          if i % args.print_freq == 0:
   384              progress.display(i)
   385  
   386  
   387  def validate(val_loader, model, criterion, args):
   388      batch_time = AverageMeter("Time", ":6.3f")
   389      losses = AverageMeter("Loss", ":.4e")
   390      top1 = AverageMeter("Acc@1", ":6.2f")
   391      top5 = AverageMeter("Acc@5", ":6.2f")
   392      progress = ProgressMeter(
   393          len(val_loader), [batch_time, losses, top1, top5], prefix="Test: "
   394      )
   395  
   396      # switch to evaluate mode
   397      model.eval()
   398  
   399      with torch.no_grad():
   400          end = time.time()
   401          for i, (images, target) in enumerate(val_loader):
   402              if args.gpu is not None:
   403                  images = images.cuda(args.gpu, non_blocking=True)
   404              if torch.cuda.is_available():
   405                  target = target.cuda(args.gpu, non_blocking=True)
   406  
   407              # compute output
   408              output = model(images)
   409              loss = criterion(output, target)
   410  
   411              # measure accuracy and record loss
   412              acc1, acc5 = accuracy(output, target, topk=(1, 5))
   413              losses.update(loss.item(), images.size(0))
   414              top1.update(acc1[0], images.size(0))
   415              top5.update(acc5[0], images.size(0))
   416  
   417              # measure elapsed time
   418              batch_time.update(time.time() - end)
   419              end = time.time()
   420  
   421              if i % args.print_freq == 0:
   422                  progress.display(i)
   423  
   424          # TODO: this should also be done with the ProgressMeter
   425          print(
   426              " * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}".format(top1=top1, top5=top5)
   427          )
   428  
   429      return top1.avg
   430  
   431  
   432  def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"):
   433      torch.save(state, filename)
   434      if is_best:
   435          shutil.copyfile(filename, "model_best.pth.tar")
   436  
   437  
   438  class AverageMeter(object):
   439      """Computes and stores the average and current value"""
   440  
   441      def __init__(self, name, fmt=":f"):
   442          self.name = name
   443          self.fmt = fmt
   444          self.reset()
   445  
   446      def reset(self):
   447          self.val = 0
   448          self.avg = 0
   449          self.sum = 0
   450          self.count = 0
   451  
   452      def update(self, val, n=1):
   453          self.val = val
   454          self.sum += val * n
   455          self.count += n
   456          self.avg = self.sum / self.count
   457  
   458      def __str__(self):
   459          fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
   460          return fmtstr.format(**self.__dict__)
   461  
   462  
   463  class ProgressMeter(object):
   464      def __init__(self, num_batches, meters, prefix=""):
   465          self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
   466          self.meters = meters
   467          self.prefix = prefix
   468  
   469      def display(self, batch):
   470          entries = [self.prefix + self.batch_fmtstr.format(batch)]
   471          entries += [str(meter) for meter in self.meters]
   472          print("\t".join(entries))
   473  
   474      def _get_batch_fmtstr(self, num_batches):
   475          num_digits = len(str(num_batches // 1))
   476          fmt = "{:" + str(num_digits) + "d}"
   477          return "[" + fmt + "/" + fmt.format(num_batches) + "]"
   478  
   479  
   480  def adjust_learning_rate(optimizer, epoch, args):
   481      """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
   482      lr = args.lr * (0.1 ** (epoch // 30))
   483      for param_group in optimizer.param_groups:
   484          param_group["lr"] = lr
   485  
   486  
   487  def accuracy(output, target, topk=(1,)):
   488      """Computes the accuracy over the k top predictions for the specified values of k"""
   489      with torch.no_grad():
   490          maxk = max(topk)
   491          batch_size = target.size(0)
   492  
   493          _, pred = output.topk(maxk, 1, True, True)
   494          pred = pred.t()
   495          correct = pred.eq(target.view(1, -1).expand_as(pred))
   496  
   497          res = []
   498          for k in topk:
   499              correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
   500              res.append(correct_k.mul_(100.0 / batch_size))
   501          return res
   502  
   503  
   504  if __name__ == "__main__":
   505      main()