gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/images/gpu/pytorch/download_pytorch_datasets.py (about)

     1  # Copyright 2023 The gVisor Authors.
     2  #
     3  # Licensed under the Apache License, Version 2.0 (the "License");
     4  # you may not use this file except in compliance with the License.
     5  # You may obtain a copy of the License at
     6  #
     7  #     http://www.apache.org/licenses/LICENSE-2.0
     8  #
     9  # Unless required by applicable law or agreed to in writing, software
    10  # distributed under the License is distributed on an "AS IS" BASIS,
    11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  # See the License for the specific language governing permissions and
    13  # limitations under the License.
    14  
    15  """Download PyTorch datasets used in tests."""
    16  
    17  import os
    18  
    19  from torchvision import datasets
    20  from torchvision import models
    21  
    22  datasets_dir = os.environ["PYTORCH_DATASETS_DIR"]
    23  for dataset in (
    24      datasets.MNIST,
    25      datasets.CIFAR100,
    26  ):
    27    dataset(datasets_dir, train=True, download=True)
    28    dataset(datasets_dir, train=False, download=True)
    29  
    30  # Download resnet50 weights to TORCH_HOME:
    31  models.resnet50(weights=models.ResNet50_Weights.DEFAULT)