Coverage for mlos_bench/mlos_bench/tests/tunables/tunables_copy_test.py: 90%

29 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-01 00:52 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5"""Unit tests for deep copy of tunable objects and groups.""" 

6 

7from mlos_bench.tunables.covariant_group import CovariantTunableGroup 

8from mlos_bench.tunables.tunable import Tunable 

9from mlos_bench.tunables.tunable_groups import TunableGroups 

10from mlos_bench.tunables.tunable_types import TunableValue 

11 

12 

13def test_copy_tunable_int(tunable_int: Tunable) -> None: 

14 """Check if deep copy works for Tunable object.""" 

15 tunable_copy = tunable_int.copy() 

16 assert tunable_int == tunable_copy 

17 tunable_copy.numerical_value += 200 

18 assert tunable_int != tunable_copy 

19 

20 

21def test_copy_tunable_groups(tunable_groups: TunableGroups) -> None: 

22 """Check if deep copy works for TunableGroups object.""" 

23 tunable_groups_copy = tunable_groups.copy() 

24 assert tunable_groups == tunable_groups_copy 

25 tunable_groups_copy["vmSize"] = "Standard_B2ms" 

26 assert tunable_groups_copy.is_updated() 

27 assert not tunable_groups.is_updated() 

28 assert tunable_groups != tunable_groups_copy 

29 

30 

31def test_copy_covariant_group(covariant_group: CovariantTunableGroup) -> None: 

32 """Check if deep copy works for TunableGroups object.""" 

33 covariant_group_copy = covariant_group.copy() 

34 assert covariant_group == covariant_group_copy 

35 tunable = next(iter(covariant_group.get_tunables())) 

36 new_value: TunableValue 

37 if tunable.is_categorical: 

38 new_value = [x for x in tunable.categories if x != tunable.category][0] 

39 elif tunable.is_numerical: 

40 new_value = tunable.numerical_value + 1 

41 else: 

42 raise ValueError(f"{tunable=} :: unsupported tunable type.") 

43 covariant_group_copy[tunable] = new_value 

44 assert covariant_group_copy.is_updated() 

45 assert not covariant_group.is_updated() 

46 assert covariant_group != covariant_group_copy