Coverage for mlos_bench/mlos_bench/tunables/tunable.py: 97%

299 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2024-12-20 00:44 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Tunable parameter definition.""" 

6import collections 

7import copy 

8import logging 

9from typing import ( 

10 Any, 

11 Dict, 

12 Iterable, 

13 List, 

14 Literal, 

15 Optional, 

16 Sequence, 

17 Tuple, 

18 Type, 

19 TypedDict, 

20 Union, 

21) 

22 

23import numpy as np 

24 

25from mlos_bench.util import nullable 

26 

27_LOG = logging.getLogger(__name__) 

28 

29TunableValue = Union[int, float, Optional[str]] 

30"""A tunable parameter value type alias.""" 

31 

32TunableValueType = Union[Type[int], Type[float], Type[str]] 

33"""Tunable value type.""" 

34 

35TunableValueTypeTuple = (int, float, str, type(None)) 

36""" 

37Tunable value type tuple. 

38 

39For checking with isinstance() 

40""" 

41 

42TunableValueTypeName = Literal["int", "float", "categorical"] 

43"""The string name of a tunable value type.""" 

44 

45TunableValuesDict = Dict[str, TunableValue] 

46"""Tunable values dictionary type.""" 

47 

48DistributionName = Literal["uniform", "normal", "beta"] 

49"""Tunable value distribution type.""" 

50 

51 

52class DistributionDict(TypedDict, total=False): 

53 """A typed dict for tunable parameters' distributions.""" 

54 

55 type: DistributionName 

56 params: Optional[Dict[str, float]] 

57 

58 

59class TunableDict(TypedDict, total=False): 

60 """ 

61 A typed dict for tunable parameters. 

62 

63 Mostly used for mypy type checking. 

64 

65 These are the types expected to be received from the json config. 

66 """ 

67 

68 type: TunableValueTypeName 

69 description: Optional[str] 

70 default: TunableValue 

71 values: Optional[List[Optional[str]]] 

72 range: Optional[Union[Sequence[int], Sequence[float]]] 

73 quantization_bins: Optional[int] 

74 log: Optional[bool] 

75 distribution: Optional[DistributionDict] 

76 special: Optional[Union[List[int], List[float]]] 

77 values_weights: Optional[List[float]] 

78 special_weights: Optional[List[float]] 

79 range_weight: Optional[float] 

80 meta: Dict[str, Any] 

81 

82 

83class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-methods 

84 """A tunable parameter definition and its current value.""" 

85 

86 # Maps tunable types to their corresponding Python types by name. 

87 _DTYPE: Dict[TunableValueTypeName, TunableValueType] = { 

88 "int": int, 

89 "float": float, 

90 "categorical": str, 

91 } 

92 

93 def __init__(self, name: str, config: TunableDict): 

94 """ 

95 Create an instance of a new tunable parameter. 

96 

97 Parameters 

98 ---------- 

99 name : str 

100 Human-readable identifier of the tunable parameter. 

101 config : dict 

102 Python dict that represents a Tunable (e.g., deserialized from JSON) 

103 

104 See Also 

105 -------- 

106 :py:mod:`mlos_bench.tunables` : for more information on tunable parameters and 

107 their configuration. 

108 """ 

109 if not isinstance(name, str) or "!" in name: # TODO: Use a regex here and in JSON schema 

110 raise ValueError(f"Invalid name of the tunable: {name}") 

111 self._name = name 

112 self._type: TunableValueTypeName = config["type"] # required 

113 if self._type not in self._DTYPE: 

114 raise ValueError(f"Invalid parameter type: {self._type}") 

115 self._description = config.get("description") 

116 self._default = config["default"] 

117 self._default = self.dtype(self._default) if self._default is not None else self._default 

118 self._values = config.get("values") 

119 if self._values: 

120 self._values = [str(v) if v is not None else v for v in self._values] 

121 self._meta: Dict[str, Any] = config.get("meta", {}) 

122 self._range: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None 

123 self._quantization_bins: Optional[int] = config.get("quantization_bins") 

