Source code for pe.data.text.text_csv

from pe.data import Data
import pandas as pd
from pe.constant.data import LABEL_ID_COLUMN_NAME
from pe.constant.data import TEXT_DATA_COLUMN_NAME


[docs] class TextCSV(Data): """The text dataset in CSV format."""
[docs] def __init__(self, csv_path, label_columns=[], text_column="text", num_samples=None): """Constructor. :param csv_path: The path to the CSV file :type csv_path: str :param label_columns: The names of the columns that contain the labels, defaults to [] :type label_columns: list, optional :param text_column: The name of the column that contains the text data, defaults to "text" :type text_column: str, optional :param num_samples: The number of samples to load from the CSV file. If None, load all samples. Defaults to None :type num_samples: int, optional :raises ValueError: If the label columns or text column does not exist in the CSV file """ data_frame = pd.read_csv(csv_path, dtype=str) if num_samples is not None: data_frame = data_frame[:num_samples] for column in label_columns + [text_column]: if column not in data_frame.columns: raise ValueError(f"Column {column} does not exist in the CSV file") labels = data_frame.apply(lambda row: tuple([row[col] for col in label_columns]), axis=1).tolist() label_set = list(sorted(set(labels))) label_id_map = {label: i for i, label in enumerate(label_set)} label_ids = [label_id_map[label] for label in labels] data_frame[LABEL_ID_COLUMN_NAME] = label_ids label_info = [ { "name": " | ".join(f"{label_columns[i]}: {label[i]}" for i in range(len(label_columns))), "column_values": {label_columns[i]: label[i] for i in range(len(label_columns))}, } for label in label_set ] metadata = {"label_columns": label_columns, "text_column": text_column, "label_info": label_info} data_frame = data_frame.rename(columns={text_column: TEXT_DATA_COLUMN_NAME}) super().__init__(data_frame=data_frame, metadata=metadata)