from __future__ import annotations
# Instead of async, since there are really a very small set of things
# we want to happen simultaneously within one process (namely getting
# several parameters in parallel), we can parallelize them with threads.
# That way the things we call need not be rewritten explicitly async.
import concurrent.futures
import itertools
import logging
from collections import defaultdict
from functools import partial
from typing import TYPE_CHECKING, Protocol, TypeAlias, TypeVar
from qcodes.utils import RespondingThread
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from types import TracebackType
from typing import Self
from qcodes.dataset.data_set_protocol import ValuesType
from qcodes.parameters import ParamDataType, ParameterBase
ParamMeasT: TypeAlias = "ParameterBase | Callable[[], None]"
OutType: TypeAlias = "list[tuple[ParameterBase, ValuesType]]"
T = TypeVar("T")
_LOG = logging.getLogger(__name__)
class _ParamCaller:
def __init__(self, *parameters: ParameterBase):
self._parameters = parameters
def __call__(self) -> tuple[tuple[ParameterBase, ParamDataType], ...]:
output = []
for param in self._parameters:
output.append((param, param.get()))
return tuple(output)
def __repr__(self) -> str:
names = tuple(param.full_name for param in self._parameters)
return f"ParamCaller of {','.join(names)}"
def _instrument_to_param(
params: Sequence[ParamMeasT],
) -> dict[str | None, tuple[ParameterBase, ...]]:
from qcodes.parameters import ParameterBase # noqa: PLC0415
real_parameters = [param for param in params if isinstance(param, ParameterBase)]
output: dict[str | None, tuple[ParameterBase, ...]] = defaultdict(tuple)
for param in real_parameters:
if param.underlying_instrument:
output[param.underlying_instrument.full_name] += (param,)
else:
output[None] += (param,)
return output
[docs]
def call_params_threaded(param_meas: Sequence[ParamMeasT]) -> OutType:
"""
Function to create threads per instrument for the given set of
measurement parameters.
Args:
param_meas: a Sequence of measurement parameters
"""
inst_param_mapping = _instrument_to_param(param_meas)
executors = tuple(
_ParamCaller(*param_list) for param_list in inst_param_mapping.values()
)
output: OutType = []
threads = [RespondingThread(target=executor) for executor in executors]
for t in threads:
t.start()
for t in threads:
thread_output = t.output()
assert thread_output is not None
for result in thread_output:
output.append(result)
return output
def _call_params(param_meas: Sequence[ParamMeasT]) -> OutType:
from qcodes.parameters import ParameterBase # noqa: PLC0415
output: OutType = []
for parameter in param_meas:
if isinstance(parameter, ParameterBase):
output.append((parameter, parameter.get()))
elif callable(parameter):
parameter()
return output
def process_params_meas(
param_meas: Sequence[ParamMeasT], use_threads: bool | None = None
) -> OutType:
from qcodes import config # noqa: PLC0415
if use_threads is None:
use_threads = config.dataset.use_threads
if use_threads:
return call_params_threaded(param_meas)
return _call_params(param_meas)
class _ParamsCallerProtocol(Protocol):
def __enter__(self) -> Callable[[], OutType]: ...
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
pass
[docs]
class SequentialParamsCaller(_ParamsCallerProtocol):
def __init__(self, *param_meas: ParamMeasT):
self._param_meas = tuple(param_meas)
def __enter__(self) -> Callable[[], OutType]:
return partial(_call_params, self._param_meas)
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
return None
[docs]
class ThreadPoolParamsCaller(_ParamsCallerProtocol):
"""
Context manager for calling given parameters in a thread pool.
Note that parameters that have the same underlying instrument will be
called in the same thread.
Usage:
.. code-block:: python
...
with ThreadPoolParamsCaller(p1, p2, ...) as pool_caller:
...
output = pool_caller()
...
# Output can be passed directly into DataSaver.add_result:
# datasaver.add_result(*output)
...
...
Args:
param_meas: parameter or a callable without arguments
max_workers: number of worker threads to create in the pool; if None,
the number of worker threads will be equal to the number of
unique "underlying instruments"
"""
def __init__(self, *param_meas: ParamMeasT, max_workers: int | None = None):
self._param_callers = tuple(
_ParamCaller(*param_list)
for param_list in _instrument_to_param(param_meas).values()
)
max_worker_threads = (
len(self._param_callers) if max_workers is None else max_workers
)
thread_name_prefix = (
self.__class__.__name__
+ ":"
+ "".join(" " + repr(pc) for pc in self._param_callers)
)
self._thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=max_worker_threads,
thread_name_prefix=thread_name_prefix,
)
[docs]
def __call__(self) -> OutType:
"""
Call parameters in the thread pool and return `(param, value)` tuples.
"""
output: OutType = list(
itertools.chain.from_iterable(
future.result()
for future in concurrent.futures.as_completed(
self._thread_pool.submit(param_caller)
for param_caller in self._param_callers
)
)
)
return output
def __enter__(self) -> Self:
self._thread_pool.__enter__()
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self._thread_pool.__exit__(exc_type, exc_val, exc_tb)