124 self._log: Optional[bool] = config.get("log") 

125 self._distribution: Optional[DistributionName] = None 

126 self._distribution_params: Dict[str, float] = {} 

127 distr = config.get("distribution") 

128 if distr: 

129 self._distribution = distr["type"] # required 

130 self._distribution_params = distr.get("params") or {} 

131 config_range = config.get("range") 

132 if config_range is not None: 

133 assert len(config_range) == 2, f"Invalid range: {config_range}" 

134 config_range = (config_range[0], config_range[1]) 

135 self._range = config_range 

136 self._special: Union[List[int], List[float]] = config.get("special") or [] 

137 self._weights: List[float] = ( 

138 config.get("values_weights") or config.get("special_weights") or [] 

139 ) 

140 self._range_weight: Optional[float] = config.get("range_weight") 

141 self._current_value = None 

142 self._sanity_check() 

143 self.value = self._default 

144 

145 def _sanity_check(self) -> None: 

146 """Check if the status of the Tunable is valid, and throw ValueError if it is 

147 not. 

148 """ 

149 if self.is_categorical: 

150 self._sanity_check_categorical() 

151 elif self.is_numerical: 

152 self._sanity_check_numerical() 

153 else: 

154 raise ValueError(f"Invalid parameter type for tunable {self}: {self._type}") 

155 if not self.is_valid(self.default): 

156 raise ValueError(f"Invalid default value for tunable {self}: {self.default}") 

157 

158 def _sanity_check_categorical(self) -> None: 

159 """Check if the status of the categorical Tunable is valid, and throw ValueError 

160 if it is not. 

161 """ 

162 # pylint: disable=too-complex 

163 assert self.is_categorical 

164 if not (self._values and isinstance(self._values, collections.abc.Iterable)): 

165 raise ValueError(f"Must specify values for the categorical type tunable {self}") 

166 if self._range is not None: 

167 raise ValueError(f"Range must be None for the categorical type tunable {self}") 

168 if len(set(self._values)) != len(self._values): 

169 raise ValueError(f"Values must be unique for the categorical type tunable {self}") 

170 if self._special: 

171 raise ValueError(f"Categorical tunable cannot have special values: {self}") 

172 if self._range_weight is not None: 

173 raise ValueError(f"Categorical tunable cannot have range_weight: {self}") 

174 if self._log is not None: 

175 raise ValueError(f"Categorical tunable cannot have log parameter: {self}") 

176 if self._quantization_bins is not None: 

177 raise ValueError(f"Categorical tunable cannot have quantization parameter: {self}") 

178 if self._distribution is not None: 

179 raise ValueError(f"Categorical parameters do not support `distribution`: {self}") 

180 if self._weights: 

181 if len(self._weights) != len(self._values): 

182 raise ValueError(f"Must specify weights for all values: {self}") 

183 if any(w < 0 for w in self._weights): 

184 raise ValueError(f"All weights must be non-negative: {self}") 

185 

186 def _sanity_check_numerical(self) -> None: 

187 """Check if the status of the numerical Tunable is valid, and throw ValueError 

188 if it is not. 

189 """ 

190 # pylint: disable=too-complex,too-many-branches 

191 assert self.is_numerical 

192 if self._values is not None: 

193 raise ValueError(f"Values must be None for the numerical type tunable {self}") 

194 if not self._range or len(self._range) != 2 or self._range[0] >= self._range[1]: 

195 raise ValueError(f"Invalid range for tunable {self}: {self._range}") 

196 if self._quantization_bins is not None and self._quantization_bins <= 1: 

197 raise ValueError(f"Number of quantization bins is <= 1: {self}") 

198 if self._distribution is not None and self._distribution not in { 

199 "uniform", 

200 "normal", 

201 "beta", 

202 }: 

203 raise ValueError(f"Invalid distribution: {self}") 

204 if self._distribution_params and self._distribution is None: 

205 raise ValueError(f"Must specify the distribution: {self}") 

206 if self._weights: 

207 if self._range_weight is None: 

