Coverage for mlos_bench/mlos_bench/tunables/tunable.py: 96%

310 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Tunable parameter definition. 

7""" 

8import copy 

9import collections 

10import logging 

11 

12from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, Type, TypedDict, Union 

13 

14import numpy as np 

15 

16from mlos_bench.util import nullable 

17 

18_LOG = logging.getLogger(__name__) 

19 

20 

21"""A tunable parameter value type alias.""" 

22TunableValue = Union[int, float, Optional[str]] 

23 

24"""Tunable value type.""" 

25TunableValueType = Union[Type[int], Type[float], Type[str]] 

26 

27""" 

28Tunable value type tuple. 

29For checking with isinstance() 

30""" 

31TunableValueTypeTuple = (int, float, str, type(None)) 

32 

33"""The string name of a tunable value type.""" 

34TunableValueTypeName = Literal["int", "float", "categorical"] 

35 

36"""Tunable values dictionary type""" 

37TunableValuesDict = Dict[str, TunableValue] 

38 

39"""Tunable value distribution type""" 

40DistributionName = Literal["uniform", "normal", "beta"] 

41 

42 

43class DistributionDict(TypedDict, total=False): 

44 """ 

45 A typed dict for tunable parameters' distributions. 

46 """ 

47 

48 type: DistributionName 

49 params: Optional[Dict[str, float]] 

50 

51 

52class TunableDict(TypedDict, total=False): 

53 """ 

54 A typed dict for tunable parameters. 

55 

56 Mostly used for mypy type checking. 

57 

58 These are the types expected to be received from the json config. 

59 """ 

60 

61 type: TunableValueTypeName 

62 description: Optional[str] 

63 default: TunableValue 

64 values: Optional[List[Optional[str]]] 

65 range: Optional[Union[Sequence[int], Sequence[float]]] 

66 quantization: Optional[Union[int, float]] 

67 log: Optional[bool] 

68 distribution: Optional[DistributionDict] 

69 special: Optional[Union[List[int], List[float]]] 

70 values_weights: Optional[List[float]] 

71 special_weights: Optional[List[float]] 

72 range_weight: Optional[float] 

73 meta: Dict[str, Any] 

74 

75 

76class Tunable: # pylint: disable=too-many-instance-attributes,too-many-public-methods 

77 """ 

78 A tunable parameter definition and its current value. 

79 """ 

80 

81 # Maps tunable types to their corresponding Python types by name. 

82 _DTYPE: Dict[TunableValueTypeName, TunableValueType] = { 

83 "int": int, 

84 "float": float, 

85 "categorical": str, 

86 } 

87 

88 def __init__(self, name: str, config: TunableDict): 

89 """ 

90 Create an instance of a new tunable parameter. 

91 

92 Parameters 

93 ---------- 

94 name : str 

95 Human-readable identifier of the tunable parameter. 

96 config : dict 

97 Python dict that represents a Tunable (e.g., deserialized from JSON) 

98 """ 

99 if not isinstance(name, str) or '!' in name: # TODO: Use a regex here and in JSON schema 

100 raise ValueError(f"Invalid name of the tunable: {name}") 

101 self._name = name 

102 self._type: TunableValueTypeName = config["type"] # required 

103 if self._type not in self._DTYPE: 

104 raise ValueError(f"Invalid parameter type: {self._type}") 

105 self._description = config.get("description") 

106 self._default = config["default"] 

107 self._default = self.dtype(self._default) if self._default is not None else self._default 

108 self._values = config.get("values") 

109 if self._values: 

110 self._values = [str(v) if v is not None else v for v in self._values] 

111 self._meta: Dict[str, Any] = config.get("meta", {}) 

112 self._range: Optional[Union[Tuple[int, int], Tuple[float, float]]] = None 

113 self._quantization: Optional[Union[int, float]] = config.get("quantization") 

114 self._log: Optional[bool] = config.get("log") 

115 self._distribution: Optional[DistributionName] = None 

116 self._distribution_params: Dict[str, float] = {} 

117 distr = config.get("distribution") 

118 if distr: 

119 self._distribution = distr["type"] # required 

120 self._distribution_params = distr.get("params") or {} 

121 config_range = config.get("range") 

122 if config_range is not None: 

123 assert len(config_range) == 2, f"Invalid range: {config_range}" 

124 config_range = (config_range[0], config_range[1]) 

125 self._range = config_range 

126 self._special: Union[List[int], List[float]] = config.get("special") or [] 

127 self._weights: List[float] = ( 

128 config.get("values_weights") or config.get("special_weights") or [] 

129 ) 

130 self._range_weight: Optional[float] = config.get("range_weight") 

131 self._current_value = None 

132 self._sanity_check() 

133 self.value = self._default 

134 

135 def _sanity_check(self) -> None: 

136 """ 

