Source code for qcodes.dataset.threading

from __future__ import annotations

# Instead of async, since there are really a very small set of things
# we want to happen simultaneously within one process (namely getting
# several parameters in parallel), we can parallelize them with threads.
# That way the things we call need not be rewritten explicitly async.
import concurrent
import concurrent.futures
import itertools
import logging
from collections import defaultdict
from functools import partial
from typing import TYPE_CHECKING, Protocol, TypeAlias, TypeVar

from qcodes.utils import RespondingThread

if TYPE_CHECKING:
    from collections.abc import Callable, Sequence
    from types import TracebackType

    from qcodes.dataset.data_set_protocol import values_type
    from qcodes.parameters import ParamDataType, ParameterBase

ParamMeasT: TypeAlias = "ParameterBase | Callable[[], None]"
OutType: TypeAlias = "list[tuple[ParameterBase, values_type]]"

T = TypeVar("T")

_LOG = logging.getLogger(__name__)


class _ParamCaller:
    def __init__(self, *parameters: ParameterBase):
        self._parameters = parameters

    def __call__(self) -> tuple[tuple[ParameterBase, ParamDataType], ...]:
        output = []
        for param in self._parameters:
            output.append((param, param.get()))
        return tuple(output)

    def __repr__(self) -> str:
        names = tuple(param.full_name for param in self._parameters)
        return f"ParamCaller of {','.join(names)}"


def _instrument_to_param(
    params: Sequence[ParamMeasT],
) -> dict[str | None, tuple[ParameterBase, ...]]:
    from qcodes.parameters import ParameterBase

    real_parameters = [param for param in params if isinstance(param, ParameterBase)]

    output: dict[str | None, tuple[ParameterBase, ...]] = defaultdict(tuple)
    for param in real_parameters:
        if param.underlying_instrument:
            output[param.underlying_instrument.full_name] += (param,)
        else:
            output[None] += (param,)

    return output


[docs] def call_params_threaded(param_meas: Sequence[ParamMeasT]) -> OutType: """ Function to create threads per instrument for the given set of measurement parameters. Args: param_meas: a Sequence of measurement parameters """ inst_param_mapping = _instrument_to_param(param_meas) executors = tuple( _ParamCaller(*param_list) for param_list in inst_param_mapping.values() ) output: OutType = [] threads = [RespondingThread(target=executor) for executor in executors] for t in threads: t.start() for t in threads: thread_output = t.output() assert thread_output is not None for result in thread_output: output.append(result) return output
def _call_params(param_meas: Sequence[ParamMeasT]) -> OutType: from qcodes.parameters import ParameterBase output: OutType = [] for parameter in param_meas: if isinstance(parameter, ParameterBase): output.append((parameter, parameter.get())) elif callable(parameter): parameter() return output def process_params_meas( param_meas: Sequence[ParamMeasT], use_threads: bool | None = None ) -> OutType: from qcodes import config if use_threads is None: use_threads = config.dataset.use_threads if use_threads: return call_params_threaded(param_meas) return _call_params(param_meas) class _ParamsCallerProtocol(Protocol): def __enter__(self) -> Callable[[], OutType]: ... def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: pass
[docs] class SequentialParamsCaller(_ParamsCallerProtocol): def __init__(self, *param_meas: ParamMeasT): self._param_meas = tuple(param_meas) def __enter__(self) -> Callable[[], OutType]: return partial(_call_params, self._param_meas) def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: return None
[docs] class ThreadPoolParamsCaller(_ParamsCallerProtocol): """ Context manager for calling given parameters in a thread pool. Note that parameters that have the same underlying instrument will be called in the same thread. Usage: .. code-block:: python ... with ThreadPoolParamsCaller(p1, p2, ...) as pool_caller: ... output = pool_caller() ... # Output can be passed directly into DataSaver.add_result: # datasaver.add_result(*output) ... ... Args: param_meas: parameter or a callable without arguments max_workers: number of worker threads to create in the pool; if None, the number of worker threads will be equal to the number of unique "underlying instruments" """ def __init__(self, *param_meas: ParamMeasT, max_workers: int | None = None): self._param_callers = tuple( _ParamCaller(*param_list) for param_list in _instrument_to_param(param_meas).values() ) max_worker_threads = ( len(self._param_callers) if max_workers is None else max_workers ) thread_name_prefix = ( self.__class__.__name__ + ":" + "".join(" " + repr(pc) for pc in self._param_callers) ) self._thread_pool = concurrent.futures.ThreadPoolExecutor( max_workers=max_worker_threads, thread_name_prefix=thread_name_prefix, )
[docs] def __call__(self) -> OutType: """ Call parameters in the thread pool and return `(param, value)` tuples. """ output: OutType = list( itertools.chain.from_iterable( future.result() for future in concurrent.futures.as_completed( self._thread_pool.submit(param_caller) for param_caller in self._param_callers ) ) ) return output
def __enter__(self) -> ThreadPoolParamsCaller: self._thread_pool.__enter__() return self def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: self._thread_pool.__exit__(exc_type, exc_val, exc_tb)