Coverage for mlos_bench/mlos_bench/tunables/tunable_types.py: 94%

64 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 00:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Helper types for :py:class:`~mlos_bench.tunables.tunable.Tunable`. 

7 

8The main class of interest to most users in this module is :py:class:`.TunableDict`, 

9which provides the typed conversions from a JSON config to a config used for 

10creating a :py:class:`~mlos_bench.tunables.tunable.Tunable`. 

11 

12The other types are mostly used for type checking and documentation purposes. 

13""" 

14 

15# NOTE: pydoctest doesn't scan variable docstrings so we put the examples in the 

16# Tunable class docstrings. 

17# These type aliases are moved here mostly to allow easier documentation reading of 

18# the Tunable class itself. 

19 

20from collections.abc import Sequence 

21from typing import TYPE_CHECKING, Any, Literal, TypeAlias, TypedDict 

22 

23if TYPE_CHECKING: 

24 # Used to allow for shorter docstring references. 

25 from mlos_bench.tunables.tunable import Tunable 

26 

27TunableValue: TypeAlias = int | float | str | None 

28"""A :py:class:`TypeAlias` for a :py:class:`~.Tunable` parameter value.""" 

29 

30TunableValueType: TypeAlias = type[int] | type[float] | type[str] 

31"""A :py:class:`TypeAlias` for :py:class:`~.Tunable` value 

32:py:attr:`data type <.Tunable.dtype>`. 

33 

34See Also 

35-------- 

36:py:attr:`Tunable.dtype <.Tunable.dtype>` : Example of accepted types. 

37""" 

38 

39TunableValueTypeTuple = (int, float, str, type(None)) 

40""" 

41Tunable value ``type`` tuple. 

42 

43Notes 

44----- 

45For checking whether a param is a :py:type:`.TunableValue` with 

46:py:func:`isinstance`. 

47""" 

48 

49TunableValueTypeName = Literal["int", "float", "categorical"] 

50""" 

51The accepted string names of a :py:class:`~.Tunable` value :py:attr:`~.Tunable.type`. 

52 

53See Also 

54-------- 

55:py:attr:`Tunable.type <.Tunable.type>` : Example of accepted type names. 

56""" 

57 

58TUNABLE_DTYPE: dict[TunableValueTypeName, TunableValueType] = { 

59 "int": int, 

60 "float": float, 

61 "categorical": str, 

62} 

63""" 

64Maps :py:class:`~.Tunable` types to their corresponding Python data types by name. 

65 

66See Also 

67-------- 

68:py:attr:`Tunable.dtype <.Tunable.dtype>` : Example of type mappings. 

69""" 

70 

71TunableValuesDict = dict[str, TunableValue] 

72"""Tunable values dictionary type.""" 

73 

74DistributionName = Literal["uniform", "normal", "beta"] 

75""" 

76The :py:attr:`~.Tunable.distribution` type names for a :py:class:`~.Tunable` value. 

77 

78See Also 

79-------- 

80:py:attr:`Tunable.distribution <.Tunable.distribution>` : 

81 Example of accepted distribution names. 

82""" 

83 

84 

85class DistributionDictOpt(TypedDict, total=False): # total=False allows for optional fields 

86 """ 

87 A :py:class:`TypedDict` for a :py:class:`~.Tunable` parameter's optional 

88 :py:attr:`~.Tunable.distribution_params` config. 

89 

90 Mostly used for type checking. These are the types expected to be received from 

91 the json config. 

92 

93 Notes 

94 ----- 

95 :py:class:`.DistributionDict` contains the required fields for the 

96 :py:attr:`Tunable.distribution <.Tunable.distribution>` parameter. 

97 

98 See Also 

99 -------- 

100 :py:attr:`Tunable.distribution_params <.Tunable.distribution_params>` : 

101 Examples of distribution parameters. 

102 """ 

103 

104 def __init__(self, *args, **kwargs): # type: ignore # pylint: disable=useless-super-delegation 

105 """.. comment: don't inherit the docstring""" 

106 super().__init__(*args, **kwargs) 

107 

108 params: dict[str, float] | None 

109 """ 

110 The parameters for the distribution. 

111 

112 See Also 

113 -------- 

114 :py:attr:`Tunable.distribution_params <.Tunable.distribution_params>` : 

115 Examples of distribution parameters. 

116 """ 

117 

118 

119class DistributionDict(DistributionDictOpt): 