137 Check if the status of the Tunable is valid, and throw ValueError if it is not. 

138 """ 

139 if self.is_categorical: 

140 self._sanity_check_categorical() 

141 elif self.is_numerical: 

142 self._sanity_check_numerical() 

143 else: 

144 raise ValueError(f"Invalid parameter type for tunable {self}: {self._type}") 

145 if not self.is_valid(self.default): 

146 raise ValueError(f"Invalid default value for tunable {self}: {self.default}") 

147 

148 def _sanity_check_categorical(self) -> None: 

149 """ 

150 Check if the status of the categorical Tunable is valid, and throw ValueError if it is not. 

151 """ 

152 # pylint: disable=too-complex 

153 assert self.is_categorical 

154 if not (self._values and isinstance(self._values, collections.abc.Iterable)): 

155 raise ValueError(f"Must specify values for the categorical type tunable {self}") 

156 if self._range is not None: 

157 raise ValueError(f"Range must be None for the categorical type tunable {self}") 

158 if len(set(self._values)) != len(self._values): 

159 raise ValueError(f"Values must be unique for the categorical type tunable {self}") 

160 if self._special: 

161 raise ValueError(f"Categorical tunable cannot have special values: {self}") 

162 if self._range_weight is not None: 

163 raise ValueError(f"Categorical tunable cannot have range_weight: {self}") 

164 if self._log is not None: 

165 raise ValueError(f"Categorical tunable cannot have log parameter: {self}") 

166 if self._quantization is not None: 

167 raise ValueError(f"Categorical tunable cannot have quantization parameter: {self}") 

168 if self._distribution is not None: 

169 raise ValueError(f"Categorical parameters do not support `distribution`: {self}") 

170 if self._weights: 

171 if len(self._weights) != len(self._values): 

172 raise ValueError(f"Must specify weights for all values: {self}") 

173 if any(w < 0 for w in self._weights): 

174 raise ValueError(f"All weights must be non-negative: {self}") 

175 

176 def _sanity_check_numerical(self) -> None: 

177 """ 

178 Check if the status of the numerical Tunable is valid, and throw ValueError if it is not. 

179 """ 

180 # pylint: disable=too-complex,too-many-branches 

181 assert self.is_numerical 

182 if self._values is not None: 

183 raise ValueError(f"Values must be None for the numerical type tunable {self}") 

184 if not self._range or len(self._range) != 2 or self._range[0] >= self._range[1]: 

185 raise ValueError(f"Invalid range for tunable {self}: {self._range}") 

186 if self._quantization is not None: 

187 if self.dtype == int: 

188 if not isinstance(self._quantization, int): 

189 raise ValueError(f"Quantization of a int param should be an int: {self}") 

190 if self._quantization <= 1: 

191 raise ValueError(f"Number of quantization points is <= 1: {self}") 

192 if self.dtype == float: 

193 if not isinstance(self._quantization, (float, int)): 

194 raise ValueError(f"Quantization of a float param should be a float or int: {self}") 

195 if self._quantization <= 0: 

196 raise ValueError(f"Number of quantization points is <= 0: {self}") 

197 if self._distribution is not None and self._distribution not in {"uniform", "normal", "beta"}: 

198 raise ValueError(f"Invalid distribution: {self}") 

199 if self._distribution_params and self._distribution is None: 

200 raise ValueError(f"Must specify the distribution: {self}") 

201 if self._weights: 

202 if self._range_weight is None: 

203 raise ValueError(f"Must specify weight for the range: {self}") 

204 if len(self._weights) != len(self._special): 

205 raise ValueError("Must specify weights for all special values {self}") 

206 if any(w < 0 for w in self._weights + [self._range_weight]): 

207 raise ValueError(f"All weights must be non-negative: {self}") 

208 elif self._range_weight is not None: 

209 raise ValueError(f"Must specify both weights and range_weight or none: {self}") 

210 

211 def __repr__(self) -> str: 

212 """ 

