Coverage for mlos_bench/mlos_bench/tests/environments/composite_env_service_test.py: 100%

18 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-06 00:35 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Check how the services get inherited and overridden in child environments. 

7""" 

8import os 

9 

10import pytest 

11 

12from mlos_bench.environments.composite_env import CompositeEnv 

13from mlos_bench.tunables.tunable_groups import TunableGroups 

14from mlos_bench.services.config_persistence import ConfigPersistenceService 

15from mlos_bench.services.local.local_exec import LocalExecService 

16from mlos_bench.util import path_join 

17 

18# pylint: disable=redefined-outer-name 

19 

20 

21@pytest.fixture 

22def composite_env(tunable_groups: TunableGroups) -> CompositeEnv: 

23 """ 

24 Test fixture for CompositeEnv with services included on multiple levels. 

25 """ 

26 return CompositeEnv( 

27 name="Root", 

28 config={ 

29 "children": [ 

30 { 

31 "name": "Env 1 :: tmp_global", 

32 "class": "mlos_bench.environments.mock_env.MockEnv", 

33 }, 

34 { 

35 "name": "Env 2 :: tmp_other_2", 

36 "class": "mlos_bench.environments.mock_env.MockEnv", 

37 "include_services": ["services/local/mock/mock_local_exec_service_2.jsonc"], 

38 }, 

39 { 

40 "name": "Env 3 :: tmp_other_3", 

41 "class": "mlos_bench.environments.mock_env.MockEnv", 

42 "include_services": ["services/local/mock/mock_local_exec_service_3.jsonc"], 

43 } 

44 ] 

45 }, 

46 tunables=tunable_groups, 

47 service=LocalExecService( 

48 config={ 

49 "temp_dir": "_test_tmp_global" 

50 }, 

51 parent=ConfigPersistenceService({ 

52 "config_path": [ 

53 path_join(os.path.dirname(__file__), "../config", abs_path=True), 

54 ] 

55 }) 

56 ) 

57 ) 

58 

59 

60def test_composite_services(composite_env: CompositeEnv) -> None: 

61 """ 

62 Check that each environment gets its own instance of the services. 

63 """ 

64 for (i, path) in ((0, "_test_tmp_global"), (1, "_test_tmp_other_2"), (2, "_test_tmp_other_3")): 

65 service = composite_env.children[i]._service # pylint: disable=protected-access 

66 assert service is not None and hasattr(service, "temp_dir_context") 

67 with service.temp_dir_context() as temp_dir: 

68 assert os.path.samefile(temp_dir, path) 

69 os.rmdir(path)