Source code for pe.api.image.nearest_image_api

import numpy as np
import pandas as pd

from pe.api import API
from pe.logging import execution_logger
from pe.data import Data
from pe.constant.data import IMAGE_DATA_COLUMN_NAME
from pe.constant.data import LABEL_ID_COLUMN_NAME
from pe.api.util import ConstantList


DATA_ID_COLUMN_NAME = "PE.NEAREST_IMAGE.DATA_ID"


def _to_constant_list_if_needed(value):
    if not isinstance(value, list):
        value = ConstantList(value)
    return value


[docs] class NearestImage(API): """The API that generates synthetic images by randomly drawing an image from the given dataset as the RANDOM_API and finding the nearest images in the given dataset as the VARIATION_API."""
[docs] def __init__( self, data, embedding, nearest_neighbor_mode, variation_degrees, nearest_neighbor_backend="auto", ): """Constructor. :param data: The data object that contains the images :type data: :py:class:`pe.data.Data` :param embedding: The embedding object that computes the embeddings of the images :type embedding: :py:class:`pe.embedding.Embedding` :param nearest_neighbor_mode: The distance metric to use for finding the nearest neighbors. It should be one of the following: "l2" (l2 distance), "cos_sim" (cosine similarity), "ip" (inner product). Not all backends support all modes :type nearest_neighbor_mode: str :param variation_degrees: The variation degrees utilized at each PE iteration. If a single value is provided, the same variation degree will be used for all iterations. The value means the number of nearest neighbors to consider for the VARIAITON_API :type variation_degrees: int or list[int] :param nearest_neighbor_backend: The backend to use for finding the nearest neighbors. It should be one of the following: "faiss" (FAISS), "sklearn" (scikit-learn), "auto" (using FAISS if available, otherwise scikit-learn). Defaults to "auto". FAISS supports GPU and is much faster when the number of samples is large. It requires the installation of `faiss-gpu` or `faiss-cpu` package. See https://faiss.ai/ :type nearest_neighbor_backend: str, optional :raises ValueError: If the `nearest_neighbor_backend` is unknown """ super().__init__() self._data = data self._embedding = embedding self._nearest_neighbor_mode = nearest_neighbor_mode self._nearest_neighbor_backend = nearest_neighbor_backend self._variation_degrees = _to_constant_list_if_needed(variation_degrees) self._max_variation_degree = ( self._variation_degrees[0] if isinstance(self._variation_degrees, ConstantList) else max(self._variation_degrees) ) if nearest_neighbor_backend.lower() == "faiss": from pe.histogram.nearest_neighbor_backend.faiss import search self._search = search elif nearest_neighbor_backend.lower() == "sklearn": from pe.histogram.nearest_neighbor_backend.sklearn import search self._search = search elif nearest_neighbor_backend.lower() == "auto": from pe.histogram.nearest_neighbor_backend.auto import search self._search = search else: raise ValueError(f"Unknown backend: {nearest_neighbor_backend}") self._build_nearest_neighbor_graph()
[docs] def _build_nearest_neighbor_graph(self): """Finding the nearest neighbor for each sample in the given dataset.""" self._data = self._embedding.compute_embedding(self._data) embedding = np.stack(self._data.data_frame[self._embedding.column_name].values, axis=0).astype(np.float32) distances, ids = self._search( syn_embedding=embedding, priv_embedding=embedding, num_nearest_neighbors=self._max_variation_degree, mode=self._nearest_neighbor_mode, ) sorted_indices = np.argsort(distances, axis=1) self._nearest_neighbor_ids = ids[np.arange(len(ids))[:, None], sorted_indices] self._nearest_neighbor_distances = distances[np.arange(len(distances))[:, None], sorted_indices]
[docs] def random_api(self, label_info, num_samples): """Generating random synthetic data by randomly drawing images from the given dataset. :param label_info: The info of the label, not utilized in this API :type label_info: omegaconf.dictconfig.DictConfig :param num_samples: The number of random samples to generate :type num_samples: int :return: The data object of the generated synthetic data :rtype: :py:class:`pe.data.Data` """ label_name = label_info.name execution_logger.info(f"RANDOM API: creating {num_samples} samples for label {label_name}") data_ids = np.random.choice(len(self._data.data_frame), num_samples, replace=False) images = self._data.data_frame.iloc[data_ids][IMAGE_DATA_COLUMN_NAME].values data_frame = pd.DataFrame( { IMAGE_DATA_COLUMN_NAME: images, DATA_ID_COLUMN_NAME: data_ids, LABEL_ID_COLUMN_NAME: 0, } ) metadata = {"label_info": [label_info]} execution_logger.info(f"RANDOM API: finished creating {num_samples} samples for label {label_name}") return Data(data_frame=data_frame, metadata=metadata)
[docs] def variation_api(self, syn_data): """Generating variations of the synthetic data by finding the nearest images in the given dataset. :param syn_data: The data object of the synthetic data :type syn_data: :py:class:`pe.data.Data` :return: The data object of the variation of the input synthetic data :rtype: :py:class:`pe.data.Data` """ execution_logger.info(f"VARIATION API: creating variations for {len(syn_data.data_frame)} samples") original_data_ids = list(syn_data.data_frame[DATA_ID_COLUMN_NAME].values) iteration = getattr(syn_data.metadata, "iteration", -1) variation_degree = self._variation_degrees[iteration + 1] execution_logger.info(f"VARIATION API parameters: variation_degree={variation_degree}") nn_ids = np.random.choice(variation_degree, len(original_data_ids), replace=True) new_data_ids = self._nearest_neighbor_ids[original_data_ids, nn_ids] images = self._data.data_frame.iloc[new_data_ids][IMAGE_DATA_COLUMN_NAME].values data_frame = pd.DataFrame( { IMAGE_DATA_COLUMN_NAME: images, DATA_ID_COLUMN_NAME: new_data_ids, LABEL_ID_COLUMN_NAME: syn_data.data_frame[LABEL_ID_COLUMN_NAME].values, } ) execution_logger.info(f"VARIATION API: finished creating variations for {len(syn_data.data_frame)} samples") return Data(data_frame=data_frame, metadata=syn_data.metadata)