208 raise ValueError(f"Must specify weight for the range: {self}") 

209 if len(self._weights) != len(self._special): 

210 raise ValueError("Must specify weights for all special values {self}") 

211 if any(w < 0 for w in self._weights + [self._range_weight]): 

212 raise ValueError(f"All weights must be non-negative: {self}") 

213 elif self._range_weight is not None: 

214 raise ValueError(f"Must specify both weights and range_weight or none: {self}") 

215 

216 def __repr__(self) -> str: 

217 """ 

218 Produce a human-readable version of the Tunable (mostly for logging). 

219 

220 Returns 

221 ------- 

222 string : str 

223 A human-readable version of the Tunable. 

224 """ 

225 # TODO? Add weights, specials, quantization, distribution? 

226 if self.is_categorical: 

227 return ( 

228 f"{self._name}[{self._type}]({self._values}:{self._default})={self._current_value}" 

229 ) 

230 return f"{self._name}[{self._type}]({self._range}:{self._default})={self._current_value}" 

231 

232 def __eq__(self, other: object) -> bool: 

233 """ 

234 Check if two Tunable objects are equal. 

235 

236 Parameters 

237 ---------- 

238 other : Tunable 

239 A tunable object to compare to. 

240 

241 Returns 

242 ------- 

243 is_equal : bool 

244 True if the Tunables correspond to the same parameter and have the same value and type. 

245 NOTE: ranges and special values are not currently considered in the comparison. 

246 """ 

247 if not isinstance(other, Tunable): 

248 return False 

249 return bool( 

250 self._name == other._name 

251 and self._type == other._type 

252 and self._current_value == other._current_value 

253 ) 

254 

255 def __lt__(self, other: object) -> bool: # pylint: disable=too-many-return-statements 

256 """ 

257 Compare the two Tunable objects. We mostly need this to create a canonical list 

258 of tunable objects when hashing a TunableGroup. 

259 

260 Parameters 

261 ---------- 

262 other : Tunable 

263 A tunable object to compare to. 

264 

265 Returns 

266 ------- 

267 is_less : bool 

268 True if the current Tunable is less then the other one, False otherwise. 

269 """ 

270 if not isinstance(other, Tunable): 

271 return False 

272 if self._name < other._name: 

273 return True 

274 if self._name == other._name and self._type < other._type: 

275 return True 

276 if self._name == other._name and self._type == other._type: 

277 if self.is_numerical: 

278 assert self._current_value is not None 

279 assert other._current_value is not None 

280 return bool(float(self._current_value) < float(other._current_value)) 

281 # else: categorical 

282 if self._current_value is None: 

283 return True 

284 if other._current_value is None: 

285 return False 

286 return bool(str(self._current_value) < str(other._current_value)) 

287 return False 

288 

289 def copy(self) -> "Tunable": 

290 """ 

291 Deep copy of the Tunable object. 

292 

293 Returns 

294 ------- 

295 tunable : Tunable 

296 A new Tunable object that is a deep copy of the original one. 

297 """ 

298 return copy.deepcopy(self) 

299 

300 @property 

301 def default(self) -> TunableValue: 

302 """Get the default value of the tunable.""" 

303 return self._default 

304 

305 def is_default(self) -> TunableValue: 

306 """Checks whether the currently assigned value of the tunable is at its 

307 default. 

308 """ 

309 return self._default == self._current_value 

310 

311 @property 

312 def value(self) -> TunableValue: 

313 """Get the current value of the tunable.""" 

314 return self._current_value 

315 

316 @value.setter 

317 def value(self, value: TunableValue) -> TunableValue: 

318 """Set the current value of the tunable.""" 

319 # We need this coercion for the values produced by some optimizers 

320 # (e.g., scikit-optimize) and for data restored from certain storage 

321 # systems (where values can be strings). 

322 try: 

323 if self.is_categorical and value is None: 

324 coerced_value = None 

325 else: 

326 assert value is not None 

327 coerced_value = self.dtype(value) 

328 except Exception: 

