Source code for pe.data.image.cat

import pandas as pd
import os
from tqdm import tqdm
import numpy as np
import zipfile
from PIL import Image
import torchvision.transforms as T
from collections import defaultdict

from pe.data import Data
from pe.constant.data import LABEL_ID_COLUMN_NAME
from pe.constant.data import IMAGE_DATA_COLUMN_NAME
from pe.util import download

CAT_LABEL_NAMES = [
    "cookie",
    "doudou",
]


[docs]class Cat(Data): """The Cat dataset.""" #: The URL of the dataset URL = "https://www.kaggle.com/api/v1/datasets/download/fjxmlzn/cat-cookie-doudou"
[docs] def __init__(self, root_dir="data", res=512): """Constructor. :param root_dir: The root directory to save the dataset, defaults to "data" :type root_dir: str, optional :param res: The resolution of the images, defaults to 512 :type res: int, optional """ self._zip_path = os.path.join(root_dir, "cat-cookie-doudou.zip") self._download() data = self._read_data() transform = T.Resize(res) images = [] labels = [] for label, sub_images in data.items(): for image in tqdm(sub_images, desc=f"Processing {label} images"): image = Image.fromarray(image) image = transform(image) image = np.array(image) images.append(image) labels.append(CAT_LABEL_NAMES.index(label)) data_frame = pd.DataFrame( { IMAGE_DATA_COLUMN_NAME: images, LABEL_ID_COLUMN_NAME: labels, } ) metadata = {"label_info": [{"name": n} for n in CAT_LABEL_NAMES]} super().__init__(data_frame=data_frame, metadata=metadata)
[docs] def _download(self): """Download the dataset if it does not exist.""" if not os.path.exists(self._zip_path): os.makedirs(os.path.dirname(self._zip_path), exist_ok=True) download(url=self.URL, fname=self._zip_path)
[docs] def _read_data(self): """Read the data from the zip file.""" data = defaultdict(list) with zipfile.ZipFile(self._zip_path) as z: for name in tqdm(z.namelist(), desc="Reading zip file"): with z.open(name) as f: image = Image.open(f) label = name.split("/")[0] data[label].append(np.array(image)) return data