Source code for pe.callback.tabular.save_tab_to_csv

import os
import pandas as pd
from pe.callback.callback import Callback
from pe.constant.data import LABEL_ID_COLUMN_NAME
from pe.constant.data import TABULAR_DATA_COLUMN_NAME
from pe.logging import execution_logger


[docs] class SaveTabToCSV(Callback): """The callback that saves the synthetic tabular data to a CSV file."""
[docs] def __init__( self, output_folder, iteration_format="09d", ): """Constructor. :param output_folder: The output folder that will be used to save the CSV files :type output_folder: str :param iteration_format: The format of the iteration part of the CSV paths, defaults to "09d" :type iteration_format: str, optional """ self._output_folder = output_folder self._iteration_format = iteration_format
[docs] def _get_csv_path(self, iteration): """Get the CSV path. :param iteration: The PE iteration number :type iteration: int :return: The CSV path :rtype: str """ os.makedirs(self._output_folder, exist_ok=True) iteration_string = format(iteration, self._iteration_format) csv_path = os.path.join( self._output_folder, f"{iteration_string}.csv", ) return csv_path
[docs] def __call__(self, syn_data): """This function is called after each PE iteration that saves the synthetic tabular data to a CSV file. :param syn_data: The :py:class:`pe.data.Data` object of the synthetic data :type syn_data: :py:class:`pe.data.Data` """ execution_logger.info("Saving the synthetic tabular data to a CSV file") feature_columns = syn_data.metadata["feature_columns"] label_ids = syn_data.data_frame[LABEL_ID_COLUMN_NAME].tolist() features_df = pd.DataFrame(syn_data.data_frame[TABULAR_DATA_COLUMN_NAME].tolist(), columns=feature_columns) for i in range(len(syn_data.metadata.label_columns)): column_name = syn_data.metadata.label_columns[i] features_df[column_name] = [ syn_data.metadata.label_info[label_id].column_values[column_name] for label_id in label_ids ] csv_path = self._get_csv_path(syn_data.metadata.iteration) features_df.to_csv(csv_path, index=False) execution_logger.info("Finished saving the synthetic tabular data to a CSV file")