Source code for pe.callback.image.sample_images

import numpy as np

from pe.callback.callback import Callback
from pe.constant.data import IMAGE_DATA_COLUMN_NAME
from pe.constant.data import LABEL_ID_COLUMN_NAME
from pe.metric_item import ImageListMetricItem


[docs]class SampleImages(Callback): """The callback that samples images from the synthetic data."""
[docs] def __init__(self, num_images_per_class=10): """Constructor. :param num_images_per_class: number of images to sample per class, defaults to 10 :type num_images_per_class: int, optional """ self._num_images_per_class = num_images_per_class
[docs] def __call__(self, syn_data): """This function is called after each PE iteration that samples images from the synthetic data. :param syn_data: The :py:class:`pe.data.data.Data` object of the synthetic data :type syn_data: :py:class:`pe.data.data.Data` :return: A metric item with the list of sampled images :rtype: list[:py:class:`pe.metric_item.ImageListMetricItem`] """ all_image_list = [] num_classes = len(syn_data.metadata.label_info) for class_id in range(num_classes): image_list = syn_data.data_frame[syn_data.data_frame[LABEL_ID_COLUMN_NAME] == class_id][ IMAGE_DATA_COLUMN_NAME ] image_list = image_list.sample(min(self._num_images_per_class, len(image_list))) all_image_list.extend(image_list) assert len(image_list) > 0 if len(image_list) < self._num_images_per_class: all_image_list.extend(np.zeros_like(image_list[0]) * (self._num_images_per_class - len(image_list))) metric_item = ImageListMetricItem( name="image_sample", value=all_image_list, num_images_per_row=None if num_classes == 1 else self._num_images_per_class, ) return [metric_item]