Source code for archai.common.file_utils

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import os
import pathlib
import re
import shutil
import tempfile
from pathlib import Path
from typing import Optional
from types import TracebackType

import torch

# File-related constants
CHECKPOINT_FOLDER_PREFIX = "checkpoint"
CHECKPOINT_REGEX = re.compile(r"^" + CHECKPOINT_FOLDER_PREFIX + r"\-(\d+)$")


[docs]def calculate_onnx_model_size(model_path: str) -> float: """Calculate the size of an ONNX model. This function calculates the size of an ONNX model by reading the size of the file on disk. Args: model_path: The path to the ONNX model on disk. Returns: The size of the model in megabytes. """ size = os.path.getsize(model_path) / 1e6 return size
[docs]def calculate_torch_model_size(model: torch.nn.Module) -> float: """Calculate the size of a PyTorch model. This function calculates the size of a PyTorch model by saving its state dictionary to a temporary file and reading the size of the file on disk. Args: model: The PyTorch model. Returns: The size of the model in megabytes. """ torch.save(model.state_dict(), "temp.p") size = os.path.getsize("temp.p") / 1e6 os.remove("temp.p") return size
[docs]def check_available_checkpoint(folder_name: str) -> bool: """Check if there are any available checkpoints in a given folder. This function checks if a given folder contains any checkpoints by looking for directories that match a regular expression for checkpoint names. Args: folder_name: The path to the folder that might contain checkpoints. Returns: `True` if there are available checkpoints, `False` otherwise. """ if not os.path.exists(folder_name): return False folder_content = os.listdir(folder_name) checkpoints = [ path for path in folder_content if CHECKPOINT_REGEX.search(path) is not None and os.path.isdir(os.path.join(folder_name, path)) ] if len(checkpoints) == 0: return False return True
[docs]def create_file_name_identifier(file_name: str, identifier: str) -> str: """Create a new file name by adding an identifier to the end of an existing file name (before the file extension). Args: file_name: The original file name. identifier: The identifier to be added to the file name. Returns: The new file name with the added identifier. """ file_name = Path(file_name) file_name_identifier = file_name.parent.joinpath(file_name.stem + identifier).with_suffix(file_name.suffix) return file_name_identifier.as_posix()
[docs]def create_empty_file(file_path: str) -> None: """Create an empty file at the given path. Args: file_path: The path to the file to be created. """ open(file_path, "w").close()
[docs]def create_file_with_string(file_path: str, content: str) -> None: """Create a file at the given path and writes the given string to it. Args: file_path: The path to the file to be created. content: The string to be written to the file. """ pathlib.Path(file_path).write_text(content)
[docs]def copy_file( src_file_path: str, dest_file_path: str, force_shutil: Optional[bool] = True, keep_metadata: Optional[bool] = False ) -> str: """Copy a file from one location to another. Args: src_file_path: The path to the source file. dest_file_path: The path to the destination file. force_shutil: Whether to use `shutil` to copy the file. keep_metadata: Whether to keep source file metadata when copying. Returns: The path to the destination file. """ def _copy_file_basic_mode(src_file_path: str, dest_file_path: str) -> str: if os.path.isdir(dest_file_path): dest_file_path = os.path.join(dest_file_path, pathlib.Path(src_file_path).name) with open(src_file_path, "rb") as src, open(dest_file_path, "wb") as dest: dest.write(src.read()) return dest_file_path if not force_shutil: return _copy_file_basic_mode(src_file_path, dest_file_path) # Note shutil.copy2 might fail on Azure if file system does not support OS level copystats # Use keep_metadata=True only if needed for maximum compatibility try: copy_fn = shutil.copy2 if keep_metadata else shutil.copy return copy_fn(src_file_path, dest_file_path) except OSError as e: if keep_metadata or e.errno != 38: # OSError 38: Function not implemented raise return _copy_file_basic_mode(src_file_path, dest_file_path)
[docs]def get_full_path(path: str, create_folder: Optional[bool] = False) -> str: """Get the full path to a file or folder. Args: path: The path to the file or folder. create_folder: Whether to create the folder if it does not exist. Returns: The full path to the file or folder. """ assert path path = os.path.abspath(os.path.expanduser(os.path.expandvars(path))) if create_folder: os.makedirs(path, exist_ok=True) return path
[docs]class TemporaryFiles: """ Windows has a weird quirk where the tempfile.NamedTemporaryFile cannot be opened a second time. """ def __init__(self): self.files_to_delete = [] def __enter__(self): return self def __exit__(self, exc_type: type[BaseException], exc_val: BaseException, exc_tb: TracebackType) -> None: for name in self.files_to_delete: os.unlink(name) self.files_to_delete = []
[docs] def get_temp_file(self) -> str: result = None with tempfile.NamedTemporaryFile(delete=False) as tmp: result = tmp.name self.files_to_delete += [result] return result