Source code for pe.data.image.imagenet

import pandas as pd
import torchvision.datasets
import torchvision.transforms as T
from tqdm import tqdm
import torch

from pe.data import Data
from pe.constant.data import LABEL_ID_COLUMN_NAME
from pe.constant.data import IMAGE_DATA_COLUMN_NAME


[docs] class ImageNet(Data): """The ImageNet dataset."""
[docs] def __init__(self, root_dir, conditional=False, split="train", res=32, batch_size=1000, num_workers=10): """Constructor. :param root_dir: The root directory of the dataset. :param conditional: Whether to use conditional ImageNet. Defaults to False :type conditional: bool, optional :param split: The split of the dataset, defaults to "train" :type split: str, optional :param res: The resolution of the images, defaults to 32 :type res: int, optional :param batch_size: The batch size to load the images, defaults to 1000 :type batch_size: int, optional :param num_workers: The number of workers to load the images, defaults to 10 :type num_workers: int, optional """ transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.Resize(res), T.PILToTensor()]) dataset = torchvision.datasets.ImageNet( root=root_dir, split=split, transform=transform, ) data_loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=False, num_workers=4, drop_last=False, ) images = [] for batch in tqdm(data_loader, desc="Loading ImageNet", unit="batch"): images.append(batch[0]) images = torch.cat(images, dim=0) images = images.permute(0, 2, 3, 1).numpy() data_frame = pd.DataFrame( { IMAGE_DATA_COLUMN_NAME: list(images), LABEL_ID_COLUMN_NAME: dataset.targets if conditional else [0] * len(images), } ) if conditional: metadata = {"label_info": [{"name": n} for n in map(str, dataset.classes)]} else: metadata = {"label_info": [{"name": "none"}]} super().__init__(data_frame=data_frame, metadata=metadata)