github.com/kubeflow/training-operator@v1.7.0/examples/pytorch/elastic/imagenet/imagenet.py (about) 1 #!/usr/bin/env python3 2 3 # Copyright (c) Facebook, Inc. and its affiliates. 4 # All rights reserved. 5 # 6 # This source code is licensed under the BSD-style license found in the 7 # LICENSE file in the root directory of this source tree. 8 9 r""" 10 Source: `pytorch imagenet example <https://github.com/pytorch/examples/blob/master/imagenet/main.py>`_ # noqa B950 11 12 Modified and simplified to make the original pytorch example compatible with 13 torchelastic.distributed.launch. 14 15 Changes: 16 17 1. Removed ``rank``, ``gpu``, ``multiprocessing-distributed``, ``dist_url`` options. 18 These are obsolete parameters when using ``torchelastic.distributed.launch``. 19 20 2. Removed ``seed``, ``evaluate``, ``pretrained`` options for simplicity. 21 22 3. Removed ``resume``, ``start-epoch`` options. 23 Loads the most recent checkpoint by default. 24 25 4. ``batch-size`` is now per GPU (worker) batch size rather than for all GPUs. 26 27 5. Defaults ``workers`` (num data loader workers) to ``0``. 28 29 Usage 30 31 :: 32 33 >>> python -m torchelastic.distributed.launch 34 --nnodes=$NUM_NODES 35 --nproc_per_node=$WORKERS_PER_NODE 36 --rdzv_id=$JOB_ID 37 --rdzv_backend=etcd 38 --rdzv_endpoint=$ETCD_HOST:$ETCD_PORT 39 main.py 40 --arch resnet18 41 --epochs 20 42 --batch-size 32 43 <DATA_DIR> 44 """ 45 46 import argparse 47 import io 48 import os 49 import shutil 50 import time 51 from contextlib import contextmanager 52 from datetime import timedelta 53 from typing import List, Tuple 54 55 import numpy 56 import torch 57 import torch.distributed as dist 58 import torch.nn as nn 59 import torch.nn.parallel 60 import torch.optim 61 import torch.utils.data 62 import torch.utils.data.distributed 63 import torchvision.datasets as datasets 64 import torchvision.models as models 65 import torchvision.transforms as transforms 66 from torch.distributed.elastic.utils.data import ElasticDistributedSampler 67 from torch.distributed.elastic.multiprocessing.errors import record 68 from torch.nn.parallel import DistributedDataParallel 69 from torch.optim import SGD 70 from torch.utils.data import DataLoader 71 72 73 model_names = sorted( 74 name 75 for name in models.__dict__ 76 if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) 77 ) 78 79 parser = argparse.ArgumentParser(description="PyTorch Elastic ImageNet Training") 80 parser.add_argument("data", metavar="DIR", help="path to dataset") 81 parser.add_argument( 82 "-a", 83 "--arch", 84 metavar="ARCH", 85 default="resnet18", 86 choices=model_names, 87 help="model architecture: " + " | ".join(model_names) + " (default: resnet18)", 88 ) 89 parser.add_argument( 90 "-j", 91 "--workers", 92 default=0, 93 type=int, 94 metavar="N", 95 help="number of data loading workers", 96 ) 97 parser.add_argument( 98 "--epochs", default=90, type=int, metavar="N", help="number of total epochs to run" 99 ) 100 parser.add_argument( 101 "-b", 102 "--batch-size", 103 default=32, 104 type=int, 105 metavar="N", 106 help="mini-batch size (default: 32), per worker (GPU)", 107 ) 108 parser.add_argument( 109 "--lr", 110 "--learning-rate", 111 default=0.1, 112 type=float, 113 metavar="LR", 114 help="initial learning rate", 115 dest="lr", 116 ) 117 parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") 118 parser.add_argument( 119 "--wd", 120 "--weight-decay", 121 default=1e-4, 122 type=float, 123 metavar="W", 124 help="weight decay (default: 1e-4)", 125 dest="weight_decay", 126 ) 127 parser.add_argument( 128 "-p", 129 "--print-freq", 130 default=10, 131 type=int, 132 metavar="N", 133 help="print frequency (default: 10)", 134 ) 135 parser.add_argument( 136 "--dist-backend", 137 default="gloo", 138 choices=["nccl", "gloo"], 139 type=str, 140 help="distributed backend", 141 ) 142 parser.add_argument( 143 "--checkpoint-file", 144 default="/tmp/checkpoint.pth.tar", 145 type=str, 146 help="checkpoint file path, to load and save to", 147 ) 148 149 @record 150 def main(): 151 args = parser.parse_args() 152 device = torch.device("cpu") 153 154 dist.init_process_group( 155 backend=args.dist_backend, init_method="env://", timeout=timedelta(seconds=10) 156 ) 157 158 model, criterion, optimizer = initialize_model( 159 args.arch, args.lr, args.momentum, args.weight_decay, device 160 ) 161 162 train_loader, val_loader = initialize_data_loader( 163 args.data, args.batch_size, args.workers 164 ) 165 166 # resume from checkpoint if one exists; 167 state = load_checkpoint( 168 args.checkpoint_file, args.arch, model, optimizer 169 ) 170 171 start_epoch = state.epoch + 1 172 print(f"=> start_epoch: {start_epoch}, best_acc1: {state.best_acc1}") 173 174 print_freq = args.print_freq 175 for epoch in range(start_epoch, args.epochs): 176 state.epoch = epoch 177 train_loader.batch_sampler.sampler.set_epoch(epoch) 178 adjust_learning_rate(optimizer, epoch, args.lr) 179 180 # train for one epoch 181 train(train_loader, model, criterion, optimizer, epoch, print_freq) 182 183 # evaluate on validation set 184 acc1 = validate(val_loader, model, criterion, print_freq) 185 186 # remember best acc@1 and save checkpoint 187 is_best = acc1 > state.best_acc1 188 state.best_acc1 = max(acc1, state.best_acc1) 189 190 save_checkpoint(state, is_best, args.checkpoint_file) 191 192 193 class State: 194 """ 195 Container for objects that we want to checkpoint. Represents the 196 current "state" of the worker. This object is mutable. 197 """ 198 199 def __init__(self, arch, model, optimizer): 200 self.epoch = -1 201 self.best_acc1 = 0 202 self.arch = arch 203 self.model = model 204 self.optimizer = optimizer 205 206 def capture_snapshot(self): 207 """ 208 Essentially a ``serialize()`` function, returns the state as an 209 object compatible with ``torch.save()``. The following should work 210 :: 211 212 snapshot = state_0.capture_snapshot() 213 state_1.apply_snapshot(snapshot) 214 assert state_0 == state_1 215 """ 216 return { 217 "epoch": self.epoch, 218 "best_acc1": self.best_acc1, 219 "arch": self.arch, 220 "state_dict": self.model.state_dict(), 221 "optimizer": self.optimizer.state_dict(), 222 } 223 224 def apply_snapshot(self, obj): 225 """ 226 The complimentary function of ``capture_snapshot()``. Applies the 227 snapshot object that was returned by ``capture_snapshot()``. 228 This function mutates this state object. 229 """ 230 231 self.epoch = obj["epoch"] 232 self.best_acc1 = obj["best_acc1"] 233 self.state_dict = obj["state_dict"] 234 self.model.load_state_dict(obj["state_dict"]) 235 self.optimizer.load_state_dict(obj["optimizer"]) 236 237 def save(self, f): 238 torch.save(self.capture_snapshot(), f) 239 240 def load(self, f): 241 # Map model to be loaded to specified single gpu. 242 snapshot = torch.load(f) 243 self.apply_snapshot(snapshot) 244 245 246 def initialize_model( 247 arch: str, lr: float, momentum: float, weight_decay: float, device 248 ): 249 print(f"=> creating model: {arch}") 250 model = models.__dict__[arch]() 251 # For multiprocessing distributed, DistributedDataParallel constructor 252 # should always set the single device scope, otherwise, 253 # DistributedDataParallel will use all available devices. 254 model.to(device) 255 model = nn.parallel.DistributedDataParallel(model) 256 # define loss function (criterion) and optimizer 257 criterion = nn.CrossEntropyLoss() 258 optimizer = SGD( 259 model.parameters(), lr, momentum=momentum, weight_decay=weight_decay 260 ) 261 return model, criterion, optimizer 262 263 264 def initialize_data_loader( 265 data_dir, batch_size, num_data_workers 266 ) -> Tuple[DataLoader, DataLoader]: 267 traindir = os.path.join(data_dir, "train") 268 valdir = os.path.join(data_dir, "val") 269 normalize = transforms.Normalize( 270 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 271 ) 272 train_dataset = datasets.ImageFolder( 273 traindir, 274 transforms.Compose( 275 [ 276 transforms.RandomResizedCrop(224), 277 transforms.RandomHorizontalFlip(), 278 transforms.ToTensor(), 279 normalize, 280 ] 281 ), 282 ) 283 train_sampler = ElasticDistributedSampler(train_dataset) 284 train_loader = DataLoader( 285 train_dataset, 286 batch_size=batch_size, 287 num_workers=num_data_workers, 288 # pin_memory=True, 289 sampler=train_sampler, 290 ) 291 val_loader = DataLoader( 292 datasets.ImageFolder( 293 valdir, 294 transforms.Compose( 295 [ 296 transforms.Resize(256), 297 transforms.CenterCrop(224), 298 transforms.ToTensor(), 299 normalize, 300 ] 301 ), 302 ), 303 batch_size=batch_size, 304 shuffle=False, 305 num_workers=num_data_workers, 306 # pin_memory=True, 307 ) 308 return train_loader, val_loader 309 310 311 def load_checkpoint( 312 checkpoint_file: str, 313 arch: str, 314 model: DistributedDataParallel, 315 optimizer, # SGD 316 ) -> State: 317 """ 318 Loads a local checkpoint (if any). Otherwise, checks to see if any of 319 the neighbors have a non-zero state. If so, restore the state 320 from the rank that has the most up-to-date checkpoint. 321 322 .. note:: when your job has access to a globally visible persistent storage 323 (e.g. nfs mount, S3) you can simply have all workers load 324 from the most recent checkpoint from such storage. Since this 325 example is expected to run on vanilla hosts (with no shared 326 storage) the checkpoints are written to local disk, hence 327 we have the extra logic to broadcast the checkpoint from a 328 surviving node. 329 """ 330 331 state = State(arch, model, optimizer) 332 333 if os.path.isfile(checkpoint_file): 334 print(f"=> loading checkpoint file: {checkpoint_file}") 335 state.load(checkpoint_file) 336 print(f"=> loaded checkpoint file: {checkpoint_file}") 337 338 # logic below is unnecessary when the checkpoint is visible on all nodes! 339 # create a temporary cpu pg to broadcast most up-to-date checkpoint 340 with tmp_process_group(backend="gloo") as pg: 341 rank = dist.get_rank(group=pg) 342 343 # get rank that has the largest state.epoch 344 epochs = torch.zeros(dist.get_world_size(), dtype=torch.int32) 345 epochs[rank] = state.epoch 346 dist.all_reduce(epochs, op=dist.ReduceOp.SUM, group=pg) 347 t_max_epoch, t_max_rank = torch.max(epochs, dim=0) 348 max_epoch = t_max_epoch.item() 349 max_rank = t_max_rank.item() 350 351 # max_epoch == -1 means no one has checkpointed return base state 352 if max_epoch == -1: 353 print(f"=> no workers have checkpoints, starting from epoch 0") 354 return state 355 356 # broadcast the state from max_rank (which has the most up-to-date state) 357 # pickle the snapshot, convert it into a byte-blob tensor 358 # then broadcast it, unpickle it and apply the snapshot 359 print(f"=> using checkpoint from rank: {max_rank}, max_epoch: {max_epoch}") 360 361 with io.BytesIO() as f: 362 torch.save(state.capture_snapshot(), f) 363 raw_blob = numpy.frombuffer(f.getvalue(), dtype=numpy.uint8) 364 365 blob_len = torch.tensor(len(raw_blob)) 366 dist.broadcast(blob_len, src=max_rank, group=pg) 367 print(f"=> checkpoint broadcast size is: {blob_len}") 368 369 if rank != max_rank: 370 blob = torch.zeros(blob_len.item(), dtype=torch.uint8) 371 else: 372 blob = torch.as_tensor(raw_blob, dtype=torch.uint8) 373 374 dist.broadcast(blob, src=max_rank, group=pg) 375 print(f"=> done broadcasting checkpoint") 376 377 if rank != max_rank: 378 with io.BytesIO(blob.numpy()) as f: 379 snapshot = torch.load(f) 380 state.apply_snapshot(snapshot) 381 382 # wait till everyone has loaded the checkpoint 383 dist.barrier(group=pg) 384 385 print(f"=> done restoring from previous checkpoint") 386 return state 387 388 389 @contextmanager 390 def tmp_process_group(backend): 391 cpu_pg = dist.new_group(backend=backend) 392 try: 393 yield cpu_pg 394 finally: 395 dist.destroy_process_group(cpu_pg) 396 397 398 def save_checkpoint(state: State, is_best: bool, filename: str): 399 checkpoint_dir = os.path.dirname(filename) 400 os.makedirs(checkpoint_dir, exist_ok=True) 401 402 # save to tmp, then commit by moving the file in case the job 403 # gets interrupted while writing the checkpoint 404 tmp_filename = filename + ".tmp" 405 torch.save(state.capture_snapshot(), tmp_filename) 406 os.rename(tmp_filename, filename) 407 print(f"=> saved checkpoint for epoch {state.epoch} at {filename}") 408 if is_best: 409 best = os.path.join(checkpoint_dir, "model_best.pth.tar") 410 print(f"=> best model found at epoch {state.epoch} saving to {best}") 411 shutil.copyfile(filename, best) 412 413 414 def train( 415 train_loader: DataLoader, 416 model: DistributedDataParallel, 417 criterion, # nn.CrossEntropyLoss 418 optimizer, # SGD, 419 epoch: int, 420 print_freq: int, 421 ): 422 batch_time = AverageMeter("Time", ":6.3f") 423 data_time = AverageMeter("Data", ":6.3f") 424 losses = AverageMeter("Loss", ":.4e") 425 top1 = AverageMeter("Acc@1", ":6.2f") 426 top5 = AverageMeter("Acc@5", ":6.2f") 427 progress = ProgressMeter( 428 len(train_loader), 429 [batch_time, data_time, losses, top1, top5], 430 prefix="Epoch: [{}]".format(epoch), 431 ) 432 433 # switch to train mode 434 model.train() 435 436 end = time.time() 437 for i, (images, target) in enumerate(train_loader): 438 # measure data loading time 439 data_time.update(time.time() - end) 440 441 # compute output 442 output = model(images) 443 loss = criterion(output, target) 444 445 # measure accuracy and record loss 446 acc1, acc5 = accuracy(output, target, topk=(1, 5)) 447 losses.update(loss.item(), images.size(0)) 448 top1.update(acc1[0], images.size(0)) 449 top5.update(acc5[0], images.size(0)) 450 451 # compute gradient and do SGD step 452 optimizer.zero_grad() 453 loss.backward() 454 optimizer.step() 455 456 # measure elapsed time 457 batch_time.update(time.time() - end) 458 end = time.time() 459 460 if i % print_freq == 0: 461 progress.display(i) 462 463 464 def validate( 465 val_loader: DataLoader, 466 model: DistributedDataParallel, 467 criterion, # nn.CrossEntropyLoss 468 print_freq: int, 469 ): 470 batch_time = AverageMeter("Time", ":6.3f") 471 losses = AverageMeter("Loss", ":.4e") 472 top1 = AverageMeter("Acc@1", ":6.2f") 473 top5 = AverageMeter("Acc@5", ":6.2f") 474 progress = ProgressMeter( 475 len(val_loader), [batch_time, losses, top1, top5], prefix="Test: " 476 ) 477 478 # switch to evaluate mode 479 model.eval() 480 481 with torch.no_grad(): 482 end = time.time() 483 for i, (images, target) in enumerate(val_loader): 484 # compute output 485 output = model(images) 486 loss = criterion(output, target) 487 488 # measure accuracy and record loss 489 acc1, acc5 = accuracy(output, target, topk=(1, 5)) 490 losses.update(loss.item(), images.size(0)) 491 top1.update(acc1[0], images.size(0)) 492 top5.update(acc5[0], images.size(0)) 493 494 # measure elapsed time 495 batch_time.update(time.time() - end) 496 end = time.time() 497 498 if i % print_freq == 0: 499 progress.display(i) 500 501 # TODO: this should also be done with the ProgressMeter 502 print( 503 " * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}".format(top1=top1, top5=top5) 504 ) 505 506 return top1.avg 507 508 509 class AverageMeter(object): 510 """Computes and stores the average and current value""" 511 512 def __init__(self, name: str, fmt: str = ":f"): 513 self.name = name 514 self.fmt = fmt 515 self.reset() 516 517 def reset(self) -> None: 518 self.val = 0 519 self.avg = 0 520 self.sum = 0 521 self.count = 0 522 523 def update(self, val, n=1) -> None: 524 self.val = val 525 self.sum += val * n 526 self.count += n 527 self.avg = self.sum / self.count 528 529 def __str__(self): 530 fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 531 return fmtstr.format(**self.__dict__) 532 533 534 class ProgressMeter(object): 535 def __init__(self, num_batches: int, meters: List[AverageMeter], prefix: str = ""): 536 self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 537 self.meters = meters 538 self.prefix = prefix 539 540 def display(self, batch: int) -> None: 541 entries = [self.prefix + self.batch_fmtstr.format(batch)] 542 entries += [str(meter) for meter in self.meters] 543 print("\t".join(entries)) 544 545 def _get_batch_fmtstr(self, num_batches: int) -> str: 546 num_digits = len(str(num_batches // 1)) 547 fmt = "{:" + str(num_digits) + "d}" 548 return "[" + fmt + "/" + fmt.format(num_batches) + "]" 549 550 551 def adjust_learning_rate(optimizer, epoch: int, lr: float) -> None: 552 """ 553 Sets the learning rate to the initial LR decayed by 10 every 30 epochs 554 """ 555 learning_rate = lr * (0.1 ** (epoch // 30)) 556 for param_group in optimizer.param_groups: 557 param_group["lr"] = learning_rate 558 559 560 def accuracy(output, target, topk=(1,)): 561 """ 562 Computes the accuracy over the k top predictions for the specified values of k 563 """ 564 with torch.no_grad(): 565 maxk = max(topk) 566 batch_size = target.size(0) 567 568 _, pred = output.topk(maxk, 1, True, True) 569 pred = pred.t() 570 correct = pred.eq(target.view(1, -1).expand_as(pred)) 571 572 res = [] 573 for k in topk: 574 correct_k = correct[:k].reshape(1, -1).view(-1).float().sum(0, keepdim=True) 575 res.append(correct_k.mul_(100.0 / batch_size)) 576 return res 577 578 579 if __name__ == "__main__": 580 main()