213 Produce a human-readable version of the Tunable (mostly for logging). 

214 

215 Returns 

216 ------- 

217 string : str 

218 A human-readable version of the Tunable. 

219 """ 

220 # TODO? Add weights, specials, quantization, distribution? 

221 if self.is_categorical: 

222 return f"{self._name}[{self._type}]({self._values}:{self._default})={self._current_value}" 

223 return f"{self._name}[{self._type}]({self._range}:{self._default})={self._current_value}" 

224 

225 def __eq__(self, other: object) -> bool: 

226 """ 

227 Check if two Tunable objects are equal. 

228 

229 Parameters 

230 ---------- 

231 other : Tunable 

232 A tunable object to compare to. 

233 

234 Returns 

235 ------- 

236 is_equal : bool 

237 True if the Tunables correspond to the same parameter and have the same value and type. 

238 NOTE: ranges and special values are not currently considered in the comparison. 

239 """ 

240 if not isinstance(other, Tunable): 

241 return False 

242 return bool( 

243 self._name == other._name and 

244 self._type == other._type and 

245 self._current_value == other._current_value 

246 ) 

247 

248 def __lt__(self, other: object) -> bool: # pylint: disable=too-many-return-statements 

249 """ 

250 Compare the two Tunable objects. We mostly need this to create a canonical list 

251 of tunable objects when hashing a TunableGroup. 

252 

253 Parameters 

254 ---------- 

255 other : Tunable 

256 A tunable object to compare to. 

257 

258 Returns 

259 ------- 

260 is_less : bool 

261 True if the current Tunable is less then the other one, False otherwise. 

262 """ 

263 if not isinstance(other, Tunable): 

264 return False 

265 if self._name < other._name: 

266 return True 

267 if self._name == other._name and self._type < other._type: 

268 return True 

269 if self._name == other._name and self._type == other._type: 

270 if self.is_numerical: 

271 assert self._current_value is not None 

272 assert other._current_value is not None 

273 return bool(float(self._current_value) < float(other._current_value)) 

274 # else: categorical 

275 if self._current_value is None: 

276 return True 

277 if other._current_value is None: 

278 return False 

279 return bool(str(self._current_value) < str(other._current_value)) 

280 return False 

281 

282 def copy(self) -> "Tunable": 

283 """ 

284 Deep copy of the Tunable object. 

285 

286 Returns 

287 ------- 

288 tunable : Tunable 

289 A new Tunable object that is a deep copy of the original one. 

290 """ 

291 return copy.deepcopy(self) 

292 

293 @property 

294 def default(self) -> TunableValue: 

295 """ 

296 Get the default value of the tunable. 

297 """ 

298 return self._default 

299 

300 def is_default(self) -> TunableValue: 

301 """ 

302 Checks whether the currently assigned value of the tunable is at its default. 

303 """ 

304 return self._default == self._current_value 

305 

306 @property 

307 def value(self) -> TunableValue: 

308 """ 

309 Get the current value of the tunable. 

310 """ 

311 return self._current_value 

312 

313 @value.setter 

314 def value(self, value: TunableValue) -> TunableValue: 

315 """ 

316 Set the current value of the tunable. 

