Coverage for mlos_bench/mlos_bench/tests/environments/base_env_test.py: 100%

20 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Unit tests for base environment class functionality. 

7""" 

8 

9from typing import Dict 

10 

11import pytest 

12 

13from mlos_bench.tunables.tunable import TunableValue 

14from mlos_bench.environments.base_environment import Environment 

15 

16_GROUPS = { 

17 "group": ["a", "b"], 

18 "list": ["c", "d"], 

19 "str": "efg", 

20 "empty": [], 

21 "other": ["h", "i", "j"], 

22} 

23 

24# pylint: disable=protected-access 

25 

26 

27def test_expand_groups() -> None: 

28 """ 

29 Check the dollar variable expansion for tunable groups. 

30 """ 

31 assert Environment._expand_groups( 

32 ["begin", "$list", "$empty", "$str", "end"], 

33 _GROUPS) == ["begin", "c", "d", "efg", "end"] 

34 

35 

36def test_expand_groups_empty_input() -> None: 

37 """ 

38 Make sure an empty group stays empty. 

39 """ 

40 assert Environment._expand_groups([], _GROUPS) == [] 

41 

42 

43def test_expand_groups_empty_list() -> None: 

44 """ 

45 Make sure an empty group expansion works properly. 

46 """ 

47 assert not Environment._expand_groups(["$empty"], _GROUPS) 

48 

49 

50def test_expand_groups_unknown() -> None: 

51 """ 

52 Make sure we fail on unknown $GROUP names expansion. 

53 """ 

54 with pytest.raises(KeyError): 

55 Environment._expand_groups(["$list", "$UNKNOWN", "$str", "end"], _GROUPS) 

56 

57 

58def test_expand_const_args() -> None: 

59 """ 

60 Test expansion of const args via expand_vars. 

61 """ 

62 const_args: Dict[str, TunableValue] = { 

63 "a": "b", 

64 "foo": "$bar/baz", 

65 "1": 1, 

66 "recursive": "$foo/expansion", 

67 } 

68 global_config: Dict[str, TunableValue] = { 

69 "bar": "blah", 

70 } 

71 result = Environment._expand_vars(const_args, global_config) 

72 assert result == { 

73 "a": "b", 

74 "foo": "blah/baz", 

75 "1": 1, 

76 "recursive": "blah/baz/expansion", 

77 }