github.com/kubeflow/training-operator@v1.7.0/examples/pytorch/elastic/imagenet/imagenet.py (about)

     1  #!/usr/bin/env python3
     2  
     3  # Copyright (c) Facebook, Inc. and its affiliates.
     4  # All rights reserved.
     5  #
     6  # This source code is licensed under the BSD-style license found in the
     7  # LICENSE file in the root directory of this source tree.
     8  
     9  r"""
    10  Source: `pytorch imagenet example <https://github.com/pytorch/examples/blob/master/imagenet/main.py>`_ # noqa B950
    11  
    12  Modified and simplified to make the original pytorch example compatible with
    13  torchelastic.distributed.launch.
    14  
    15  Changes:
    16  
    17  1. Removed ``rank``, ``gpu``, ``multiprocessing-distributed``, ``dist_url`` options.
    18     These are obsolete parameters when using ``torchelastic.distributed.launch``.
    19  
    20  2. Removed ``seed``, ``evaluate``, ``pretrained`` options for simplicity.
    21  
    22  3. Removed ``resume``, ``start-epoch`` options.
    23     Loads the most recent checkpoint by default.
    24  
    25  4. ``batch-size`` is now per GPU (worker) batch size rather than for all GPUs.
    26  
    27  5. Defaults ``workers`` (num data loader workers) to ``0``.
    28  
    29  Usage
    30  
    31  ::
    32  
    33   >>> python -m torchelastic.distributed.launch
    34          --nnodes=$NUM_NODES
    35          --nproc_per_node=$WORKERS_PER_NODE
    36          --rdzv_id=$JOB_ID
    37          --rdzv_backend=etcd
    38          --rdzv_endpoint=$ETCD_HOST:$ETCD_PORT
    39          main.py
    40          --arch resnet18
    41          --epochs 20
    42          --batch-size 32
    43          <DATA_DIR>
    44  """
    45  
    46  import argparse
    47  import io
    48  import os
    49  import shutil
    50  import time
    51  from contextlib import contextmanager
    52  from datetime import timedelta
    53  from typing import List, Tuple
    54  
    55  import numpy
    56  import torch
    57  import torch.distributed as dist
    58  import torch.nn as nn
    59  import torch.nn.parallel
    60  import torch.optim
    61  import torch.utils.data
    62  import torch.utils.data.distributed
    63  import torchvision.datasets as datasets
    64  import torchvision.models as models
    65  import torchvision.transforms as transforms
    66  from torch.distributed.elastic.utils.data import ElasticDistributedSampler
    67  from torch.distributed.elastic.multiprocessing.errors import record
    68  from torch.nn.parallel import DistributedDataParallel
    69  from torch.optim import SGD
    70  from torch.utils.data import DataLoader
    71  
    72  
    73  model_names = sorted(
    74      name
    75      for name in models.__dict__
    76      if name.islower() and not name.startswith("__") and callable(models.__dict__[name])
    77  )
    78  
    79  parser = argparse.ArgumentParser(description="PyTorch Elastic ImageNet Training")
    80  parser.add_argument("data", metavar="DIR", help="path to dataset")
    81  parser.add_argument(
    82      "-a",
    83      "--arch",
    84      metavar="ARCH",
    85      default="resnet18",
    86      choices=model_names,
    87      help="model architecture: " + " | ".join(model_names) + " (default: resnet18)",
    88  )
    89  parser.add_argument(
    90      "-j",
    91      "--workers",
    92      default=0,
    93      type=int,
    94      metavar="N",
    95      help="number of data loading workers",
    96  )
    97  parser.add_argument(
    98      "--epochs", default=90, type=int, metavar="N", help="number of total epochs to run"
    99  )
   100  parser.add_argument(
   101      "-b",
   102      "--batch-size",
   103      default=32,
   104      type=int,
   105      metavar="N",
   106      help="mini-batch size (default: 32), per worker (GPU)",
   107  )
   108  parser.add_argument(
   109      "--lr",
   110      "--learning-rate",
   111      default=0.1,
   112      type=float,
   113      metavar="LR",
   114      help="initial learning rate",
   115      dest="lr",
   116  )
   117  parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
   118  parser.add_argument(
   119      "--wd",
   120      "--weight-decay",
   121      default=1e-4,
   122      type=float,
   123      metavar="W",
   124      help="weight decay (default: 1e-4)",
   125      dest="weight_decay",
   126  )
   127  parser.add_argument(
   128      "-p",
   129      "--print-freq",
   130      default=10,
   131      type=int,
   132      metavar="N",
   133      help="print frequency (default: 10)",
   134  )
   135  parser.add_argument(
   136      "--dist-backend",
   137      default="gloo",
   138      choices=["nccl", "gloo"],
   139      type=str,
   140      help="distributed backend",
   141  )
   142  parser.add_argument(
   143      "--checkpoint-file",
   144      default="/tmp/checkpoint.pth.tar",
   145      type=str,
   146      help="checkpoint file path, to load and save to",
   147  )
   148  
   149  @record
   150  def main():
   151      args = parser.parse_args()
   152      device = torch.device("cpu")
   153  
   154      dist.init_process_group(
   155          backend=args.dist_backend, init_method="env://", timeout=timedelta(seconds=10)
   156      )
   157  
   158      model, criterion, optimizer = initialize_model(
   159          args.arch, args.lr, args.momentum, args.weight_decay, device
   160      )
   161  
   162      train_loader, val_loader = initialize_data_loader(
   163          args.data, args.batch_size, args.workers
   164      )
   165  
   166      # resume from checkpoint if one exists;
   167      state = load_checkpoint(
   168          args.checkpoint_file, args.arch, model, optimizer
   169      )
   170  
   171      start_epoch = state.epoch + 1
   172      print(f"=> start_epoch: {start_epoch}, best_acc1: {state.best_acc1}")
   173  
   174      print_freq = args.print_freq
   175      for epoch in range(start_epoch, args.epochs):
   176          state.epoch = epoch
   177          train_loader.batch_sampler.sampler.set_epoch(epoch)
   178          adjust_learning_rate(optimizer, epoch, args.lr)
   179  
   180          # train for one epoch
   181          train(train_loader, model, criterion, optimizer, epoch, print_freq)
   182  
   183          # evaluate on validation set
   184          acc1 = validate(val_loader, model, criterion, print_freq)
   185  
   186          # remember best acc@1 and save checkpoint
   187          is_best = acc1 > state.best_acc1
   188          state.best_acc1 = max(acc1, state.best_acc1)
   189  
   190          save_checkpoint(state, is_best, args.checkpoint_file)
   191  
   192  
   193  class State:
   194      """
   195      Container for objects that we want to checkpoint. Represents the
   196      current "state" of the worker. This object is mutable.
   197      """
   198  
   199      def __init__(self, arch, model, optimizer):
   200          self.epoch = -1
   201          self.best_acc1 = 0
   202          self.arch = arch
   203          self.model = model
   204          self.optimizer = optimizer
   205  
   206      def capture_snapshot(self):
   207          """
   208          Essentially a ``serialize()`` function, returns the state as an
   209          object compatible with ``torch.save()``. The following should work
   210          ::
   211  
   212          snapshot = state_0.capture_snapshot()
   213          state_1.apply_snapshot(snapshot)
   214          assert state_0 == state_1
   215          """
   216          return {
   217              "epoch": self.epoch,
   218              "best_acc1": self.best_acc1,
   219              "arch": self.arch,
   220              "state_dict": self.model.state_dict(),
   221              "optimizer": self.optimizer.state_dict(),
   222          }
   223  
   224      def apply_snapshot(self, obj):
   225          """
   226          The complimentary function of ``capture_snapshot()``. Applies the
   227          snapshot object that was returned by ``capture_snapshot()``.
   228          This function mutates this state object.
   229          """
   230  
   231          self.epoch = obj["epoch"]
   232          self.best_acc1 = obj["best_acc1"]
   233          self.state_dict = obj["state_dict"]
   234          self.model.load_state_dict(obj["state_dict"])
   235          self.optimizer.load_state_dict(obj["optimizer"])
   236  
   237      def save(self, f):
   238          torch.save(self.capture_snapshot(), f)
   239  
   240      def load(self, f):
   241          # Map model to be loaded to specified single gpu.
   242          snapshot = torch.load(f)
   243          self.apply_snapshot(snapshot)
   244  
   245  
   246  def initialize_model(
   247      arch: str, lr: float, momentum: float, weight_decay: float, device
   248  ):
   249      print(f"=> creating model: {arch}")
   250      model = models.__dict__[arch]()
   251      # For multiprocessing distributed, DistributedDataParallel constructor
   252      # should always set the single device scope, otherwise,
   253      # DistributedDataParallel will use all available devices.
   254      model.to(device)
   255      model = nn.parallel.DistributedDataParallel(model)
   256      # define loss function (criterion) and optimizer
   257      criterion = nn.CrossEntropyLoss()
   258      optimizer = SGD(
   259          model.parameters(), lr, momentum=momentum, weight_decay=weight_decay
   260      )
   261      return model, criterion, optimizer
   262  
   263  
   264  def initialize_data_loader(
   265      data_dir, batch_size, num_data_workers
   266  ) -> Tuple[DataLoader, DataLoader]:
   267      traindir = os.path.join(data_dir, "train")
   268      valdir = os.path.join(data_dir, "val")
   269      normalize = transforms.Normalize(
   270          mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
   271      )
   272      train_dataset = datasets.ImageFolder(
   273          traindir,
   274          transforms.Compose(
   275              [
   276                  transforms.RandomResizedCrop(224),
   277                  transforms.RandomHorizontalFlip(),
   278                  transforms.ToTensor(),
   279                  normalize,
   280              ]
   281          ),
   282      )
   283      train_sampler = ElasticDistributedSampler(train_dataset)
   284      train_loader = DataLoader(
   285          train_dataset,
   286          batch_size=batch_size,
   287          num_workers=num_data_workers,
   288          # pin_memory=True,
   289          sampler=train_sampler,
   290      )
   291      val_loader = DataLoader(
   292          datasets.ImageFolder(
   293              valdir,
   294              transforms.Compose(
   295                  [
   296                      transforms.Resize(256),
   297                      transforms.CenterCrop(224),
   298                      transforms.ToTensor(),
   299                      normalize,
   300                  ]
   301              ),
   302          ),
   303          batch_size=batch_size,
   304          shuffle=False,
   305          num_workers=num_data_workers,
   306          # pin_memory=True,
   307      )
   308      return train_loader, val_loader
   309  
   310  
   311  def load_checkpoint(
   312      checkpoint_file: str,
   313      arch: str,
   314      model: DistributedDataParallel,
   315      optimizer,  # SGD
   316  ) -> State:
   317      """
   318      Loads a local checkpoint (if any). Otherwise, checks to see if any of
   319      the neighbors have a non-zero state. If so, restore the state
   320      from the rank that has the most up-to-date checkpoint.
   321  
   322      .. note:: when your job has access to a globally visible persistent storage
   323                (e.g. nfs mount, S3) you can simply have all workers load
   324                from the most recent checkpoint from such storage. Since this
   325                example is expected to run on vanilla hosts (with no shared
   326                storage) the checkpoints are written to local disk, hence
   327                we have the extra logic to broadcast the checkpoint from a
   328                surviving node.
   329      """
   330  
   331      state = State(arch, model, optimizer)
   332  
   333      if os.path.isfile(checkpoint_file):
   334          print(f"=> loading checkpoint file: {checkpoint_file}")
   335          state.load(checkpoint_file)
   336          print(f"=> loaded checkpoint file: {checkpoint_file}")
   337  
   338      # logic below is unnecessary when the checkpoint is visible on all nodes!
   339      # create a temporary cpu pg to broadcast most up-to-date checkpoint
   340      with tmp_process_group(backend="gloo") as pg:
   341          rank = dist.get_rank(group=pg)
   342  
   343          # get rank that has the largest state.epoch
   344          epochs = torch.zeros(dist.get_world_size(), dtype=torch.int32)
   345          epochs[rank] = state.epoch
   346          dist.all_reduce(epochs, op=dist.ReduceOp.SUM, group=pg)
   347          t_max_epoch, t_max_rank = torch.max(epochs, dim=0)
   348          max_epoch = t_max_epoch.item()
   349          max_rank = t_max_rank.item()
   350  
   351          # max_epoch == -1 means no one has checkpointed return base state
   352          if max_epoch == -1:
   353              print(f"=> no workers have checkpoints, starting from epoch 0")
   354              return state
   355  
   356          # broadcast the state from max_rank (which has the most up-to-date state)
   357          # pickle the snapshot, convert it into a byte-blob tensor
   358          # then broadcast it, unpickle it and apply the snapshot
   359          print(f"=> using checkpoint from rank: {max_rank}, max_epoch: {max_epoch}")
   360  
   361          with io.BytesIO() as f:
   362              torch.save(state.capture_snapshot(), f)
   363              raw_blob = numpy.frombuffer(f.getvalue(), dtype=numpy.uint8)
   364  
   365          blob_len = torch.tensor(len(raw_blob))
   366          dist.broadcast(blob_len, src=max_rank, group=pg)
   367          print(f"=> checkpoint broadcast size is: {blob_len}")
   368  
   369          if rank != max_rank:
   370              blob = torch.zeros(blob_len.item(), dtype=torch.uint8)
   371          else:
   372              blob = torch.as_tensor(raw_blob, dtype=torch.uint8)
   373  
   374          dist.broadcast(blob, src=max_rank, group=pg)
   375          print(f"=> done broadcasting checkpoint")
   376  
   377          if rank != max_rank:
   378              with io.BytesIO(blob.numpy()) as f:
   379                  snapshot = torch.load(f)
   380              state.apply_snapshot(snapshot)
   381  
   382          # wait till everyone has loaded the checkpoint
   383          dist.barrier(group=pg)
   384  
   385      print(f"=> done restoring from previous checkpoint")
   386      return state
   387  
   388  
   389  @contextmanager
   390  def tmp_process_group(backend):
   391      cpu_pg = dist.new_group(backend=backend)
   392      try:
   393          yield cpu_pg
   394      finally:
   395          dist.destroy_process_group(cpu_pg)
   396  
   397  
   398  def save_checkpoint(state: State, is_best: bool, filename: str):
   399      checkpoint_dir = os.path.dirname(filename)
   400      os.makedirs(checkpoint_dir, exist_ok=True)
   401  
   402      # save to tmp, then commit by moving the file in case the job
   403      # gets interrupted while writing the checkpoint
   404      tmp_filename = filename + ".tmp"
   405      torch.save(state.capture_snapshot(), tmp_filename)
   406      os.rename(tmp_filename, filename)
   407      print(f"=> saved checkpoint for epoch {state.epoch} at {filename}")
   408      if is_best:
   409          best = os.path.join(checkpoint_dir, "model_best.pth.tar")
   410          print(f"=> best model found at epoch {state.epoch} saving to {best}")
   411          shutil.copyfile(filename, best)
   412  
   413  
   414  def train(
   415      train_loader: DataLoader,
   416      model: DistributedDataParallel,
   417      criterion,  # nn.CrossEntropyLoss
   418      optimizer,  # SGD,
   419      epoch: int,
   420      print_freq: int,
   421  ):
   422      batch_time = AverageMeter("Time", ":6.3f")
   423      data_time = AverageMeter("Data", ":6.3f")
   424      losses = AverageMeter("Loss", ":.4e")
   425      top1 = AverageMeter("Acc@1", ":6.2f")
   426      top5 = AverageMeter("Acc@5", ":6.2f")
   427      progress = ProgressMeter(
   428          len(train_loader),
   429          [batch_time, data_time, losses, top1, top5],
   430          prefix="Epoch: [{}]".format(epoch),
   431      )
   432  
   433      # switch to train mode
   434      model.train()
   435  
   436      end = time.time()
   437      for i, (images, target) in enumerate(train_loader):
   438          # measure data loading time
   439          data_time.update(time.time() - end)
   440  
   441          # compute output
   442          output = model(images)
   443          loss = criterion(output, target)
   444  
   445          # measure accuracy and record loss
   446          acc1, acc5 = accuracy(output, target, topk=(1, 5))
   447          losses.update(loss.item(), images.size(0))
   448          top1.update(acc1[0], images.size(0))
   449          top5.update(acc5[0], images.size(0))
   450  
   451          # compute gradient and do SGD step
   452          optimizer.zero_grad()
   453          loss.backward()
   454          optimizer.step()
   455  
   456          # measure elapsed time
   457          batch_time.update(time.time() - end)
   458          end = time.time()
   459  
   460          if i % print_freq == 0:
   461              progress.display(i)
   462  
   463  
   464  def validate(
   465      val_loader: DataLoader,
   466      model: DistributedDataParallel,
   467      criterion,  # nn.CrossEntropyLoss
   468      print_freq: int,
   469  ):
   470      batch_time = AverageMeter("Time", ":6.3f")
   471      losses = AverageMeter("Loss", ":.4e")
   472      top1 = AverageMeter("Acc@1", ":6.2f")
   473      top5 = AverageMeter("Acc@5", ":6.2f")
   474      progress = ProgressMeter(
   475          len(val_loader), [batch_time, losses, top1, top5], prefix="Test: "
   476      )
   477  
   478      # switch to evaluate mode
   479      model.eval()
   480  
   481      with torch.no_grad():
   482          end = time.time()
   483          for i, (images, target) in enumerate(val_loader):
   484              # compute output
   485              output = model(images)
   486              loss = criterion(output, target)
   487  
   488              # measure accuracy and record loss
   489              acc1, acc5 = accuracy(output, target, topk=(1, 5))
   490              losses.update(loss.item(), images.size(0))
   491              top1.update(acc1[0], images.size(0))
   492              top5.update(acc5[0], images.size(0))
   493  
   494              # measure elapsed time
   495              batch_time.update(time.time() - end)
   496              end = time.time()
   497  
   498              if i % print_freq == 0:
   499                  progress.display(i)
   500  
   501          # TODO: this should also be done with the ProgressMeter
   502          print(
   503              " * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}".format(top1=top1, top5=top5)
   504          )
   505  
   506      return top1.avg
   507  
   508  
   509  class AverageMeter(object):
   510      """Computes and stores the average and current value"""
   511  
   512      def __init__(self, name: str, fmt: str = ":f"):
   513          self.name = name
   514          self.fmt = fmt
   515          self.reset()
   516  
   517      def reset(self) -> None:
   518          self.val = 0
   519          self.avg = 0
   520          self.sum = 0
   521          self.count = 0
   522  
   523      def update(self, val, n=1) -> None:
   524          self.val = val
   525          self.sum += val * n
   526          self.count += n
   527          self.avg = self.sum / self.count
   528  
   529      def __str__(self):
   530          fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
   531          return fmtstr.format(**self.__dict__)
   532  
   533  
   534  class ProgressMeter(object):
   535      def __init__(self, num_batches: int, meters: List[AverageMeter], prefix: str = ""):
   536          self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
   537          self.meters = meters
   538          self.prefix = prefix
   539  
   540      def display(self, batch: int) -> None:
   541          entries = [self.prefix + self.batch_fmtstr.format(batch)]
   542          entries += [str(meter) for meter in self.meters]
   543          print("\t".join(entries))
   544  
   545      def _get_batch_fmtstr(self, num_batches: int) -> str:
   546          num_digits = len(str(num_batches // 1))
   547          fmt = "{:" + str(num_digits) + "d}"
   548          return "[" + fmt + "/" + fmt.format(num_batches) + "]"
   549  
   550  
   551  def adjust_learning_rate(optimizer, epoch: int, lr: float) -> None:
   552      """
   553      Sets the learning rate to the initial LR decayed by 10 every 30 epochs
   554      """
   555      learning_rate = lr * (0.1 ** (epoch // 30))
   556      for param_group in optimizer.param_groups:
   557          param_group["lr"] = learning_rate
   558  
   559  
   560  def accuracy(output, target, topk=(1,)):
   561      """
   562      Computes the accuracy over the k top predictions for the specified values of k
   563      """
   564      with torch.no_grad():
   565          maxk = max(topk)
   566          batch_size = target.size(0)
   567  
   568          _, pred = output.topk(maxk, 1, True, True)
   569          pred = pred.t()
   570          correct = pred.eq(target.view(1, -1).expand_as(pred))
   571  
   572          res = []
   573          for k in topk:
   574              correct_k = correct[:k].reshape(1, -1).view(-1).float().sum(0, keepdim=True)
   575              res.append(correct_k.mul_(100.0 / batch_size))
   576          return res
   577  
   578  
   579  if __name__ == "__main__":
   580      main()