120 """ 

121 A :py:class:`TypedDict` for a :py:class:`~.Tunable` parameter's required 

122 :py:attr:`~.Tunable.distribution` config parameters. 

123 

124 Mostly used for type checking. These are the types expected to be received from the 

125 json config. 

126 

127 See Also 

128 -------- 

129 :py:attr:`Tunable.distribution <.Tunable.distribution>` : 

130 Examples of Tunables with distributions. 

131 :py:attr:`Tunable.distribution_params <.Tunable.distribution_params>` : 

132 Examples of distribution parameters. 

133 """ 

134 

135 def __init__(self, *args, **kwargs): # type: ignore # pylint: disable=useless-super-delegation 

136 """.. comment: don't inherit the docstring""" 

137 super().__init__(*args, **kwargs) 

138 

139 type: DistributionName 

140 """ 

141 The name of the distribution. 

142 

143 See Also 

144 -------- 

145 :py:attr:`Tunable.distribution <.Tunable.distribution>` : 

146 Examples of distribution names. 

147 """ 

148 

149 

150class TunableDictOpt(TypedDict, total=False): # total=False allows for optional fields 

151 """ 

152 A :py:class:`TypedDict` for a :py:class:`~.Tunable` parameter's optional config 

153 parameters. 

154 

155 Mostly used for type checking. These are the types expected to be received from 

156 the json config. 

157 

158 Notes 

159 ----- 

160 :py:class:`TunableDict` contains the required fields for the 

161 :py:class:`~.Tunable` parameter. 

162 """ 

163 

164 def __init__(self, *args, **kwargs): # type: ignore # pylint: disable=useless-super-delegation 

165 """.. comment: don't inherit the docstring""" 

166 super().__init__(*args, **kwargs) 

167 

168 # Optional fields: 

169 

170 description: str | None 

171 """ 

172 Description of the :py:class:`~.Tunable` parameter. 

173 

174 See Also 

175 -------- 

176 :py:attr:`Tunable.description <.Tunable.description>` 

177 """ 

178 

179 values: list[str | None] | None 

180 """ 

181 List of values (or categories) for a "categorical" type :py:class:`~.Tunable` 

182 parameter. 

183 

184 A list of values is required for "categorical" type Tunables. 

185 

186 See Also 

187 -------- 

188 :py:attr:`Tunable.categories <.Tunable.categories>` 

189 :py:attr:`Tunable.values <.Tunable.values>` 

190 """ 

191 

192 range: Sequence[int] | Sequence[float] | None 

193 """ 

194 The range of values for an "int" or "float" type :py:class:`~.Tunable` parameter. 

195 

196 Must be a sequence of two values: ``[min, max]``. 

197 

198 A range is required for "int" and "float" type Tunables. 

199 

200 See Also 

201 -------- 

202 :py:attr:`Tunable.range <.Tunable.range>` : Examples of ranges. 

203 :py:attr:`Tunable.values <.Tunable.values>` 

204 """ 

205 

206 special: list[int] | list[float] | None 

207 """ 

208 List of special values for an "int" or "float" type :py:class:`~.Tunable` parameter. 

209 

210 These are values that are considered special by the target system (e.g., 

211 ``null``, ``0``, ``-1``, ``auto``, etc.) and should be sampled with higher 

212 weights. 

213 

214 See Also 

215 -------- 

216 :py:attr:`Tunable.special <.Tunable.special>` : Examples of special values. 

217 """ 

218 

219 quantization_bins: int | None 

220 """ 

221 The number of quantization bins for an "int" or "float" type :py:class:`~.Tunable` 

222 parameter. 

223 

224 See Also 

225 -------- 

226 :py:attr:`Tunable.quantization_bins <.Tunable.quantization_bins>` : 

227 Examples of quantized Tunables. 

228 """ 

229 

230 log: bool | None 

231 """ 

232 Whether to use log sampling for an "int" or "float" type :py:class:`~.Tunable` 

233 parameter. 

234 

235 See Also 

236 -------- 

237 :py:attr:`Tunable.is_log <.Tunable.is_log>` 

238 """ 

239 

240 distribution: DistributionDict | None 

241 """ 

242 Optional sampling distribution configuration for an "int" or "float" type 

243 :py:class:`~.Tunable` parameter. 

244 

245 See Also 

246 -------- 

247 :py:attr:`Tunable.distribution <.Tunable.distribution>` : 

248 Examples of distributions. 

249 :py:attr:`Tunable.distribution_params <.Tunable.distribution_params>` : 

250 Examples of distribution parameters. 

251 """ 

252 

253 values_weights: list[float] | None 

254 """ 

255 Optional sampling weights for the values of a "categorical" type 

256 :py:class:`~.Tunable` parameter. 

257 

258 See Also 

259 -------- 

260 :py:attr:`Tunable.weights <.Tunable.weights>` : Examples of weighted sampling Tunables. 

261 """ 

262 