317 """ 

318 # We need this coercion for the values produced by some optimizers 

319 # (e.g., scikit-optimize) and for data restored from certain storage 

320 # systems (where values can be strings). 

321 try: 

322 if self.is_categorical and value is None: 

323 coerced_value = None 

324 else: 

325 assert value is not None 

326 coerced_value = self.dtype(value) 

327 except Exception: 

328 _LOG.error("Impossible conversion: %s %s <- %s %s", 

329 self._type, self._name, type(value), value) 

330 raise 

331 

332 if self._type == "int" and isinstance(value, float) and value != coerced_value: 

333 _LOG.error("Loss of precision: %s %s <- %s %s", 

334 self._type, self._name, type(value), value) 

335 raise ValueError(f"Loss of precision: {self._name}={value}") 

336 

337 if not self.is_valid(coerced_value): 

338 _LOG.error("Invalid assignment: %s %s <- %s %s", 

339 self._type, self._name, type(value), value) 

340 raise ValueError(f"Invalid value for the tunable: {self._name}={value}") 

341 

342 self._current_value = coerced_value 

343 return self._current_value 

344 

345 def update(self, value: TunableValue) -> bool: 

346 """ 

347 Assign the value to the tunable. Return True if it is a new value, False otherwise. 

348 

349 Parameters 

350 ---------- 

351 value : Union[int, float, str] 

352 Value to assign. 

353 

354 Returns 

355 ------- 

356 is_updated : bool 

357 True if the new value is different from the previous one, False otherwise. 

358 """ 

359 prev_value = self._current_value 

360 self.value = value 

361 return prev_value != self._current_value 

362 

363 def is_valid(self, value: TunableValue) -> bool: 

364 """ 

365 Check if the value can be assigned to the tunable. 

366 

367 Parameters 

368 ---------- 

369 value : Union[int, float, str] 

370 Value to validate. 

371 

372 Returns 

373 ------- 

374 is_valid : bool 

375 True if the value is valid, False otherwise. 

376 """ 

377 # FIXME: quantization check? 

378 if self.is_categorical and self._values: 

379 return value in self._values 

380 elif self.is_numerical and self._range: 

381 if isinstance(value, (int, float)): 

382 return self.in_range(value) or value in self._special 

383 else: 

384 raise ValueError(f"Invalid value type for tunable {self}: {value}={type(value)}") 

385 else: 

386 raise ValueError(f"Invalid parameter type: {self._type}") 

387 

388 def in_range(self, value: Union[int, float, str, None]) -> bool: 

389 """ 

390 Check if the value is within the range of the tunable. 

391 Do *NOT* check for special values. 

392 Return False if the tunable or value is categorical or None. 

393 """ 

394 return ( 

395 isinstance(value, (float, int)) and 

396 self.is_numerical and 

397 self._range is not None and 

398 bool(self._range[0] <= value <= self._range[1]) 

399 ) 

400 

401 @property 

402 def category(self) -> Optional[str]: 

403 """ 

404 Get the current value of the tunable as a number. 

405 """ 

406 if self.is_categorical: 

407 return nullable(str, self._current_value) 

408 else: 

409 raise ValueError("Cannot get categorical values for a numerical tunable.") 

410 

411 @category.setter 

412 def category(self, new_value: Optional[str]) -> Optional[str]: 

413 """ 

414 Set the current value of the tunable. 

415 """ 

416 assert self.is_categorical 

417 assert isinstance(new_value, (str, type(None))) 

418 self.value = new_value 

419 return self.value 

420 

421 @property 

422 def numerical_value(self) -> Union[int, float]: 

423 """ 

424 Get the current value of the tunable as a number. 

425 """ 

426 assert self._current_value is not None 

427 if self._type == "int": 

428 return int(self._current_value) 

429 elif self._type == "float": 

430 return float(self._current_value) 

431 else: 

432 raise ValueError("Cannot get numerical value for a categorical tunable.") 

433 

434 @numerical_value.setter 

435 def numerical_value(self, new_value: Union[int, float]) -> Union[int, float]: 

436 """ 

437 Set the current numerical value of the tunable. 

