github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/python/examples/sdk/writing-webdataset.ipynb (about)

     1  {
     2   "cells": [
     3    {
     4     "cell_type": "markdown",
     5     "metadata": {},
     6     "source": [
     7      "# Writing a Dataset to AIS in WDs format \n",
     8      "\n",
     9      "In this notebook we will download and store the following datasets in [WebDataset](https://github.com/webdataset/webdataset) format in AIS:\n",
    10      "\n",
    11      "- [The Oxford-IIIT Pet Dataset](https://academictorrents.com/details/b18bbd9ba03d50b0f7f479acc9f4228a408cecc1)\n",
    12      "- [Flickr Image dataset](https://www.kaggle.com/datasets/hsankesara/flickr-image-dataset)"
    13     ]
    14    },
    15    {
    16     "cell_type": "code",
    17     "execution_count": null,
    18     "metadata": {},
    19     "outputs": [],
    20     "source": [
    21      "pip install aistore"
    22     ]
    23    },
    24    {
    25     "cell_type": "markdown",
    26     "metadata": {},
    27     "source": [
    28      "## Setting Up Client"
    29     ]
    30    },
    31    {
    32     "cell_type": "code",
    33     "execution_count": null,
    34     "metadata": {},
    35     "outputs": [],
    36     "source": [
    37      "import os\n",
    38      "from aistore.client import Client\n",
    39      "\n",
    40      "ais_url = os.getenv(\"AIS_ENDPOINT\", \"http://localhost:8080\")\n",
    41      "client = Client(ais_url)"
    42     ]
    43    },
    44    {
    45     "cell_type": "markdown",
    46     "metadata": {},
    47     "source": [
    48      "## The Oxford-IIIT Pet Dataset"
    49     ]
    50    },
    51    {
    52     "cell_type": "markdown",
    53     "metadata": {},
    54     "source": [
    55      "### Downloading the Dataset"
    56     ]
    57    },
    58    {
    59     "cell_type": "code",
    60     "execution_count": null,
    61     "metadata": {},
    62     "outputs": [],
    63     "source": [
    64      "import requests\n",
    65      "import tarfile\n",
    66      "import os\n",
    67      "\n",
    68      "def download_and_extract(url, dest_path):\n",
    69      "    response = requests.get(url, stream=True)\n",
    70      "    if response.status_code == 200:\n",
    71      "        with open(dest_path, 'wb') as f:\n",
    72      "            f.write(response.raw.read())\n",
    73      "        with tarfile.open(dest_path) as tar:\n",
    74      "            tar.extractall(path=os.path.dirname(dest_path))\n",
    75      "        os.remove(dest_path)  # Clean up the tar file after extraction"
    76     ]
    77    },
    78    {
    79     "cell_type": "code",
    80     "execution_count": null,
    81     "metadata": {},
    82     "outputs": [],
    83     "source": [
    84      "base_url = \"http://www.robots.ox.ac.uk/~vgg/data/pets/data\"\n",
    85      "images_url = f\"{base_url}/images.tar.gz\"\n",
    86      "annotations_url = f\"{base_url}/annotations.tar.gz\"\n",
    87      "\n",
    88      "data_dir = \"/data\"\n",
    89      "images_path = os.path.join(data_dir, \"images.tar.gz\")\n",
    90      "annotations_path = os.path.join(data_dir, \"annotations.tar.gz\")\n",
    91      "\n",
    92      "if not os.path.exists(data_dir):\n",
    93      "    os.makedirs(data_dir)\n",
    94      "\n",
    95      "download_and_extract(images_url, images_path)\n",
    96      "download_and_extract(annotations_url, annotations_path)"
    97     ]
    98    },
    99    {
   100     "cell_type": "markdown",
   101     "metadata": {},
   102     "source": [
   103      "### Creating a bucket and writing the dataset"
   104     ]
   105    },
   106    {
   107     "cell_type": "code",
   108     "execution_count": null,
   109     "metadata": {},
   110     "outputs": [],
   111     "source": [
   112      "from pathlib import Path\n",
   113      "from aistore.sdk.dataset.dataset_config import DatasetConfig\n",
   114      "from aistore.sdk.dataset.data_attribute import DataAttribute\n",
   115      "from aistore.sdk.dataset.label_attribute import LabelAttribute\n",
   116      "\n",
   117      "bucket = client.bucket(\"pets-dataset\").create(exist_ok=True)\n",
   118      "base_path = Path(\"/data\")"
   119     ]
   120    },
   121    {
   122     "cell_type": "code",
   123     "execution_count": null,
   124     "metadata": {},
   125     "outputs": [],
   126     "source": [
   127      "# Function to get label from the annotation file\n",
   128      "\n",
   129      "def get_class_dict(path: Path):\n",
   130      "    parsed_dict = {}\n",
   131      "    with open(path, \"r\", encoding=\"utf-8\") as file:\n",
   132      "        for line in file.readlines():\n",
   133      "            if line[0] == \"#\":\n",
   134      "                continue\n",
   135      "            file_name, label = line.split(\" \")[:2]\n",
   136      "            parsed_dict[file_name] = label\n",
   137      "\n",
   138      "    return parsed_dict\n",
   139      "\n",
   140      "\n",
   141      "parsed_dict = get_class_dict(base_path.joinpath(\"annotations\").joinpath(\"list.txt\"))\n",
   142      "\n",
   143      "\n",
   144      "def get_label_for_filename(filename):\n",
   145      "    return parsed_dict.get(filename, None)"
   146     ]
   147    },
   148    {
   149     "cell_type": "code",
   150     "execution_count": null,
   151     "metadata": {},
   152     "outputs": [],
   153     "source": [
   154      "dataset_config = DatasetConfig(\n",
   155      "    primary_attribute= DataAttribute(\n",
   156      "            path=base_path.joinpath(\"images\"),\n",
   157      "            file_type=\"jpg\",\n",
   158      "            name=\"image\"\n",
   159      "        ),\n",
   160      "    secondary_attributes=[\n",
   161      "        DataAttribute(\n",
   162      "            path=base_path.joinpath(\"annotations\").joinpath(\"trimaps\"),\n",
   163      "            file_type=\"png\",\n",
   164      "            name=\"trimap\",\n",
   165      "        ),\n",
   166      "        LabelAttribute(\n",
   167      "            name=\"cls\",\n",
   168      "            label_identifier=get_label_for_filename,\n",
   169      "        ),\n",
   170      "    ],\n",
   171      ")\n",
   172      "\n",
   173      "bucket.write_dataset(config=dataset_config, pattern=\"img_dataset\", maxcount=1000)"
   174     ]
   175    },
   176    {
   177     "cell_type": "markdown",
   178     "metadata": {},
   179     "source": [
   180      "## Flickr Image dataset"
   181     ]
   182    },
   183    {
   184     "cell_type": "markdown",
   185     "metadata": {},
   186     "source": [
   187      "### Downloading the Dataset"
   188     ]
   189    },
   190    {
   191     "cell_type": "markdown",
   192     "metadata": {},
   193     "source": [
   194      "**NOTE:** We are using the [kaggle API](https://github.com/Kaggle/kaggle-api/blob/main/docs/README.md) to download the dataset. "
   195     ]
   196    },
   197    {
   198     "cell_type": "code",
   199     "execution_count": null,
   200     "metadata": {},
   201     "outputs": [],
   202     "source": [
   203      "pip install kaggle"
   204     ]
   205    },
   206    {
   207     "cell_type": "code",
   208     "execution_count": null,
   209     "metadata": {},
   210     "outputs": [],
   211     "source": [
   212      "!kaggle datasets download -d hsankesara/flickr-image-dataset -p /data --unzip"
   213     ]
   214    },
   215    {
   216     "cell_type": "markdown",
   217     "metadata": {},
   218     "source": [
   219      "### Creating a bucket and writing the dataset"
   220     ]
   221    },
   222    {
   223     "cell_type": "code",
   224     "execution_count": null,
   225     "metadata": {},
   226     "outputs": [],
   227     "source": [
   228      "from pathlib import Path\n",
   229      "from aistore.sdk.dataset.dataset_config import DatasetConfig\n",
   230      "from aistore.sdk.dataset.data_attribute import DataAttribute\n",
   231      "from aistore.sdk.dataset.label_attribute import LabelAttribute\n",
   232      "\n",
   233      "bucket = client.bucket(\"flickr-dataset\").create(exist_ok=True)\n",
   234      "base_path = Path(\"/data\")"
   235     ]
   236    },
   237    {
   238     "cell_type": "code",
   239     "execution_count": null,
   240     "metadata": {},
   241     "outputs": [],
   242     "source": [
   243      "# Function to get the caption from results.csv\n",
   244      "def parse_csv(path: Path):\n",
   245      "    parsed_dict = {}\n",
   246      "    with open(path, \"r\", encoding=\"utf-8\") as file:\n",
   247      "        for line in file:\n",
   248      "            splitted = line.split(\"|\")\n",
   249      "            if len(splitted) < 3:\n",
   250      "                continue  \n",
   251      "            filename = splitted[0].strip().split(\".\")[0]\n",
   252      "            caption = splitted[2].strip()  \n",
   253      "            parsed_dict[filename] = caption\n",
   254      "    return parsed_dict   \n",
   255      "\n",
   256      "parsed_dict = parse_csv(base_path.joinpath(\"flickr30k_images/results.csv\"))\n",
   257      "def get_caption_for_filename(filename):\n",
   258      "    return parsed_dict.get(filename, None)"
   259     ]
   260    },
   261    {
   262     "cell_type": "code",
   263     "execution_count": null,
   264     "metadata": {},
   265     "outputs": [],
   266     "source": [
   267      "dataset_config = DatasetConfig(\n",
   268      "    primary_attribute=DataAttribute(\n",
   269      "            path=base_path.joinpath(\"flickr30k_images/flickr30k_images\"),\n",
   270      "            file_type=\"jpg\",\n",
   271      "            name=\"image\"\n",
   272      "        ),\n",
   273      "    secondary_attributes=[\n",
   274      "        LabelAttribute(\n",
   275      "            name=\"caption\",\n",
   276      "            label_identifier=get_caption_for_filename,\n",
   277      "        ),\n",
   278      "    ],\n",
   279      ")\n",
   280      "\n",
   281      "bucket.write_dataset(config=dataset_config, pattern=\"flickr_dataset\", maxcount=1000)"
   282     ]
   283    }
   284   ],
   285   "metadata": {
   286    "kernelspec": {
   287     "display_name": "my-python3-kernel",
   288     "language": "python",
   289     "name": "my-python3-kernel"
   290    },
   291    "language_info": {
   292     "codemirror_mode": {
   293      "name": "ipython",
   294      "version": 3
   295     },
   296     "file_extension": ".py",
   297     "mimetype": "text/x-python",
   298     "name": "python",
   299     "nbconvert_exporter": "python",
   300     "pygments_lexer": "ipython3",
   301     "version": "3.11.8"
   302    }
   303   },
   304   "nbformat": 4,
   305   "nbformat_minor": 2
   306  }