263 special_weights: list[float] | None 

264 """ 

265 Optional sampling weights for the special values of an "int" or "float" type 

266 :py:class:`~.Tunable` parameter. 

267 

268 See Also 

269 -------- 

270 :py:attr:`Tunable.weights <.Tunable.weights>` : Examples of weighted sampling Tunables. 

271 """ 

272 

273 range_weight: float | None 

274 """ 

275 Optional sampling weight for the main ranges of an "int" or "float" type 

276 :py:class:`~.Tunable` parameter. 

277 

278 See Also 

279 -------- 

280 :py:attr:`Tunable.range_weight <.Tunable.range_weight>` : 

281 Examples of weighted sampling Tunables. 

282 """ 

283 

284 meta: dict[str, Any] 

285 """ 

286 Free form dict to store additional metadata for the :py:class:`~.Tunable` parameter 

287 (e.g., unit suffix, etc.) 

288 

289 See Also 

290 -------- 

291 :py:attr:`Tunable.meta <.Tunable.meta>` : Examples of Tunables with metadata. 

292 """ 

293 

294 

295class TunableDict(TunableDictOpt): 

296 """ 

297 A :py:class:`TypedDict` for a :py:class:`~.Tunable` parameter's required config 

298 parameters. 

299 

300 Mostly used for type checking. These are the types expected to be received from 

301 the json config. 

302 

303 Examples 

304 -------- 

305 >>> # Example values of the TunableDict 

306 >>> TunableDict({'type': 'int', 'default': 0, 'range': [0, 10]}) 

307 {'type': 'int', 'default': 0, 'range': [0, 10]} 

308 

309 >>> # Example values of the TunableDict with optional fields 

310 >>> TunableDict({'type': 'categorical', 'default': 'a', 'values': ['a', 'b']}) 

311 {'type': 'categorical', 'default': 'a', 'values': ['a', 'b']} 

312 """ 

313 

314 def __init__(self, *args, **kwargs): # type: ignore # pylint: disable=useless-super-delegation 

315 """.. comment: don't inherit the docstring""" 

316 super().__init__(*args, **kwargs) 

317 

318 # Required fields 

319 

320 type: TunableValueTypeName 

321 """ 

322 The name of the type of the :py:class:`~.Tunable` parameter. 

323 

324 See Also 

325 -------- 

326 :py:attr:`Tunable.type <.Tunable.type>` : Examples of type names. 

327 """ 

328 

329 default: TunableValue 

330 """ 

331 The default value of the :py:class:`~.Tunable` parameter. 

332 

333 See Also 

334 -------- 

335 :py:attr:`Tunable.default <.Tunable.default>` 

336 """ 

337 

338 

339def tunable_dict_from_dict(config: dict[str, Any]) -> TunableDict: 

340 """ 

341 Creates a TunableDict from a regular dict. 

342 

343 Notes 

344 ----- 

345 Mostly used for type checking while instantiating a 

346 :py:class:`~.Tunable` from a json config. 

347 

348 Parameters 

349 ---------- 

350 config : dict[str, Any] 

351 A regular dict that represents a :py:class:`.TunableDict`. 

352 

353 Returns 

354 ------- 

355 TunableDict 

356 

357 Examples 

358 -------- 

359 >>> # Example values of the TunableDict 

360 >>> import json5 as json 

361 >>> config = json.loads("{'type': 'int', 'default': 0, 'range': [0, 10]}") 

362 >>> config 

363 {'type': 'int', 'default': 0, 'range': [0, 10]} 

364 >>> typed_dict = tunable_dict_from_dict(config) 

365 >>> typed_dict 

366 {'type': 'int', 'description': None, 'default': 0, 'values': None, 'range': [0, 10], 'quantization_bins': None, 'log': None, 'distribution': None, 'special': None, 'values_weights': None, 'special_weights': None, 'range_weight': None, 'meta': {}} 

367 """ # pylint: disable=line-too-long # noqa: E501 

368 _type = config.get("type") 

369 if _type not in TUNABLE_DTYPE: 

370 raise ValueError(f"Invalid parameter type: {_type}") 

371 _meta = config.get("meta", {}) 

372 return TunableDict( 

373 type=_type, 

374 description=config.get("description"), 

375 default=config.get("default"), 

376 values=config.get("values"), 

377 range=config.get("range"), 

378 quantization_bins=config.get("quantization_bins"), 

379 log=config.get("log"), 

380 distribution=config.get("distribution"), 

381 special=config.get("special"), 

382 values_weights=config.get("values_weights"), 

383 special_weights=config.get("special_weights"), 

384 range_weight=config.get("range_weight"), 

385 meta=_meta, 

386 )