Coverage for mlos_bench/mlos_bench/tests/services/config_persistence_test.py: 100%

47 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Unit tests for configuration persistence service. 

7""" 

8 

9import os 

10import sys 

11import pytest 

12 

13from mlos_bench.config.schemas import ConfigSchema 

14from mlos_bench.services.config_persistence import ConfigPersistenceService 

15from mlos_bench.util import path_join 

16 

17 

18if sys.version_info < (3, 9): 

19 from importlib_resources import files 

20else: 

21 from importlib.resources import files 

22 

23 

24# pylint: disable=redefined-outer-name 

25 

26 

27@pytest.fixture 

28def config_persistence_service() -> ConfigPersistenceService: 

29 """ 

30 Test fixture for ConfigPersistenceService. 

31 """ 

32 return ConfigPersistenceService({ 

33 "config_path": [ 

34 "./non-existent-dir/test/foo/bar", # Non-existent config path 

35 ".", # cwd 

36 str(files("mlos_bench.tests.config").joinpath("")), # Test configs (relative to mlos_bench/tests) 

37 # Shouldn't be necessary since we automatically add this. 

38 # str(files("mlos_bench.config").joinpath("")), # Stock configs 

39 ] 

40 }) 

41 

42 

43def test_cwd_in_explicit_search_path(config_persistence_service: ConfigPersistenceService) -> None: 

44 """ 

45 Check that CWD is in the search path in the correct place. 

46 """ 

47 # pylint: disable=protected-access 

48 assert config_persistence_service._config_path is not None 

49 cwd = path_join(os.getcwd(), abs_path=True) 

50 assert config_persistence_service._config_path.index(cwd) == 1 

51 with pytest.raises(ValueError): 

52 config_persistence_service._config_path.index(cwd, 2) 

53 

54 

55def test_cwd_in_default_search_path() -> None: 

56 """ 

57 Checks that the CWD is prepended to the search path if not explicitly present. 

58 """ 

59 # pylint: disable=protected-access 

60 config_persistence_service = ConfigPersistenceService() 

61 assert config_persistence_service._config_path is not None 

62 cwd = path_join(os.getcwd(), abs_path=True) 

63 assert config_persistence_service._config_path.index(cwd) == 0 

64 with pytest.raises(ValueError): 

65 config_persistence_service._config_path.index(cwd, 1) 

66 

67 

68def test_resolve_stock_path(config_persistence_service: ConfigPersistenceService) -> None: 

69 """ 

70 Check if we can actually find a file somewhere in `config_path`. 

71 """ 

72 # pylint: disable=protected-access 

73 assert config_persistence_service._config_path is not None 

74 assert ConfigPersistenceService.BUILTIN_CONFIG_PATH in config_persistence_service._config_path 

75 file_path = "storage/in-memory.jsonc" 

76 path = config_persistence_service.resolve_path(file_path) 

77 assert path.endswith(file_path) 

78 assert os.path.exists(path) 

79 assert os.path.samefile( 

80 ConfigPersistenceService.BUILTIN_CONFIG_PATH, 

81 os.path.commonpath([ConfigPersistenceService.BUILTIN_CONFIG_PATH, path]) 

82 ) 

83 

84 

85def test_resolve_path(config_persistence_service: ConfigPersistenceService) -> None: 

86 """ 

87 Check if we can actually find a file somewhere in `config_path`. 

88 """ 

89 file_path = "tunable-values/tunable-values-example.jsonc" 

90 path = config_persistence_service.resolve_path(file_path) 

91 assert path.endswith(file_path) 

92 assert os.path.exists(path) 

93 

94 

95def test_resolve_path_fail(config_persistence_service: ConfigPersistenceService) -> None: 

96 """ 

97 Check if non-existent file resolves without using `config_path`. 

98 """ 

99 file_path = "foo/non-existent-config.json" 

100 path = config_persistence_service.resolve_path(file_path) 

101 assert not os.path.exists(path) 

102 assert path == file_path 

103 

104 

105def test_load_config(config_persistence_service: ConfigPersistenceService) -> None: 

106 """ 

107 Check if we can successfully load a config file located relative to `config_path`. 

108 """ 

109 tunables_data = config_persistence_service.load_config("tunable-values/tunable-values-example.jsonc", 

110 ConfigSchema.TUNABLE_VALUES) 

111 assert tunables_data is not None 

112 assert isinstance(tunables_data, dict) 

113 assert len(tunables_data) >= 1