329 _LOG.error( 

330 "Impossible conversion: %s %s <- %s %s", 

331 self._type, 

332 self._name, 

333 type(value), 

334 value, 

335 ) 

336 raise 

337 

338 if self._type == "int" and isinstance(value, float) and value != coerced_value: 

339 _LOG.error( 

340 "Loss of precision: %s %s <- %s %s", 

341 self._type, 

342 self._name, 

343 type(value), 

344 value, 

345 ) 

346 raise ValueError(f"Loss of precision: {self._name}={value}") 

347 

348 if not self.is_valid(coerced_value): 

349 _LOG.error( 

350 "Invalid assignment: %s %s <- %s %s", 

351 self._type, 

352 self._name, 

353 type(value), 

354 value, 

355 ) 

356 raise ValueError(f"Invalid value for the tunable: {self._name}={value}") 

357 

358 self._current_value = coerced_value 

359 return self._current_value 

360 

361 def update(self, value: TunableValue) -> bool: 

362 """ 

363 Assign the value to the tunable. Return True if it is a new value, False 

364 otherwise. 

365 

366 Parameters 

367 ---------- 

368 value : Union[int, float, str] 

369 Value to assign. 

370 

371 Returns 

372 ------- 

373 is_updated : bool 

374 True if the new value is different from the previous one, False otherwise. 

375 """ 

376 prev_value = self._current_value 

377 self.value = value 

378 return prev_value != self._current_value 

379 

380 def is_valid(self, value: TunableValue) -> bool: 

381 """ 

382 Check if the value can be assigned to the tunable. 

383 

384 Parameters 

385 ---------- 

386 value : Union[int, float, str] 

387 Value to validate. 

388 

389 Returns 

390 ------- 

391 is_valid : bool 

392 True if the value is valid, False otherwise. 

393 """ 

394 if self.is_categorical and self._values: 

395 return value in self._values 

396 elif self.is_numerical and self._range: 

397 if isinstance(value, (int, float)): 

398 return self.in_range(value) or value in self._special 

399 else: 

400 raise ValueError(f"Invalid value type for tunable {self}: {value}={type(value)}") 

401 else: 

402 raise ValueError(f"Invalid parameter type: {self._type}") 

403 

404 def in_range(self, value: Union[int, float, str, None]) -> bool: 

405 """ 

406 Check if the value is within the range of the tunable. 

407 

408 Do *NOT* check for special values. Return False if the tunable or value is 

409 categorical or None. 

410 """ 

411 return ( 

412 isinstance(value, (float, int)) 

413 and self.is_numerical 

414 and self._range is not None 

415 and bool(self._range[0] <= value <= self._range[1]) 

416 ) 

417 

418 @property 

419 def category(self) -> Optional[str]: 

420 """Get the current value of the tunable as a string.""" 

421 if self.is_categorical: 

422 return nullable(str, self._current_value) 

423 else: 

424 raise ValueError("Cannot get categorical values for a numerical tunable.") 

425 

426 @category.setter 

427 def category(self, new_value: Optional[str]) -> Optional[str]: 

428 """Set the current value of the tunable.""" 

429 assert self.is_categorical 

430 assert isinstance(new_value, (str, type(None))) 

431 self.value = new_value 

432 return self.value 

433 

434 @property 

435 def numerical_value(self) -> Union[int, float]: 

436 """Get the current value of the tunable as a number.""" 

437 assert self._current_value is not None 

438 if self._type == "int": 

439 return int(self._current_value) 

440 elif self._type == "float": 

441 return float(self._current_value) 

442 else: 

443 raise ValueError("Cannot get numerical value for a categorical tunable.") 

444 

445 @numerical_value.setter 

446 def numerical_value(self, new_value: Union[int, float]) -> Union[int, float]: 

447 """Set the current numerical value of the tunable.""" 

448 # We need this coercion for the values produced by some optimizers 

449 # (e.g., scikit-optimize) and for data restored from certain storage 

450 # systems (where values can be strings). 

451 assert self.is_numerical 

