import os
import csv
import torch
import numpy as np
from collections import defaultdict
from .logger import Logger
from pe.metric_item import FloatMetricItem
from pe.metric_item import FloatListMetricItem
[docs]class CSVPrint(Logger):
"""The logger that prints the metrics to CSV files."""
[docs] def __init__(
self,
output_folder,
path_separator="-",
float_format=".8f",
flush_iteration_freq=1,
):
"""Constructor.
:param output_folder: The output folder that will be used to save the CSV files
:type output_folder: str
:param path_separator: The string that will be used to replace '\' and '/' in log names, defaults to "-"
:type path_separator: str, optional
:param float_format: The format of the floating point numbers, defaults to ".8f"
:type float_format: str, optional
:param flush_iteration_freq: The frequency to flush the logs, defaults to 1
:type flush_iteration_freq: int, optional
"""
self._output_folder = output_folder
os.makedirs(self._output_folder, exist_ok=True)
self._path_separator = path_separator
self._float_format = float_format
self._flush_iteration_freq = flush_iteration_freq
self._clear_logs()
[docs] def _clear_logs(self):
"""Clear the logs."""
self._logs = defaultdict(list)
[docs] def _get_log_path(self, iteration, item):
"""Get the log path.
:param iteration: The PE iteration number
:type iteration: int
:param item: The metric item
:type item: :py:class:`pe.metric_item.MetricItem`
:return: The log path
:rtype: str
"""
log_path = item.name
log_path = log_path.replace("/", self._path_separator)
log_path = log_path.replace("\\", self._path_separator)
log_path = os.path.join(self._output_folder, log_path + ".csv")
return log_path
[docs] def _flush(self):
"""Flush the logs."""
for path in self._logs:
with open(path, "a") as f:
writer = csv.writer(f)
writer.writerows(self._logs[path])
[docs] def _log_float(self, log_path, iteration, item):
"""Log a float metric item.
:param log_path: The path of the log file
:type log_path: str
:param iteration: The PE iteration number
:type iteration: int
:param item: The float metric item
:type item: :py:class:`pe.metric_item.FloatMetricItem` or :py:class:`pe.metric_item.FloatListMetricItem`
"""
str_iteration = str(iteration)
str_value = item.value
if isinstance(item.value, torch.Tensor):
str_value = item.value.cpu().detach().numpy()
if isinstance(str_value, np.ndarray):
str_value = str_value.tolist()
if isinstance(str_value, list):
str_value = ",".join([format(v, self._float_format) for v in str_value])
else:
str_value = format(str_value, self._float_format)
self._logs[log_path].append([str_iteration, str_value])
[docs] def log(self, iteration, metric_items):
"""Log the metrics.
:param iteration: The PE iteration number
:type iteration: int
:param metric_items: The metrics to log
:type metric_items: list[:py:class:`pe.metric_item.MetricItem`]
"""
for item in metric_items:
if not isinstance(item, (FloatMetricItem, FloatListMetricItem)):
continue
log_path = self._get_log_path(iteration, item)
self._log_float(log_path, iteration, item)
if iteration % self._flush_iteration_freq == 0:
self._flush()
self._clear_logs()
[docs] def clean_up(self):
"""Clean up the logger."""
self._flush()
self._clear_logs()