gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/images/gpu/pytorch/download_pytorch_datasets.py (about) 1 # Copyright 2023 The gVisor Authors. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 15 """Download PyTorch datasets used in tests.""" 16 17 import os 18 19 from torchvision import datasets 20 from torchvision import models 21 22 datasets_dir = os.environ["PYTORCH_DATASETS_DIR"] 23 for dataset in ( 24 datasets.MNIST, 25 datasets.CIFAR100, 26 ): 27 dataset(datasets_dir, train=True, download=True) 28 dataset(datasets_dir, train=False, download=True) 29 30 # Download resnet50 weights to TORCH_HOME: 31 models.resnet50(weights=models.ResNet50_Weights.DEFAULT)