452 self.value = new_value 

453 return self.value 

454 

455 @property 

456 def name(self) -> str: 

457 """Get the name / string ID of the tunable.""" 

458 return self._name 

459 

460 @property 

461 def special(self) -> Union[List[int], List[float]]: 

462 """ 

463 Get the special values of the tunable. Return an empty list if there are none. 

464 

465 Returns 

466 ------- 

467 special : [int] | [float] 

468 A list of special values of the tunable. Can be empty. 

469 """ 

470 return self._special 

471 

472 @property 

473 def is_special(self) -> bool: 

474 """ 

475 Check if the current value of the tunable is special. 

476 

477 Returns 

478 ------- 

479 is_special : bool 

480 True if the current value of the tunable is special, False otherwise. 

481 """ 

482 return self.value in self._special 

483 

484 @property 

485 def weights(self) -> Optional[List[float]]: 

486 """ 

487 Get the weights of the categories or special values of the tunable. Return None 

488 if there are none. 

489 

490 Returns 

491 ------- 

492 weights : [float] 

493 A list of weights or None. 

494 """ 

495 return self._weights 

496 

497 @property 

498 def range_weight(self) -> Optional[float]: 

499 """ 

500 Get weight of the range of the numeric tunable. Return None if there are no 

501 weights or a tunable is categorical. 

502 

503 Returns 

504 ------- 

505 weight : float 

506 Weight of the range or None. 

507 """ 

508 assert self.is_numerical 

509 assert self._special 

510 assert self._weights 

511 return self._range_weight 

512 

513 @property 

514 def type(self) -> TunableValueTypeName: 

515 """ 

516 Get the data type of the tunable. 

517 

518 Returns 

519 ------- 

520 type : str 

521 Data type of the tunable - one of {'int', 'float', 'categorical'}. 

522 """ 

523 return self._type 

524 

525 @property 

526 def dtype(self) -> TunableValueType: 

527 """ 

528 Get the actual Python data type of the tunable. 

529 

530 This is useful for bulk conversions of the input data. 

531 

532 Returns 

533 ------- 

534 dtype : type 

535 Data type of the tunable - one of {int, float, str}. 

536 """ 

537 return self._DTYPE[self._type] 

538 

539 @property 

540 def is_categorical(self) -> bool: 

541 """ 

542 Check if the tunable is categorical. 

543 

544 Returns 

545 ------- 

546 is_categorical : bool 

547 True if the tunable is categorical, False otherwise. 

548 """ 

549 return self._type == "categorical" 

550 

551 @property 

552 def is_numerical(self) -> bool: 

553 """ 

554 Check if the tunable is an integer or float. 

555 

556 Returns 

557 ------- 

558 is_int : bool 

559 True if the tunable is an integer or float, False otherwise. 

560 """ 

561 return self._type in {"int", "float"} 

562 

563 @property 

564 def range(self) -> Union[Tuple[int, int], Tuple[float, float]]: 

565 """ 

566 Get the range of the tunable if it is numerical, None otherwise. 

567 

568 Returns 

569 ------- 

570 range : Union[Tuple[int, int], Tuple[float, float]] 

571 A 2-tuple of numbers that represents the range of the tunable. 

572 Numbers can be int or float, depending on the type of the tunable. 

573 """ 

574 assert self.is_numerical 

575 assert self._range is not None 

576 return self._range 

577 

578 @property 

579 def span(self) -> Union[int, float]: 

580 """ 

581 Gets the span of the range. 

582 

583 Note: this does not take quantization into account. 

584 

585 Returns 

586 ------- 

587 Union[int, float] 

588 (max - min) for numerical tunables. 

589 """ 

590 num_range = self.range 

591 return num_range[1] - num_range[0] 

592 

593 @property 

594 def quantization_bins(self) -> Optional[int]: 

595 """ 

596 Get the number of quantization bins, if specified. 

597 

598 Returns 

599 ------- 

600 quantization_bins : int | None 

601 Number of quantization bins, or None. 

602 """ 

