Source code for archai.common.ordered_dict_logger_utils
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import os
import sys
from logging import Filter, Formatter, Logger, LogRecord, StreamHandler
from logging.handlers import TimedRotatingFileHandler
FORMATTER = Formatter("%(asctime)s - %(name)s — %(levelname)s — %(message)s")
LOG_FILE = "archai.log"
LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0))
[docs]class RankFilter(Filter):
"""A filter for logging records based on the rank of the process.
Only log records from the process with rank 0 will be logged,
while log records from other processes will be filtered out.
"""
def __init__(self, rank: int) -> None:
"""Initialize the filter with the rank of the process.
Args:
rank: The rank of the process that will generate log records.
"""
self.rank = rank
[docs] def filter(self, record: LogRecord) -> bool:
"""Filter a logging record based on the process rank.
Args:
record: The logging record to be filtered.
Returns:
`True` if the record should be logged, `False` otherwise.
"""
return self.rank == 0
[docs]def get_console_handler() -> StreamHandler:
"""Get a `StreamHandler` for logging to the console.
The `StreamHandler` can be used to log messages to the
console (i.e., `sys.stdout`) and is configured with a formatter.
Returns:
A `StreamHandler` for logging to the console.
"""
console_handler = StreamHandler(sys.stdout)
console_handler.setFormatter(FORMATTER)
return console_handler
[docs]def get_timed_file_handler() -> TimedRotatingFileHandler:
"""Get a `TimedRotatingFileHandler` for logging to timestamped files.
Returns:
A `TimedRotatingFileHandler` for logging to timestamped files.
"""
file_handler = TimedRotatingFileHandler(LOG_FILE, delay=True, when="midnight", encoding="utf-8")
file_handler.setFormatter(FORMATTER)
return file_handler
[docs]def get_logger(logger_name: str) -> Logger:
"""Get a logger with the specified name and default settings.
Args:
logger_name: The name of the logger.
Returns:
A `Logger` instance with the specified name and default settings.
"""
logger = logging.getLogger(logger_name)
logger.setLevel(logging.DEBUG)
logger.addHandler(get_console_handler())
logger.addHandler(get_timed_file_handler())
logger.addFilter(RankFilter(LOCAL_RANK))
logger.propagate = False
return logger