Coverage for mlos_bench/mlos_bench/tests/environments/local/local_env_test.py: 100%

31 statements  

« prev     ^ index     » next       coverage.py v7.5.1, created at 2024-05-05 00:36 +0000

1# 

2# Copyright (c) Microsoft Corporation. 

3# Licensed under the MIT License. 

4# 

5""" 

6Unit tests for LocalEnv benchmark environment. 

7""" 

8import pytest 

9 

10from mlos_bench.tunables.tunable_groups import TunableGroups 

11from mlos_bench.tests.environments import check_env_success 

12from mlos_bench.tests.environments.local import create_local_env 

13 

14 

15def test_local_env(tunable_groups: TunableGroups) -> None: 

16 """ 

17 Produce benchmark and telemetry data in a local script and read it. 

18 """ 

19 local_env = create_local_env(tunable_groups, { 

20 "run": [ 

21 "echo 'metric,value' > output.csv", 

22 "echo 'latency,10' >> output.csv", 

23 "echo 'throughput,66' >> output.csv", 

24 "echo 'score,0.9' >> output.csv", 

25 ], 

26 "read_results_file": "output.csv", 

27 }) 

28 

29 check_env_success( 

30 local_env, tunable_groups, 

31 expected_results={ 

32 "latency": 10.0, 

33 "throughput": 66.0, 

34 "score": 0.9, 

35 }, 

36 expected_telemetry=[], 

37 ) 

38 

39 

40def test_local_env_service_context(tunable_groups: TunableGroups) -> None: 

41 """ 

42 Basic check that context support for Service mixins are handled when environment contexts are entered. 

43 """ 

44 local_env = create_local_env(tunable_groups, { 

45 "run": ["echo NA"] 

46 }) 

47 # pylint: disable=protected-access 

48 assert local_env._service 

49 assert not local_env._service._in_context 

50 assert not local_env._service._service_contexts 

51 with local_env as env_context: 

52 assert env_context._in_context 

53 assert local_env._service._in_context 

54 assert local_env._service._service_contexts # type: ignore[unreachable] # (false positive) 

55 assert all(svc._in_context for svc in local_env._service._service_contexts) 

56 assert all(svc._in_context for svc in local_env._service._services) 

57 assert not local_env._service._in_context # type: ignore[unreachable] # (false positive) 

58 assert not local_env._service._service_contexts 

59 assert not any(svc._in_context for svc in local_env._service._services) 

60 

61 

62def test_local_env_results_no_header(tunable_groups: TunableGroups) -> None: 

63 """ 

64 Fail if the results are not in the expected format. 

65 """ 

66 local_env = create_local_env(tunable_groups, { 

67 "run": [ 

68 # No header 

69 "echo 'latency,10' > output.csv", 

70 "echo 'throughput,66' >> output.csv", 

71 "echo 'score,0.9' >> output.csv", 

72 ], 

73 "read_results_file": "output.csv", 

74 }) 

75 

76 with local_env as env_context: 

77 assert env_context.setup(tunable_groups) 

78 with pytest.raises(ValueError): 

79 env_context.run() 

80 

81 

82def test_local_env_wide(tunable_groups: TunableGroups) -> None: 

83 """ 

84 Produce benchmark data in wide format and read it. 

85 """ 

86 local_env = create_local_env(tunable_groups, { 

87 "run": [ 

88 "echo 'latency,throughput,score' > output.csv", 

89 "echo '10,66,0.9' >> output.csv", 

90 ], 

91 "read_results_file": "output.csv", 

92 }) 

93 

94 check_env_success( 

95 local_env, tunable_groups, 

96 expected_results={ 

97 "latency": 10, 

98 "throughput": 66, 

99 "score": 0.9, 

100 }, 

101 expected_telemetry=[], 

102 )