Common Packages#
APEX Utilities#
- class archai.common.apex_utils.ApexUtils(apex_config: Config)[source]#
-
- step(multi_optim: MultiOptim) None [source]#
- to_amp(model: Module, multi_optim: MultiOptim, batch_size: int) Module [source]#
- clip_grad(clip: float, model: Module, multi_optim: MultiOptim) None [source]#
Atomic File Handler#
- class archai.common.atomic_file_handler.AtomicFileHandler(filename, encoding=None, save_delay=30.0)[source]#
This class opens and writes entire file instead of appending one line at a time
- terminator = '\n'#
- flush()[source]#
Ensure all logging output has been flushed.
This version does nothing and is intended to be implemented by subclasses.
AzureML Helper#
- archai.common.azureml_helper.get_aml_client_from_file(config_path: str | Path) MLClient [source]#
Creates an MLClient object from a workspace config file
- Parameters:
config_path (Union[str, Path]) – Path to the workspace config file
- Returns:
MLClient object
- Return type:
MLClient
- archai.common.azureml_helper.create_compute_cluster(ml_client: MLClient, compute_name: str, type: str | None = 'amlcompute', size: str | None = 'Standard_D14_v2', min_instances: int | None = 0, max_instances: int | None = 4, idle_time_before_scale_down: int | None = 180, tier: str | None = 'Dedicated', **kwargs)[source]#
Creates a compute cluster for the workspace
- Parameters:
ml_client (MLClient) – MLClient object
compute_name (str) – Name of the (CPU/GPU) compute cluster
type (str, optional) – Type of the compute cluster. Defaults to “amlcompute”.
size (str, optional) – VM Family of the compute cluster. Defaults to “Standard_D14_v2”.
min_instances (int, optional) – Minimum running nodes when there is no job running. Defaults to 0.
max_instances (int, optional) – Maximum number of nodes in the cluster. Defaults to 4.
idle_time_before_scale_down (int, optional) – How many seconds will the node be allowed to run after the job termination. Defaults to 180.
tier (str, optional) – Dedicated or LowPriority. The latter is cheaper but there is a chance of job termination. Defaults to “Dedicated”.
- Returns:
Compute object
- Return type:
Compute
- archai.common.azureml_helper.create_environment_from_file(ml_client: MLClient, custom_env_name: str | None = 'aml-archai', description: str | None = 'Custom environment for Archai', tags: Dict[str, Any] | None = None, conda_file: str | None = 'conda.yaml', image: str | None = None, version: str | None = '0.1.0', **kwargs) Environment [source]#
Creates an environment from a conda file
- Parameters:
ml_client (MLClient) – MLClient object
custom_env_name (str, optional) – Name of the environment. Defaults to “aml-archai”.
description (str, optional) – Description of the environment. Defaults to “Custom environment for Archai”.
tags (Dict[str, Any], optional) – Tags for the environment, e.g. {“archai”: “1.0.0”}. Defaults to None.
conda_file (str, optional) – Path to the conda file. Defaults to “conda.yaml”.
image (str, optional) – Docker image for the environment.
version (str, optional) – Version of the environment. Defaults to “0.1.0”.
- Returns:
Environment object
- Return type:
Environment
- archai.common.azureml_helper.download_job_output(ml_client: MLClient, job_name: str, output_name: str, download_path: str | Path | None = 'output') None [source]#
Downloads the output of a job
- Parameters:
ml_client (MLClient) – MLClient object
job_name (str) – Name of the job
output_name (str) – Named output to downlaod
download_path (Union[str, Path], optional) – Path to download the output to. Defaults to “output”.
- Returns:
None
Common#
- archai.common.common.get_tb_writer() SummaryWriterDummy | SummaryWriter [source]#
- archai.common.common.get_state() CommonState [source]#
- archai.common.common.init_from(state: CommonState) None [source]#
- archai.common.common.create_conf(config_filepath: str | None = None, param_args: list = [], use_args=True) Config [source]#
- archai.common.common.common_init(config_filepath: str | None = None, param_args: list = [], use_args=True, clean_expdir=False) Config [source]#
- archai.common.common.expdir_abspath(path: str, create=False) str [source]#
Returns full path for given relative path within experiment directory.
- archai.common.common.create_tb_writer(conf: Config, is_master=True) SummaryWriterDummy | SummaryWriter [source]#
- archai.common.common.update_envvars(conf) None [source]#
Get values from config and put it into env vars
Configuration#
- archai.common.config.deep_update(d: MutableMapping, u: Mapping, create_map: Callable[[], MutableMapping]) MutableMapping [source]#
Delimited Text#
Deprecation (Utilities)#
- archai.common.deprecation_utils.deprecated(message: str | None = None, deprecate_version: str | None = None, remove_version: str | None = None) None [source]#
Decorator to mark a function or class as deprecated.
- Parameters:
message – Message to include in the warning.
deprecated_version – Version in which the function was deprecated. If None, the version will not be included in the warning message.
remove_version – Version in which the function will be removed. If None, the version will not be included in the warning message.
Distributed (Utilities)#
- archai.common.distributed_utils.init_distributed(use_cuda: bool) None [source]#
Initialize distributed backend for parallel training.
This method sets up the distributed backend for parallel training based on the specified use_cuda flag. If use_cuda is True, it initializes the distributed mode using the CUDA/NCCL backend. Otherwise, it uses the Gloo backend.
- Parameters:
use_cuda – Whether to initialize the distributed mode using the CUDA/NCCL backend.
- Raises:
AssertionError – If the distributed mode is not initialized successfully.
- archai.common.distributed_utils.barrier() None [source]#
Synchronize all processes in the distributed backend.
This method calls the torch.distributed.barrier function if the distributed mode is available and initialized. The barrier function synchronizes all processes in the distributed backend by blocking the processes until all processes have reached this point.
- archai.common.distributed_utils.get_rank() int [source]#
Get the rank of the current process in the distributed backend.
- Returns:
- The rank of the current process in the distributed backend. If the distributed mode
is not available or not initialized, the returned rank will be 0.
- archai.common.distributed_utils.get_world_size() int [source]#
Get the total number of processes in the distributed backend.
- Returns:
- The total number of processes in the distributed backend. If the distributed mode
is not available or not initialized, the returned world size will be 1.
- archai.common.distributed_utils.all_reduce(tensor: int | float | Tensor, op: str | None = 'sum') int | float [source]#
Reduce the input tensor/value into a scalar using the specified reduction operator.
This method applies the specified reduction operator to the input tensor/value in a distributed manner. The result is a scalar value that is computed by aggregating the values from all processes in the distributed backend.
- Parameters:
tensor – Input tensor/value to be reduced.
op – Type of reduction operator. The supported operators are “sum”, “mean”, “min”, “max”, and “product”.
- Returns:
- The scalar value obtained by applying the reduction operator to the input
tensor/value. If the distributed mode is not available or not initialized, the inputvtensor/value is returned as is.
- Raises:
RuntimeError – If the specified reduction operator is not supported.
- archai.common.distributed_utils.sync_workers() Generator[int, None, None] [source]#
Context manager for synchronizing the processes in the distributed backend.
This context manager yields the rank of the current process in the distributed backend and synchronizes all processes on exit.
- Yields:
The rank of the current process in the distributed backend.
Example
>>> with sync_workers(): >>> # Execute some code that should be synchronized across all processes. >>> pass
File-Related (Utilities)#
- archai.common.file_utils.calculate_onnx_model_size(model_path: str) float [source]#
Calculate the size of an ONNX model.
This function calculates the size of an ONNX model by reading the size of the file on disk.
- Parameters:
model_path – The path to the ONNX model on disk.
- Returns:
The size of the model in megabytes.
- archai.common.file_utils.calculate_torch_model_size(model: Module) float [source]#
Calculate the size of a PyTorch model.
This function calculates the size of a PyTorch model by saving its state dictionary to a temporary file and reading the size of the file on disk.
- Parameters:
model – The PyTorch model.
- Returns:
The size of the model in megabytes.
- archai.common.file_utils.check_available_checkpoint(folder_name: str) bool [source]#
Check if there are any available checkpoints in a given folder.
This function checks if a given folder contains any checkpoints by looking for directories that match a regular expression for checkpoint names.
- Parameters:
folder_name – The path to the folder that might contain checkpoints.
- Returns:
True if there are available checkpoints, False otherwise.
- archai.common.file_utils.create_file_name_identifier(file_name: str, identifier: str) str [source]#
Create a new file name by adding an identifier to the end of an existing file name (before the file extension).
- Parameters:
file_name – The original file name.
identifier – The identifier to be added to the file name.
- Returns:
The new file name with the added identifier.
- archai.common.file_utils.create_empty_file(file_path: str) None [source]#
Create an empty file at the given path.
- Parameters:
file_path – The path to the file to be created.
- archai.common.file_utils.create_file_with_string(file_path: str, content: str) None [source]#
Create a file at the given path and writes the given string to it.
- Parameters:
file_path – The path to the file to be created.
content – The string to be written to the file.
- archai.common.file_utils.copy_file(src_file_path: str, dest_file_path: str, force_shutil: bool | None = True, keep_metadata: bool | None = False) str [source]#
Copy a file from one location to another.
- Parameters:
src_file_path – The path to the source file.
dest_file_path – The path to the destination file.
force_shutil – Whether to use shutil to copy the file.
keep_metadata – Whether to keep source file metadata when copying.
- Returns:
The path to the destination file.
- archai.common.file_utils.get_full_path(path: str, create_folder: bool | None = False) str [source]#
Get the full path to a file or folder.
- Parameters:
path – The path to the file or folder.
create_folder – Whether to create the folder if it does not exist.
- Returns:
The full path to the file or folder.
ML Performance (Utilities)#
ML (Utilities)#
- archai.common.ml_utils.join_chunks(chunks: List[Tensor]) Tensor [source]#
If batch was divided in chunks, this functions joins them again
- archai.common.ml_utils.create_lr_scheduler(conf_lrs: Config, epochs: int, optimizer: Optimizer, steps_per_epoch: int | None) Tuple[_LRScheduler | None, bool] [source]#
- archai.common.ml_utils.param_size(module: Module, ignore_aux=True, only_req_grad=False)[source]#
count all parameters excluding auxiliary
- archai.common.ml_utils.accuracy(output, target, topk=(1,))[source]#
Computes the precision@k for the specified values of k
Model Summary#
Notebook Helper#
- archai.common.notebook_helper.get_search_csv(output_path: str | Path, iteration_num: int | None = -1) DataFrame [source]#
Reads the search csv file from the output path and returns a pandas dataframe
- Parameters:
output_path (Union[str, Path]) – Path to the output directory
iteration_num (int, optional) – Search iteration to read from. Defaults to -1, which will point to the last iteration
- Returns:
Pandas dataframe with the search state
- Return type:
pd.DataFrame
- archai.common.notebook_helper.get_arch_abs_path(archid: str, downloaded_folder: str | Path, iteration_num: int | None = -1) Path [source]#
Returns the absolute path to the architecture file
- Parameters:
archid (str) – Architecture id
downloaded_folder (Union[str, Path]) – Path to the downloaded folder
iteration_num (int, optional) – Search iteration to read from. Defaults to -1, which will point to the last iteration
- Returns:
Absolute path to the architecture file
- Return type:
Path
Ordered Dict Logger#
- class archai.common.ordered_dict_logger.OrderedDictLogger(source: str | None = None, file_path: str | None = None, delay: float | None = 60.0)[source]#
Log and save data in a hierarchical YAML structure.
The purpose of the structured logging is to store logs as key value pair. However, when you have loop and sub routine calls, what you need is hierarchical dictionaries where the value for a key could be a dictionary. The idea is that you set one of the nodes in tree as current node and start logging your values. You can then use pushd to create and go to child node and popd to come back to parent.
To implement this mechanism we use two main variables: _stack allows us to push each node on stack when pushd is called. The node is OrderedDictionary. As a convinience, we let specify child path in pushd in which case child hierarchy is created and current node will be set to the last node in specified path. When popd is called, we go back to original parent instead of parent of current node. To implement this we use _paths variable which stores subpath when each pushd call was made.
- property root_node: OrderedDict#
Return the root node of the current stack.
- property current_node: OrderedDict#
Return the current node of the current stack.
- Raises:
RuntimeError – If a key stores a scalar value and is trying to store new information.
- property current_path: str#
Return the current path of the current stack.
- save() None [source]#
Save the current log data to an output file.
This method only saves to a file if a valid file_path has been provided in the constructor.
- load(file_path: str) None [source]#
Load log data from an input file.
- Parameters:
file_path – File path to load data from.
- log(obj: Dict[str, Any] | str, level: int | None = None, override_key: bool | None = True) None [source]#
Log the provided dictionary/string at the specified level.
- Parameters:
obj – Object to log.
level – Logging level.
override_key – Whether key can be overridden if it’s already in current node.
- info(obj: Dict[str, Any] | str, override_key: bool | None = True) None [source]#
Log the provided dictionary/string at the info level.
- Parameters:
obj – Object to log.
override_key – Whether key can be overridden if it’s already in current node.
- debug(obj: Dict[str, Any] | str, override_key: bool | None = True) None [source]#
Log the provided dictionary/string at the debug level.
- Parameters:
obj – Object to log.
override_key – Whether key can be overridden if it’s already in current node.
- warn(obj: Dict[str, Any] | str, override_key: bool | None = True) None [source]#
Log the provided dictionary/string at the warning level.
- Parameters:
obj – Object to log.
override_key – Whether key can be overridden if it’s already in current node.
- error(obj: Dict[str, Any] | str, override_key: bool | None = True) None [source]#
Log the provided dictionary/string at the error level.
- Parameters:
obj – Object to log.
override_key – Whether key can be overridden if it’s already in current node.
- pushd(*keys: Any) OrderedDictLogger [source]#
Push the provided keys onto the current path stack.
- Returns:
Instance of current logger.
- static set_global_instance(instance: OrderedDictLogger) None [source]#
Set a global logger instance.
- Parameters:
instance – Instance to be set globally.
- static get_global_instance() OrderedDictLogger [source]#
Get a global logger instance.
- Returns:
Global logger.
- archai.common.ordered_dict_logger.get_global_logger() OrderedDictLogger [source]#
Get a global logger instance.
This method assures that if a global logger instance does not exist, it will be created and set as the global logger instance.
- Returns:
Global logger.
Ordered Dict Logger (Utilities)#
- class archai.common.ordered_dict_logger_utils.RankFilter(rank: int)[source]#
A filter for logging records based on the rank of the process.
Only log records from the process with rank 0 will be logged, while log records from other processes will be filtered out.
- archai.common.ordered_dict_logger_utils.get_console_handler() StreamHandler [source]#
Get a StreamHandler for logging to the console.
The StreamHandler can be used to log messages to the console (i.e., sys.stdout) and is configured with a formatter.
- Returns:
A StreamHandler for logging to the console.
Stopwatch#
Store#
- class archai.common.store.ArchaiStore(storage_account_name, storage_account_key, blob_container_name='models', table_name='status', partition_key='main')[source]#
ArchaiStore wraps an Azure ‘status’ Table and associated Blob Storage used to provide a backing store for models and an associated table for collating status of long running jobs. This is actually a general purpose utility class that could be used for anything.
The naming scheme is such that each Entity in the table has a ‘name’ property which is a simple friendly name or a guid, and this row will have an associated folder in the blob storage container with the same name where models and other peripheral files can be stored.
The ‘status’ table supports a locking concept that allows the status table to be used as a way of coordinating jobs across multiple machines where each machine grabs free work, locks that row until the work is done, uploads new files, and updates the status to ‘complete’ then unlocks that row. So this ArchaiStore can be used as the backing store for a simple distributed job scheduler.
This also has a convenient command line interface provided below.
- static parse_connection_string(storage_connection_string)[source]#
This helper method extracts the storage account name and key pair from a connection string and returns that pair in a tuple. This pair can then be used to construct an ArchaiStore object
- get_utc_date()[source]#
This handy function can be used to put a UTC timestamp column in your entity, like a ‘model_date’ column, for example.
- get_all_status_entities(status=None, not_equal=False)[source]#
Get all status entities with optional status column filter. For example, pass “status=complete” to find all status rows that have the status “complete”. Pass not_equal of True if you want to check the status is not equal to the given value.
- get_status(name)[source]#
Get or create a new status entity with the given name. The returned entity is a python dictionary where the name can be retrieved using e[‘name’], you can then add keys to that dictionary and call update_status_entity.
- get_existing_status(name)[source]#
Find the given entity by name, and return it, or return None if the name is not found.
- get_updated_status(e)[source]#
Return an updated version of the entity by querying the table again, this way you can pick up any changes that another process may have made.
- update_status_entity(entity)[source]#
This method replaces everything in the entity store with what you have here. The entity can store strings, bool, float, int, datetime, so anything like a python list is best serialized using json.dumps and stored as a string, the you can use json.loads to parse it later.
- merge_status_entity(entity)[source]#
This method merges everything in the entity store with what you have here. So you can add a property without clobbering any other new properties other processes have added in parallel. Merge cannot delete properties, for that you have to use update_status_entity.
The entity can store strings, bool, float, int, datetime, so anything like a python list is best serialized using json.dumps and stored as a string, the you can use json.loads to parse it later.
- update_status(name, status, priority=None)[source]#
This is a simple wrapper that gets the entity by name, and updates the status field. If you already have the entity then use update_status_entity.
- delete_status(name)[source]#
Delete the status entry with this name, note this does not delete any associated blobs. See delete_blobs for that.
- delete_status_entity(e)[source]#
Delete the status entry with this name, note this does not delete any associated blobs. See delete_blobs for that.
- upload_blob(folder_name, file, blob_name=None)[source]#
Upload the given file to the blob store, under the given folder name. The folder name could have multiple parts like ‘project/experiment/foo’. By default the blob will use the base file name, but you can override that with the given blob_name if you want to.
- lock(name, status)[source]#
Lock the named entity to this computer identified by platform.node() and set the status to the given status. This way you can use this ArchaiStore as a way of coordinating the parallel executing of a number of jobs, where each long running job is allocated to a particular node in a distributed cluster using this locking mechanism. Be sure to call unlock when done, preferably in a try/finally block.
- lock_entity(e, status)[source]#
Lock the given entity to this computer identified by platform.node() and set the status to the given status. This way you can use this ArchaiStore as a way of coordinating the parallel executing of a number of jobs, where each long running job is allocated to a particular node in a distributed cluster using this locking mechanism. Be sure to call unlock when done, preferably in a try/finally block.
- is_locked(name)[source]#
Return true if the entity exists and is locked by anyone (including this computer).
- is_locked_by_self(name)[source]#
Return true if the entity exists and is locked this computer. This is handy if the computer restarts and wants to continue processing rows it has already claimed.
- is_locked_by_other(name)[source]#
Return true if the entity exists and is locked some other computer. This will tell the local computer not to touch this row of the table as someone else has it.
- unlock(name)[source]#
Unlock the entity (regardless of who owns it - so use carefully, preferably only when is_locked_by_self is true).
- unlock_entity(e)[source]#
Unlock the entity (regardless of who owns it - so use carefully, preferably only when is_locked_by_self is true).
- unlock_all(node_name)[source]#
This is a sledge hammer for unlocking all entities, use carefully. This might be necessary if you are moving everything to a new cluster.
- reset(name, except_list=[])[source]#
This resets all properties on the given entity that are not primary keys, ‘name’ or ‘status’ and are not in the given except_list. This will not touch a node that is locked by another computer.
- upload(name, path, reset, priority=0, **kwargs)[source]#
Upload a file to the named folder in the blob store associated with this ArchaiStore and add the given named status row in our status table. It also locks the row with ‘uploading’ status until the upload is complete which ensures another machine does not try processing work until the upload is finished. The path points to a file or a folder. If a folder it uploads everything in that folder. This can also optionally reset the row, since sometimes you want to upload a new model for training, then reset all the metrics computed on the previous model. The optional priority is just a added as a property on the row which can be used by a distributed job scheduler to prioritize the work that is being queued up in this table.
- batch_upload(path, glob_pattern='*.onnx', override=False, reset=False, priority=0, **kwargs)[source]#
Upload all the matching files in the given path to the blob store where the status table ‘name’ will be the base name of the files found by the given non-recursive glob_pattern.
- download(name, folder, specific_file=None)[source]#
Download files from the given folder name from our associated blob container and return a list of the local paths to all downloaded files. If an optional specific_file is given then it tries to find and download that file only. Returns a list of local files created. The specific_file can be a regular expression like ‘*.onnx’.
Timing#
Utilities#
- archai.common.utils.deep_update(d: ~typing.MutableMapping, u: ~typing.Mapping, map_type: ~typing.Type[~typing.MutableMapping] = <class 'dict'>) MutableMapping [source]#
- archai.common.utils.append_csv_file(filepath: str, new_row: List[Tuple[str, Any]], delimiter='\t')[source]#
- archai.common.utils.download_and_extract_tar(url, download_root, extract_root=None, filename=None, md5=None, **kwargs)[source]#
- archai.common.utils.download_and_extract_zip(url, download_root, extract_root=None, filename=None, md5=None, **kwargs)[source]#
- archai.common.utils.exec_shell_command(command: str, print_command_start=True, print_command_end=True) CompletedProcess [source]#
- archai.common.utils.filepath_without_ext(filepath: str) str [source]#
Returns ‘/a/b/c/d.e’ for ‘/a/b/c/d.e.f’
- archai.common.utils.filepath_name_ext(filepath: str) str [source]#
Returns ‘d.e.f’ for ‘/a/b/c/d.e.f’
- archai.common.utils.filepath_name_only(filepath: str) str [source]#
Returns ‘d.e’ for ‘/a/b/c/d.e.f’
- archai.common.utils.change_filepath_ext(filepath: str, new_ext: str) str [source]#
Returns ‘/a/b/c/d.e.g’ for filepath=’/a/b/c/d.e.f’, new_ext=’.g’
- archai.common.utils.change_filepath_name(filepath: str, new_name: str, new_ext: str | None = None) str [source]#
Returns ‘/a/b/c/h.f’ for filepath=’/a/b/c/d.e.f’, new_name=’h’
- archai.common.utils.append_to_filename(filepath: str, name_suffix: str, new_ext: str | None = None) str [source]#
Returns ‘/a/b/c/h.f’ for filepath=’/a/b/c/d.e.f’, new_name=’h’
- archai.common.utils.copy_file(src_file: str, dest_dir_or_file: str, preserve_metadata=False, use_shutil: bool = True) str [source]#
- archai.common.utils.is_main_process() bool [source]#
Returns True if this process was started as main process instead of child process during multiprocessing