github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/docs/examples/etl-imagenet-wd/pytorch_wd.py (about) 1 import argparse 2 import os 3 import random 4 import shutil 5 import time 6 import warnings 7 8 import aistore 9 from aistore.client import Bck, Client 10 from aistore.client.transform import WDTransform 11 12 import webdataset as wds 13 14 import torch 15 import torch.nn as nn 16 import torch.nn.parallel 17 import torch.backends.cudnn as cudnn 18 import torch.optim 19 import torch.utils.data 20 import torch.utils.data.distributed 21 import torchvision.models as models 22 import torchvision.transforms as transforms 23 24 model_names = sorted( 25 name 26 for name in models.__dict__ 27 if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) 28 ) 29 30 parser = argparse.ArgumentParser(description="PyTorch ImageNet Training") 31 parser.add_argument( 32 "-a", 33 "--arch", 34 metavar="ARCH", 35 default="resnet18", 36 choices=model_names, 37 help="model architecture: " + " | ".join(model_names) + " (default: resnet18)", 38 ) 39 parser.add_argument( 40 "-j", 41 "--workers", 42 default=4, 43 type=int, 44 metavar="N", 45 help="number of data loading workers (default: 4)", 46 ) 47 parser.add_argument( 48 "--epochs", 49 default=90, 50 type=int, 51 metavar="N", 52 help="number of total epochs to run", 53 ) 54 parser.add_argument( 55 "--start-epoch", 56 default=0, 57 type=int, 58 metavar="N", 59 help="manual epoch number (useful on restarts)", 60 ) 61 parser.add_argument( 62 "-b", 63 "--batch-size", 64 default=256, 65 type=int, 66 metavar="N", 67 help="mini-batch size (default: 256), this is the total " 68 "batch size of all GPUs on the current node when " 69 "using Data Parallel or Distributed Data Parallel", 70 ) 71 parser.add_argument( 72 "--lr", 73 "--learning-rate", 74 default=0.1, 75 type=float, 76 metavar="LR", 77 help="initial learning rate", 78 dest="lr", 79 ) 80 parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") 81 parser.add_argument( 82 "--wd", 83 "--weight-decay", 84 default=1e-4, 85 type=float, 86 metavar="W", 87 help="weight decay (default: 1e-4)", 88 dest="weight_decay", 89 ) 90 parser.add_argument( 91 "-p", 92 "--print-freq", 93 default=10, 94 type=int, 95 metavar="N", 96 help="print frequency (default: 10)", 97 ) 98 parser.add_argument( 99 "--resume", 100 default="", 101 type=str, 102 metavar="PATH", 103 help="path to latest checkpoint (default: none)", 104 ) 105 parser.add_argument( 106 "-e", 107 "--evaluate", 108 dest="evaluate", 109 action="store_true", 110 help="evaluate model on validation set", 111 ) 112 parser.add_argument( 113 "--pretrained", 114 dest="pretrained", 115 action="store_true", 116 help="use pre-trained model", 117 ) 118 parser.add_argument( 119 "--seed", default=None, type=int, help="seed for initializing training" 120 ) 121 parser.add_argument("--gpu", default=None, type=int, help="GPU id to use.") 122 parser.add_argument( 123 "--train-shards", 124 default="imagenet-train-{000000..000005}.tar", 125 type=str, 126 help="template for shards to use for training", 127 ) 128 parser.add_argument( 129 "--val-shards", 130 default="imagenet-train-{000000..000005}.tar", 131 type=str, 132 help="template for shards to use for validation", 133 ) 134 parser.add_argument( 135 "--ais-endpoint", 136 default="http://aistore-sample-proxy:51080", 137 type=str, 138 help="AIStore proxy endpoint", 139 ) 140 parser.add_argument( 141 "--bucket-name", default="imagenet", type=str, help="dataset bucket name" 142 ) 143 parser.add_argument( 144 "--bucket-provider", 145 default="ais", 146 type=str, 147 help='bucket provider (eg. "gcp", "aws", "ais")', 148 ) 149 150 best_acc1 = 0 151 152 153 def main(): 154 args = parser.parse_args() 155 156 if args.seed is not None: 157 random.seed(args.seed) 158 torch.manual_seed(args.seed) 159 cudnn.deterministic = True 160 warnings.warn( 161 "You have chosen to seed training. " 162 "This will turn on the CUDNN deterministic setting, " 163 "which can slow down your training considerably! " 164 "You may see unexpected behavior when restarting " 165 "from checkpoints." 166 ) 167 168 if args.gpu is not None: 169 warnings.warn( 170 "You have chosen a specific GPU. This will completely " 171 "disable data parallelism." 172 ) 173 174 ngpus_per_node = torch.cuda.device_count() 175 main_worker(args.gpu, ngpus_per_node, args) 176 177 178 def wd_transform(client, pytorch_transform, name): 179 def transform(sample): 180 sample["npy"] = ( 181 pytorch_transform(sample.pop("jpg")) 182 .permute(1, 2, 0) 183 .numpy() 184 .astype("float32") 185 ) 186 return sample 187 188 return WDTransform(client, transform, transform_name=name, verbose=True) 189 190 191 def loader(urls, batch_size, workers): 192 to_tensor = transforms.Compose([transforms.ToTensor()]) 193 etl_dataset = ( 194 wds.WebDataset(urls, handler=wds.handlers.warn_and_continue) 195 .decode("rgb") 196 .to_tuple("npy cls", handler=wds.handlers.warn_and_continue) 197 .map_tuple(to_tensor, lambda x: x) 198 ) 199 ds_size = (500 * len(urls)) // batch_size 200 etl_dataset = etl_dataset.with_length(ds_size) 201 loader = wds.WebLoader( 202 etl_dataset, 203 batch_size=batch_size, 204 num_workers=workers, 205 ) 206 return loader.with_length(ds_size) 207 208 209 def main_worker(gpu, ngpus_per_node, args): 210 global best_acc1 211 args.gpu = gpu 212 213 if args.gpu is not None: 214 print("Use GPU: {} for training".format(args.gpu)) 215 216 # create model 217 if args.pretrained: 218 print("=> using pre-trained model '{}'".format(args.arch)) 219 model = models.__dict__[args.arch](pretrained=True) 220 else: 221 print("=> creating model '{}'".format(args.arch)) 222 model = models.__dict__[args.arch]() 223 224 if not torch.cuda.is_available(): 225 print("using CPU, this will be slow") 226 elif args.gpu is not None: 227 torch.cuda.set_device(args.gpu) 228 model = model.cuda(args.gpu) 229 else: 230 # DataParallel will divide and allocate batch_size to all available GPUs 231 if args.arch.startswith("alexnet") or args.arch.startswith("vgg"): 232 model.features = torch.nn.DataParallel(model.features) 233 model.cuda() 234 else: 235 model = torch.nn.DataParallel(model).cuda() 236 237 # define loss function (criterion) and optimizer 238 criterion = nn.CrossEntropyLoss().cuda(args.gpu) 239 240 optimizer = torch.optim.SGD( 241 model.parameters(), 242 args.lr, 243 momentum=args.momentum, 244 weight_decay=args.weight_decay, 245 ) 246 247 # optionally resume from a checkpoint 248 if args.resume: 249 if os.path.isfile(args.resume): 250 print("=> loading checkpoint '{}'".format(args.resume)) 251 if args.gpu is None: 252 checkpoint = torch.load(args.resume) 253 else: 254 # Map model to be loaded to specified single gpu. 255 loc = "cuda:{}".format(args.gpu) 256 checkpoint = torch.load(args.resume, map_location=loc) 257 args.start_epoch = checkpoint["epoch"] 258 best_acc1 = checkpoint["best_acc1"] 259 if args.gpu is not None: 260 # best_acc1 may be from a checkpoint from a different GPU 261 best_acc1 = best_acc1.to(args.gpu) 262 model.load_state_dict(checkpoint["state_dict"]) 263 optimizer.load_state_dict(checkpoint["optimizer"]) 264 print( 265 "=> loaded checkpoint '{}' (epoch {})".format( 266 args.resume, checkpoint["epoch"] 267 ) 268 ) 269 else: 270 print("=> no checkpoint found at '{}'".format(args.resume)) 271 272 cudnn.benchmark = True 273 274 normalize = transforms.Normalize( 275 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 276 ) 277 278 client = Client(args.ais_endpoint) 279 bck = Bck(args.bucket_name, provider=args.bucket_provider) 280 281 train_transform = transforms.Compose( 282 [ 283 transforms.RandomResizedCrop(224), 284 transforms.RandomHorizontalFlip(), 285 transforms.ToTensor(), 286 normalize, 287 ] 288 ) 289 290 train_etl = wd_transform(client, train_transform, "img-train") 291 train_urls = client.expand_object_urls( 292 bck, transform_id=train_etl.uuid, template=args.train_shards 293 ) 294 train_loader = loader(train_urls, args.batch_size, args.workers) 295 296 val_transform = transforms.Compose( 297 [ 298 transforms.Resize(256), 299 transforms.CenterCrop(224), 300 transforms.ToTensor(), 301 normalize, 302 ] 303 ) 304 val_etl = wd_transform(client, val_transform, "img-val") 305 val_urls = client.expand_object_urls( 306 bck, transform_id=val_etl.uuid, template=args.val_shards 307 ) 308 val_loader = loader(val_urls, args.batch_size, args.workers) 309 310 if args.evaluate: 311 validate(val_loader, model, criterion, args) 312 return 313 314 for epoch in range(args.start_epoch, args.epochs): 315 adjust_learning_rate(optimizer, epoch, args) 316 317 # train for one epoch 318 train(train_loader, model, criterion, optimizer, epoch, args) 319 320 # evaluate on validation set 321 acc1 = validate(val_loader, model, criterion, args) 322 323 # remember best acc@1 and save checkpoint 324 is_best = acc1 > best_acc1 325 best_acc1 = max(acc1, best_acc1) 326 327 save_checkpoint( 328 { 329 "epoch": epoch + 1, 330 "arch": args.arch, 331 "state_dict": model.state_dict(), 332 "best_acc1": best_acc1, 333 "optimizer": optimizer.state_dict(), 334 }, 335 is_best, 336 ) 337 338 339 def train(train_loader, model, criterion, optimizer, epoch, args): 340 batch_time = AverageMeter("Time", ":6.3f") 341 data_time = AverageMeter("Data", ":6.3f") 342 losses = AverageMeter("Loss", ":.4e") 343 top1 = AverageMeter("Acc@1", ":6.2f") 344 top5 = AverageMeter("Acc@5", ":6.2f") 345 progress = ProgressMeter( 346 len(train_loader), 347 [batch_time, data_time, losses, top1, top5], 348 prefix="Epoch: [{}]".format(epoch), 349 ) 350 351 # switch to train mode 352 model.train() 353 354 end = time.time() 355 for i, (images, target) in enumerate(train_loader): 356 # measure data loading time 357 data_time.update(time.time() - end) 358 359 if args.gpu is not None: 360 images = images.cuda(args.gpu, non_blocking=True) 361 if torch.cuda.is_available(): 362 target = target.cuda(args.gpu, non_blocking=True) 363 364 # compute output 365 output = model(images) 366 loss = criterion(output, target) 367 368 # measure accuracy and record loss 369 acc1, acc5 = accuracy(output, target, topk=(1, 5)) 370 losses.update(loss.item(), images.size(0)) 371 top1.update(acc1[0], images.size(0)) 372 top5.update(acc5[0], images.size(0)) 373 374 # compute gradient and do SGD step 375 optimizer.zero_grad() 376 loss.backward() 377 optimizer.step() 378 379 # measure elapsed time 380 batch_time.update(time.time() - end) 381 end = time.time() 382 383 if i % args.print_freq == 0: 384 progress.display(i) 385 386 387 def validate(val_loader, model, criterion, args): 388 batch_time = AverageMeter("Time", ":6.3f") 389 losses = AverageMeter("Loss", ":.4e") 390 top1 = AverageMeter("Acc@1", ":6.2f") 391 top5 = AverageMeter("Acc@5", ":6.2f") 392 progress = ProgressMeter( 393 len(val_loader), [batch_time, losses, top1, top5], prefix="Test: " 394 ) 395 396 # switch to evaluate mode 397 model.eval() 398 399 with torch.no_grad(): 400 end = time.time() 401 for i, (images, target) in enumerate(val_loader): 402 if args.gpu is not None: 403 images = images.cuda(args.gpu, non_blocking=True) 404 if torch.cuda.is_available(): 405 target = target.cuda(args.gpu, non_blocking=True) 406 407 # compute output 408 output = model(images) 409 loss = criterion(output, target) 410 411 # measure accuracy and record loss 412 acc1, acc5 = accuracy(output, target, topk=(1, 5)) 413 losses.update(loss.item(), images.size(0)) 414 top1.update(acc1[0], images.size(0)) 415 top5.update(acc5[0], images.size(0)) 416 417 # measure elapsed time 418 batch_time.update(time.time() - end) 419 end = time.time() 420 421 if i % args.print_freq == 0: 422 progress.display(i) 423 424 # TODO: this should also be done with the ProgressMeter 425 print( 426 " * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}".format(top1=top1, top5=top5) 427 ) 428 429 return top1.avg 430 431 432 def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"): 433 torch.save(state, filename) 434 if is_best: 435 shutil.copyfile(filename, "model_best.pth.tar") 436 437 438 class AverageMeter(object): 439 """Computes and stores the average and current value""" 440 441 def __init__(self, name, fmt=":f"): 442 self.name = name 443 self.fmt = fmt 444 self.reset() 445 446 def reset(self): 447 self.val = 0 448 self.avg = 0 449 self.sum = 0 450 self.count = 0 451 452 def update(self, val, n=1): 453 self.val = val 454 self.sum += val * n 455 self.count += n 456 self.avg = self.sum / self.count 457 458 def __str__(self): 459 fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 460 return fmtstr.format(**self.__dict__) 461 462 463 class ProgressMeter(object): 464 def __init__(self, num_batches, meters, prefix=""): 465 self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 466 self.meters = meters 467 self.prefix = prefix 468 469 def display(self, batch): 470 entries = [self.prefix + self.batch_fmtstr.format(batch)] 471 entries += [str(meter) for meter in self.meters] 472 print("\t".join(entries)) 473 474 def _get_batch_fmtstr(self, num_batches): 475 num_digits = len(str(num_batches // 1)) 476 fmt = "{:" + str(num_digits) + "d}" 477 return "[" + fmt + "/" + fmt.format(num_batches) + "]" 478 479 480 def adjust_learning_rate(optimizer, epoch, args): 481 """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" 482 lr = args.lr * (0.1 ** (epoch // 30)) 483 for param_group in optimizer.param_groups: 484 param_group["lr"] = lr 485 486 487 def accuracy(output, target, topk=(1,)): 488 """Computes the accuracy over the k top predictions for the specified values of k""" 489 with torch.no_grad(): 490 maxk = max(topk) 491 batch_size = target.size(0) 492 493 _, pred = output.topk(maxk, 1, True, True) 494 pred = pred.t() 495 correct = pred.eq(target.view(1, -1).expand_as(pred)) 496 497 res = [] 498 for k in topk: 499 correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 500 res.append(correct_k.mul_(100.0 / batch_size)) 501 return res 502 503 504 if __name__ == "__main__": 505 main()