Coverage for mlos_bench/mlos_bench/tunables/tunable.py: 97%
299 statements
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 01:18 +0000
« prev ^ index » next coverage.py v7.6.7, created at 2024-11-22 01:18 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""Tunable parameter definition."""
6import collections
7import copy
8import logging
9from typing import (
10 Any,
11 Dict,
12 Iterable,
13 List,
14 Literal,
15 Optional,
16 Sequence,
17 Tuple,
18 Type,
19 TypedDict,
20 Union,
21)
23import numpy as np
25from mlos_bench.util import nullable
27_LOG = logging.getLogger(__name__)
28"""A tunable parameter value type alias."""
29TunableValue = Union[int, float, Optional[str]]
30"""Tunable value type."""
31TunableValueType = Union[Type[int], Type[float], Type[str]]
32"""
33Tunable value type tuple.
35For checking with isinstance()
36"""
37TunableValueTypeTuple = (int, float, str, type(None))
38"""The string name of a tunable value type."""
39TunableValueTypeName = Literal["int", "float", "categorical"]
40"""Tunable values dictionary type."""
41TunableValuesDict = Dict[str, TunableValue]
42"""Tunable value distribution type."""
43DistributionName = Literal["uniform", "normal", "beta"]
46class DistributionDict(TypedDict, total=False):
47 """A typed dict for tunable parameters' distributions."""
49 type: DistributionName
50 params: Optional[Dict[str, float]]
53class TunableDict(TypedDict, total=False):
54 """
55 A typed dict for tunable parameters.
57 Mostly used for mypy type checking.
59 These are the types expected to be received from the json config.
60 """
62 type: TunableValueTypeName
63 description: Optional[str]
64 default: TunableValue
65 values: Optional[List[Optional[str]]]
66 range: Optional[Union[Sequence[int], Sequence[float]]]
67 quantization_bins: Optional[int]
68 log: Optional[bool]
69 distribution: Optional[DistributionDict]
70 special: Optional[Union[List[int], List[float]]]
71 values_weights: Optional[List[float]]
72 special_weights: Optional[List[float]]
73 range_weight: Optional[float]
74 meta: Dict[str, Any]
77class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-methods
78 """A tunable parameter definition and its current value."""
80 # Maps tunable types to their corresponding Python types by name.
81 _DTYPE: Dict[TunableValueTypeName, TunableValueType] = {
82 "int": int,
83 "float": float,
84 "categorical": str,
85 }
87 def __init__(self, name: str, config: TunableDict):
88 """
89 Create an instance of a new tunable parameter.
91 Parameters
92 ----------
93 name : str
94 Human-readable identifier of the tunable parameter.
95 config : dict
96 Python dict that represents a Tunable (e.g., deserialized from JSON)
97 """
98 if not isinstance(name, str) or "!" in name: # TODO: Use a regex here and in JSON schema
99 raise ValueError(f"Invalid name of the tunable: {name}")
100 self._name = name
101 self._type: TunableValueTypeName = config["type"] # required
102 if self._type not in self._DTYPE:
103 raise ValueError(f"Invalid parameter type: {self._type}")
104 self._description = config.get("description")
105 self._default = config["default"]
106 self._default = self.dtype(self._default) if self._default is not None else self._default
107 self._values = config.get("values")
108 if self._values:
109 self._values = [str(v) if v is not None else v for v in self._values]
110 self._meta: Dict[str, Any] = config.get("meta", {})
111 self._range: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None
112 self._quantization_bins: Optional[int] = config.get("quantization_bins")
113 self._log: Optional[bool] = config.get("log")
114 self._distribution: Optional[DistributionName] = None
115 self._distribution_params: Dict[str, float] = {}
116 distr = config.get("distribution")
117 if distr:
118 self._distribution = distr["type"] # required
119 self._distribution_params = distr.get("params") or {}
120 config_range = config.get("range")
121 if config_range is not None:
122 assert len(config_range) == 2, f"Invalid range: {config_range}"
123 config_range = (config_range[0], config_range[1])
124 self._range = config_range
125 self._special: Union[List[int], List[float]] = config.get("special") or []
126 self._weights: List[float] = (
127 config.get("values_weights") or config.get("special_weights") or []
128 )
129 self._range_weight: Optional[float] = config.get("range_weight")
130 self._current_value = None
131 self._sanity_check()
132 self.value = self._default
134 def _sanity_check(self) -> None:
135 """Check if the status of the Tunable is valid, and throw ValueError if it is
136 not.
137 """
138 if self.is_categorical:
139 self._sanity_check_categorical()
140 elif self.is_numerical:
141 self._sanity_check_numerical()
142 else:
143 raise ValueError(f"Invalid parameter type for tunable {self}: {self._type}")
144 if not self.is_valid(self.default):
145 raise ValueError(f"Invalid default value for tunable {self}: {self.default}")
147 def _sanity_check_categorical(self) -> None:
148 """Check if the status of the categorical Tunable is valid, and throw ValueError
149 if it is not.
150 """
151 # pylint: disable=too-complex
152 assert self.is_categorical
153 if not (self._values and isinstance(self._values, collections.abc.Iterable)):
154 raise ValueError(f"Must specify values for the categorical type tunable {self}")
155 if self._range is not None:
156 raise ValueError(f"Range must be None for the categorical type tunable {self}")
157 if len(set(self._values)) != len(self._values):
158 raise ValueError(f"Values must be unique for the categorical type tunable {self}")
159 if self._special:
160 raise ValueError(f"Categorical tunable cannot have special values: {self}")
161 if self._range_weight is not None:
162 raise ValueError(f"Categorical tunable cannot have range_weight: {self}")
163 if self._log is not None:
164 raise ValueError(f"Categorical tunable cannot have log parameter: {self}")
165 if self._quantization_bins is not None:
166 raise ValueError(f"Categorical tunable cannot have quantization parameter: {self}")
167 if self._distribution is not None:
168 raise ValueError(f"Categorical parameters do not support `distribution`: {self}")
169 if self._weights:
170 if len(self._weights) != len(self._values):
171 raise ValueError(f"Must specify weights for all values: {self}")
172 if any(w < 0 for w in self._weights):
173 raise ValueError(f"All weights must be non-negative: {self}")
175 def _sanity_check_numerical(self) -> None:
176 """Check if the status of the numerical Tunable is valid, and throw ValueError
177 if it is not.
178 """
179 # pylint: disable=too-complex,too-many-branches
180 assert self.is_numerical
181 if self._values is not None:
182 raise ValueError(f"Values must be None for the numerical type tunable {self}")
183 if not self._range or len(self._range) != 2 or self._range[0] >= self._range[1]:
184 raise ValueError(f"Invalid range for tunable {self}: {self._range}")
185 if self._quantization_bins is not None and self._quantization_bins <= 1:
186 raise ValueError(f"Number of quantization bins is <= 1: {self}")
187 if self._distribution is not None and self._distribution not in {
188 "uniform",
189 "normal",
190 "beta",
191 }:
192 raise ValueError(f"Invalid distribution: {self}")
193 if self._distribution_params and self._distribution is None:
194 raise ValueError(f"Must specify the distribution: {self}")
195 if self._weights:
196 if self._range_weight is None:
197 raise ValueError(f"Must specify weight for the range: {self}")
198 if len(self._weights) != len(self._special):
199 raise ValueError("Must specify weights for all special values {self}")
200 if any(w < 0 for w in self._weights + [self._range_weight]):
201 raise ValueError(f"All weights must be non-negative: {self}")
202 elif self._range_weight is not None:
203 raise ValueError(f"Must specify both weights and range_weight or none: {self}")
205 def __repr__(self) -> str:
206 """
207 Produce a human-readable version of the Tunable (mostly for logging).
209 Returns
210 -------
211 string : str
212 A human-readable version of the Tunable.
213 """
214 # TODO? Add weights, specials, quantization, distribution?
215 if self.is_categorical:
216 return (
217 f"{self._name}[{self._type}]({self._values}:{self._default})={self._current_value}"
218 )
219 return f"{self._name}[{self._type}]({self._range}:{self._default})={self._current_value}"
221 def __eq__(self, other: object) -> bool:
222 """
223 Check if two Tunable objects are equal.
225 Parameters
226 ----------
227 other : Tunable
228 A tunable object to compare to.
230 Returns
231 -------
232 is_equal : bool
233 True if the Tunables correspond to the same parameter and have the same value and type.
234 NOTE: ranges and special values are not currently considered in the comparison.
235 """
236 if not isinstance(other, Tunable):
237 return False
238 return bool(
239 self._name == other._name
240 and self._type == other._type
241 and self._current_value == other._current_value
242 )
244 def __lt__(self, other: object) -> bool: # pylint: disable=too-many-return-statements
245 """
246 Compare the two Tunable objects. We mostly need this to create a canonical list
247 of tunable objects when hashing a TunableGroup.
249 Parameters
250 ----------
251 other : Tunable
252 A tunable object to compare to.
254 Returns
255 -------
256 is_less : bool
257 True if the current Tunable is less then the other one, False otherwise.
258 """
259 if not isinstance(other, Tunable):
260 return False
261 if self._name < other._name:
262 return True
263 if self._name == other._name and self._type < other._type:
264 return True
265 if self._name == other._name and self._type == other._type:
266 if self.is_numerical:
267 assert self._current_value is not None
268 assert other._current_value is not None
269 return bool(float(self._current_value) < float(other._current_value))
270 # else: categorical
271 if self._current_value is None:
272 return True
273 if other._current_value is None:
274 return False
275 return bool(str(self._current_value) < str(other._current_value))
276 return False
278 def copy(self) -> "Tunable":
279 """
280 Deep copy of the Tunable object.
282 Returns
283 -------
284 tunable : Tunable
285 A new Tunable object that is a deep copy of the original one.
286 """
287 return copy.deepcopy(self)
289 @property
290 def default(self) -> TunableValue:
291 """Get the default value of the tunable."""
292 return self._default
294 def is_default(self) -> TunableValue:
295 """Checks whether the currently assigned value of the tunable is at its
296 default.
297 """
298 return self._default == self._current_value
300 @property
301 def value(self) -> TunableValue:
302 """Get the current value of the tunable."""
303 return self._current_value
305 @value.setter
306 def value(self, value: TunableValue) -> TunableValue:
307 """Set the current value of the tunable."""
308 # We need this coercion for the values produced by some optimizers
309 # (e.g., scikit-optimize) and for data restored from certain storage
310 # systems (where values can be strings).
311 try:
312 if self.is_categorical and value is None:
313 coerced_value = None
314 else:
315 assert value is not None
316 coerced_value = self.dtype(value)
317 except Exception:
318 _LOG.error(
319 "Impossible conversion: %s %s <- %s %s",
320 self._type,
321 self._name,
322 type(value),
323 value,
324 )
325 raise
327 if self._type == "int" and isinstance(value, float) and value != coerced_value:
328 _LOG.error(
329 "Loss of precision: %s %s <- %s %s",
330 self._type,
331 self._name,
332 type(value),
333 value,
334 )
335 raise ValueError(f"Loss of precision: {self._name}={value}")
337 if not self.is_valid(coerced_value):
338 _LOG.error(
339 "Invalid assignment: %s %s <- %s %s",
340 self._type,
341 self._name,
342 type(value),
343 value,
344 )
345 raise ValueError(f"Invalid value for the tunable: {self._name}={value}")
347 self._current_value = coerced_value
348 return self._current_value
350 def update(self, value: TunableValue) -> bool:
351 """
352 Assign the value to the tunable. Return True if it is a new value, False
353 otherwise.
355 Parameters
356 ----------
357 value : Union[int, float, str]
358 Value to assign.
360 Returns
361 -------
362 is_updated : bool
363 True if the new value is different from the previous one, False otherwise.
364 """
365 prev_value = self._current_value
366 self.value = value
367 return prev_value != self._current_value
369 def is_valid(self, value: TunableValue) -> bool:
370 """
371 Check if the value can be assigned to the tunable.
373 Parameters
374 ----------
375 value : Union[int, float, str]
376 Value to validate.
378 Returns
379 -------
380 is_valid : bool
381 True if the value is valid, False otherwise.
382 """
383 if self.is_categorical and self._values:
384 return value in self._values
385 elif self.is_numerical and self._range:
386 if isinstance(value, (int, float)):
387 return self.in_range(value) or value in self._special
388 else:
389 raise ValueError(f"Invalid value type for tunable {self}: {value}={type(value)}")
390 else:
391 raise ValueError(f"Invalid parameter type: {self._type}")
393 def in_range(self, value: Union[int, float, str, None]) -> bool:
394 """
395 Check if the value is within the range of the tunable.
397 Do *NOT* check for special values. Return False if the tunable or value is
398 categorical or None.
399 """
400 return (
401 isinstance(value, (float, int))
402 and self.is_numerical
403 and self._range is not None
404 and bool(self._range[0] <= value <= self._range[1])
405 )
407 @property
408 def category(self) -> Optional[str]:
409 """Get the current value of the tunable as a string."""
410 if self.is_categorical:
411 return nullable(str, self._current_value)
412 else:
413 raise ValueError("Cannot get categorical values for a numerical tunable.")
415 @category.setter
416 def category(self, new_value: Optional[str]) -> Optional[str]:
417 """Set the current value of the tunable."""
418 assert self.is_categorical
419 assert isinstance(new_value, (str, type(None)))
420 self.value = new_value
421 return self.value
423 @property
424 def numerical_value(self) -> Union[int, float]:
425 """Get the current value of the tunable as a number."""
426 assert self._current_value is not None
427 if self._type == "int":
428 return int(self._current_value)
429 elif self._type == "float":
430 return float(self._current_value)
431 else:
432 raise ValueError("Cannot get numerical value for a categorical tunable.")
434 @numerical_value.setter
435 def numerical_value(self, new_value: Union[int, float]) -> Union[int, float]:
436 """Set the current numerical value of the tunable."""
437 # We need this coercion for the values produced by some optimizers
438 # (e.g., scikit-optimize) and for data restored from certain storage
439 # systems (where values can be strings).
440 assert self.is_numerical
441 self.value = new_value
442 return self.value
444 @property
445 def name(self) -> str:
446 """Get the name / string ID of the tunable."""
447 return self._name
449 @property
450 def special(self) -> Union[List[int], List[float]]:
451 """
452 Get the special values of the tunable. Return an empty list if there are none.
454 Returns
455 -------
456 special : [int] | [float]
457 A list of special values of the tunable. Can be empty.
458 """
459 return self._special
461 @property
462 def is_special(self) -> bool:
463 """
464 Check if the current value of the tunable is special.
466 Returns
467 -------
468 is_special : bool
469 True if the current value of the tunable is special, False otherwise.
470 """
471 return self.value in self._special
473 @property
474 def weights(self) -> Optional[List[float]]:
475 """
476 Get the weights of the categories or special values of the tunable. Return None
477 if there are none.
479 Returns
480 -------
481 weights : [float]
482 A list of weights or None.
483 """
484 return self._weights
486 @property
487 def range_weight(self) -> Optional[float]:
488 """
489 Get weight of the range of the numeric tunable. Return None if there are no
490 weights or a tunable is categorical.
492 Returns
493 -------
494 weight : float
495 Weight of the range or None.
496 """
497 assert self.is_numerical
498 assert self._special
499 assert self._weights
500 return self._range_weight
502 @property
503 def type(self) -> TunableValueTypeName:
504 """
505 Get the data type of the tunable.
507 Returns
508 -------
509 type : str
510 Data type of the tunable - one of {'int', 'float', 'categorical'}.
511 """
512 return self._type
514 @property
515 def dtype(self) -> TunableValueType:
516 """
517 Get the actual Python data type of the tunable.
519 This is useful for bulk conversions of the input data.
521 Returns
522 -------
523 dtype : type
524 Data type of the tunable - one of {int, float, str}.
525 """
526 return self._DTYPE[self._type]
528 @property
529 def is_categorical(self) -> bool:
530 """
531 Check if the tunable is categorical.
533 Returns
534 -------
535 is_categorical : bool
536 True if the tunable is categorical, False otherwise.
537 """
538 return self._type == "categorical"
540 @property
541 def is_numerical(self) -> bool:
542 """
543 Check if the tunable is an integer or float.
545 Returns
546 -------
547 is_int : bool
548 True if the tunable is an integer or float, False otherwise.
549 """
550 return self._type in {"int", "float"}
552 @property
553 def range(self) -> Union[Tuple[int, int], Tuple[float, float]]:
554 """
555 Get the range of the tunable if it is numerical, None otherwise.
557 Returns
558 -------
559 range : Union[Tuple[int, int], Tuple[float, float]]
560 A 2-tuple of numbers that represents the range of the tunable.
561 Numbers can be int or float, depending on the type of the tunable.
562 """
563 assert self.is_numerical
564 assert self._range is not None
565 return self._range
567 @property
568 def span(self) -> Union[int, float]:
569 """
570 Gets the span of the range.
572 Note: this does not take quantization into account.
574 Returns
575 -------
576 Union[int, float]
577 (max - min) for numerical tunables.
578 """
579 num_range = self.range
580 return num_range[1] - num_range[0]
582 @property
583 def quantization_bins(self) -> Optional[int]:
584 """
585 Get the number of quantization bins, if specified.
587 Returns
588 -------
589 quantization_bins : int | None
590 Number of quantization bins, or None.
591 """
592 if self.is_categorical:
593 return None
594 return self._quantization_bins
596 @property
597 def quantized_values(self) -> Optional[Union[Iterable[int], Iterable[float]]]:
598 """
599 Get a sequence of quanitized values for this tunable.
601 Returns
602 -------
603 Optional[Union[Iterable[int], Iterable[float]]]
604 If the Tunable is quantizable, returns a sequence of those elements,
605 else None (e.g., for unquantized float type tunables).
606 """
607 num_range = self.range
608 if self.type == "float":
609 if not self.quantization_bins:
610 return None
611 # Be sure to return python types instead of numpy types.
612 return (
613 float(x)
614 for x in np.linspace(
615 start=num_range[0],
616 stop=num_range[1],
617 num=self.quantization_bins,
618 endpoint=True,
619 )
620 )
621 assert self.type == "int", f"Unhandled tunable type: {self}"
622 return range(
623 int(num_range[0]),
624 int(num_range[1]) + 1,
625 int(self.span / (self.quantization_bins - 1)) if self.quantization_bins else 1,
626 )
628 @property
629 def cardinality(self) -> Optional[int]:
630 """
631 Gets the cardinality of elements in this tunable, or else None. (i.e., when the
632 tunable is continuous float and not quantized).
634 If the tunable has quantization set, this
636 Returns
637 -------
638 cardinality : int
639 Either the number of points in the tunable or else None.
640 """
641 if self.is_categorical:
642 return len(self.categories)
643 if self.quantization_bins:
644 return self.quantization_bins
645 if self.type == "int":
646 return int(self.span) + 1
647 return None
649 @property
650 def is_log(self) -> Optional[bool]:
651 """
652 Check if numeric tunable is log scale.
654 Returns
655 -------
656 log : bool
657 True if numeric tunable is log scale, False if linear.
658 """
659 assert self.is_numerical
660 return self._log
662 @property
663 def distribution(self) -> Optional[DistributionName]:
664 """
665 Get the name of the distribution (uniform, normal, or beta) if specified.
667 Returns
668 -------
669 distribution : str
670 Name of the distribution (uniform, normal, or beta) or None.
671 """
672 return self._distribution
674 @property
675 def distribution_params(self) -> Dict[str, float]:
676 """
677 Get the parameters of the distribution, if specified.
679 Returns
680 -------
681 distribution_params : Dict[str, float]
682 Parameters of the distribution or None.
683 """
684 assert self._distribution is not None
685 return self._distribution_params
687 @property
688 def categories(self) -> List[Optional[str]]:
689 """
690 Get the list of all possible values of a categorical tunable. Return None if the
691 tunable is not categorical.
693 Returns
694 -------
695 values : List[str]
696 List of all possible values of a categorical tunable.
697 """
698 assert self.is_categorical
699 assert self._values is not None
700 return self._values
702 @property
703 def values(self) -> Optional[Union[Iterable[Optional[str]], Iterable[int], Iterable[float]]]:
704 """
705 Gets the categories or quantized values for this tunable.
707 Returns
708 -------
709 Optional[Union[Iterable[Optional[str]], Iterable[int], Iterable[float]]]
710 Categories or quantized values.
711 """
712 if self.is_categorical:
713 return self.categories
714 assert self.is_numerical
715 return self.quantized_values
717 @property
718 def meta(self) -> Dict[str, Any]:
719 """
720 Get the tunable's metadata.
722 This is a free-form dictionary that can be used to store any additional
723 information about the tunable (e.g., the unit information).
724 """
725 return self._meta