Coverage for mlos_bench/mlos_bench/tunables/tunable.py: 96%
310 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Tunable parameter definition.
7"""
8import copy
9import collections
10import logging
12from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, Type, TypedDict, Union
14import numpy as np
16from mlos_bench.util import nullable
18_LOG = logging.getLogger(__name__)
21"""A tunable parameter value type alias."""
22TunableValue = Union[int, float, Optional[str]]
24"""Tunable value type."""
25TunableValueType = Union[Type[int], Type[float], Type[str]]
27"""
28Tunable value type tuple.
29For checking with isinstance()
30"""
31TunableValueTypeTuple = (int, float, str, type(None))
33"""The string name of a tunable value type."""
34TunableValueTypeName = Literal["int", "float", "categorical"]
36"""Tunable values dictionary type"""
37TunableValuesDict = Dict[str, TunableValue]
39"""Tunable value distribution type"""
40DistributionName = Literal["uniform", "normal", "beta"]
43class DistributionDict(TypedDict, total=False):
44 """
45 A typed dict for tunable parameters' distributions.
46 """
48 type: DistributionName
49 params: Optional[Dict[str, float]]
52class TunableDict(TypedDict, total=False):
53 """
54 A typed dict for tunable parameters.
56 Mostly used for mypy type checking.
58 These are the types expected to be received from the json config.
59 """
61 type: TunableValueTypeName
62 description: Optional[str]
63 default: TunableValue
64 values: Optional[List[Optional[str]]]
65 range: Optional[Union[Sequence[int], Sequence[float]]]
66 quantization: Optional[Union[int, float]]
67 log: Optional[bool]
68 distribution: Optional[DistributionDict]
69 special: Optional[Union[List[int], List[float]]]
70 values_weights: Optional[List[float]]
71 special_weights: Optional[List[float]]
72 range_weight: Optional[float]
73 meta: Dict[str, Any]
76class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-methods
77 """
78 A tunable parameter definition and its current value.
79 """
81 # Maps tunable types to their corresponding Python types by name.
82 _DTYPE: Dict[TunableValueTypeName, TunableValueType] = {
83 "int": int,
84 "float": float,
85 "categorical": str,
86 }
88 def __init__(self, name: str, config: TunableDict):
89 """
90 Create an instance of a new tunable parameter.
92 Parameters
93 ----------
94 name : str
95 Human-readable identifier of the tunable parameter.
96 config : dict
97 Python dict that represents a Tunable (e.g., deserialized from JSON)
98 """
99 if not isinstance(name, str) or '!' in name: # TODO: Use a regex here and in JSON schema
100 raise ValueError(f"Invalid name of the tunable: {name}")
101 self._name = name
102 self._type: TunableValueTypeName = config["type"] # required
103 if self._type not in self._DTYPE:
104 raise ValueError(f"Invalid parameter type: {self._type}")
105 self._description = config.get("description")
106 self._default = config["default"]
107 self._default = self.dtype(self._default) if self._default is not None else self._default
108 self._values = config.get("values")
109 if self._values:
110 self._values = [str(v) if v is not None else v for v in self._values]
111 self._meta: Dict[str, Any] = config.get("meta", {})
112 self._range: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None
113 self._quantization: Optional[Union[int, float]] = config.get("quantization")
114 self._log: Optional[bool] = config.get("log")
115 self._distribution: Optional[DistributionName] = None
116 self._distribution_params: Dict[str, float] = {}
117 distr = config.get("distribution")
118 if distr:
119 self._distribution = distr["type"] # required
120 self._distribution_params = distr.get("params") or {}
121 config_range = config.get("range")
122 if config_range is not None:
123 assert len(config_range) == 2, f"Invalid range: {config_range}"
124 config_range = (config_range[0], config_range[1])
125 self._range = config_range
126 self._special: Union[List[int], List[float]] = config.get("special") or []
127 self._weights: List[float] = (
128 config.get("values_weights") or config.get("special_weights") or []
129 )
130 self._range_weight: Optional[float] = config.get("range_weight")
131 self._current_value = None
132 self._sanity_check()
133 self.value = self._default
135 def _sanity_check(self) -> None:
136 """
137 Check if the status of the Tunable is valid, and throw ValueError if it is not.
138 """
139 if self.is_categorical:
140 self._sanity_check_categorical()
141 elif self.is_numerical:
142 self._sanity_check_numerical()
143 else:
144 raise ValueError(f"Invalid parameter type for tunable {self}: {self._type}")
145 if not self.is_valid(self.default):
146 raise ValueError(f"Invalid default value for tunable {self}: {self.default}")
148 def _sanity_check_categorical(self) -> None:
149 """
150 Check if the status of the categorical Tunable is valid, and throw ValueError if it is not.
151 """
152 # pylint: disable=too-complex
153 assert self.is_categorical
154 if not (self._values and isinstance(self._values, collections.abc.Iterable)):
155 raise ValueError(f"Must specify values for the categorical type tunable {self}")
156 if self._range is not None:
157 raise ValueError(f"Range must be None for the categorical type tunable {self}")
158 if len(set(self._values)) != len(self._values):
159 raise ValueError(f"Values must be unique for the categorical type tunable {self}")
160 if self._special:
161 raise ValueError(f"Categorical tunable cannot have special values: {self}")
162 if self._range_weight is not None:
163 raise ValueError(f"Categorical tunable cannot have range_weight: {self}")
164 if self._log is not None:
165 raise ValueError(f"Categorical tunable cannot have log parameter: {self}")
166 if self._quantization is not None:
167 raise ValueError(f"Categorical tunable cannot have quantization parameter: {self}")
168 if self._distribution is not None:
169 raise ValueError(f"Categorical parameters do not support `distribution`: {self}")
170 if self._weights:
171 if len(self._weights) != len(self._values):
172 raise ValueError(f"Must specify weights for all values: {self}")
173 if any(w < 0 for w in self._weights):
174 raise ValueError(f"All weights must be non-negative: {self}")
176 def _sanity_check_numerical(self) -> None:
177 """
178 Check if the status of the numerical Tunable is valid, and throw ValueError if it is not.
179 """
180 # pylint: disable=too-complex,too-many-branches
181 assert self.is_numerical
182 if self._values is not None:
183 raise ValueError(f"Values must be None for the numerical type tunable {self}")
184 if not self._range or len(self._range) != 2 or self._range[0] >= self._range[1]:
185 raise ValueError(f"Invalid range for tunable {self}: {self._range}")
186 if self._quantization is not None:
187 if self.dtype == int:
188 if not isinstance(self._quantization, int):
189 raise ValueError(f"Quantization of a int param should be an int: {self}")
190 if self._quantization <= 1:
191 raise ValueError(f"Number of quantization points is <= 1: {self}")
192 if self.dtype == float:
193 if not isinstance(self._quantization, (float, int)):
194 raise ValueError(f"Quantization of a float param should be a float or int: {self}")
195 if self._quantization <= 0:
196 raise ValueError(f"Number of quantization points is <= 0: {self}")
197 if self._distribution is not None and self._distribution not in {"uniform", "normal", "beta"}:
198 raise ValueError(f"Invalid distribution: {self}")
199 if self._distribution_params and self._distribution is None:
200 raise ValueError(f"Must specify the distribution: {self}")
201 if self._weights:
202 if self._range_weight is None:
203 raise ValueError(f"Must specify weight for the range: {self}")
204 if len(self._weights) != len(self._special):
205 raise ValueError("Must specify weights for all special values {self}")
206 if any(w < 0 for w in self._weights + [self._range_weight]):
207 raise ValueError(f"All weights must be non-negative: {self}")
208 elif self._range_weight is not None:
209 raise ValueError(f"Must specify both weights and range_weight or none: {self}")
211 def __repr__(self) -> str:
212 """
213 Produce a human-readable version of the Tunable (mostly for logging).
215 Returns
216 -------
217 string : str
218 A human-readable version of the Tunable.
219 """
220 # TODO? Add weights, specials, quantization, distribution?
221 if self.is_categorical:
222 return f"{self._name}[{self._type}]({self._values}:{self._default})={self._current_value}"
223 return f"{self._name}[{self._type}]({self._range}:{self._default})={self._current_value}"
225 def __eq__(self, other: object) -> bool:
226 """
227 Check if two Tunable objects are equal.
229 Parameters
230 ----------
231 other : Tunable
232 A tunable object to compare to.
234 Returns
235 -------
236 is_equal : bool
237 True if the Tunables correspond to the same parameter and have the same value and type.
238 NOTE: ranges and special values are not currently considered in the comparison.
239 """
240 if not isinstance(other, Tunable):
241 return False
242 return bool(
243 self._name == other._name and
244 self._type == other._type and
245 self._current_value == other._current_value
246 )
248 def __lt__(self, other: object) -> bool: # pylint: disable=too-many-return-statements
249 """
250 Compare the two Tunable objects. We mostly need this to create a canonical list
251 of tunable objects when hashing a TunableGroup.
253 Parameters
254 ----------
255 other : Tunable
256 A tunable object to compare to.
258 Returns
259 -------
260 is_less : bool
261 True if the current Tunable is less then the other one, False otherwise.
262 """
263 if not isinstance(other, Tunable):
264 return False
265 if self._name < other._name:
266 return True
267 if self._name == other._name and self._type < other._type:
268 return True
269 if self._name == other._name and self._type == other._type:
270 if self.is_numerical:
271 assert self._current_value is not None
272 assert other._current_value is not None
273 return bool(float(self._current_value) < float(other._current_value))
274 # else: categorical
275 if self._current_value is None:
276 return True
277 if other._current_value is None:
278 return False
279 return bool(str(self._current_value) < str(other._current_value))
280 return False
282 def copy(self) -> "Tunable":
283 """
284 Deep copy of the Tunable object.
286 Returns
287 -------
288 tunable : Tunable
289 A new Tunable object that is a deep copy of the original one.
290 """
291 return copy.deepcopy(self)
293 @property
294 def default(self) -> TunableValue:
295 """
296 Get the default value of the tunable.
297 """
298 return self._default
300 def is_default(self) -> TunableValue:
301 """
302 Checks whether the currently assigned value of the tunable is at its default.
303 """
304 return self._default == self._current_value
306 @property
307 def value(self) -> TunableValue:
308 """
309 Get the current value of the tunable.
310 """
311 return self._current_value
313 @value.setter
314 def value(self, value: TunableValue) -> TunableValue:
315 """
316 Set the current value of the tunable.
317 """
318 # We need this coercion for the values produced by some optimizers
319 # (e.g., scikit-optimize) and for data restored from certain storage
320 # systems (where values can be strings).
321 try:
322 if self.is_categorical and value is None:
323 coerced_value = None
324 else:
325 assert value is not None
326 coerced_value = self.dtype(value)
327 except Exception:
328 _LOG.error("Impossible conversion: %s %s <- %s %s",
329 self._type, self._name, type(value), value)
330 raise
332 if self._type == "int" and isinstance(value, float) and value != coerced_value:
333 _LOG.error("Loss of precision: %s %s <- %s %s",
334 self._type, self._name, type(value), value)
335 raise ValueError(f"Loss of precision: {self._name}={value}")
337 if not self.is_valid(coerced_value):
338 _LOG.error("Invalid assignment: %s %s <- %s %s",
339 self._type, self._name, type(value), value)
340 raise ValueError(f"Invalid value for the tunable: {self._name}={value}")
342 self._current_value = coerced_value
343 return self._current_value
345 def update(self, value: TunableValue) -> bool:
346 """
347 Assign the value to the tunable. Return True if it is a new value, False otherwise.
349 Parameters
350 ----------
351 value : Union[int, float, str]
352 Value to assign.
354 Returns
355 -------
356 is_updated : bool
357 True if the new value is different from the previous one, False otherwise.
358 """
359 prev_value = self._current_value
360 self.value = value
361 return prev_value != self._current_value
363 def is_valid(self, value: TunableValue) -> bool:
364 """
365 Check if the value can be assigned to the tunable.
367 Parameters
368 ----------
369 value : Union[int, float, str]
370 Value to validate.
372 Returns
373 -------
374 is_valid : bool
375 True if the value is valid, False otherwise.
376 """
377 # FIXME: quantization check?
378 if self.is_categorical and self._values:
379 return value in self._values
380 elif self.is_numerical and self._range:
381 if isinstance(value, (int, float)):
382 return self.in_range(value) or value in self._special
383 else:
384 raise ValueError(f"Invalid value type for tunable {self}: {value}={type(value)}")
385 else:
386 raise ValueError(f"Invalid parameter type: {self._type}")
388 def in_range(self, value: Union[int, float, str, None]) -> bool:
389 """
390 Check if the value is within the range of the tunable.
391 Do *NOT* check for special values.
392 Return False if the tunable or value is categorical or None.
393 """
394 return (
395 isinstance(value, (float, int)) and
396 self.is_numerical and
397 self._range is not None and
398 bool(self._range[0] <= value <= self._range[1])
399 )
401 @property
402 def category(self) -> Optional[str]:
403 """
404 Get the current value of the tunable as a number.
405 """
406 if self.is_categorical:
407 return nullable(str, self._current_value)
408 else:
409 raise ValueError("Cannot get categorical values for a numerical tunable.")
411 @category.setter
412 def category(self, new_value: Optional[str]) -> Optional[str]:
413 """
414 Set the current value of the tunable.
415 """
416 assert self.is_categorical
417 assert isinstance(new_value, (str, type(None)))
418 self.value = new_value
419 return self.value
421 @property
422 def numerical_value(self) -> Union[int, float]:
423 """
424 Get the current value of the tunable as a number.
425 """
426 assert self._current_value is not None
427 if self._type == "int":
428 return int(self._current_value)
429 elif self._type == "float":
430 return float(self._current_value)
431 else:
432 raise ValueError("Cannot get numerical value for a categorical tunable.")
434 @numerical_value.setter
435 def numerical_value(self, new_value: Union[int, float]) -> Union[int, float]:
436 """
437 Set the current numerical value of the tunable.
438 """
439 # We need this coercion for the values produced by some optimizers
440 # (e.g., scikit-optimize) and for data restored from certain storage
441 # systems (where values can be strings).
442 assert self.is_numerical
443 self.value = new_value
444 return self.value
446 @property
447 def name(self) -> str:
448 """
449 Get the name / string ID of the tunable.
450 """
451 return self._name
453 @property
454 def special(self) -> Union[List[int], List[float]]:
455 """
456 Get the special values of the tunable. Return an empty list if there are none.
458 Returns
459 -------
460 special : [int] | [float]
461 A list of special values of the tunable. Can be empty.
462 """
463 return self._special
465 @property
466 def is_special(self) -> bool:
467 """
468 Check if the current value of the tunable is special.
470 Returns
471 -------
472 is_special : bool
473 True if the current value of the tunable is special, False otherwise.
474 """
475 return self.value in self._special
477 @property
478 def weights(self) -> Optional[List[float]]:
479 """
480 Get the weights of the categories or special values of the tunable.
481 Return None if there are none.
483 Returns
484 -------
485 weights : [float]
486 A list of weights or None.
487 """
488 return self._weights
490 @property
491 def range_weight(self) -> Optional[float]:
492 """
493 Get weight of the range of the numeric tunable.
494 Return None if there are no weights or a tunable is categorical.
496 Returns
497 -------
498 weight : float
499 Weight of the range or None.
500 """
501 assert self.is_numerical
502 assert self._special
503 assert self._weights
504 return self._range_weight
506 @property
507 def type(self) -> TunableValueTypeName:
508 """
509 Get the data type of the tunable.
511 Returns
512 -------
513 type : str
514 Data type of the tunable - one of {'int', 'float', 'categorical'}.
515 """
516 return self._type
518 @property
519 def dtype(self) -> TunableValueType:
520 """
521 Get the actual Python data type of the tunable.
523 This is useful for bulk conversions of the input data.
525 Returns
526 -------
527 dtype : type
528 Data type of the tunable - one of {int, float, str}.
529 """
530 return self._DTYPE[self._type]
532 @property
533 def is_categorical(self) -> bool:
534 """
535 Check if the tunable is categorical.
537 Returns
538 -------
539 is_categorical : bool
540 True if the tunable is categorical, False otherwise.
541 """
542 return self._type == "categorical"
544 @property
545 def is_numerical(self) -> bool:
546 """
547 Check if the tunable is an integer or float.
549 Returns
550 -------
551 is_int : bool
552 True if the tunable is an integer or float, False otherwise.
553 """
554 return self._type in {"int", "float"}
556 @property
557 def range(self) -> Union[Tuple[int, int], Tuple[float, float]]:
558 """
559 Get the range of the tunable if it is numerical, None otherwise.
561 Returns
562 -------
563 range : (number, number)
564 A 2-tuple of numbers that represents the range of the tunable.
565 Numbers can be int or float, depending on the type of the tunable.
566 """
567 assert self.is_numerical
568 assert self._range is not None
569 return self._range
571 @property
572 def span(self) -> Union[int, float]:
573 """
574 Gets the span of the range.
576 Note: this does not take quantization into account.
578 Returns
579 -------
580 Union[int, float]
581 (max - min) for numerical tunables.
582 """
583 num_range = self.range
584 return num_range[1] - num_range[0]
586 @property
587 def quantization(self) -> Optional[Union[int, float]]:
588 """
589 Get the quantization factor, if specified.
591 Returns
592 -------
593 quantization : int, float, None
594 The quantization factor, or None.
595 """
596 if self.is_categorical:
597 return None
598 return self._quantization
600 @property
601 def quantized_values(self) -> Optional[Union[Iterable[int], Iterable[float]]]:
602 """
603 Get a sequence of quanitized values for this tunable.
605 Returns
606 -------
607 Optional[Union[Iterable[int], Iterable[float]]]
608 If the Tunable is quantizable, returns a sequence of those elements,
609 else None (e.g., for unquantized float type tunables).
610 """
611 num_range = self.range
612 if self.type == "float":
613 if not self._quantization:
614 return None
615 # Be sure to return python types instead of numpy types.
616 cardinality = self.cardinality
617 assert isinstance(cardinality, int)
618 return (float(x) for x in np.linspace(start=num_range[0],
619 stop=num_range[1],
620 num=cardinality,
621 endpoint=True))
622 assert self.type == "int", f"Unhandled tunable type: {self}"
623 return range(int(num_range[0]), int(num_range[1]) + 1, int(self._quantization or 1))
625 @property
626 def cardinality(self) -> Union[int, float]:
627 """
628 Gets the cardinality of elements in this tunable, or else infinity.
630 If the tunable has quantization set, this
632 Returns
633 -------
634 cardinality : int, float
635 Either the number of points in the tunable or else infinity.
636 """
637 if self.is_categorical:
638 return len(self.categories)
639 if not self.quantization and self.type == "float":
640 return np.inf
641 q_factor = self.quantization or 1
642 return int(self.span / q_factor) + 1
644 @property
645 def is_log(self) -> Optional[bool]:
646 """
647 Check if numeric tunable is log scale.
649 Returns
650 -------
651 log : bool
652 True if numeric tunable is log scale, False if linear.
653 """
654 assert self.is_numerical
655 return self._log
657 @property
658 def distribution(self) -> Optional[DistributionName]:
659 """
660 Get the name of the distribution (uniform, normal, or beta) if specified.
662 Returns
663 -------
664 distribution : str
665 Name of the distribution (uniform, normal, or beta) or None.
666 """
667 return self._distribution
669 @property
670 def distribution_params(self) -> Dict[str, float]:
671 """
672 Get the parameters of the distribution, if specified.
674 Returns
675 -------
676 distribution_params : Dict[str, float]
677 Parameters of the distribution or None.
678 """
679 assert self._distribution is not None
680 return self._distribution_params
682 @property
683 def categories(self) -> List[Optional[str]]:
684 """
685 Get the list of all possible values of a categorical tunable.
686 Return None if the tunable is not categorical.
688 Returns
689 -------
690 values : List[str]
691 List of all possible values of a categorical tunable.
692 """
693 assert self.is_categorical
694 assert self._values is not None
695 return self._values
697 @property
698 def values(self) -> Optional[Union[Iterable[Optional[str]], Iterable[int], Iterable[float]]]:
699 """
700 Gets the categories or quantized values for this tunable.
702 Returns
703 -------
704 Optional[Union[Iterable[Optional[str]], Iterable[int], Iterable[float]]]
705 Categories or quantized values.
706 """
707 if self.is_categorical:
708 return self.categories
709 assert self.is_numerical
710 return self.quantized_values
712 @property
713 def meta(self) -> Dict[str, Any]:
714 """
715 Get the tunable's metadata. This is a free-form dictionary that can be used to
716 store any additional information about the tunable (e.g., the unit information).
717 """
718 return self._meta