Source code for qcodes.dataset.threading

from __future__ import annotations

# Instead of async, since there are really a very small set of things
# we want to happen simultaneously within one process (namely getting
# several parameters in parallel), we can parallelize them with threads.
# That way the things we call need not be rewritten explicitly async.
import concurrent.futures
import itertools
import logging
from collections import defaultdict
from functools import partial
from typing import TYPE_CHECKING, Protocol, TypeAlias, TypeVar

from qcodes.utils import RespondingThread

if TYPE_CHECKING:
    from collections.abc import Callable, Sequence
    from types import TracebackType
    from typing import Self

    from qcodes.dataset.data_set_protocol import ValuesType
    from qcodes.parameters import ParamDataType, ParameterBase

ParamMeasT: TypeAlias = "ParameterBase | Callable[[], None]"
OutType: TypeAlias = "list[tuple[ParameterBase, ValuesType]]"

T = TypeVar("T")

_LOG = logging.getLogger(__name__)


class _ParamCaller:
    def __init__(self, *parameters: ParameterBase):
        self._parameters = parameters

    def __call__(self) -> tuple[tuple[ParameterBase, ParamDataType], ...]:
        output = []
        for param in self._parameters:
            output.append((param, param.get()))
        return tuple(output)

    def __repr__(self) -> str:
        names = tuple(param.full_name for param in self._parameters)
        return f"ParamCaller of {','.join(names)}"


def _instrument_to_param(
    params: Sequence[ParamMeasT],
) -> dict[str | None, tuple[ParameterBase, ...]]:
    from qcodes.parameters import ParameterBase  # noqa: PLC0415

    real_parameters = [param for param in params if isinstance(param, ParameterBase)]

    output: dict[str | None, tuple[ParameterBase, ...]] = defaultdict(tuple)
    for param in real_parameters:
        if param.underlying_instrument:
            output[param.underlying_instrument.full_name] += (param,)
        else:
            output[None] += (param,)

    return output


[docs] def call_params_threaded(param_meas: Sequence[ParamMeasT]) -> OutType: """ Function to create threads per instrument for the given set of measurement parameters. Args: param_meas: a Sequence of measurement parameters """ inst_param_mapping = _instrument_to_param(param_meas) executors = tuple( _ParamCaller(*param_list) for param_list in inst_param_mapping.values() ) output: OutType = [] threads = [RespondingThread(target=executor) for executor in executors] for t in threads: t.start() for t in threads: thread_output = t.output() assert thread_output is not None for result in thread_output: output.append(result) return output
def _call_params(param_meas: Sequence[ParamMeasT]) -> OutType: from qcodes.parameters import ParameterBase # noqa: PLC0415 output: OutType = [] for parameter in param_meas: if isinstance(parameter, ParameterBase): output.append((parameter, parameter.get())) elif callable(parameter): parameter() return output def process_params_meas( param_meas: Sequence[ParamMeasT], use_threads: bool | None = None ) -> OutType: from qcodes import config # noqa: PLC0415 if use_threads is None: use_threads = config.dataset.use_threads if use_threads: return call_params_threaded(param_meas) return _call_params(param_meas) class _ParamsCallerProtocol(Protocol): def __enter__(self) -> Callable[[], OutType]: ... def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: pass
[docs] class SequentialParamsCaller(_ParamsCallerProtocol): def __init__(self, *param_meas: ParamMeasT): self._param_meas = tuple(param_meas) def __enter__(self) -> Callable[[], OutType]: return partial(_call_params, self._param_meas) def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: return None
[docs] class ThreadPoolParamsCaller(_ParamsCallerProtocol): """ Context manager for calling given parameters in a thread pool. Note that parameters that have the same underlying instrument will be called in the same thread. Usage: .. code-block:: python ... with ThreadPoolParamsCaller(p1, p2, ...) as pool_caller: ... output = pool_caller() ... # Output can be passed directly into DataSaver.add_result: # datasaver.add_result(*output) ... ... Args: param_meas: parameter or a callable without arguments max_workers: number of worker threads to create in the pool; if None, the number of worker threads will be equal to the number of unique "underlying instruments" """ def __init__(self, *param_meas: ParamMeasT, max_workers: int | None = None): self._param_callers = tuple( _ParamCaller(*param_list) for param_list in _instrument_to_param(param_meas).values() ) max_worker_threads = ( len(self._param_callers) if max_workers is None else max_workers ) thread_name_prefix = ( self.__class__.__name__ + ":" + "".join(" " + repr(pc) for pc in self._param_callers) ) self._thread_pool = concurrent.futures.ThreadPoolExecutor( max_workers=max_worker_threads, thread_name_prefix=thread_name_prefix, )
[docs] def __call__(self) -> OutType: """ Call parameters in the thread pool and return `(param, value)` tuples. """ output: OutType = list( itertools.chain.from_iterable( future.result() for future in concurrent.futures.as_completed( self._thread_pool.submit(param_caller) for param_caller in self._param_callers ) ) ) return output
def __enter__(self) -> Self: self._thread_pool.__enter__() return self def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: self._thread_pool.__exit__(exc_type, exc_val, exc_tb)