github.com/kubeflow/training-operator@v1.7.0/examples/pytorch/mnist/mnist.py (about)

     1  from __future__ import print_function
     2  
     3  import argparse
     4  import os
     5  
     6  from tensorboardX import SummaryWriter
     7  from torchvision import datasets, transforms
     8  import torch
     9  import torch.distributed as dist
    10  import torch.nn as nn
    11  import torch.nn.functional as F
    12  import torch.optim as optim
    13  
    14  WORLD_SIZE = int(os.environ.get('WORLD_SIZE', 1))
    15  
    16  
    17  class Net(nn.Module):
    18      def __init__(self):
    19          super(Net, self).__init__()
    20          self.conv1 = nn.Conv2d(1, 20, 5, 1)
    21          self.conv2 = nn.Conv2d(20, 50, 5, 1)
    22          self.fc1 = nn.Linear(4*4*50, 500)
    23          self.fc2 = nn.Linear(500, 10)
    24  
    25      def forward(self, x):
    26          x = F.relu(self.conv1(x))
    27          x = F.max_pool2d(x, 2, 2)
    28          x = F.relu(self.conv2(x))
    29          x = F.max_pool2d(x, 2, 2)
    30          x = x.view(-1, 4*4*50)
    31          x = F.relu(self.fc1(x))
    32          x = self.fc2(x)
    33          return F.log_softmax(x, dim=1)
    34      
    35  def train(args, model, device, train_loader, optimizer, epoch, writer):
    36      model.train()
    37      for batch_idx, (data, target) in enumerate(train_loader):
    38          data, target = data.to(device), target.to(device)
    39          optimizer.zero_grad()
    40          output = model(data)
    41          loss = F.nll_loss(output, target)
    42          loss.backward()
    43          optimizer.step()
    44          if batch_idx % args.log_interval == 0:
    45              print('Train Epoch: {} [{}/{} ({:.0f}%)]\tloss={:.4f}'.format(
    46                  epoch, batch_idx * len(data), len(train_loader.dataset),
    47                  100. * batch_idx / len(train_loader), loss.item()))
    48              niter = epoch * len(train_loader) + batch_idx
    49              writer.add_scalar('loss', loss.item(), niter)
    50  
    51  def test(args, model, device, test_loader, writer, epoch):
    52      model.eval()
    53      test_loss = 0
    54      correct = 0
    55      with torch.no_grad():
    56          for data, target in test_loader:
    57              data, target = data.to(device), target.to(device)
    58              output = model(data)
    59              test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
    60              pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
    61              correct += pred.eq(target.view_as(pred)).sum().item()
    62  
    63      test_loss /= len(test_loader.dataset)
    64      print('\naccuracy={:.4f}\n'.format(float(correct) / len(test_loader.dataset)))
    65      writer.add_scalar('accuracy', float(correct) / len(test_loader.dataset), epoch)
    66  
    67  
    68  def should_distribute():
    69      return dist.is_available() and WORLD_SIZE > 1
    70  
    71  
    72  def is_distributed():
    73      return dist.is_available() and dist.is_initialized()
    74  
    75  
    76  def main():
    77      # Training settings
    78      parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    79      parser.add_argument('--batch-size', type=int, default=64, metavar='N',
    80                          help='input batch size for training (default: 64)')
    81      parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
    82                          help='input batch size for testing (default: 1000)')
    83      parser.add_argument('--epochs', type=int, default=1, metavar='N',
    84                          help='number of epochs to train (default: 10)')
    85      parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
    86                          help='learning rate (default: 0.01)')
    87      parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
    88                          help='SGD momentum (default: 0.5)')
    89      parser.add_argument('--no-cuda', action='store_true', default=False,
    90                          help='disables CUDA training')
    91      parser.add_argument('--seed', type=int, default=1, metavar='S',
    92                          help='random seed (default: 1)')
    93      parser.add_argument('--log-interval', type=int, default=10, metavar='N',
    94                          help='how many batches to wait before logging training status')
    95      parser.add_argument('--save-model', action='store_true', default=False,
    96                          help='For Saving the current Model')
    97      parser.add_argument('--dir', default='logs', metavar='L',
    98                          help='directory where summary logs are stored')
    99      if dist.is_available():
   100          parser.add_argument('--backend', type=str, help='Distributed backend',
   101                              choices=[dist.Backend.GLOO, dist.Backend.NCCL, dist.Backend.MPI],
   102                              default=dist.Backend.GLOO)
   103      args = parser.parse_args()
   104      use_cuda = not args.no_cuda and torch.cuda.is_available()
   105      if use_cuda:
   106          print('Using CUDA')
   107  
   108      writer = SummaryWriter(args.dir)
   109  
   110      torch.manual_seed(args.seed)
   111  
   112      device = torch.device("cuda" if use_cuda else "cpu")
   113  
   114      if should_distribute():
   115          print('Using distributed PyTorch with {} backend'.format(args.backend))
   116          dist.init_process_group(backend=args.backend)
   117  
   118      kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
   119      train_loader = torch.utils.data.DataLoader(
   120          datasets.FashionMNIST('../data', train=True, download=True,
   121                         transform=transforms.Compose([
   122                             transforms.ToTensor(),
   123                             transforms.Normalize((0.1307,), (0.3081,))
   124                         ])),
   125          batch_size=args.batch_size, shuffle=True, **kwargs)
   126      test_loader = torch.utils.data.DataLoader(
   127          datasets.FashionMNIST('../data', train=False, transform=transforms.Compose([
   128                             transforms.ToTensor(),
   129                             transforms.Normalize((0.1307,), (0.3081,))
   130                         ])),
   131          batch_size=args.test_batch_size, shuffle=False, **kwargs)
   132  
   133      model = Net().to(device)
   134  
   135      if is_distributed():
   136          Distributor = nn.parallel.DistributedDataParallel if use_cuda \
   137              else nn.parallel.DistributedDataParallelCPU
   138          model = Distributor(model)
   139  
   140      optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
   141  
   142      for epoch in range(1, args.epochs + 1):
   143          train(args, model, device, train_loader, optimizer, epoch, writer)
   144          test(args, model, device, test_loader, writer, epoch)
   145  
   146      if (args.save_model):
   147          torch.save(model.state_dict(),"mnist_cnn.pt")
   148          
   149  if __name__ == '__main__':
   150      main()