603 if self.is_categorical: 

604 return None 

605 return self._quantization_bins 

606 

607 @property 

608 def quantized_values(self) -> Optional[Union[Iterable[int], Iterable[float]]]: 

609 """ 

610 Get a sequence of quantized values for this tunable. 

611 

612 Returns 

613 ------- 

614 Optional[Union[Iterable[int], Iterable[float]]] 

615 If the Tunable is quantizable, returns a sequence of those elements, 

616 else None (e.g., for unquantized float type tunables). 

617 """ 

618 num_range = self.range 

619 if self.type == "float": 

620 if not self.quantization_bins: 

621 return None 

622 # Be sure to return python types instead of numpy types. 

623 return ( 

624 float(x) 

625 for x in np.linspace( 

626 start=num_range[0], 

627 stop=num_range[1], 

628 num=self.quantization_bins, 

629 endpoint=True, 

630 ) 

631 ) 

632 assert self.type == "int", f"Unhandled tunable type: {self}" 

633 return range( 

634 int(num_range[0]), 

635 int(num_range[1]) + 1, 

636 int(self.span / (self.quantization_bins - 1)) if self.quantization_bins else 1, 

637 ) 

638 

639 @property 

640 def cardinality(self) -> Optional[int]: 

641 """ 

642 Gets the cardinality of elements in this tunable, or else None. (i.e., when the 

643 tunable is continuous float and not quantized). 

644 

645 If the tunable has quantization set, this 

646 

647 Returns 

648 ------- 

649 cardinality : int 

650 Either the number of points in the tunable or else None. 

651 """ 

652 if self.is_categorical: 

653 return len(self.categories) 

654 if self.quantization_bins: 

655 return self.quantization_bins 

656 if self.type == "int": 

657 return int(self.span) + 1 

658 return None 

659 

660 @property 

661 def is_log(self) -> Optional[bool]: 

662 """ 

663 Check if numeric tunable is log scale. 

664 

665 Returns 

666 ------- 

667 log : bool 

668 True if numeric tunable is log scale, False if linear. 

669 """ 

670 assert self.is_numerical 

671 return self._log 

672 

673 @property 

674 def distribution(self) -> Optional[DistributionName]: 

675 """ 

676 Get the name of the distribution (uniform, normal, or beta) if specified. 

677 

678 Returns 

679 ------- 

680 distribution : str 

681 Name of the distribution (uniform, normal, or beta) or None. 

682 """ 

683 return self._distribution 

684 

685 @property 

686 def distribution_params(self) -> Dict[str, float]: 

687 """ 

688 Get the parameters of the distribution, if specified. 

689 

690 Returns 

691 ------- 

692 distribution_params : Dict[str, float] 

693 Parameters of the distribution or None. 

694 """ 

695 assert self._distribution is not None 

696 return self._distribution_params 

697 

698 @property 

699 def categories(self) -> List[Optional[str]]: 

700 """ 

701 Get the list of all possible values of a categorical tunable. Return None if the 

702 tunable is not categorical. 

703 

704 Returns 

705 ------- 

706 values : List[str] 

707 List of all possible values of a categorical tunable. 

708 """ 

709 assert self.is_categorical 

710 assert self._values is not None 

711 return self._values 

712 

713 @property 

714 def values(self) -> Optional[Union[Iterable[Optional[str]], Iterable[int], Iterable[float]]]: 

715 """ 

716 Gets the categories or quantized values for this tunable. 

717 

718 Returns 

719 ------- 

720 Optional[Union[Iterable[Optional[str]], Iterable[int], Iterable[float]]] 

721 Categories or quantized values. 

722 """ 

723 if self.is_categorical: 

724 return self.categories 

725 assert self.is_numerical 

726 return self.quantized_values 

727 

728 @property 

729 def meta(self) -> Dict[str, Any]: 

730 """ 

731 Get the tunable's metadata. 

732 

733 This is a free-form dictionary that can be used to store any additional 

734 information about the tunable (e.g., the unit information). 

735 """ 

736 return self._meta