Coverage for mlos_bench/mlos_bench/tests/services/config_persistence_test.py: 100%
47 statements
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
« prev ^ index » next coverage.py v7.5.1, created at 2024-05-06 00:35 +0000
1#
2# Copyright (c) Microsoft Corporation.
3# Licensed under the MIT License.
4#
5"""
6Unit tests for configuration persistence service.
7"""
9import os
10import sys
11import pytest
13from mlos_bench.config.schemas import ConfigSchema
14from mlos_bench.services.config_persistence import ConfigPersistenceService
15from mlos_bench.util import path_join
18if sys.version_info < (3, 9):
19 from importlib_resources import files
20else:
21 from importlib.resources import files
24# pylint: disable=redefined-outer-name
27@pytest.fixture
28def config_persistence_service() -> ConfigPersistenceService:
29 """
30 Test fixture for ConfigPersistenceService.
31 """
32 return ConfigPersistenceService({
33 "config_path": [
34 "./non-existent-dir/test/foo/bar", # Non-existent config path
35 ".", # cwd
36 str(files("mlos_bench.tests.config").joinpath("")), # Test configs (relative to mlos_bench/tests)
37 # Shouldn't be necessary since we automatically add this.
38 # str(files("mlos_bench.config").joinpath("")), # Stock configs
39 ]
40 })
43def test_cwd_in_explicit_search_path(config_persistence_service: ConfigPersistenceService) -> None:
44 """
45 Check that CWD is in the search path in the correct place.
46 """
47 # pylint: disable=protected-access
48 assert config_persistence_service._config_path is not None
49 cwd = path_join(os.getcwd(), abs_path=True)
50 assert config_persistence_service._config_path.index(cwd) == 1
51 with pytest.raises(ValueError):
52 config_persistence_service._config_path.index(cwd, 2)
55def test_cwd_in_default_search_path() -> None:
56 """
57 Checks that the CWD is prepended to the search path if not explicitly present.
58 """
59 # pylint: disable=protected-access
60 config_persistence_service = ConfigPersistenceService()
61 assert config_persistence_service._config_path is not None
62 cwd = path_join(os.getcwd(), abs_path=True)
63 assert config_persistence_service._config_path.index(cwd) == 0
64 with pytest.raises(ValueError):
65 config_persistence_service._config_path.index(cwd, 1)
68def test_resolve_stock_path(config_persistence_service: ConfigPersistenceService) -> None:
69 """
70 Check if we can actually find a file somewhere in `config_path`.
71 """
72 # pylint: disable=protected-access
73 assert config_persistence_service._config_path is not None
74 assert ConfigPersistenceService.BUILTIN_CONFIG_PATH in config_persistence_service._config_path
75 file_path = "storage/in-memory.jsonc"
76 path = config_persistence_service.resolve_path(file_path)
77 assert path.endswith(file_path)
78 assert os.path.exists(path)
79 assert os.path.samefile(
80 ConfigPersistenceService.BUILTIN_CONFIG_PATH,
81 os.path.commonpath([ConfigPersistenceService.BUILTIN_CONFIG_PATH, path])
82 )
85def test_resolve_path(config_persistence_service: ConfigPersistenceService) -> None:
86 """
87 Check if we can actually find a file somewhere in `config_path`.
88 """
89 file_path = "tunable-values/tunable-values-example.jsonc"
90 path = config_persistence_service.resolve_path(file_path)
91 assert path.endswith(file_path)
92 assert os.path.exists(path)
95def test_resolve_path_fail(config_persistence_service: ConfigPersistenceService) -> None:
96 """
97 Check if non-existent file resolves without using `config_path`.
98 """
99 file_path = "foo/non-existent-config.json"
100 path = config_persistence_service.resolve_path(file_path)
101 assert not os.path.exists(path)
102 assert path == file_path
105def test_load_config(config_persistence_service: ConfigPersistenceService) -> None:
106 """
107 Check if we can successfully load a config file located relative to `config_path`.
108 """
109 tunables_data = config_persistence_service.load_config("tunable-values/tunable-values-example.jsonc",
110 ConfigSchema.TUNABLE_VALUES)
111 assert tunables_data is not None
112 assert isinstance(tunables_data, dict)
113 assert len(tunables_data) >= 1