Source code for pe.data.image.camelyon17

import pandas as pd
from wilds import get_dataset
from tqdm import tqdm
import numpy as np
import torchvision.transforms as T

from pe.data import Data
from pe.constant.data import LABEL_ID_COLUMN_NAME
from pe.constant.data import IMAGE_DATA_COLUMN_NAME

CAMELYON17_LABEL_NAMES = [
    "no_tumor",
    "tumor",
]


[docs]class Camelyon17(Data): """The Camelyon17 dataset."""
[docs] def __init__(self, split="train", root_dir="data", res=64): """Constructor. :param split: The split of the dataset. It should be either "train", "val", or "test", defaults to "train" :type split: str, optional :param root_dir: The root directory to save the dataset, defaults to "data" :type root_dir: str, optional :param res: The resolution of the images, defaults to 64 :type res: int, optional :raises ValueError: If the split is invalid """ if split not in ["train", "val", "test"]: raise ValueError(f"Invalid split: {split}") dataset = get_dataset(dataset="camelyon17", download=True, root_dir=root_dir) data = dataset.get_subset(split) transform = T.Resize(res) images = [] labels = [] for i in tqdm(range(len(data))): image, label, _ = data[i] images.append(np.array(transform(image))) labels.append(label.item()) data_frame = pd.DataFrame( { IMAGE_DATA_COLUMN_NAME: images, LABEL_ID_COLUMN_NAME: labels, } ) metadata = {"label_info": [{"name": n} for n in CAMELYON17_LABEL_NAMES]} super().__init__(data_frame=data_frame, metadata=metadata)