438 """ 

439 # We need this coercion for the values produced by some optimizers 

440 # (e.g., scikit-optimize) and for data restored from certain storage 

441 # systems (where values can be strings). 

442 assert self.is_numerical 

443 self.value = new_value 

444 return self.value 

445 

446 @property 

447 def name(self) -> str: 

448 """ 

449 Get the name / string ID of the tunable. 

450 """ 

451 return self._name 

452 

453 @property 

454 def special(self) -> Union[List[int], List[float]]: 

455 """ 

456 Get the special values of the tunable. Return an empty list if there are none. 

457 

458 Returns 

459 ------- 

460 special : [int] | [float] 

461 A list of special values of the tunable. Can be empty. 

462 """ 

463 return self._special 

464 

465 @property 

466 def is_special(self) -> bool: 

467 """ 

468 Check if the current value of the tunable is special. 

469 

470 Returns 

471 ------- 

472 is_special : bool 

473 True if the current value of the tunable is special, False otherwise. 

474 """ 

475 return self.value in self._special 

476 

477 @property 

478 def weights(self) -> Optional[List[float]]: 

479 """ 

480 Get the weights of the categories or special values of the tunable. 

481 Return None if there are none. 

482 

483 Returns 

484 ------- 

485 weights : [float] 

486 A list of weights or None. 

487 """ 

488 return self._weights 

489 

490 @property 

491 def range_weight(self) -> Optional[float]: 

492 """ 

493 Get weight of the range of the numeric tunable. 

494 Return None if there are no weights or a tunable is categorical. 

495 

496 Returns 

497 ------- 

498 weight : float 

499 Weight of the range or None. 

500 """ 

501 assert self.is_numerical 

502 assert self._special 

503 assert self._weights 

504 return self._range_weight 

505 

506 @property 

507 def type(self) -> TunableValueTypeName: 

508 """ 

509 Get the data type of the tunable. 

510 

511 Returns 

512 ------- 

513 type : str 

514 Data type of the tunable - one of {'int', 'float', 'categorical'}. 

515 """ 

516 return self._type 

517 

518 @property 

519 def dtype(self) -> TunableValueType: 

520 """ 

521 Get the actual Python data type of the tunable. 

522 

523 This is useful for bulk conversions of the input data. 

524 

525 Returns 

526 ------- 

527 dtype : type 

528 Data type of the tunable - one of {int, float, str}. 

529 """ 

530 return self._DTYPE[self._type] 

531 

532 @property 

533 def is_categorical(self) -> bool: 

534 """ 

535 Check if the tunable is categorical. 

536 

537 Returns 

538 ------- 

539 is_categorical : bool 

540 True if the tunable is categorical, False otherwise. 

541 """ 

542 return self._type == "categorical" 

543 

544 @property 

545 def is_numerical(self) -> bool: 

546 """ 

547 Check if the tunable is an integer or float. 

548 

549 Returns 

550 ------- 

551 is_int : bool 

552 True if the tunable is an integer or float, False otherwise. 

553 """ 

554 return self._type in {"int", "float"} 

555 

556 @property 

557 def range(self) -> Union[Tuple[int, int], Tuple[float, float]]: 

558 """ 

559 Get the range of the tunable if it is numerical, None otherwise. 

560 

561 Returns 

562 ------- 

563 range : (number, number) 

564 A 2-tuple of numbers that represents the range of the tunable. 

565 Numbers can be int or float, depending on the type of the tunable. 

566 """ 

567 assert self.is_numerical 

568 assert self._range is not None 

569 return self._range 

570 

571 @property 

572 def span(self) -> Union[int, float]: 

573 """ 

574 Gets the span of the range. 

575 

576 Note: this does not take quantization into account. 

577 

578 Returns 

579 ------- 

580 Union[int, float] 

581 (max - min) for numerical tunables. 

582 """ 

583 num_range = self.range 

584 return num_range[1] - num_range[0] 

585 

586 @property 

587 def quantization(self) -> Optional[Union[int, float]]: 

588 """ 

589 Get the quantization factor, if specified. 

