import numpy as np
import pandas as pd
from pe.api import API
from pe.logging import execution_logger
from pe.data import Data
from pe.constant.data import IMAGE_DATA_COLUMN_NAME
from pe.constant.data import LABEL_ID_COLUMN_NAME
from pe.api.util import ConstantList
DATA_ID_COLUMN_NAME = "PE.NEAREST_IMAGE.DATA_ID"
def _to_constant_list_if_needed(value):
if not isinstance(value, list):
value = ConstantList(value)
return value
[docs]
class NearestImage(API):
"""The API that generates synthetic images by randomly drawing an image from the given dataset as the RANDOM_API
and finding the nearest images in the given dataset as the VARIATION_API."""
[docs]
def __init__(
self,
data,
embedding,
nearest_neighbor_mode,
variation_degrees,
nearest_neighbor_backend="auto",
):
"""Constructor.
:param data: The data object that contains the images
:type data: :py:class:`pe.data.Data`
:param embedding: The embedding object that computes the embeddings of the images
:type embedding: :py:class:`pe.embedding.Embedding`
:param nearest_neighbor_mode: The distance metric to use for finding the nearest neighbors. It should be one
of the following: "l2" (l2 distance), "cos_sim" (cosine similarity), "ip" (inner product). Not all backends
support all modes
:type nearest_neighbor_mode: str
:param variation_degrees: The variation degrees utilized at each PE iteration. If a single value is provided,
the same variation degree will be used for all iterations. The value means the number of nearest neighbors
to consider for the VARIAITON_API
:type variation_degrees: int or list[int]
:param nearest_neighbor_backend: The backend to use for finding the nearest neighbors. It should be one of the
following: "faiss" (FAISS), "sklearn" (scikit-learn), "auto" (using FAISS if available, otherwise
scikit-learn). Defaults to "auto". FAISS supports GPU and is much faster when the number of samples is
large. It requires the installation of `faiss-gpu` or `faiss-cpu` package. See https://faiss.ai/
:type nearest_neighbor_backend: str, optional
:raises ValueError: If the `nearest_neighbor_backend` is unknown
"""
super().__init__()
self._data = data
self._embedding = embedding
self._nearest_neighbor_mode = nearest_neighbor_mode
self._nearest_neighbor_backend = nearest_neighbor_backend
self._variation_degrees = _to_constant_list_if_needed(variation_degrees)
self._max_variation_degree = (
self._variation_degrees[0]
if isinstance(self._variation_degrees, ConstantList)
else max(self._variation_degrees)
)
if nearest_neighbor_backend.lower() == "faiss":
from pe.histogram.nearest_neighbor_backend.faiss import search
self._search = search
elif nearest_neighbor_backend.lower() == "sklearn":
from pe.histogram.nearest_neighbor_backend.sklearn import search
self._search = search
elif nearest_neighbor_backend.lower() == "auto":
from pe.histogram.nearest_neighbor_backend.auto import search
self._search = search
else:
raise ValueError(f"Unknown backend: {nearest_neighbor_backend}")
self._build_nearest_neighbor_graph()
[docs]
def _build_nearest_neighbor_graph(self):
"""Finding the nearest neighbor for each sample in the given dataset."""
self._data = self._embedding.compute_embedding(self._data)
embedding = np.stack(self._data.data_frame[self._embedding.column_name].values, axis=0).astype(np.float32)
distances, ids = self._search(
syn_embedding=embedding,
priv_embedding=embedding,
num_nearest_neighbors=self._max_variation_degree,
mode=self._nearest_neighbor_mode,
)
sorted_indices = np.argsort(distances, axis=1)
self._nearest_neighbor_ids = ids[np.arange(len(ids))[:, None], sorted_indices]
self._nearest_neighbor_distances = distances[np.arange(len(distances))[:, None], sorted_indices]
[docs]
def random_api(self, label_info, num_samples):
"""Generating random synthetic data by randomly drawing images from the given dataset.
:param label_info: The info of the label, not utilized in this API
:type label_info: omegaconf.dictconfig.DictConfig
:param num_samples: The number of random samples to generate
:type num_samples: int
:return: The data object of the generated synthetic data
:rtype: :py:class:`pe.data.Data`
"""
label_name = label_info.name
execution_logger.info(f"RANDOM API: creating {num_samples} samples for label {label_name}")
data_ids = np.random.choice(len(self._data.data_frame), num_samples, replace=False)
images = self._data.data_frame.iloc[data_ids][IMAGE_DATA_COLUMN_NAME].values
data_frame = pd.DataFrame(
{
IMAGE_DATA_COLUMN_NAME: images,
DATA_ID_COLUMN_NAME: data_ids,
LABEL_ID_COLUMN_NAME: 0,
}
)
metadata = {"label_info": [label_info]}
execution_logger.info(f"RANDOM API: finished creating {num_samples} samples for label {label_name}")
return Data(data_frame=data_frame, metadata=metadata)
[docs]
def variation_api(self, syn_data):
"""Generating variations of the synthetic data by finding the nearest images in the given dataset.
:param syn_data: The data object of the synthetic data
:type syn_data: :py:class:`pe.data.Data`
:return: The data object of the variation of the input synthetic data
:rtype: :py:class:`pe.data.Data`
"""
execution_logger.info(f"VARIATION API: creating variations for {len(syn_data.data_frame)} samples")
original_data_ids = list(syn_data.data_frame[DATA_ID_COLUMN_NAME].values)
iteration = getattr(syn_data.metadata, "iteration", -1)
variation_degree = self._variation_degrees[iteration + 1]
execution_logger.info(f"VARIATION API parameters: variation_degree={variation_degree}")
nn_ids = np.random.choice(variation_degree, len(original_data_ids), replace=True)
new_data_ids = self._nearest_neighbor_ids[original_data_ids, nn_ids]
images = self._data.data_frame.iloc[new_data_ids][IMAGE_DATA_COLUMN_NAME].values
data_frame = pd.DataFrame(
{
IMAGE_DATA_COLUMN_NAME: images,
DATA_ID_COLUMN_NAME: new_data_ids,
LABEL_ID_COLUMN_NAME: syn_data.data_frame[LABEL_ID_COLUMN_NAME].values,
}
)
execution_logger.info(f"VARIATION API: finished creating variations for {len(syn_data.data_frame)} samples")
return Data(data_frame=data_frame, metadata=syn_data.metadata)