import os
from omegaconf import OmegaConf
import pandas as pd
from pe.constant.data import LABEL_ID_COLUMN_NAME
[docs]class Data:
"""The class that holds the private data or synthetic data from PE."""
[docs] def __init__(self, data_frame=None, metadata={}):
"""Constructor.
:param data_frame: A pandas dataframe that holds the data, defaults to None
:type data_frame: :py:class:`pandas.DataFrame`, optional
:param metadata: the metadata of the data, defaults to {}
:type metadata: dict, optional
"""
self.data_frame = data_frame
self.metadata = OmegaConf.create(metadata)
self._data_frame_file_name = "data_frame.pkl"
self._metadata_file_name = "metadata.yaml"
def __str__(self):
return f"Metadata:\n{self.metadata}\nData frame:\n{self.data_frame}"
[docs] def save_checkpoint(self, path):
"""Save the data to a checkpoint.
:param path: The folder to save the checkpoint
:type path: str
:raises ValueError: If the path is None
:raises ValueError: If the data frame is empty
"""
if path is None:
raise ValueError("Path is None")
if self.data_frame is None:
raise ValueError("Data frame is empty")
os.makedirs(path, exist_ok=True)
self.data_frame.to_pickle(os.path.join(path, self._data_frame_file_name))
with open(os.path.join(path, self._metadata_file_name), "w") as file:
file.write(OmegaConf.to_yaml(self.metadata))
[docs] def load_checkpoint(self, path):
"""Load data from a checkpoint
:param path: The folder that contains the checkpoint
:type path: str
:return: Whether the checkpoint is loaded successfully
:rtype: bool
"""
data_frame_path = os.path.join(path, self._data_frame_file_name)
metadata_path = os.path.join(path, self._metadata_file_name)
if not os.path.exists(data_frame_path) or not os.path.exists(metadata_path):
return False
self.data_frame = pd.read_pickle(data_frame_path)
with open(metadata_path, "r") as file:
self.metadata = OmegaConf.create(file.read())
return True
[docs] def filter_label_id(self, label_id):
"""Filter the data frame according to a label id
:param label_id: The label id that is used to filter the data frame
:type label_id: int
:return: :py:class:`pe.data.data.Data` object with the filtered data frame
:rtype: :py:class:`pe.data.data.Data`
"""
return Data(
data_frame=self.data_frame[self.data_frame[LABEL_ID_COLUMN_NAME] == label_id],
metadata=self.metadata,
)
[docs] def set_label_id(self, label_id):
"""Set the label id for the data frame
:param label_id: The label id to set
:type label_id: int
"""
self.data_frame[LABEL_ID_COLUMN_NAME] = label_id
[docs] def truncate(self, num_samples):
"""Truncate the data frame to a certain number of samples
:param num_samples: The number of samples to truncate
:type num_samples: int
:return: A new :py:class:`pe.data.data.Data` object with the truncated data frame
:rtype: :py:class:`pe.data.data.Data`
"""
return Data(data_frame=self.data_frame[:num_samples], metadata=self.metadata)
[docs] def random_truncate(self, num_samples):
"""Randomly truncate the data frame to a certain number of samples
:param num_samples: The number of samples to randomly truncate
:type num_samples: int
:return: A new :py:class:`pe.data.data.Data` object with the randomly truncated data frame
:rtype: :py:class:`pe.data.data.Data`
"""
data_frame = self.data_frame.sample(n=num_samples)
return Data(data_frame=data_frame, metadata=self.metadata)
[docs] def merge(self, data):
"""Merge the data object with another data object
:param data: The data object to merge
:type data: :py:class:`pe.data.data.Data`
:raises ValueError: If the metadata of `data` is not the same as the metadata of the current object
:return: The merged data object
:rtype: :py:class:`pe.data.data.Data`
"""
if self.metadata != data.metadata:
raise ValueError("Metadata must be the same")
cols_to_use = data.data_frame.columns.difference(self.data_frame.columns)
if len(cols_to_use) == 0:
return self
data_frame = self.data_frame.join(data.data_frame[cols_to_use])
return Data(data_frame=data_frame, metadata=self.metadata)
[docs] @classmethod
def concat(cls, data_list, metadata=None):
"""Concatenate the data frames of a list of data objects
:param data_list: The list of data objects to concatenate
:type data_list: list[:py:class:`pe.data.data.Data`]
:param metadata: The metadata of the concatenated data. When None, the metadata of the list of data objects
must be the same and will be used. Defaults to None
:type metadata: dict, optional
:raises ValueError: If the metadata of the data objects are not the same
:return: The concatenated data object
:rtype: :py:class:`pe.data.data.Data`
"""
data_frame_list = [data.data_frame for data in data_list]
if metadata is None:
metadata_list = [data.metadata for data in data_list]
# Check that all metadata are the same.
if len(set(metadata_list)) != 1:
raise ValueError("Metadata must be the same")
metadata = metadata_list[0]
return Data(data_frame=pd.concat(data_frame_list), metadata=metadata)