590 

591 Returns 

592 ------- 

593 quantization : int, float, None 

594 The quantization factor, or None. 

595 """ 

596 if self.is_categorical: 

597 return None 

598 return self._quantization 

599 

600 @property 

601 def quantized_values(self) -> Optional[Union[Iterable[int], Iterable[float]]]: 

602 """ 

603 Get a sequence of quanitized values for this tunable. 

604 

605 Returns 

606 ------- 

607 Optional[Union[Iterable[int], Iterable[float]]] 

608 If the Tunable is quantizable, returns a sequence of those elements, 

609 else None (e.g., for unquantized float type tunables). 

610 """ 

611 num_range = self.range 

612 if self.type == "float": 

613 if not self._quantization: 

614 return None 

615 # Be sure to return python types instead of numpy types. 

616 cardinality = self.cardinality 

617 assert isinstance(cardinality, int) 

618 return (float(x) for x in np.linspace(start=num_range[0], 

619 stop=num_range[1], 

620 num=cardinality, 

621 endpoint=True)) 

622 assert self.type == "int", f"Unhandled tunable type: {self}" 

623 return range(int(num_range[0]), int(num_range[1]) + 1, int(self._quantization or 1)) 

624 

625 @property 

626 def cardinality(self) -> Union[int, float]: 

627 """ 

628 Gets the cardinality of elements in this tunable, or else infinity. 

629 

630 If the tunable has quantization set, this 

631 

632 Returns 

633 ------- 

634 cardinality : int, float 

635 Either the number of points in the tunable or else infinity. 

636 """ 

637 if self.is_categorical: 

638 return len(self.categories) 

639 if not self.quantization and self.type == "float": 

640 return np.inf 

641 q_factor = self.quantization or 1 

642 return int(self.span / q_factor) + 1 

643 

644 @property 

645 def is_log(self) -> Optional[bool]: 

646 """ 

647 Check if numeric tunable is log scale. 

648 

649 Returns 

650 ------- 

651 log : bool 

652 True if numeric tunable is log scale, False if linear. 

653 """ 

654 assert self.is_numerical 

655 return self._log 

656 

657 @property 

658 def distribution(self) -> Optional[DistributionName]: 

659 """ 

660 Get the name of the distribution (uniform, normal, or beta) if specified. 

661 

662 Returns 

663 ------- 

664 distribution : str 

665 Name of the distribution (uniform, normal, or beta) or None. 

666 """ 

667 return self._distribution 

668 

669 @property 

670 def distribution_params(self) -> Dict[str, float]: 

671 """ 

672 Get the parameters of the distribution, if specified. 

673 

674 Returns 

675 ------- 

676 distribution_params : Dict[str, float] 

677 Parameters of the distribution or None. 

678 """ 

679 assert self._distribution is not None 

680 return self._distribution_params 

681 

682 @property 

683 def categories(self) -> List[Optional[str]]: 

684 """ 

685 Get the list of all possible values of a categorical tunable. 

686 Return None if the tunable is not categorical. 

687 

688 Returns 

689 ------- 

690 values : List[str] 

691 List of all possible values of a categorical tunable. 

692 """ 

693 assert self.is_categorical 

694 assert self._values is not None 

695 return self._values 

696 

697 @property 

698 def values(self) -> Optional[Union[Iterable[Optional[str]], Iterable[int], Iterable[float]]]: 

699 """ 

700 Gets the categories or quantized values for this tunable. 

701 

702 Returns 

703 ------- 

704 Optional[Union[Iterable[Optional[str]], Iterable[int], Iterable[float]]] 

705 Categories or quantized values. 

706 """ 

707 if self.is_categorical: 

708 return self.categories 

709 assert self.is_numerical 

710 return self.quantized_values 

711 

712 @property 

713 def meta(self) -> Dict[str, Any]: 

714 """ 

715 Get the tunable's metadata. This is a free-form dictionary that can be used to 

716 store any additional information about the tunable (e.g., the unit information). 

717 """ 

718 return self._meta